1use std::io::{Seek, SeekFrom, Write};
3
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 CharArray, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::io::filetext::registry;
16use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
17use runmat_filesystem::File;
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21 name: "fwrite",
22 op_kind: GpuOpKind::Custom("file-io-write"),
23 supported_precisions: &[],
24 broadcast: BroadcastSemantics::None,
25 provider_hooks: &[],
26 constant_strategy: ConstantStrategy::InlineLiteral,
27 residency: ResidencyPolicy::GatherImmediately,
28 nan_mode: ReductionNaN::Include,
29 two_pass_threshold: None,
30 workgroup_size: None,
31 accepts_nan_mode: false,
32 notes: "Host-only binary file I/O; GPU arguments are gathered to the CPU prior to writing.",
33};
34
35#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
36pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
37 name: "fwrite",
38 shape: ShapeRequirements::Any,
39 constant_strategy: ConstantStrategy::InlineLiteral,
40 elementwise: None,
41 reduction: None,
42 emits_nan: false,
43 notes: "File I/O is never fused; metadata recorded for completeness.",
44};
45
46const BUILTIN_NAME: &str = "fwrite";
47
48const FWRITE_OUTPUT_COUNT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
49 name: "count",
50 ty: BuiltinParamType::NumericScalar,
51 arity: BuiltinParamArity::Required,
52 default: None,
53 description: "Number of elements successfully written.",
54}];
55const FWRITE_INPUTS_FID_DATA: [BuiltinParamDescriptor; 2] = [
56 BuiltinParamDescriptor {
57 name: "fid",
58 ty: BuiltinParamType::NumericScalar,
59 arity: BuiltinParamArity::Required,
60 default: None,
61 description: "File identifier opened by fopen.",
62 },
63 BuiltinParamDescriptor {
64 name: "data",
65 ty: BuiltinParamType::Any,
66 arity: BuiltinParamArity::Required,
67 default: None,
68 description: "Numeric/logical/text payload to write.",
69 },
70];
71const FWRITE_INPUTS_FID_DATA_PRECISION: [BuiltinParamDescriptor; 3] = [
72 BuiltinParamDescriptor {
73 name: "fid",
74 ty: BuiltinParamType::NumericScalar,
75 arity: BuiltinParamArity::Required,
76 default: None,
77 description: "File identifier opened by fopen.",
78 },
79 BuiltinParamDescriptor {
80 name: "data",
81 ty: BuiltinParamType::Any,
82 arity: BuiltinParamArity::Required,
83 default: None,
84 description: "Numeric/logical/text payload to write.",
85 },
86 BuiltinParamDescriptor {
87 name: "precision",
88 ty: BuiltinParamType::StringScalar,
89 arity: BuiltinParamArity::Optional,
90 default: Some("\"uint8\""),
91 description: "Write precision label (for example \"uint8\", \"double\").",
92 },
93];
94const FWRITE_INPUTS_FID_DATA_PRECISION_SKIP: [BuiltinParamDescriptor; 4] = [
95 BuiltinParamDescriptor {
96 name: "fid",
97 ty: BuiltinParamType::NumericScalar,
98 arity: BuiltinParamArity::Required,
99 default: None,
100 description: "File identifier opened by fopen.",
101 },
102 BuiltinParamDescriptor {
103 name: "data",
104 ty: BuiltinParamType::Any,
105 arity: BuiltinParamArity::Required,
106 default: None,
107 description: "Numeric/logical/text payload to write.",
108 },
109 BuiltinParamDescriptor {
110 name: "precision",
111 ty: BuiltinParamType::StringScalar,
112 arity: BuiltinParamArity::Optional,
113 default: Some("\"uint8\""),
114 description: "Write precision label (for example \"uint8\", \"double\").",
115 },
116 BuiltinParamDescriptor {
117 name: "skip",
118 ty: BuiltinParamType::NumericScalar,
119 arity: BuiltinParamArity::Optional,
120 default: Some("0"),
121 description: "Bytes skipped after each element written.",
122 },
123];
124const FWRITE_INPUTS_FID_DATA_PRECISION_MACHINEFMT: [BuiltinParamDescriptor; 4] = [
125 BuiltinParamDescriptor {
126 name: "fid",
127 ty: BuiltinParamType::NumericScalar,
128 arity: BuiltinParamArity::Required,
129 default: None,
130 description: "File identifier opened by fopen.",
131 },
132 BuiltinParamDescriptor {
133 name: "data",
134 ty: BuiltinParamType::Any,
135 arity: BuiltinParamArity::Required,
136 default: None,
137 description: "Numeric/logical/text payload to write.",
138 },
139 BuiltinParamDescriptor {
140 name: "precision",
141 ty: BuiltinParamType::StringScalar,
142 arity: BuiltinParamArity::Optional,
143 default: Some("\"uint8\""),
144 description: "Write precision label (for example \"uint8\", \"double\").",
145 },
146 BuiltinParamDescriptor {
147 name: "machinefmt",
148 ty: BuiltinParamType::StringScalar,
149 arity: BuiltinParamArity::Optional,
150 default: Some("\"native\""),
151 description: "Machine format label (native/little-endian/big-endian aliases).",
152 },
153];
154const FWRITE_INPUTS_FID_DATA_PRECISION_SKIP_MACHINEFMT: [BuiltinParamDescriptor; 5] = [
155 BuiltinParamDescriptor {
156 name: "fid",
157 ty: BuiltinParamType::NumericScalar,
158 arity: BuiltinParamArity::Required,
159 default: None,
160 description: "File identifier opened by fopen.",
161 },
162 BuiltinParamDescriptor {
163 name: "data",
164 ty: BuiltinParamType::Any,
165 arity: BuiltinParamArity::Required,
166 default: None,
167 description: "Numeric/logical/text payload to write.",
168 },
169 BuiltinParamDescriptor {
170 name: "precision",
171 ty: BuiltinParamType::StringScalar,
172 arity: BuiltinParamArity::Optional,
173 default: Some("\"uint8\""),
174 description: "Write precision label (for example \"uint8\", \"double\").",
175 },
176 BuiltinParamDescriptor {
177 name: "skip",
178 ty: BuiltinParamType::NumericScalar,
179 arity: BuiltinParamArity::Optional,
180 default: Some("0"),
181 description: "Bytes skipped after each element written.",
182 },
183 BuiltinParamDescriptor {
184 name: "machinefmt",
185 ty: BuiltinParamType::StringScalar,
186 arity: BuiltinParamArity::Optional,
187 default: Some("\"native\""),
188 description: "Machine format label (native/little-endian/big-endian aliases).",
189 },
190];
191const FWRITE_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
192 BuiltinSignatureDescriptor {
193 label: "count = fwrite(fid, data)",
194 inputs: &FWRITE_INPUTS_FID_DATA,
195 outputs: &FWRITE_OUTPUT_COUNT,
196 },
197 BuiltinSignatureDescriptor {
198 label: "count = fwrite(fid, data, precision)",
199 inputs: &FWRITE_INPUTS_FID_DATA_PRECISION,
200 outputs: &FWRITE_OUTPUT_COUNT,
201 },
202 BuiltinSignatureDescriptor {
203 label: "count = fwrite(fid, data, precision, skip)",
204 inputs: &FWRITE_INPUTS_FID_DATA_PRECISION_SKIP,
205 outputs: &FWRITE_OUTPUT_COUNT,
206 },
207 BuiltinSignatureDescriptor {
208 label: "count = fwrite(fid, data, precision, machinefmt)",
209 inputs: &FWRITE_INPUTS_FID_DATA_PRECISION_MACHINEFMT,
210 outputs: &FWRITE_OUTPUT_COUNT,
211 },
212 BuiltinSignatureDescriptor {
213 label: "count = fwrite(fid, data, precision, skip, machinefmt)",
214 inputs: &FWRITE_INPUTS_FID_DATA_PRECISION_SKIP_MACHINEFMT,
215 outputs: &FWRITE_OUTPUT_COUNT,
216 },
217];
218
219const FWRITE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
220 code: "RM.FWRITE.INVALID_INPUT",
221 identifier: Some("RunMat:fwrite:InvalidInput"),
222 when: "Identifier, payload, or argument cardinality/type constraints are violated.",
223 message: "fwrite: invalid input arguments",
224};
225const FWRITE_ERROR_INVALID_IDENTIFIER: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
226 code: "RM.FWRITE.INVALID_IDENTIFIER",
227 identifier: Some("RunMat:fwrite:InvalidIdentifier"),
228 when: "Identifier does not refer to a writable open file.",
229 message: "fwrite: invalid file identifier. Use fopen to generate a valid file ID.",
230};
231const FWRITE_ERROR_INVALID_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
232 code: "RM.FWRITE.INVALID_OPTION",
233 identifier: Some("RunMat:fwrite:InvalidOption"),
234 when: "Precision, skip, or machine format options are invalid.",
235 message: "fwrite: invalid option configuration",
236};
237const FWRITE_ERROR_IO: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
238 code: "RM.FWRITE.IO",
239 identifier: Some("RunMat:fwrite:IoFailure"),
240 when: "Write/seek operation fails.",
241 message: "fwrite: file write failed",
242};
243const FWRITE_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
244 code: "RM.FWRITE.INTERNAL",
245 identifier: None,
246 when: "Internal runtime control-flow conversion fails.",
247 message: "fwrite: internal error",
248};
249const FWRITE_ERRORS: [BuiltinErrorDescriptor; 5] = [
250 FWRITE_ERROR_INVALID_INPUT,
251 FWRITE_ERROR_INVALID_IDENTIFIER,
252 FWRITE_ERROR_INVALID_OPTION,
253 FWRITE_ERROR_IO,
254 FWRITE_ERROR_INTERNAL,
255];
256pub const FWRITE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
257 signatures: &FWRITE_SIGNATURES,
258 output_mode: BuiltinOutputMode::Fixed,
259 completion_policy: BuiltinCompletionPolicy::Public,
260 errors: &FWRITE_ERRORS,
261};
262
263fn fwrite_error_with_detail(
264 error: &'static BuiltinErrorDescriptor,
265 detail: impl AsRef<str>,
266) -> RuntimeError {
267 let detail = detail.as_ref();
268 let detail = detail.strip_prefix("fwrite: ").unwrap_or(detail);
269 fwrite_error_with_message(format!("{}: {}", error.message, detail), error)
270}
271
272fn fwrite_error_with_message(
273 message: impl Into<String>,
274 error: &'static BuiltinErrorDescriptor,
275) -> RuntimeError {
276 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
277 if let Some(identifier) = error.identifier {
278 builder = builder.with_identifier(identifier);
279 }
280 builder.build()
281}
282
283fn map_control_flow(err: RuntimeError) -> RuntimeError {
284 let mut builder = build_runtime_error(format!("{BUILTIN_NAME}: {}", err.message()))
285 .with_builtin(BUILTIN_NAME)
286 .with_source(err);
287 if let Some(identifier) = FWRITE_ERROR_INTERNAL.identifier {
288 builder = builder.with_identifier(identifier);
289 }
290 builder.build()
291}
292
293fn map_string_result<T>(
294 result: Result<T, String>,
295 error: &'static BuiltinErrorDescriptor,
296) -> BuiltinResult<T> {
297 result.map_err(|detail| fwrite_error_with_detail(error, detail))
298}
299
300#[runtime_builtin(
301 name = "fwrite",
302 category = "io/filetext",
303 summary = "Write binary data to file identifiers.",
304 keywords = "fwrite,file,io,binary,precision",
305 accel = "cpu",
306 type_resolver(crate::builtins::io::type_resolvers::fwrite_type),
307 descriptor(crate::builtins::io::filetext::fwrite::FWRITE_DESCRIPTOR),
308 builtin_path = "crate::builtins::io::filetext::fwrite"
309)]
310async fn fwrite_builtin(fid: Value, data: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
311 let eval = evaluate(&fid, &data, &rest).await?;
312 Ok(Value::Num(eval.count as f64))
313}
314
315#[derive(Debug, Clone)]
317pub struct FwriteEval {
318 count: usize,
319}
320
321impl FwriteEval {
322 fn new(count: usize) -> Self {
323 Self { count }
324 }
325
326 pub fn count(&self) -> usize {
328 self.count
329 }
330}
331
332pub async fn evaluate(
334 fid_value: &Value,
335 data_value: &Value,
336 rest: &[Value],
337) -> BuiltinResult<FwriteEval> {
338 let fid_host = gather_value(fid_value).await?;
339 let fid = map_string_result(parse_fid(&fid_host), &FWRITE_ERROR_INVALID_INPUT)?;
340 if fid < 0 {
341 return Err(fwrite_error_with_detail(
342 &FWRITE_ERROR_INVALID_INPUT,
343 "file identifier must be non-negative",
344 ));
345 }
346 if fid < 3 {
347 return Err(fwrite_error_with_detail(
348 &FWRITE_ERROR_INVALID_INPUT,
349 "standard input/output identifiers are not supported yet",
350 ));
351 }
352
353 let info = registry::info_for(fid).ok_or_else(|| {
354 fwrite_error_with_message(
355 FWRITE_ERROR_INVALID_IDENTIFIER.message,
356 &FWRITE_ERROR_INVALID_IDENTIFIER,
357 )
358 })?;
359 let handle = registry::take_handle(fid).ok_or_else(|| {
360 fwrite_error_with_message(
361 FWRITE_ERROR_INVALID_IDENTIFIER.message,
362 &FWRITE_ERROR_INVALID_IDENTIFIER,
363 )
364 })?;
365
366 let data_host = gather_value(data_value).await?;
367 let rest_host = gather_args(rest).await?;
368 let (precision_arg, skip_arg, machine_arg) =
369 map_string_result(classify_arguments(&rest_host), &FWRITE_ERROR_INVALID_INPUT)?;
370
371 let precision_spec =
372 map_string_result(parse_precision(precision_arg), &FWRITE_ERROR_INVALID_OPTION)?;
373 let skip_bytes = map_string_result(parse_skip(skip_arg), &FWRITE_ERROR_INVALID_OPTION)?;
374 let machine_format = map_string_result(
375 parse_machine_format(machine_arg, &info.machinefmt),
376 &FWRITE_ERROR_INVALID_OPTION,
377 )?;
378
379 let mut guard = handle.lock().map_err(|_| {
380 fwrite_error_with_detail(
381 &FWRITE_ERROR_INTERNAL,
382 "failed to lock file handle (poisoned mutex)",
383 )
384 })?;
385 let file = guard.as_mut().ok_or_else(|| {
386 fwrite_error_with_message(
387 FWRITE_ERROR_INVALID_IDENTIFIER.message,
388 &FWRITE_ERROR_INVALID_IDENTIFIER,
389 )
390 })?;
391
392 let elements = map_string_result(flatten_elements(&data_host), &FWRITE_ERROR_INVALID_INPUT)?;
393 let count = map_string_result(
394 write_elements(file, &elements, precision_spec, skip_bytes, machine_format),
395 &FWRITE_ERROR_IO,
396 )?;
397 Ok(FwriteEval::new(count))
398}
399
400async fn gather_value(value: &Value) -> BuiltinResult<Value> {
401 gather_if_needed_async(value)
402 .await
403 .map_err(map_control_flow)
404}
405
406async fn gather_args(args: &[Value]) -> BuiltinResult<Vec<Value>> {
407 let mut gathered = Vec::with_capacity(args.len());
408 for value in args {
409 gathered.push(
410 gather_if_needed_async(value)
411 .await
412 .map_err(map_control_flow)?,
413 );
414 }
415 Ok(gathered)
416}
417
418fn parse_fid(value: &Value) -> Result<i32, String> {
419 let scalar = match value {
420 Value::Num(n) => *n,
421 Value::Int(int) => int.to_f64(),
422 _ => return Err("fwrite: file identifier must be numeric".to_string()),
423 };
424 if !scalar.is_finite() {
425 return Err("fwrite: file identifier must be finite".to_string());
426 }
427 if scalar.fract().abs() > f64::EPSILON {
428 return Err("fwrite: file identifier must be an integer".to_string());
429 }
430 Ok(scalar as i32)
431}
432
433type FwriteArgs<'a> = (Option<&'a Value>, Option<&'a Value>, Option<&'a Value>);
434
435fn classify_arguments(args: &[Value]) -> Result<FwriteArgs<'_>, String> {
436 match args.len() {
437 0 => Ok((None, None, None)),
438 1 => {
439 if is_string_like(&args[0]) {
440 Ok((Some(&args[0]), None, None))
441 } else {
442 Err(
443 "fwrite: precision argument must be a string scalar or character vector"
444 .to_string(),
445 )
446 }
447 }
448 2 => {
449 if !is_string_like(&args[0]) {
450 return Err(
451 "fwrite: precision argument must be a string scalar or character vector"
452 .to_string(),
453 );
454 }
455 if is_numeric_like(&args[1]) {
456 Ok((Some(&args[0]), Some(&args[1]), None))
457 } else if is_string_like(&args[1]) {
458 Ok((Some(&args[0]), None, Some(&args[1])))
459 } else {
460 Err("fwrite: invalid argument combination (expected numeric skip or machine format string)".to_string())
461 }
462 }
463 3 => {
464 if !is_string_like(&args[0]) || !is_numeric_like(&args[1]) || !is_string_like(&args[2])
465 {
466 return Err("fwrite: expected arguments (precision, skip, machinefmt)".to_string());
467 }
468 Ok((Some(&args[0]), Some(&args[1]), Some(&args[2])))
469 }
470 _ => Err("fwrite: too many input arguments".to_string()),
471 }
472}
473
474fn is_string_like(value: &Value) -> bool {
475 match value {
476 Value::String(_) => true,
477 Value::CharArray(ca) => ca.rows == 1,
478 Value::StringArray(sa) => sa.data.len() == 1,
479 _ => false,
480 }
481}
482
483fn is_numeric_like(value: &Value) -> bool {
484 match value {
485 Value::Num(_) | Value::Int(_) | Value::Bool(_) => true,
486 Value::Tensor(t) => t.data.len() == 1,
487 Value::LogicalArray(la) => la.data.len() == 1,
488 _ => false,
489 }
490}
491
492#[derive(Clone, Copy, Debug)]
493struct WriteSpec {
494 input: InputType,
495}
496
497impl WriteSpec {
498 fn default() -> Self {
499 Self {
500 input: InputType::UInt8,
501 }
502 }
503}
504
505fn parse_precision(arg: Option<&Value>) -> Result<WriteSpec, String> {
506 match arg {
507 None => Ok(WriteSpec::default()),
508 Some(value) => {
509 let text = scalar_string(
510 value,
511 "fwrite: precision argument must be a string scalar or character vector",
512 )?;
513 parse_precision_string(&text)
514 }
515 }
516}
517
518fn parse_precision_string(raw: &str) -> Result<WriteSpec, String> {
519 let trimmed = raw.trim();
520 if trimmed.is_empty() {
521 return Err("fwrite: precision argument must not be empty".to_string());
522 }
523 let lower = trimmed.to_ascii_lowercase();
524 if let Some((lhs, rhs)) = lower.split_once("=>") {
525 let lhs = lhs.trim();
526 let rhs = rhs.trim();
527 let input = parse_input_label(lhs)?;
528 let output = parse_input_label(rhs)?;
529 if input != output {
530 return Err(
531 "fwrite: differing input/output precisions are not implemented yet".to_string(),
532 );
533 }
534 Ok(WriteSpec { input })
535 } else {
536 parse_input_label(lower.trim()).map(|input| WriteSpec { input })
537 }
538}
539
540fn parse_skip(arg: Option<&Value>) -> Result<usize, String> {
541 match arg {
542 None => Ok(0),
543 Some(value) => {
544 let scalar = numeric_scalar(value, "fwrite: skip must be numeric")?;
545 if !scalar.is_finite() {
546 return Err("fwrite: skip value must be finite".to_string());
547 }
548 if scalar < 0.0 {
549 return Err("fwrite: skip value must be non-negative".to_string());
550 }
551 let rounded = scalar.round();
552 if (rounded - scalar).abs() > f64::EPSILON {
553 return Err("fwrite: skip value must be an integer".to_string());
554 }
555 if rounded > i64::MAX as f64 {
556 return Err("fwrite: skip value is too large".to_string());
557 }
558 Ok(rounded as usize)
559 }
560 }
561}
562
563#[derive(Clone, Copy, Debug)]
564enum MachineFormat {
565 Native,
566 LittleEndian,
567 BigEndian,
568}
569
570impl MachineFormat {
571 fn to_endianness(self) -> Endianness {
572 match self {
573 MachineFormat::Native => {
574 if cfg!(target_endian = "little") {
575 Endianness::Little
576 } else {
577 Endianness::Big
578 }
579 }
580 MachineFormat::LittleEndian => Endianness::Little,
581 MachineFormat::BigEndian => Endianness::Big,
582 }
583 }
584}
585
586#[derive(Clone, Copy, Debug)]
587enum Endianness {
588 Little,
589 Big,
590}
591
592fn parse_machine_format(arg: Option<&Value>, default_label: &str) -> Result<MachineFormat, String> {
593 match arg {
594 Some(value) => {
595 let text = scalar_string(
596 value,
597 "fwrite: machine format must be a string scalar or character vector",
598 )?;
599 machine_format_from_label(&text)
600 }
601 None => machine_format_from_label(default_label),
602 }
603}
604
605fn machine_format_from_label(label: &str) -> Result<MachineFormat, String> {
606 let trimmed = label.trim();
607 if trimmed.is_empty() {
608 return Err("fwrite: machine format must not be empty".to_string());
609 }
610 let lower = trimmed.to_ascii_lowercase();
611 let collapsed: String = lower
612 .chars()
613 .filter(|c| !matches!(c, '-' | '_' | ' '))
614 .collect();
615 if matches!(collapsed.as_str(), "native" | "n" | "system" | "default") {
616 return Ok(MachineFormat::Native);
617 }
618 if matches!(
619 collapsed.as_str(),
620 "l" | "le" | "littleendian" | "pc" | "intel"
621 ) {
622 return Ok(MachineFormat::LittleEndian);
623 }
624 if matches!(
625 collapsed.as_str(),
626 "b" | "be" | "bigendian" | "mac" | "motorola"
627 ) {
628 return Ok(MachineFormat::BigEndian);
629 }
630 if lower.starts_with("ieee-le") {
631 return Ok(MachineFormat::LittleEndian);
632 }
633 if lower.starts_with("ieee-be") {
634 return Ok(MachineFormat::BigEndian);
635 }
636 Err(format!("fwrite: unsupported machine format '{trimmed}'"))
637}
638
639fn scalar_string(value: &Value, err: &str) -> Result<String, String> {
640 match value {
641 Value::String(s) => Ok(s.clone()),
642 Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
643 Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
644 _ => Err(err.to_string()),
645 }
646}
647
648fn numeric_scalar(value: &Value, err: &str) -> Result<f64, String> {
649 match value {
650 Value::Num(n) => Ok(*n),
651 Value::Int(int) => Ok(int.to_f64()),
652 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
653 Value::Tensor(t) if t.data.len() == 1 => Ok(t.data[0]),
654 Value::LogicalArray(la) if la.data.len() == 1 => {
655 Ok(if la.data[0] != 0 { 1.0 } else { 0.0 })
656 }
657 _ => Err(err.to_string()),
658 }
659}
660
661fn flatten_elements(value: &Value) -> Result<Vec<f64>, String> {
662 match value {
663 Value::Tensor(tensor) => Ok(tensor.data.clone()),
664 Value::Num(n) => Ok(vec![*n]),
665 Value::Int(int) => Ok(vec![int.to_f64()]),
666 Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
667 Value::LogicalArray(array) => Ok(array
668 .data
669 .iter()
670 .map(|bit| if *bit != 0 { 1.0 } else { 0.0 })
671 .collect()),
672 Value::CharArray(ca) => Ok(flatten_char_array(ca)),
673 Value::String(text) => Ok(text.chars().map(|ch| ch as u32 as f64).collect()),
674 Value::StringArray(sa) => Ok(flatten_string_array(sa)),
675 Value::GpuTensor(_) => Err("fwrite: expected host tensor data after gathering".to_string()),
676 Value::Complex(_, _) | Value::ComplexTensor(_) => {
677 Err("fwrite: complex values are not supported yet".to_string())
678 }
679 _ => Err(format!("fwrite: unsupported data type {:?}", value)),
680 }
681}
682
683fn flatten_char_array(ca: &CharArray) -> Vec<f64> {
684 let mut values = Vec::with_capacity(ca.rows.saturating_mul(ca.cols));
685 for c in 0..ca.cols {
686 for r in 0..ca.rows {
687 let idx = r * ca.cols + c;
688 values.push(ca.data[idx] as u32 as f64);
689 }
690 }
691 values
692}
693
694fn flatten_string_array(sa: &runmat_builtins::StringArray) -> Vec<f64> {
695 if sa.data.is_empty() {
696 return Vec::new();
697 }
698 let mut values = Vec::new();
699 for (idx, text) in sa.data.iter().enumerate() {
700 if idx > 0 {
701 values.push('\n' as u32 as f64);
702 }
703 values.extend(text.chars().map(|ch| ch as u32 as f64));
704 }
705 values
706}
707
708fn write_elements(
709 file: &mut File,
710 values: &[f64],
711 spec: WriteSpec,
712 skip: usize,
713 machine: MachineFormat,
714) -> Result<usize, String> {
715 let endianness = machine.to_endianness();
716 let skip_offset = skip as i64;
717 for &value in values {
718 match spec.input {
719 InputType::UInt8 => {
720 let byte = to_u8(value);
721 write_bytes(file, &[byte])?;
722 }
723 InputType::Int8 => {
724 let byte = to_i8(value) as u8;
725 write_bytes(file, &[byte])?;
726 }
727 InputType::UInt16 => {
728 let bytes = encode_u16(value, endianness);
729 write_bytes(file, &bytes)?;
730 }
731 InputType::Int16 => {
732 let bytes = encode_i16(value, endianness);
733 write_bytes(file, &bytes)?;
734 }
735 InputType::UInt32 => {
736 let bytes = encode_u32(value, endianness);
737 write_bytes(file, &bytes)?;
738 }
739 InputType::Int32 => {
740 let bytes = encode_i32(value, endianness);
741 write_bytes(file, &bytes)?;
742 }
743 InputType::UInt64 => {
744 let bytes = encode_u64(value, endianness);
745 write_bytes(file, &bytes)?;
746 }
747 InputType::Int64 => {
748 let bytes = encode_i64(value, endianness);
749 write_bytes(file, &bytes)?;
750 }
751 InputType::Float32 => {
752 let bytes = encode_f32(value, endianness);
753 write_bytes(file, &bytes)?;
754 }
755 InputType::Float64 => {
756 let bytes = encode_f64(value, endianness);
757 write_bytes(file, &bytes)?;
758 }
759 }
760
761 if skip > 0 {
762 file.seek(SeekFrom::Current(skip_offset))
763 .map_err(|err| format!("fwrite: failed to seek while applying skip ({err})"))?;
764 }
765 }
766 Ok(values.len())
767}
768
769fn write_bytes(file: &mut File, bytes: &[u8]) -> Result<(), String> {
770 file.write_all(bytes)
771 .map_err(|err| format!("fwrite: failed to write to file ({err})"))
772}
773
774fn to_u8(value: f64) -> u8 {
775 if !value.is_finite() {
776 return if value.is_sign_negative() { 0 } else { u8::MAX };
777 }
778 let mut rounded = value.round();
779 if rounded.is_nan() {
780 return 0;
781 }
782 if rounded < 0.0 {
783 rounded = 0.0;
784 }
785 if rounded > u8::MAX as f64 {
786 rounded = u8::MAX as f64;
787 }
788 rounded as u8
789}
790
791fn to_i8(value: f64) -> i8 {
792 saturating_round(value, i8::MIN as f64, i8::MAX as f64) as i8
793}
794
795fn encode_u16(value: f64, endianness: Endianness) -> [u8; 2] {
796 let rounded = saturating_round(value, 0.0, u16::MAX as f64) as u16;
797 match endianness {
798 Endianness::Little => rounded.to_le_bytes(),
799 Endianness::Big => rounded.to_be_bytes(),
800 }
801}
802
803fn encode_i16(value: f64, endianness: Endianness) -> [u8; 2] {
804 let rounded = saturating_round(value, i16::MIN as f64, i16::MAX as f64) as i16;
805 match endianness {
806 Endianness::Little => rounded.to_le_bytes(),
807 Endianness::Big => rounded.to_be_bytes(),
808 }
809}
810
811fn encode_u32(value: f64, endianness: Endianness) -> [u8; 4] {
812 let rounded = saturating_round(value, 0.0, u32::MAX as f64) as u32;
813 match endianness {
814 Endianness::Little => rounded.to_le_bytes(),
815 Endianness::Big => rounded.to_be_bytes(),
816 }
817}
818
819fn encode_i32(value: f64, endianness: Endianness) -> [u8; 4] {
820 let rounded = saturating_round(value, i32::MIN as f64, i32::MAX as f64) as i32;
821 match endianness {
822 Endianness::Little => rounded.to_le_bytes(),
823 Endianness::Big => rounded.to_be_bytes(),
824 }
825}
826
827fn encode_u64(value: f64, endianness: Endianness) -> [u8; 8] {
828 let rounded = saturating_round(value, 0.0, u64::MAX as f64);
829 let as_u64 = if rounded.is_finite() {
830 rounded as u64
831 } else if rounded.is_sign_negative() {
832 0
833 } else {
834 u64::MAX
835 };
836 match endianness {
837 Endianness::Little => as_u64.to_le_bytes(),
838 Endianness::Big => as_u64.to_be_bytes(),
839 }
840}
841
842fn encode_i64(value: f64, endianness: Endianness) -> [u8; 8] {
843 let rounded = saturating_round(value, i64::MIN as f64, i64::MAX as f64);
844 let as_i64 = if rounded.is_finite() {
845 rounded as i64
846 } else if rounded.is_sign_negative() {
847 i64::MIN
848 } else {
849 i64::MAX
850 };
851 match endianness {
852 Endianness::Little => as_i64.to_le_bytes(),
853 Endianness::Big => as_i64.to_be_bytes(),
854 }
855}
856
857fn encode_f32(value: f64, endianness: Endianness) -> [u8; 4] {
858 let as_f32 = value as f32;
859 let bits = as_f32.to_bits();
860 match endianness {
861 Endianness::Little => bits.to_le_bytes(),
862 Endianness::Big => bits.to_be_bytes(),
863 }
864}
865
866fn encode_f64(value: f64, endianness: Endianness) -> [u8; 8] {
867 let bits = value.to_bits();
868 match endianness {
869 Endianness::Little => bits.to_le_bytes(),
870 Endianness::Big => bits.to_be_bytes(),
871 }
872}
873
874fn saturating_round(value: f64, min: f64, max: f64) -> f64 {
875 if !value.is_finite() {
876 return if value.is_sign_negative() { min } else { max };
877 }
878 let mut rounded = value.round();
879 if rounded.is_nan() {
880 return 0.0;
881 }
882 if rounded < min {
883 rounded = min;
884 }
885 if rounded > max {
886 rounded = max;
887 }
888 rounded
889}
890
891#[derive(Clone, Copy, Debug, PartialEq, Eq)]
892enum InputType {
893 UInt8,
894 Int8,
895 UInt16,
896 Int16,
897 UInt32,
898 Int32,
899 UInt64,
900 Int64,
901 Float32,
902 Float64,
903}
904
905fn parse_input_label(label: &str) -> Result<InputType, String> {
906 match label {
907 "double" | "float64" | "real*8" => Ok(InputType::Float64),
908 "single" | "float32" | "real*4" => Ok(InputType::Float32),
909 "int8" | "schar" | "integer*1" => Ok(InputType::Int8),
910 "uint8" | "uchar" | "unsignedchar" | "char" | "byte" => Ok(InputType::UInt8),
911 "int16" | "short" | "integer*2" => Ok(InputType::Int16),
912 "uint16" | "ushort" | "unsignedshort" => Ok(InputType::UInt16),
913 "int32" | "integer*4" | "long" => Ok(InputType::Int32),
914 "uint32" | "unsignedint" | "unsignedlong" => Ok(InputType::UInt32),
915 "int64" | "integer*8" | "longlong" => Ok(InputType::Int64),
916 "uint64" | "unsignedlonglong" => Ok(InputType::UInt64),
917 other => Err(format!("fwrite: unsupported precision '{other}'")),
918 }
919}
920
921#[cfg(test)]
922pub(crate) mod tests {
923 use super::*;
924 use crate::builtins::common::test_support;
925 use crate::builtins::io::filetext::registry;
926 use crate::builtins::io::filetext::{fclose, fopen};
927 use crate::RuntimeError;
928 #[cfg(feature = "wgpu")]
929 use runmat_accelerate::backend::wgpu::provider;
930 #[cfg(feature = "wgpu")]
931 use runmat_accelerate_api::AccelProvider;
932 use runmat_accelerate_api::HostTensorView;
933 use runmat_builtins::Tensor;
934 use runmat_filesystem::File;
935 use runmat_time::system_time_now;
936 use std::io::Read;
937 use std::path::PathBuf;
938 use std::time::UNIX_EPOCH;
939
940 fn unwrap_error_message(err: RuntimeError) -> String {
941 err.message().to_string()
942 }
943
944 fn run_evaluate(
945 fid_value: &Value,
946 data_value: &Value,
947 rest: &[Value],
948 ) -> BuiltinResult<FwriteEval> {
949 futures::executor::block_on(evaluate(fid_value, data_value, rest))
950 }
951
952 fn run_fopen(args: &[Value]) -> BuiltinResult<fopen::FopenEval> {
953 futures::executor::block_on(fopen::evaluate(args))
954 }
955
956 fn run_fclose(args: &[Value]) -> BuiltinResult<fclose::FcloseEval> {
957 futures::executor::block_on(fclose::evaluate(args))
958 }
959
960 fn registry_guard() -> std::sync::MutexGuard<'static, ()> {
961 registry::test_guard()
962 }
963
964 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
965 #[test]
966 fn fwrite_descriptor_signatures_cover_core_forms() {
967 let labels: Vec<&str> = FWRITE_DESCRIPTOR
968 .signatures
969 .iter()
970 .map(|sig| sig.label)
971 .collect();
972 assert!(labels.contains(&"count = fwrite(fid, data)"));
973 assert!(labels.contains(&"count = fwrite(fid, data, precision, skip)"));
974 assert!(labels.contains(&"count = fwrite(fid, data, precision, machinefmt)"));
975 assert!(labels.contains(&"count = fwrite(fid, data, precision, skip, machinefmt)"));
976 }
977
978 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
979 #[test]
980 fn fwrite_default_uint8_bytes() {
981 let _guard = registry_guard();
982 registry::reset_for_tests();
983 let path = unique_path("fwrite_uint8");
984 let open = run_fopen(&[
985 Value::from(path.to_string_lossy().to_string()),
986 Value::from("w+b"),
987 ])
988 .expect("fopen");
989 let fid = open.as_open().unwrap().fid as i32;
990
991 let tensor = Tensor::new(vec![1.0, 2.0, 255.0], vec![3, 1]).unwrap();
992 let eval = run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &Vec::new())
993 .expect("fwrite");
994 assert_eq!(eval.count(), 3);
995
996 run_fclose(&[Value::Num(fid as f64)]).unwrap();
997
998 let bytes = test_support::fs::read(&path).expect("read");
999 assert_eq!(bytes, vec![1u8, 2, 255]);
1000 test_support::fs::remove_file(path).unwrap();
1001 }
1002
1003 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1004 #[test]
1005 fn fwrite_double_precision_writes_native_endian() {
1006 let _guard = registry_guard();
1007 registry::reset_for_tests();
1008 let path = unique_path("fwrite_double");
1009 let open = run_fopen(&[
1010 Value::from(path.to_string_lossy().to_string()),
1011 Value::from("w+b"),
1012 ])
1013 .expect("fopen");
1014 let fid = open.as_open().unwrap().fid as i32;
1015
1016 let tensor = Tensor::new(vec![1.5, -2.25], vec![2, 1]).unwrap();
1017 let args = vec![Value::from("double")];
1018 let eval =
1019 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
1020 assert_eq!(eval.count(), 2);
1021
1022 run_fclose(&[Value::Num(fid as f64)]).unwrap();
1023
1024 let bytes = test_support::fs::read(&path).expect("read");
1025 let expected: Vec<u8> = if cfg!(target_endian = "little") {
1026 [1.5f64.to_le_bytes(), (-2.25f64).to_le_bytes()].concat()
1027 } else {
1028 [1.5f64.to_be_bytes(), (-2.25f64).to_be_bytes()].concat()
1029 };
1030 assert_eq!(bytes, expected);
1031 test_support::fs::remove_file(path).unwrap();
1032 }
1033
1034 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1035 #[test]
1036 fn fwrite_big_endian_uint16() {
1037 let _guard = registry_guard();
1038 registry::reset_for_tests();
1039 let path = unique_path("fwrite_be");
1040 let open = run_fopen(&[
1041 Value::from(path.to_string_lossy().to_string()),
1042 Value::from("w+b"),
1043 Value::from("ieee-be"),
1044 ])
1045 .expect("fopen");
1046 let fid = open.as_open().unwrap().fid as i32;
1047
1048 let tensor = Tensor::new(vec![258.0, 772.0], vec![2, 1]).unwrap();
1049 let args = vec![Value::from("uint16")];
1050 let eval =
1051 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
1052 assert_eq!(eval.count(), 2);
1053
1054 run_fclose(&[Value::Num(fid as f64)]).unwrap();
1055
1056 let bytes = test_support::fs::read(&path).expect("read");
1057 assert_eq!(bytes, vec![0x01, 0x02, 0x03, 0x04]);
1058 test_support::fs::remove_file(path).unwrap();
1059 }
1060
1061 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1062 #[test]
1063 fn fwrite_skip_inserts_padding() {
1064 let _guard = registry_guard();
1065 registry::reset_for_tests();
1066 let path = unique_path("fwrite_skip");
1067 let open = run_fopen(&[
1068 Value::from(path.to_string_lossy().to_string()),
1069 Value::from("w+b"),
1070 ])
1071 .expect("fopen");
1072 let fid = open.as_open().unwrap().fid as i32;
1073
1074 let tensor = Tensor::new(vec![10.0, 20.0, 30.0], vec![3, 1]).unwrap();
1075 let args = vec![Value::from("uint8"), Value::Num(1.0)];
1076 let eval =
1077 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
1078 assert_eq!(eval.count(), 3);
1079
1080 run_fclose(&[Value::Num(fid as f64)]).unwrap();
1081
1082 let bytes = test_support::fs::read(&path).expect("read");
1083 assert_eq!(bytes, vec![10u8, 0, 20, 0, 30]);
1084 test_support::fs::remove_file(path).unwrap();
1085 }
1086
1087 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1088 #[test]
1089 fn fwrite_gpu_tensor_gathers_before_write() {
1090 let _guard = registry_guard();
1091 registry::reset_for_tests();
1092 let path = unique_path("fwrite_gpu");
1093
1094 test_support::with_test_provider(|provider| {
1095 registry::reset_for_tests();
1096 let open = run_fopen(&[
1097 Value::from(path.to_string_lossy().to_string()),
1098 Value::from("w+b"),
1099 ])
1100 .expect("fopen");
1101 let fid = open.as_open().unwrap().fid as i32;
1102
1103 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
1104 let view = HostTensorView {
1105 data: &tensor.data,
1106 shape: &tensor.shape,
1107 };
1108 let handle = provider.upload(&view).expect("upload");
1109 let args = vec![Value::from("uint16")];
1110 let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
1111 .expect("fwrite");
1112 assert_eq!(eval.count(), 4);
1113
1114 run_fclose(&[Value::Num(fid as f64)]).unwrap();
1115 });
1116
1117 let mut file = File::open(&path).expect("open");
1118 let mut bytes = Vec::new();
1119 file.read_to_end(&mut bytes).expect("read");
1120 assert_eq!(bytes.len(), 8);
1121 let mut decoded = Vec::new();
1122 for chunk in bytes.chunks_exact(2) {
1123 let value = if cfg!(target_endian = "little") {
1124 u16::from_le_bytes([chunk[0], chunk[1]])
1125 } else {
1126 u16::from_be_bytes([chunk[0], chunk[1]])
1127 };
1128 decoded.push(value);
1129 }
1130 assert_eq!(decoded, vec![1u16, 2, 3, 4]);
1131 test_support::fs::remove_file(path).unwrap();
1132 }
1133
1134 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1135 #[test]
1136 fn fwrite_invalid_precision_errors() {
1137 let _guard = registry_guard();
1138 registry::reset_for_tests();
1139 let path = unique_path("fwrite_invalid_precision");
1140 let open = run_fopen(&[
1141 Value::from(path.to_string_lossy().to_string()),
1142 Value::from("w+b"),
1143 ])
1144 .expect("fopen");
1145 let fid = open.as_open().unwrap().fid as i32;
1146
1147 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1148 let args = vec![Value::from("bogus-class")];
1149 let err = unwrap_error_message(
1150 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
1151 );
1152 assert!(err.contains("unsupported precision"));
1153 let _ = run_fclose(&[Value::Num(fid as f64)]);
1154 test_support::fs::remove_file(path).unwrap();
1155 }
1156
1157 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1158 #[test]
1159 fn fwrite_negative_skip_errors() {
1160 let _guard = registry_guard();
1161 registry::reset_for_tests();
1162 let path = unique_path("fwrite_negative_skip");
1163 let open = run_fopen(&[
1164 Value::from(path.to_string_lossy().to_string()),
1165 Value::from("w+b"),
1166 ])
1167 .expect("fopen");
1168 let fid = open.as_open().unwrap().fid as i32;
1169
1170 let tensor = Tensor::new(vec![10.0], vec![1, 1]).unwrap();
1171 let args = vec![Value::from("uint8"), Value::Num(-1.0)];
1172 let err = unwrap_error_message(
1173 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
1174 );
1175 assert!(err.contains("skip value must be non-negative"));
1176 let _ = run_fclose(&[Value::Num(fid as f64)]);
1177 test_support::fs::remove_file(path).unwrap();
1178 }
1179
1180 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1181 #[test]
1182 #[cfg(feature = "wgpu")]
1183 fn fwrite_wgpu_tensor_roundtrip() {
1184 let _guard = registry_guard();
1185 registry::reset_for_tests();
1186 let path = unique_path("fwrite_wgpu_roundtrip");
1187 let open = run_fopen(&[
1188 Value::from(path.to_string_lossy().to_string()),
1189 Value::from("w+b"),
1190 ])
1191 .expect("fopen");
1192 let fid = open.as_open().unwrap().fid as i32;
1193
1194 let provider = provider::register_wgpu_provider(provider::WgpuProviderOptions::default())
1195 .expect("wgpu provider");
1196
1197 let tensor = Tensor::new(vec![0.5, -1.25, 3.75], vec![3, 1]).unwrap();
1198 let expected = tensor.data.clone();
1199 let view = HostTensorView {
1200 data: &tensor.data,
1201 shape: &tensor.shape,
1202 };
1203 let handle = provider.upload(&view).expect("upload to gpu");
1204 let args = vec![Value::from("double")];
1205 let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
1206 .expect("fwrite");
1207 assert_eq!(eval.count(), 3);
1208
1209 run_fclose(&[Value::Num(fid as f64)]).unwrap();
1210
1211 let mut file = File::open(&path).expect("open");
1212 let mut bytes = Vec::new();
1213 file.read_to_end(&mut bytes).expect("read");
1214 assert_eq!(bytes.len(), 24);
1215 for (chunk, expected_value) in bytes.chunks_exact(8).zip(expected.iter()) {
1216 let mut buf = [0u8; 8];
1217 buf.copy_from_slice(chunk);
1218 let value = if cfg!(target_endian = "little") {
1219 f64::from_le_bytes(buf)
1220 } else {
1221 f64::from_be_bytes(buf)
1222 };
1223 assert!(
1224 (value - expected_value).abs() < 1e-12,
1225 "mismatch: {} vs {}",
1226 value,
1227 expected_value
1228 );
1229 }
1230 test_support::fs::remove_file(path).unwrap();
1231 }
1232
1233 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1234 #[test]
1235 fn fwrite_invalid_identifier_errors() {
1236 let _guard = registry_guard();
1237 registry::reset_for_tests();
1238 let err = unwrap_error_message(
1239 run_evaluate(&Value::Num(-1.0), &Value::Num(1.0), &Vec::new()).unwrap_err(),
1240 );
1241 assert!(err.contains("file identifier must be non-negative"));
1242 }
1243
1244 fn unique_path(prefix: &str) -> PathBuf {
1245 let now = system_time_now()
1246 .duration_since(UNIX_EPOCH)
1247 .expect("time went backwards");
1248 let filename = format!(
1249 "runmat_{prefix}_{}_{}.tmp",
1250 now.as_secs(),
1251 now.subsec_nanos()
1252 );
1253 std::env::temp_dir().join(filename)
1254 }
1255}