Skip to main content

runmat_runtime/builtins/io/tabular/
csvwrite.rs

1//! MATLAB-compatible `csvwrite` builtin for RunMat.
2//!
3//! `csvwrite` is an older convenience wrapper that persists numeric matrices to
4//! comma-separated text files. Modern MATLAB code typically prefers
5//! `writematrix`, but many legacy scripts still depend on `csvwrite`'s terse
6//! API and zero-based offset arguments. This implementation mirrors those
7//! semantics while integrating with RunMat's builtin framework.
8
9use std::io::Write;
10use std::path::{Path, PathBuf};
11
12use runmat_builtins::{Tensor, Value};
13use runmat_filesystem::OpenOptions;
14use runmat_macros::runtime_builtin;
15
16use crate::builtins::common::fs::expand_user_path;
17use crate::builtins::common::spec::{
18    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
19    ReductionNaN, ResidencyPolicy, ShapeRequirements,
20};
21use crate::builtins::common::tensor;
22use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
23
24const BUILTIN_NAME: &str = "csvwrite";
25
26#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::tabular::csvwrite")]
27pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
28    name: "csvwrite",
29    op_kind: GpuOpKind::Custom("io-csvwrite"),
30    supported_precisions: &[],
31    broadcast: BroadcastSemantics::None,
32    provider_hooks: &[],
33    constant_strategy: ConstantStrategy::InlineLiteral,
34    residency: ResidencyPolicy::GatherImmediately,
35    nan_mode: ReductionNaN::Include,
36    two_pass_threshold: None,
37    workgroup_size: None,
38    accepts_nan_mode: false,
39    notes: "Runs entirely on the host; gpuArray inputs are gathered before serialisation.",
40};
41
42#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::tabular::csvwrite")]
43pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
44    name: "csvwrite",
45    shape: ShapeRequirements::Any,
46    constant_strategy: ConstantStrategy::InlineLiteral,
47    elementwise: None,
48    reduction: None,
49    emits_nan: false,
50    notes: "Not eligible for fusion; performs host-side file I/O.",
51};
52
53fn csvwrite_error(message: impl Into<String>) -> RuntimeError {
54    build_runtime_error(message)
55        .with_builtin(BUILTIN_NAME)
56        .build()
57}
58
59fn csvwrite_error_with_source<E>(message: impl Into<String>, source: E) -> RuntimeError
60where
61    E: std::error::Error + Send + Sync + 'static,
62{
63    build_runtime_error(message)
64        .with_builtin(BUILTIN_NAME)
65        .with_source(source)
66        .build()
67}
68
69fn map_control_flow(err: RuntimeError) -> RuntimeError {
70    let identifier = err.identifier().map(|value| value.to_string());
71    let message = err.message().to_string();
72    let mut builder = build_runtime_error(message)
73        .with_builtin(BUILTIN_NAME)
74        .with_source(err);
75    if let Some(identifier) = identifier {
76        builder = builder.with_identifier(identifier);
77    }
78    builder.build()
79}
80
81#[runtime_builtin(
82    name = "csvwrite",
83    category = "io/tabular",
84    summary = "Write numeric matrices to comma-separated text files using MATLAB-compatible offsets.",
85    keywords = "csvwrite,csv,write,row offset,column offset",
86    accel = "cpu",
87    type_resolver(crate::builtins::io::type_resolvers::num_type),
88    builtin_path = "crate::builtins::io::tabular::csvwrite"
89)]
90async fn csvwrite_builtin(
91    filename: Value,
92    data: Value,
93    rest: Vec<Value>,
94) -> crate::BuiltinResult<Value> {
95    let filename_value = gather_if_needed_async(&filename)
96        .await
97        .map_err(map_control_flow)?;
98    let path = resolve_path(&filename_value)?;
99
100    let mut gathered_offsets = Vec::with_capacity(rest.len());
101    for value in &rest {
102        gathered_offsets.push(
103            gather_if_needed_async(value)
104                .await
105                .map_err(map_control_flow)?,
106        );
107    }
108    let (row_offset, col_offset) = parse_offsets(&gathered_offsets)?;
109
110    let gathered_data = gather_if_needed_async(&data)
111        .await
112        .map_err(map_control_flow)?;
113    let tensor =
114        tensor::value_into_tensor_for("csvwrite", gathered_data).map_err(csvwrite_error)?;
115    ensure_matrix_shape(&tensor)?;
116
117    let bytes = write_csv(&path, &tensor, row_offset, col_offset)?;
118    Ok(Value::Num(bytes as f64))
119}
120
121fn resolve_path(value: &Value) -> BuiltinResult<PathBuf> {
122    let raw = match value {
123        Value::String(s) => s.clone(),
124        Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect(),
125        Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
126        _ => {
127            return Err(csvwrite_error(
128                "csvwrite: filename must be a string scalar or character vector",
129            ))
130        }
131    };
132
133    if raw.trim().is_empty() {
134        return Err(csvwrite_error("csvwrite: filename must not be empty"));
135    }
136
137    let expanded = expand_user_path(&raw, BUILTIN_NAME).map_err(csvwrite_error)?;
138    Ok(Path::new(&expanded).to_path_buf())
139}
140
141fn parse_offsets(args: &[Value]) -> BuiltinResult<(usize, usize)> {
142    match args.len() {
143        0 => Ok((0, 0)),
144        2 => {
145            let row = parse_offset(&args[0], "row offset")?;
146            let col = parse_offset(&args[1], "column offset")?;
147            Ok((row, col))
148        }
149        _ => Err(csvwrite_error(
150            "csvwrite: offsets must be provided as two numeric arguments (row, column)",
151        )),
152    }
153}
154
155fn parse_offset(value: &Value, context: &str) -> BuiltinResult<usize> {
156    match value {
157        Value::Int(i) => {
158            let raw = i.to_i64();
159            if raw < 0 {
160                return Err(csvwrite_error(format!("csvwrite: {context} must be >= 0")));
161            }
162            Ok(raw as usize)
163        }
164        Value::Num(n) => coerce_offset_from_float(*n, context),
165        Value::Bool(b) => Ok(if *b { 1 } else { 0 }),
166        Value::Tensor(t) => {
167            if t.data.len() != 1 {
168                return Err(csvwrite_error(format!(
169                    "csvwrite: {context} must be a scalar, got {} elements",
170                    t.data.len()
171                )));
172            }
173            coerce_offset_from_float(t.data[0], context)
174        }
175        Value::LogicalArray(logical) => {
176            if logical.data.len() != 1 {
177                return Err(csvwrite_error(format!(
178                    "csvwrite: {context} must be a scalar, got {} elements",
179                    logical.data.len()
180                )));
181            }
182            Ok(if logical.data[0] != 0 { 1 } else { 0 })
183        }
184        other => Err(csvwrite_error(format!(
185            "csvwrite: {context} must be numeric, got {:?}",
186            other
187        ))),
188    }
189}
190
191fn coerce_offset_from_float(value: f64, context: &str) -> BuiltinResult<usize> {
192    if !value.is_finite() {
193        return Err(csvwrite_error(format!(
194            "csvwrite: {context} must be finite"
195        )));
196    }
197    let rounded = value.round();
198    if (rounded - value).abs() > 1e-9 {
199        return Err(csvwrite_error(format!(
200            "csvwrite: {context} must be an integer"
201        )));
202    }
203    if rounded < 0.0 {
204        return Err(csvwrite_error(format!("csvwrite: {context} must be >= 0")));
205    }
206    Ok(rounded as usize)
207}
208
209fn ensure_matrix_shape(tensor: &Tensor) -> BuiltinResult<()> {
210    if tensor.shape.len() <= 2 {
211        return Ok(());
212    }
213    if tensor.shape[2..].iter().all(|&dim| dim == 1) {
214        return Ok(());
215    }
216    Err(csvwrite_error(
217        "csvwrite: input must be 2-D; reshape before writing",
218    ))
219}
220
221fn write_csv(
222    path: &Path,
223    tensor: &Tensor,
224    row_offset: usize,
225    col_offset: usize,
226) -> BuiltinResult<usize> {
227    let mut options = OpenOptions::new();
228    options.create(true).write(true).truncate(true);
229    let mut file = options.open(path).map_err(|err| {
230        csvwrite_error_with_source(
231            format!(
232                "csvwrite: unable to open \"{}\" for writing ({err})",
233                path.display()
234            ),
235            err,
236        )
237    })?;
238
239    let line_ending = default_line_ending();
240    let rows = tensor.rows();
241    let cols = tensor.cols();
242
243    let mut bytes_written = 0usize;
244
245    for _ in 0..row_offset {
246        file.write_all(line_ending.as_bytes()).map_err(|err| {
247            csvwrite_error_with_source(
248                format!("csvwrite: failed to write line ending ({err})"),
249                err,
250            )
251        })?;
252        bytes_written += line_ending.len();
253    }
254
255    if rows == 0 || cols == 0 {
256        file.flush().map_err(|err| {
257            csvwrite_error_with_source(format!("csvwrite: failed to flush output ({err})"), err)
258        })?;
259        return Ok(bytes_written);
260    }
261
262    for row in 0..rows {
263        let mut fields = Vec::with_capacity(col_offset + cols);
264        for _ in 0..col_offset {
265            fields.push(String::new());
266        }
267        for col in 0..cols {
268            let idx = row + col * rows;
269            let value = tensor.data[idx];
270            fields.push(format_numeric(value));
271        }
272        let line = fields.join(",");
273        if !line.is_empty() {
274            file.write_all(line.as_bytes()).map_err(|err| {
275                csvwrite_error_with_source(format!("csvwrite: failed to write value ({err})"), err)
276            })?;
277            bytes_written += line.len();
278        }
279        file.write_all(line_ending.as_bytes()).map_err(|err| {
280            csvwrite_error_with_source(
281                format!("csvwrite: failed to write line ending ({err})"),
282                err,
283            )
284        })?;
285        bytes_written += line_ending.len();
286    }
287
288    file.flush().map_err(|err| {
289        csvwrite_error_with_source(format!("csvwrite: failed to flush output ({err})"), err)
290    })?;
291
292    Ok(bytes_written)
293}
294
295fn default_line_ending() -> &'static str {
296    if cfg!(windows) {
297        "\r\n"
298    } else {
299        "\n"
300    }
301}
302
303fn format_numeric(value: f64) -> String {
304    if value.is_nan() {
305        return "NaN".to_string();
306    }
307    if value.is_infinite() {
308        return if value.is_sign_negative() {
309            "-Inf".to_string()
310        } else {
311            "Inf".to_string()
312        };
313    }
314    if value == 0.0 {
315        return "0".to_string();
316    }
317
318    let precision: i32 = 5;
319    let abs = value.abs();
320    let exp10 = abs.log10().floor() as i32;
321    let use_scientific = exp10 < -4 || exp10 >= precision;
322
323    let raw = if use_scientific {
324        let digits_after = (precision - 1).max(0) as usize;
325        format!("{:.*e}", digits_after, value)
326    } else {
327        let decimals = (precision - 1 - exp10).max(0) as usize;
328        format!("{:.*}", decimals, value)
329    };
330
331    let mut trimmed = trim_trailing_zeros(raw);
332    if trimmed == "-0" {
333        trimmed = "0".to_string();
334    }
335    trimmed
336}
337
338fn trim_trailing_zeros(mut value: String) -> String {
339    if let Some(exp_pos) = value.find(['e', 'E']) {
340        let exponent = value.split_off(exp_pos);
341        while value.ends_with('0') {
342            value.pop();
343        }
344        if value.ends_with('.') {
345            value.pop();
346        }
347        value.push_str(&normalize_exponent(&exponent));
348        value
349    } else {
350        if value.contains('.') {
351            while value.ends_with('0') {
352                value.pop();
353            }
354            if value.ends_with('.') {
355                value.pop();
356            }
357        }
358        if value.is_empty() {
359            "0".to_string()
360        } else {
361            value
362        }
363    }
364}
365
366fn normalize_exponent(exponent: &str) -> String {
367    if exponent.len() <= 1 {
368        return exponent.to_string();
369    }
370    let mut chars = exponent.chars();
371    let marker = chars.next().unwrap();
372    let rest: String = chars.collect();
373    match rest.parse::<i32>() {
374        Ok(parsed) => format!("{}{:+03}", marker, parsed),
375        Err(_) => exponent.to_string(),
376    }
377}
378
379#[cfg(test)]
380pub(crate) mod tests {
381    use super::*;
382    use runmat_time::unix_timestamp_ms;
383    use std::fs;
384    use std::sync::atomic::{AtomicU64, Ordering};
385
386    use runmat_accelerate_api::HostTensorView;
387    use runmat_builtins::{IntValue, LogicalArray};
388
389    use crate::builtins::common::fs as fs_helpers;
390    use crate::builtins::common::test_support;
391
392    fn csvwrite_builtin(filename: Value, data: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
393        futures::executor::block_on(super::csvwrite_builtin(filename, data, rest))
394    }
395
396    static NEXT_ID: AtomicU64 = AtomicU64::new(0);
397
398    fn temp_path(ext: &str) -> PathBuf {
399        let millis = unix_timestamp_ms();
400        let unique = NEXT_ID.fetch_add(1, Ordering::Relaxed);
401        let mut path = std::env::temp_dir();
402        path.push(format!(
403            "runmat_csvwrite_{}_{}_{}.{}",
404            std::process::id(),
405            millis,
406            unique,
407            ext
408        ));
409        path
410    }
411
412    fn line_ending() -> &'static str {
413        default_line_ending()
414    }
415
416    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
417    #[test]
418    fn csvwrite_writes_basic_matrix() {
419        let path = temp_path("csv");
420        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
421        let filename = path.to_string_lossy().into_owned();
422
423        csvwrite_builtin(Value::from(filename), Value::Tensor(tensor), Vec::new())
424            .expect("csvwrite");
425
426        let contents = fs::read_to_string(&path).expect("read contents");
427        assert_eq!(contents, format!("1,2,3{le}4,5,6{le}", le = line_ending()));
428        let _ = fs::remove_file(path);
429    }
430
431    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
432    #[test]
433    fn csvwrite_honours_offsets() {
434        let path = temp_path("csv");
435        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
436        let filename = path.to_string_lossy().into_owned();
437
438        csvwrite_builtin(
439            Value::from(filename),
440            Value::Tensor(tensor),
441            vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(2))],
442        )
443        .expect("csvwrite");
444
445        let contents = fs::read_to_string(&path).expect("read contents");
446        assert_eq!(
447            contents,
448            format!("{le},,1,3{le},,2,4{le}", le = line_ending())
449        );
450        let _ = fs::remove_file(path);
451    }
452
453    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
454    #[test]
455    fn csvwrite_handles_gpu_tensors() {
456        test_support::with_test_provider(|provider| {
457            let path = temp_path("csv");
458            let tensor = Tensor::new(vec![0.5, 1.5], vec![1, 2]).unwrap();
459            let view = HostTensorView {
460                data: &tensor.data,
461                shape: &tensor.shape,
462            };
463            let handle = provider.upload(&view).expect("upload");
464            let filename = path.to_string_lossy().into_owned();
465
466            csvwrite_builtin(Value::from(filename), Value::GpuTensor(handle), Vec::new())
467                .expect("csvwrite");
468
469            let contents = fs::read_to_string(&path).expect("read contents");
470            assert_eq!(contents, format!("0.5,1.5{le}", le = line_ending()));
471            let _ = fs::remove_file(path);
472        });
473    }
474
475    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
476    #[test]
477    fn csvwrite_formats_with_short_g_precision() {
478        let path = temp_path("csv");
479        let values =
480            Tensor::new(vec![12.3456, 1_234_567.0, 0.000123456, -0.0], vec![1, 4]).unwrap();
481        let filename = path.to_string_lossy().into_owned();
482
483        csvwrite_builtin(Value::from(filename), Value::Tensor(values), Vec::new())
484            .expect("csvwrite");
485
486        let contents = fs::read_to_string(&path).expect("read contents");
487        assert_eq!(
488            contents,
489            format!("12.346,1.2346e+06,0.00012346,0{le}", le = line_ending())
490        );
491        let _ = fs::remove_file(path);
492    }
493
494    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
495    #[test]
496    fn csvwrite_rejects_negative_offsets() {
497        let path = temp_path("csv");
498        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
499        let filename = path.to_string_lossy().into_owned();
500        let err = csvwrite_builtin(
501            Value::from(filename),
502            Value::Tensor(tensor),
503            vec![Value::Num(-1.0), Value::Num(0.0)],
504        )
505        .expect_err("negative offsets should be rejected");
506        let message = err.message().to_string();
507        assert!(
508            message.contains("row offset"),
509            "unexpected error message: {message}"
510        );
511    }
512
513    #[cfg(feature = "wgpu")]
514    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
515    #[test]
516    fn csvwrite_handles_wgpu_provider_gather() {
517        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
518            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
519        );
520        let Some(provider) = runmat_accelerate_api::provider() else {
521            panic!("wgpu provider not registered");
522        };
523
524        let path = temp_path("csv");
525        let tensor = Tensor::new(vec![2.0, 4.0], vec![1, 2]).unwrap();
526        let view = HostTensorView {
527            data: &tensor.data,
528            shape: &tensor.shape,
529        };
530        let handle = provider.upload(&view).expect("upload");
531        let filename = path.to_string_lossy().into_owned();
532
533        csvwrite_builtin(Value::from(filename), Value::GpuTensor(handle), Vec::new())
534            .expect("csvwrite");
535
536        let contents = fs::read_to_string(&path).expect("read contents");
537        assert_eq!(contents, format!("2,4{le}", le = line_ending()));
538        let _ = fs::remove_file(path);
539    }
540
541    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
542    #[test]
543    fn csvwrite_expands_home_directory() {
544        let Some(mut home) = fs_helpers::home_directory() else {
545            // Skip when home directory cannot be determined.
546            return;
547        };
548        let filename = format!(
549            "runmat_csvwrite_home_{}_{}.csv",
550            std::process::id(),
551            NEXT_ID.fetch_add(1, Ordering::Relaxed)
552        );
553        home.push(&filename);
554
555        let tilde_path = format!("~/{}", filename);
556        let tensor = Tensor::new(vec![42.0], vec![1, 1]).unwrap();
557
558        csvwrite_builtin(Value::from(tilde_path), Value::Tensor(tensor), Vec::new())
559            .expect("csvwrite");
560
561        let contents = fs::read_to_string(&home).expect("read contents");
562        assert_eq!(contents, format!("42{le}", le = line_ending()));
563        let _ = fs::remove_file(home);
564    }
565
566    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
567    #[test]
568    fn csvwrite_rejects_non_numeric_inputs() {
569        let path = temp_path("csv");
570        let filename = path.to_string_lossy().into_owned();
571        let err = csvwrite_builtin(
572            Value::from(filename),
573            Value::String("abc".into()),
574            Vec::new(),
575        )
576        .expect_err("csvwrite should fail");
577        let message = err.message().to_string();
578        assert!(
579            message.contains("csvwrite"),
580            "unexpected error message: {message}"
581        );
582    }
583
584    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
585    #[test]
586    fn csvwrite_accepts_logical_arrays() {
587        let path = temp_path("csv");
588        let logical = LogicalArray::new(vec![1, 0, 1, 0], vec![2, 2]).unwrap();
589        let filename = path.to_string_lossy().into_owned();
590
591        csvwrite_builtin(
592            Value::from(filename),
593            Value::LogicalArray(logical),
594            Vec::new(),
595        )
596        .expect("csvwrite");
597
598        let contents = fs::read_to_string(&path).expect("read contents");
599        assert_eq!(contents, format!("1,1{le}0,0{le}", le = line_ending()));
600        let _ = fs::remove_file(path);
601    }
602}