1use std::io::Write;
10use std::path::{Path, PathBuf};
11
12use runmat_builtins::{Tensor, Value};
13use runmat_filesystem::OpenOptions;
14use runmat_macros::runtime_builtin;
15
16use crate::builtins::common::fs::expand_user_path;
17use crate::builtins::common::spec::{
18 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
19 ReductionNaN, ResidencyPolicy, ShapeRequirements,
20};
21use crate::builtins::common::tensor;
22use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
23
24const BUILTIN_NAME: &str = "csvwrite";
25
26#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::tabular::csvwrite")]
27pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
28 name: "csvwrite",
29 op_kind: GpuOpKind::Custom("io-csvwrite"),
30 supported_precisions: &[],
31 broadcast: BroadcastSemantics::None,
32 provider_hooks: &[],
33 constant_strategy: ConstantStrategy::InlineLiteral,
34 residency: ResidencyPolicy::GatherImmediately,
35 nan_mode: ReductionNaN::Include,
36 two_pass_threshold: None,
37 workgroup_size: None,
38 accepts_nan_mode: false,
39 notes: "Runs entirely on the host; gpuArray inputs are gathered before serialisation.",
40};
41
42#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::tabular::csvwrite")]
43pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
44 name: "csvwrite",
45 shape: ShapeRequirements::Any,
46 constant_strategy: ConstantStrategy::InlineLiteral,
47 elementwise: None,
48 reduction: None,
49 emits_nan: false,
50 notes: "Not eligible for fusion; performs host-side file I/O.",
51};
52
53fn csvwrite_error(message: impl Into<String>) -> RuntimeError {
54 build_runtime_error(message)
55 .with_builtin(BUILTIN_NAME)
56 .build()
57}
58
59fn csvwrite_error_with_source<E>(message: impl Into<String>, source: E) -> RuntimeError
60where
61 E: std::error::Error + Send + Sync + 'static,
62{
63 build_runtime_error(message)
64 .with_builtin(BUILTIN_NAME)
65 .with_source(source)
66 .build()
67}
68
69fn map_control_flow(err: RuntimeError) -> RuntimeError {
70 let identifier = err.identifier().map(|value| value.to_string());
71 let message = err.message().to_string();
72 let mut builder = build_runtime_error(message)
73 .with_builtin(BUILTIN_NAME)
74 .with_source(err);
75 if let Some(identifier) = identifier {
76 builder = builder.with_identifier(identifier);
77 }
78 builder.build()
79}
80
81#[runtime_builtin(
82 name = "csvwrite",
83 category = "io/tabular",
84 summary = "Write numeric matrices to comma-separated text files using MATLAB-compatible offsets.",
85 keywords = "csvwrite,csv,write,row offset,column offset",
86 accel = "cpu",
87 type_resolver(crate::builtins::io::type_resolvers::num_type),
88 builtin_path = "crate::builtins::io::tabular::csvwrite"
89)]
90async fn csvwrite_builtin(
91 filename: Value,
92 data: Value,
93 rest: Vec<Value>,
94) -> crate::BuiltinResult<Value> {
95 let filename_value = gather_if_needed_async(&filename)
96 .await
97 .map_err(map_control_flow)?;
98 let path = resolve_path(&filename_value)?;
99
100 let mut gathered_offsets = Vec::with_capacity(rest.len());
101 for value in &rest {
102 gathered_offsets.push(
103 gather_if_needed_async(value)
104 .await
105 .map_err(map_control_flow)?,
106 );
107 }
108 let (row_offset, col_offset) = parse_offsets(&gathered_offsets)?;
109
110 let gathered_data = gather_if_needed_async(&data)
111 .await
112 .map_err(map_control_flow)?;
113 let tensor =
114 tensor::value_into_tensor_for("csvwrite", gathered_data).map_err(csvwrite_error)?;
115 ensure_matrix_shape(&tensor)?;
116
117 let bytes = write_csv(&path, &tensor, row_offset, col_offset)?;
118 Ok(Value::Num(bytes as f64))
119}
120
121fn resolve_path(value: &Value) -> BuiltinResult<PathBuf> {
122 let raw = match value {
123 Value::String(s) => s.clone(),
124 Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect(),
125 Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
126 _ => {
127 return Err(csvwrite_error(
128 "csvwrite: filename must be a string scalar or character vector",
129 ))
130 }
131 };
132
133 if raw.trim().is_empty() {
134 return Err(csvwrite_error("csvwrite: filename must not be empty"));
135 }
136
137 let expanded = expand_user_path(&raw, BUILTIN_NAME).map_err(csvwrite_error)?;
138 Ok(Path::new(&expanded).to_path_buf())
139}
140
141fn parse_offsets(args: &[Value]) -> BuiltinResult<(usize, usize)> {
142 match args.len() {
143 0 => Ok((0, 0)),
144 2 => {
145 let row = parse_offset(&args[0], "row offset")?;
146 let col = parse_offset(&args[1], "column offset")?;
147 Ok((row, col))
148 }
149 _ => Err(csvwrite_error(
150 "csvwrite: offsets must be provided as two numeric arguments (row, column)",
151 )),
152 }
153}
154
155fn parse_offset(value: &Value, context: &str) -> BuiltinResult<usize> {
156 match value {
157 Value::Int(i) => {
158 let raw = i.to_i64();
159 if raw < 0 {
160 return Err(csvwrite_error(format!("csvwrite: {context} must be >= 0")));
161 }
162 Ok(raw as usize)
163 }
164 Value::Num(n) => coerce_offset_from_float(*n, context),
165 Value::Bool(b) => Ok(if *b { 1 } else { 0 }),
166 Value::Tensor(t) => {
167 if t.data.len() != 1 {
168 return Err(csvwrite_error(format!(
169 "csvwrite: {context} must be a scalar, got {} elements",
170 t.data.len()
171 )));
172 }
173 coerce_offset_from_float(t.data[0], context)
174 }
175 Value::LogicalArray(logical) => {
176 if logical.data.len() != 1 {
177 return Err(csvwrite_error(format!(
178 "csvwrite: {context} must be a scalar, got {} elements",
179 logical.data.len()
180 )));
181 }
182 Ok(if logical.data[0] != 0 { 1 } else { 0 })
183 }
184 other => Err(csvwrite_error(format!(
185 "csvwrite: {context} must be numeric, got {:?}",
186 other
187 ))),
188 }
189}
190
191fn coerce_offset_from_float(value: f64, context: &str) -> BuiltinResult<usize> {
192 if !value.is_finite() {
193 return Err(csvwrite_error(format!(
194 "csvwrite: {context} must be finite"
195 )));
196 }
197 let rounded = value.round();
198 if (rounded - value).abs() > 1e-9 {
199 return Err(csvwrite_error(format!(
200 "csvwrite: {context} must be an integer"
201 )));
202 }
203 if rounded < 0.0 {
204 return Err(csvwrite_error(format!("csvwrite: {context} must be >= 0")));
205 }
206 Ok(rounded as usize)
207}
208
209fn ensure_matrix_shape(tensor: &Tensor) -> BuiltinResult<()> {
210 if tensor.shape.len() <= 2 {
211 return Ok(());
212 }
213 if tensor.shape[2..].iter().all(|&dim| dim == 1) {
214 return Ok(());
215 }
216 Err(csvwrite_error(
217 "csvwrite: input must be 2-D; reshape before writing",
218 ))
219}
220
221fn write_csv(
222 path: &Path,
223 tensor: &Tensor,
224 row_offset: usize,
225 col_offset: usize,
226) -> BuiltinResult<usize> {
227 let mut options = OpenOptions::new();
228 options.create(true).write(true).truncate(true);
229 let mut file = options.open(path).map_err(|err| {
230 csvwrite_error_with_source(
231 format!(
232 "csvwrite: unable to open \"{}\" for writing ({err})",
233 path.display()
234 ),
235 err,
236 )
237 })?;
238
239 let line_ending = default_line_ending();
240 let rows = tensor.rows();
241 let cols = tensor.cols();
242
243 let mut bytes_written = 0usize;
244
245 for _ in 0..row_offset {
246 file.write_all(line_ending.as_bytes()).map_err(|err| {
247 csvwrite_error_with_source(
248 format!("csvwrite: failed to write line ending ({err})"),
249 err,
250 )
251 })?;
252 bytes_written += line_ending.len();
253 }
254
255 if rows == 0 || cols == 0 {
256 file.flush().map_err(|err| {
257 csvwrite_error_with_source(format!("csvwrite: failed to flush output ({err})"), err)
258 })?;
259 return Ok(bytes_written);
260 }
261
262 for row in 0..rows {
263 let mut fields = Vec::with_capacity(col_offset + cols);
264 for _ in 0..col_offset {
265 fields.push(String::new());
266 }
267 for col in 0..cols {
268 let idx = row + col * rows;
269 let value = tensor.data[idx];
270 fields.push(format_numeric(value));
271 }
272 let line = fields.join(",");
273 if !line.is_empty() {
274 file.write_all(line.as_bytes()).map_err(|err| {
275 csvwrite_error_with_source(format!("csvwrite: failed to write value ({err})"), err)
276 })?;
277 bytes_written += line.len();
278 }
279 file.write_all(line_ending.as_bytes()).map_err(|err| {
280 csvwrite_error_with_source(
281 format!("csvwrite: failed to write line ending ({err})"),
282 err,
283 )
284 })?;
285 bytes_written += line_ending.len();
286 }
287
288 file.flush().map_err(|err| {
289 csvwrite_error_with_source(format!("csvwrite: failed to flush output ({err})"), err)
290 })?;
291
292 Ok(bytes_written)
293}
294
295fn default_line_ending() -> &'static str {
296 if cfg!(windows) {
297 "\r\n"
298 } else {
299 "\n"
300 }
301}
302
303fn format_numeric(value: f64) -> String {
304 if value.is_nan() {
305 return "NaN".to_string();
306 }
307 if value.is_infinite() {
308 return if value.is_sign_negative() {
309 "-Inf".to_string()
310 } else {
311 "Inf".to_string()
312 };
313 }
314 if value == 0.0 {
315 return "0".to_string();
316 }
317
318 let precision: i32 = 5;
319 let abs = value.abs();
320 let exp10 = abs.log10().floor() as i32;
321 let use_scientific = exp10 < -4 || exp10 >= precision;
322
323 let raw = if use_scientific {
324 let digits_after = (precision - 1).max(0) as usize;
325 format!("{:.*e}", digits_after, value)
326 } else {
327 let decimals = (precision - 1 - exp10).max(0) as usize;
328 format!("{:.*}", decimals, value)
329 };
330
331 let mut trimmed = trim_trailing_zeros(raw);
332 if trimmed == "-0" {
333 trimmed = "0".to_string();
334 }
335 trimmed
336}
337
338fn trim_trailing_zeros(mut value: String) -> String {
339 if let Some(exp_pos) = value.find(['e', 'E']) {
340 let exponent = value.split_off(exp_pos);
341 while value.ends_with('0') {
342 value.pop();
343 }
344 if value.ends_with('.') {
345 value.pop();
346 }
347 value.push_str(&normalize_exponent(&exponent));
348 value
349 } else {
350 if value.contains('.') {
351 while value.ends_with('0') {
352 value.pop();
353 }
354 if value.ends_with('.') {
355 value.pop();
356 }
357 }
358 if value.is_empty() {
359 "0".to_string()
360 } else {
361 value
362 }
363 }
364}
365
366fn normalize_exponent(exponent: &str) -> String {
367 if exponent.len() <= 1 {
368 return exponent.to_string();
369 }
370 let mut chars = exponent.chars();
371 let marker = chars.next().unwrap();
372 let rest: String = chars.collect();
373 match rest.parse::<i32>() {
374 Ok(parsed) => format!("{}{:+03}", marker, parsed),
375 Err(_) => exponent.to_string(),
376 }
377}
378
379#[cfg(test)]
380pub(crate) mod tests {
381 use super::*;
382 use runmat_time::unix_timestamp_ms;
383 use std::fs;
384 use std::sync::atomic::{AtomicU64, Ordering};
385
386 use runmat_accelerate_api::HostTensorView;
387 use runmat_builtins::{IntValue, LogicalArray};
388
389 use crate::builtins::common::fs as fs_helpers;
390 use crate::builtins::common::test_support;
391
392 fn csvwrite_builtin(filename: Value, data: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
393 futures::executor::block_on(super::csvwrite_builtin(filename, data, rest))
394 }
395
396 static NEXT_ID: AtomicU64 = AtomicU64::new(0);
397
398 fn temp_path(ext: &str) -> PathBuf {
399 let millis = unix_timestamp_ms();
400 let unique = NEXT_ID.fetch_add(1, Ordering::Relaxed);
401 let mut path = std::env::temp_dir();
402 path.push(format!(
403 "runmat_csvwrite_{}_{}_{}.{}",
404 std::process::id(),
405 millis,
406 unique,
407 ext
408 ));
409 path
410 }
411
412 fn line_ending() -> &'static str {
413 default_line_ending()
414 }
415
416 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
417 #[test]
418 fn csvwrite_writes_basic_matrix() {
419 let path = temp_path("csv");
420 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![2, 3]).unwrap();
421 let filename = path.to_string_lossy().into_owned();
422
423 csvwrite_builtin(Value::from(filename), Value::Tensor(tensor), Vec::new())
424 .expect("csvwrite");
425
426 let contents = fs::read_to_string(&path).expect("read contents");
427 assert_eq!(contents, format!("1,2,3{le}4,5,6{le}", le = line_ending()));
428 let _ = fs::remove_file(path);
429 }
430
431 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
432 #[test]
433 fn csvwrite_honours_offsets() {
434 let path = temp_path("csv");
435 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
436 let filename = path.to_string_lossy().into_owned();
437
438 csvwrite_builtin(
439 Value::from(filename),
440 Value::Tensor(tensor),
441 vec![Value::Int(IntValue::I32(1)), Value::Int(IntValue::I32(2))],
442 )
443 .expect("csvwrite");
444
445 let contents = fs::read_to_string(&path).expect("read contents");
446 assert_eq!(
447 contents,
448 format!("{le},,1,3{le},,2,4{le}", le = line_ending())
449 );
450 let _ = fs::remove_file(path);
451 }
452
453 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
454 #[test]
455 fn csvwrite_handles_gpu_tensors() {
456 test_support::with_test_provider(|provider| {
457 let path = temp_path("csv");
458 let tensor = Tensor::new(vec![0.5, 1.5], vec![1, 2]).unwrap();
459 let view = HostTensorView {
460 data: &tensor.data,
461 shape: &tensor.shape,
462 };
463 let handle = provider.upload(&view).expect("upload");
464 let filename = path.to_string_lossy().into_owned();
465
466 csvwrite_builtin(Value::from(filename), Value::GpuTensor(handle), Vec::new())
467 .expect("csvwrite");
468
469 let contents = fs::read_to_string(&path).expect("read contents");
470 assert_eq!(contents, format!("0.5,1.5{le}", le = line_ending()));
471 let _ = fs::remove_file(path);
472 });
473 }
474
475 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
476 #[test]
477 fn csvwrite_formats_with_short_g_precision() {
478 let path = temp_path("csv");
479 let values =
480 Tensor::new(vec![12.3456, 1_234_567.0, 0.000123456, -0.0], vec![1, 4]).unwrap();
481 let filename = path.to_string_lossy().into_owned();
482
483 csvwrite_builtin(Value::from(filename), Value::Tensor(values), Vec::new())
484 .expect("csvwrite");
485
486 let contents = fs::read_to_string(&path).expect("read contents");
487 assert_eq!(
488 contents,
489 format!("12.346,1.2346e+06,0.00012346,0{le}", le = line_ending())
490 );
491 let _ = fs::remove_file(path);
492 }
493
494 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
495 #[test]
496 fn csvwrite_rejects_negative_offsets() {
497 let path = temp_path("csv");
498 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
499 let filename = path.to_string_lossy().into_owned();
500 let err = csvwrite_builtin(
501 Value::from(filename),
502 Value::Tensor(tensor),
503 vec![Value::Num(-1.0), Value::Num(0.0)],
504 )
505 .expect_err("negative offsets should be rejected");
506 let message = err.message().to_string();
507 assert!(
508 message.contains("row offset"),
509 "unexpected error message: {message}"
510 );
511 }
512
513 #[cfg(feature = "wgpu")]
514 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
515 #[test]
516 fn csvwrite_handles_wgpu_provider_gather() {
517 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
518 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
519 );
520 let Some(provider) = runmat_accelerate_api::provider() else {
521 panic!("wgpu provider not registered");
522 };
523
524 let path = temp_path("csv");
525 let tensor = Tensor::new(vec![2.0, 4.0], vec![1, 2]).unwrap();
526 let view = HostTensorView {
527 data: &tensor.data,
528 shape: &tensor.shape,
529 };
530 let handle = provider.upload(&view).expect("upload");
531 let filename = path.to_string_lossy().into_owned();
532
533 csvwrite_builtin(Value::from(filename), Value::GpuTensor(handle), Vec::new())
534 .expect("csvwrite");
535
536 let contents = fs::read_to_string(&path).expect("read contents");
537 assert_eq!(contents, format!("2,4{le}", le = line_ending()));
538 let _ = fs::remove_file(path);
539 }
540
541 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
542 #[test]
543 fn csvwrite_expands_home_directory() {
544 let Some(mut home) = fs_helpers::home_directory() else {
545 return;
547 };
548 let filename = format!(
549 "runmat_csvwrite_home_{}_{}.csv",
550 std::process::id(),
551 NEXT_ID.fetch_add(1, Ordering::Relaxed)
552 );
553 home.push(&filename);
554
555 let tilde_path = format!("~/{}", filename);
556 let tensor = Tensor::new(vec![42.0], vec![1, 1]).unwrap();
557
558 csvwrite_builtin(Value::from(tilde_path), Value::Tensor(tensor), Vec::new())
559 .expect("csvwrite");
560
561 let contents = fs::read_to_string(&home).expect("read contents");
562 assert_eq!(contents, format!("42{le}", le = line_ending()));
563 let _ = fs::remove_file(home);
564 }
565
566 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
567 #[test]
568 fn csvwrite_rejects_non_numeric_inputs() {
569 let path = temp_path("csv");
570 let filename = path.to_string_lossy().into_owned();
571 let err = csvwrite_builtin(
572 Value::from(filename),
573 Value::String("abc".into()),
574 Vec::new(),
575 )
576 .expect_err("csvwrite should fail");
577 let message = err.message().to_string();
578 assert!(
579 message.contains("csvwrite"),
580 "unexpected error message: {message}"
581 );
582 }
583
584 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
585 #[test]
586 fn csvwrite_accepts_logical_arrays() {
587 let path = temp_path("csv");
588 let logical = LogicalArray::new(vec![1, 0, 1, 0], vec![2, 2]).unwrap();
589 let filename = path.to_string_lossy().into_owned();
590
591 csvwrite_builtin(
592 Value::from(filename),
593 Value::LogicalArray(logical),
594 Vec::new(),
595 )
596 .expect("csvwrite");
597
598 let contents = fs::read_to_string(&path).expect("read contents");
599 assert_eq!(contents, format!("1,1{le}0,0{le}", le = line_ending()));
600 let _ = fs::remove_file(path);
601 }
602}