Skip to main content

runmat_runtime/builtins/io/filetext/
fwrite.rs

1//! MATLAB-compatible `fwrite` builtin for RunMat.
2use std::io::{Seek, SeekFrom, Write};
3
4use runmat_builtins::{
5    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7    CharArray, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::io::filetext::registry;
16use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
17use runmat_filesystem::File;
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21    name: "fwrite",
22    op_kind: GpuOpKind::Custom("file-io-write"),
23    supported_precisions: &[],
24    broadcast: BroadcastSemantics::None,
25    provider_hooks: &[],
26    constant_strategy: ConstantStrategy::InlineLiteral,
27    residency: ResidencyPolicy::GatherImmediately,
28    nan_mode: ReductionNaN::Include,
29    two_pass_threshold: None,
30    workgroup_size: None,
31    accepts_nan_mode: false,
32    notes: "Host-only binary file I/O; GPU arguments are gathered to the CPU prior to writing.",
33};
34
35#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
36pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
37    name: "fwrite",
38    shape: ShapeRequirements::Any,
39    constant_strategy: ConstantStrategy::InlineLiteral,
40    elementwise: None,
41    reduction: None,
42    emits_nan: false,
43    notes: "File I/O is never fused; metadata recorded for completeness.",
44};
45
46const BUILTIN_NAME: &str = "fwrite";
47
48const FWRITE_OUTPUT_COUNT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
49    name: "count",
50    ty: BuiltinParamType::NumericScalar,
51    arity: BuiltinParamArity::Required,
52    default: None,
53    description: "Number of elements successfully written.",
54}];
55const FWRITE_INPUTS_FID_DATA: [BuiltinParamDescriptor; 2] = [
56    BuiltinParamDescriptor {
57        name: "fid",
58        ty: BuiltinParamType::NumericScalar,
59        arity: BuiltinParamArity::Required,
60        default: None,
61        description: "File identifier opened by fopen.",
62    },
63    BuiltinParamDescriptor {
64        name: "data",
65        ty: BuiltinParamType::Any,
66        arity: BuiltinParamArity::Required,
67        default: None,
68        description: "Numeric/logical/text payload to write.",
69    },
70];
71const FWRITE_INPUTS_FID_DATA_PRECISION: [BuiltinParamDescriptor; 3] = [
72    BuiltinParamDescriptor {
73        name: "fid",
74        ty: BuiltinParamType::NumericScalar,
75        arity: BuiltinParamArity::Required,
76        default: None,
77        description: "File identifier opened by fopen.",
78    },
79    BuiltinParamDescriptor {
80        name: "data",
81        ty: BuiltinParamType::Any,
82        arity: BuiltinParamArity::Required,
83        default: None,
84        description: "Numeric/logical/text payload to write.",
85    },
86    BuiltinParamDescriptor {
87        name: "precision",
88        ty: BuiltinParamType::StringScalar,
89        arity: BuiltinParamArity::Optional,
90        default: Some("\"uint8\""),
91        description: "Write precision label (for example \"uint8\", \"double\").",
92    },
93];
94const FWRITE_INPUTS_FID_DATA_PRECISION_SKIP: [BuiltinParamDescriptor; 4] = [
95    BuiltinParamDescriptor {
96        name: "fid",
97        ty: BuiltinParamType::NumericScalar,
98        arity: BuiltinParamArity::Required,
99        default: None,
100        description: "File identifier opened by fopen.",
101    },
102    BuiltinParamDescriptor {
103        name: "data",
104        ty: BuiltinParamType::Any,
105        arity: BuiltinParamArity::Required,
106        default: None,
107        description: "Numeric/logical/text payload to write.",
108    },
109    BuiltinParamDescriptor {
110        name: "precision",
111        ty: BuiltinParamType::StringScalar,
112        arity: BuiltinParamArity::Optional,
113        default: Some("\"uint8\""),
114        description: "Write precision label (for example \"uint8\", \"double\").",
115    },
116    BuiltinParamDescriptor {
117        name: "skip",
118        ty: BuiltinParamType::NumericScalar,
119        arity: BuiltinParamArity::Optional,
120        default: Some("0"),
121        description: "Bytes skipped after each element written.",
122    },
123];
124const FWRITE_INPUTS_FID_DATA_PRECISION_MACHINEFMT: [BuiltinParamDescriptor; 4] = [
125    BuiltinParamDescriptor {
126        name: "fid",
127        ty: BuiltinParamType::NumericScalar,
128        arity: BuiltinParamArity::Required,
129        default: None,
130        description: "File identifier opened by fopen.",
131    },
132    BuiltinParamDescriptor {
133        name: "data",
134        ty: BuiltinParamType::Any,
135        arity: BuiltinParamArity::Required,
136        default: None,
137        description: "Numeric/logical/text payload to write.",
138    },
139    BuiltinParamDescriptor {
140        name: "precision",
141        ty: BuiltinParamType::StringScalar,
142        arity: BuiltinParamArity::Optional,
143        default: Some("\"uint8\""),
144        description: "Write precision label (for example \"uint8\", \"double\").",
145    },
146    BuiltinParamDescriptor {
147        name: "machinefmt",
148        ty: BuiltinParamType::StringScalar,
149        arity: BuiltinParamArity::Optional,
150        default: Some("\"native\""),
151        description: "Machine format label (native/little-endian/big-endian aliases).",
152    },
153];
154const FWRITE_INPUTS_FID_DATA_PRECISION_SKIP_MACHINEFMT: [BuiltinParamDescriptor; 5] = [
155    BuiltinParamDescriptor {
156        name: "fid",
157        ty: BuiltinParamType::NumericScalar,
158        arity: BuiltinParamArity::Required,
159        default: None,
160        description: "File identifier opened by fopen.",
161    },
162    BuiltinParamDescriptor {
163        name: "data",
164        ty: BuiltinParamType::Any,
165        arity: BuiltinParamArity::Required,
166        default: None,
167        description: "Numeric/logical/text payload to write.",
168    },
169    BuiltinParamDescriptor {
170        name: "precision",
171        ty: BuiltinParamType::StringScalar,
172        arity: BuiltinParamArity::Optional,
173        default: Some("\"uint8\""),
174        description: "Write precision label (for example \"uint8\", \"double\").",
175    },
176    BuiltinParamDescriptor {
177        name: "skip",
178        ty: BuiltinParamType::NumericScalar,
179        arity: BuiltinParamArity::Optional,
180        default: Some("0"),
181        description: "Bytes skipped after each element written.",
182    },
183    BuiltinParamDescriptor {
184        name: "machinefmt",
185        ty: BuiltinParamType::StringScalar,
186        arity: BuiltinParamArity::Optional,
187        default: Some("\"native\""),
188        description: "Machine format label (native/little-endian/big-endian aliases).",
189    },
190];
191const FWRITE_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
192    BuiltinSignatureDescriptor {
193        label: "count = fwrite(fid, data)",
194        inputs: &FWRITE_INPUTS_FID_DATA,
195        outputs: &FWRITE_OUTPUT_COUNT,
196    },
197    BuiltinSignatureDescriptor {
198        label: "count = fwrite(fid, data, precision)",
199        inputs: &FWRITE_INPUTS_FID_DATA_PRECISION,
200        outputs: &FWRITE_OUTPUT_COUNT,
201    },
202    BuiltinSignatureDescriptor {
203        label: "count = fwrite(fid, data, precision, skip)",
204        inputs: &FWRITE_INPUTS_FID_DATA_PRECISION_SKIP,
205        outputs: &FWRITE_OUTPUT_COUNT,
206    },
207    BuiltinSignatureDescriptor {
208        label: "count = fwrite(fid, data, precision, machinefmt)",
209        inputs: &FWRITE_INPUTS_FID_DATA_PRECISION_MACHINEFMT,
210        outputs: &FWRITE_OUTPUT_COUNT,
211    },
212    BuiltinSignatureDescriptor {
213        label: "count = fwrite(fid, data, precision, skip, machinefmt)",
214        inputs: &FWRITE_INPUTS_FID_DATA_PRECISION_SKIP_MACHINEFMT,
215        outputs: &FWRITE_OUTPUT_COUNT,
216    },
217];
218
219const FWRITE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
220    code: "RM.FWRITE.INVALID_INPUT",
221    identifier: Some("RunMat:fwrite:InvalidInput"),
222    when: "Identifier, payload, or argument cardinality/type constraints are violated.",
223    message: "fwrite: invalid input arguments",
224};
225const FWRITE_ERROR_INVALID_IDENTIFIER: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
226    code: "RM.FWRITE.INVALID_IDENTIFIER",
227    identifier: Some("RunMat:fwrite:InvalidIdentifier"),
228    when: "Identifier does not refer to a writable open file.",
229    message: "fwrite: invalid file identifier. Use fopen to generate a valid file ID.",
230};
231const FWRITE_ERROR_INVALID_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
232    code: "RM.FWRITE.INVALID_OPTION",
233    identifier: Some("RunMat:fwrite:InvalidOption"),
234    when: "Precision, skip, or machine format options are invalid.",
235    message: "fwrite: invalid option configuration",
236};
237const FWRITE_ERROR_IO: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
238    code: "RM.FWRITE.IO",
239    identifier: Some("RunMat:fwrite:IoFailure"),
240    when: "Write/seek operation fails.",
241    message: "fwrite: file write failed",
242};
243const FWRITE_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
244    code: "RM.FWRITE.INTERNAL",
245    identifier: None,
246    when: "Internal runtime control-flow conversion fails.",
247    message: "fwrite: internal error",
248};
249const FWRITE_ERRORS: [BuiltinErrorDescriptor; 5] = [
250    FWRITE_ERROR_INVALID_INPUT,
251    FWRITE_ERROR_INVALID_IDENTIFIER,
252    FWRITE_ERROR_INVALID_OPTION,
253    FWRITE_ERROR_IO,
254    FWRITE_ERROR_INTERNAL,
255];
256pub const FWRITE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
257    signatures: &FWRITE_SIGNATURES,
258    output_mode: BuiltinOutputMode::Fixed,
259    completion_policy: BuiltinCompletionPolicy::Public,
260    errors: &FWRITE_ERRORS,
261};
262
263fn fwrite_error_with_detail(
264    error: &'static BuiltinErrorDescriptor,
265    detail: impl AsRef<str>,
266) -> RuntimeError {
267    let detail = detail.as_ref();
268    let detail = detail.strip_prefix("fwrite: ").unwrap_or(detail);
269    fwrite_error_with_message(format!("{}: {}", error.message, detail), error)
270}
271
272fn fwrite_error_with_message(
273    message: impl Into<String>,
274    error: &'static BuiltinErrorDescriptor,
275) -> RuntimeError {
276    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
277    if let Some(identifier) = error.identifier {
278        builder = builder.with_identifier(identifier);
279    }
280    builder.build()
281}
282
283fn map_control_flow(err: RuntimeError) -> RuntimeError {
284    let mut builder = build_runtime_error(format!("{BUILTIN_NAME}: {}", err.message()))
285        .with_builtin(BUILTIN_NAME)
286        .with_source(err);
287    if let Some(identifier) = FWRITE_ERROR_INTERNAL.identifier {
288        builder = builder.with_identifier(identifier);
289    }
290    builder.build()
291}
292
293fn map_string_result<T>(
294    result: Result<T, String>,
295    error: &'static BuiltinErrorDescriptor,
296) -> BuiltinResult<T> {
297    result.map_err(|detail| fwrite_error_with_detail(error, detail))
298}
299
300#[runtime_builtin(
301    name = "fwrite",
302    category = "io/filetext",
303    summary = "Write binary data to file identifiers.",
304    keywords = "fwrite,file,io,binary,precision",
305    accel = "cpu",
306    type_resolver(crate::builtins::io::type_resolvers::fwrite_type),
307    descriptor(crate::builtins::io::filetext::fwrite::FWRITE_DESCRIPTOR),
308    builtin_path = "crate::builtins::io::filetext::fwrite"
309)]
310async fn fwrite_builtin(fid: Value, data: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
311    let eval = evaluate(&fid, &data, &rest).await?;
312    Ok(Value::Num(eval.count as f64))
313}
314
315/// Result of an `fwrite` evaluation.
316#[derive(Debug, Clone)]
317pub struct FwriteEval {
318    count: usize,
319}
320
321impl FwriteEval {
322    fn new(count: usize) -> Self {
323        Self { count }
324    }
325
326    /// Number of elements successfully written.
327    pub fn count(&self) -> usize {
328        self.count
329    }
330}
331
332/// Evaluate the `fwrite` builtin without invoking the runtime dispatcher.
333pub async fn evaluate(
334    fid_value: &Value,
335    data_value: &Value,
336    rest: &[Value],
337) -> BuiltinResult<FwriteEval> {
338    let fid_host = gather_value(fid_value).await?;
339    let fid = map_string_result(parse_fid(&fid_host), &FWRITE_ERROR_INVALID_INPUT)?;
340    if fid < 0 {
341        return Err(fwrite_error_with_detail(
342            &FWRITE_ERROR_INVALID_INPUT,
343            "file identifier must be non-negative",
344        ));
345    }
346    if fid < 3 {
347        return Err(fwrite_error_with_detail(
348            &FWRITE_ERROR_INVALID_INPUT,
349            "standard input/output identifiers are not supported yet",
350        ));
351    }
352
353    let info = registry::info_for(fid).ok_or_else(|| {
354        fwrite_error_with_message(
355            FWRITE_ERROR_INVALID_IDENTIFIER.message,
356            &FWRITE_ERROR_INVALID_IDENTIFIER,
357        )
358    })?;
359    let handle = registry::take_handle(fid).ok_or_else(|| {
360        fwrite_error_with_message(
361            FWRITE_ERROR_INVALID_IDENTIFIER.message,
362            &FWRITE_ERROR_INVALID_IDENTIFIER,
363        )
364    })?;
365
366    let data_host = gather_value(data_value).await?;
367    let rest_host = gather_args(rest).await?;
368    let (precision_arg, skip_arg, machine_arg) =
369        map_string_result(classify_arguments(&rest_host), &FWRITE_ERROR_INVALID_INPUT)?;
370
371    let precision_spec =
372        map_string_result(parse_precision(precision_arg), &FWRITE_ERROR_INVALID_OPTION)?;
373    let skip_bytes = map_string_result(parse_skip(skip_arg), &FWRITE_ERROR_INVALID_OPTION)?;
374    let machine_format = map_string_result(
375        parse_machine_format(machine_arg, &info.machinefmt),
376        &FWRITE_ERROR_INVALID_OPTION,
377    )?;
378
379    let mut guard = handle.lock().map_err(|_| {
380        fwrite_error_with_detail(
381            &FWRITE_ERROR_INTERNAL,
382            "failed to lock file handle (poisoned mutex)",
383        )
384    })?;
385    let file = guard.as_mut().ok_or_else(|| {
386        fwrite_error_with_message(
387            FWRITE_ERROR_INVALID_IDENTIFIER.message,
388            &FWRITE_ERROR_INVALID_IDENTIFIER,
389        )
390    })?;
391
392    let elements = map_string_result(flatten_elements(&data_host), &FWRITE_ERROR_INVALID_INPUT)?;
393    let count = map_string_result(
394        write_elements(file, &elements, precision_spec, skip_bytes, machine_format),
395        &FWRITE_ERROR_IO,
396    )?;
397    Ok(FwriteEval::new(count))
398}
399
400async fn gather_value(value: &Value) -> BuiltinResult<Value> {
401    gather_if_needed_async(value)
402        .await
403        .map_err(map_control_flow)
404}
405
406async fn gather_args(args: &[Value]) -> BuiltinResult<Vec<Value>> {
407    let mut gathered = Vec::with_capacity(args.len());
408    for value in args {
409        gathered.push(
410            gather_if_needed_async(value)
411                .await
412                .map_err(map_control_flow)?,
413        );
414    }
415    Ok(gathered)
416}
417
418fn parse_fid(value: &Value) -> Result<i32, String> {
419    let scalar = match value {
420        Value::Num(n) => *n,
421        Value::Int(int) => int.to_f64(),
422        _ => return Err("fwrite: file identifier must be numeric".to_string()),
423    };
424    if !scalar.is_finite() {
425        return Err("fwrite: file identifier must be finite".to_string());
426    }
427    if scalar.fract().abs() > f64::EPSILON {
428        return Err("fwrite: file identifier must be an integer".to_string());
429    }
430    Ok(scalar as i32)
431}
432
433type FwriteArgs<'a> = (Option<&'a Value>, Option<&'a Value>, Option<&'a Value>);
434
435fn classify_arguments(args: &[Value]) -> Result<FwriteArgs<'_>, String> {
436    match args.len() {
437        0 => Ok((None, None, None)),
438        1 => {
439            if is_string_like(&args[0]) {
440                Ok((Some(&args[0]), None, None))
441            } else {
442                Err(
443                    "fwrite: precision argument must be a string scalar or character vector"
444                        .to_string(),
445                )
446            }
447        }
448        2 => {
449            if !is_string_like(&args[0]) {
450                return Err(
451                    "fwrite: precision argument must be a string scalar or character vector"
452                        .to_string(),
453                );
454            }
455            if is_numeric_like(&args[1]) {
456                Ok((Some(&args[0]), Some(&args[1]), None))
457            } else if is_string_like(&args[1]) {
458                Ok((Some(&args[0]), None, Some(&args[1])))
459            } else {
460                Err("fwrite: invalid argument combination (expected numeric skip or machine format string)".to_string())
461            }
462        }
463        3 => {
464            if !is_string_like(&args[0]) || !is_numeric_like(&args[1]) || !is_string_like(&args[2])
465            {
466                return Err("fwrite: expected arguments (precision, skip, machinefmt)".to_string());
467            }
468            Ok((Some(&args[0]), Some(&args[1]), Some(&args[2])))
469        }
470        _ => Err("fwrite: too many input arguments".to_string()),
471    }
472}
473
474fn is_string_like(value: &Value) -> bool {
475    match value {
476        Value::String(_) => true,
477        Value::CharArray(ca) => ca.rows == 1,
478        Value::StringArray(sa) => sa.data.len() == 1,
479        _ => false,
480    }
481}
482
483fn is_numeric_like(value: &Value) -> bool {
484    match value {
485        Value::Num(_) | Value::Int(_) | Value::Bool(_) => true,
486        Value::Tensor(t) => t.data.len() == 1,
487        Value::LogicalArray(la) => la.data.len() == 1,
488        _ => false,
489    }
490}
491
492#[derive(Clone, Copy, Debug)]
493struct WriteSpec {
494    input: InputType,
495}
496
497impl WriteSpec {
498    fn default() -> Self {
499        Self {
500            input: InputType::UInt8,
501        }
502    }
503}
504
505fn parse_precision(arg: Option<&Value>) -> Result<WriteSpec, String> {
506    match arg {
507        None => Ok(WriteSpec::default()),
508        Some(value) => {
509            let text = scalar_string(
510                value,
511                "fwrite: precision argument must be a string scalar or character vector",
512            )?;
513            parse_precision_string(&text)
514        }
515    }
516}
517
518fn parse_precision_string(raw: &str) -> Result<WriteSpec, String> {
519    let trimmed = raw.trim();
520    if trimmed.is_empty() {
521        return Err("fwrite: precision argument must not be empty".to_string());
522    }
523    let lower = trimmed.to_ascii_lowercase();
524    if let Some((lhs, rhs)) = lower.split_once("=>") {
525        let lhs = lhs.trim();
526        let rhs = rhs.trim();
527        let input = parse_input_label(lhs)?;
528        let output = parse_input_label(rhs)?;
529        if input != output {
530            return Err(
531                "fwrite: differing input/output precisions are not implemented yet".to_string(),
532            );
533        }
534        Ok(WriteSpec { input })
535    } else {
536        parse_input_label(lower.trim()).map(|input| WriteSpec { input })
537    }
538}
539
540fn parse_skip(arg: Option<&Value>) -> Result<usize, String> {
541    match arg {
542        None => Ok(0),
543        Some(value) => {
544            let scalar = numeric_scalar(value, "fwrite: skip must be numeric")?;
545            if !scalar.is_finite() {
546                return Err("fwrite: skip value must be finite".to_string());
547            }
548            if scalar < 0.0 {
549                return Err("fwrite: skip value must be non-negative".to_string());
550            }
551            let rounded = scalar.round();
552            if (rounded - scalar).abs() > f64::EPSILON {
553                return Err("fwrite: skip value must be an integer".to_string());
554            }
555            if rounded > i64::MAX as f64 {
556                return Err("fwrite: skip value is too large".to_string());
557            }
558            Ok(rounded as usize)
559        }
560    }
561}
562
563#[derive(Clone, Copy, Debug)]
564enum MachineFormat {
565    Native,
566    LittleEndian,
567    BigEndian,
568}
569
570impl MachineFormat {
571    fn to_endianness(self) -> Endianness {
572        match self {
573            MachineFormat::Native => {
574                if cfg!(target_endian = "little") {
575                    Endianness::Little
576                } else {
577                    Endianness::Big
578                }
579            }
580            MachineFormat::LittleEndian => Endianness::Little,
581            MachineFormat::BigEndian => Endianness::Big,
582        }
583    }
584}
585
586#[derive(Clone, Copy, Debug)]
587enum Endianness {
588    Little,
589    Big,
590}
591
592fn parse_machine_format(arg: Option<&Value>, default_label: &str) -> Result<MachineFormat, String> {
593    match arg {
594        Some(value) => {
595            let text = scalar_string(
596                value,
597                "fwrite: machine format must be a string scalar or character vector",
598            )?;
599            machine_format_from_label(&text)
600        }
601        None => machine_format_from_label(default_label),
602    }
603}
604
605fn machine_format_from_label(label: &str) -> Result<MachineFormat, String> {
606    let trimmed = label.trim();
607    if trimmed.is_empty() {
608        return Err("fwrite: machine format must not be empty".to_string());
609    }
610    let lower = trimmed.to_ascii_lowercase();
611    let collapsed: String = lower
612        .chars()
613        .filter(|c| !matches!(c, '-' | '_' | ' '))
614        .collect();
615    if matches!(collapsed.as_str(), "native" | "n" | "system" | "default") {
616        return Ok(MachineFormat::Native);
617    }
618    if matches!(
619        collapsed.as_str(),
620        "l" | "le" | "littleendian" | "pc" | "intel"
621    ) {
622        return Ok(MachineFormat::LittleEndian);
623    }
624    if matches!(
625        collapsed.as_str(),
626        "b" | "be" | "bigendian" | "mac" | "motorola"
627    ) {
628        return Ok(MachineFormat::BigEndian);
629    }
630    if lower.starts_with("ieee-le") {
631        return Ok(MachineFormat::LittleEndian);
632    }
633    if lower.starts_with("ieee-be") {
634        return Ok(MachineFormat::BigEndian);
635    }
636    Err(format!("fwrite: unsupported machine format '{trimmed}'"))
637}
638
639fn scalar_string(value: &Value, err: &str) -> Result<String, String> {
640    match value {
641        Value::String(s) => Ok(s.clone()),
642        Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
643        Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
644        _ => Err(err.to_string()),
645    }
646}
647
648fn numeric_scalar(value: &Value, err: &str) -> Result<f64, String> {
649    match value {
650        Value::Num(n) => Ok(*n),
651        Value::Int(int) => Ok(int.to_f64()),
652        Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
653        Value::Tensor(t) if t.data.len() == 1 => Ok(t.data[0]),
654        Value::LogicalArray(la) if la.data.len() == 1 => {
655            Ok(if la.data[0] != 0 { 1.0 } else { 0.0 })
656        }
657        _ => Err(err.to_string()),
658    }
659}
660
661fn flatten_elements(value: &Value) -> Result<Vec<f64>, String> {
662    match value {
663        Value::Tensor(tensor) => Ok(tensor.data.clone()),
664        Value::Num(n) => Ok(vec![*n]),
665        Value::Int(int) => Ok(vec![int.to_f64()]),
666        Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
667        Value::LogicalArray(array) => Ok(array
668            .data
669            .iter()
670            .map(|bit| if *bit != 0 { 1.0 } else { 0.0 })
671            .collect()),
672        Value::CharArray(ca) => Ok(flatten_char_array(ca)),
673        Value::String(text) => Ok(text.chars().map(|ch| ch as u32 as f64).collect()),
674        Value::StringArray(sa) => Ok(flatten_string_array(sa)),
675        Value::GpuTensor(_) => Err("fwrite: expected host tensor data after gathering".to_string()),
676        Value::Complex(_, _) | Value::ComplexTensor(_) => {
677            Err("fwrite: complex values are not supported yet".to_string())
678        }
679        _ => Err(format!("fwrite: unsupported data type {:?}", value)),
680    }
681}
682
683fn flatten_char_array(ca: &CharArray) -> Vec<f64> {
684    let mut values = Vec::with_capacity(ca.rows.saturating_mul(ca.cols));
685    for c in 0..ca.cols {
686        for r in 0..ca.rows {
687            let idx = r * ca.cols + c;
688            values.push(ca.data[idx] as u32 as f64);
689        }
690    }
691    values
692}
693
694fn flatten_string_array(sa: &runmat_builtins::StringArray) -> Vec<f64> {
695    if sa.data.is_empty() {
696        return Vec::new();
697    }
698    let mut values = Vec::new();
699    for (idx, text) in sa.data.iter().enumerate() {
700        if idx > 0 {
701            values.push('\n' as u32 as f64);
702        }
703        values.extend(text.chars().map(|ch| ch as u32 as f64));
704    }
705    values
706}
707
708fn write_elements(
709    file: &mut File,
710    values: &[f64],
711    spec: WriteSpec,
712    skip: usize,
713    machine: MachineFormat,
714) -> Result<usize, String> {
715    let endianness = machine.to_endianness();
716    let skip_offset = skip as i64;
717    for &value in values {
718        match spec.input {
719            InputType::UInt8 => {
720                let byte = to_u8(value);
721                write_bytes(file, &[byte])?;
722            }
723            InputType::Int8 => {
724                let byte = to_i8(value) as u8;
725                write_bytes(file, &[byte])?;
726            }
727            InputType::UInt16 => {
728                let bytes = encode_u16(value, endianness);
729                write_bytes(file, &bytes)?;
730            }
731            InputType::Int16 => {
732                let bytes = encode_i16(value, endianness);
733                write_bytes(file, &bytes)?;
734            }
735            InputType::UInt32 => {
736                let bytes = encode_u32(value, endianness);
737                write_bytes(file, &bytes)?;
738            }
739            InputType::Int32 => {
740                let bytes = encode_i32(value, endianness);
741                write_bytes(file, &bytes)?;
742            }
743            InputType::UInt64 => {
744                let bytes = encode_u64(value, endianness);
745                write_bytes(file, &bytes)?;
746            }
747            InputType::Int64 => {
748                let bytes = encode_i64(value, endianness);
749                write_bytes(file, &bytes)?;
750            }
751            InputType::Float32 => {
752                let bytes = encode_f32(value, endianness);
753                write_bytes(file, &bytes)?;
754            }
755            InputType::Float64 => {
756                let bytes = encode_f64(value, endianness);
757                write_bytes(file, &bytes)?;
758            }
759        }
760
761        if skip > 0 {
762            file.seek(SeekFrom::Current(skip_offset))
763                .map_err(|err| format!("fwrite: failed to seek while applying skip ({err})"))?;
764        }
765    }
766    Ok(values.len())
767}
768
769fn write_bytes(file: &mut File, bytes: &[u8]) -> Result<(), String> {
770    file.write_all(bytes)
771        .map_err(|err| format!("fwrite: failed to write to file ({err})"))
772}
773
774fn to_u8(value: f64) -> u8 {
775    if !value.is_finite() {
776        return if value.is_sign_negative() { 0 } else { u8::MAX };
777    }
778    let mut rounded = value.round();
779    if rounded.is_nan() {
780        return 0;
781    }
782    if rounded < 0.0 {
783        rounded = 0.0;
784    }
785    if rounded > u8::MAX as f64 {
786        rounded = u8::MAX as f64;
787    }
788    rounded as u8
789}
790
791fn to_i8(value: f64) -> i8 {
792    saturating_round(value, i8::MIN as f64, i8::MAX as f64) as i8
793}
794
795fn encode_u16(value: f64, endianness: Endianness) -> [u8; 2] {
796    let rounded = saturating_round(value, 0.0, u16::MAX as f64) as u16;
797    match endianness {
798        Endianness::Little => rounded.to_le_bytes(),
799        Endianness::Big => rounded.to_be_bytes(),
800    }
801}
802
803fn encode_i16(value: f64, endianness: Endianness) -> [u8; 2] {
804    let rounded = saturating_round(value, i16::MIN as f64, i16::MAX as f64) as i16;
805    match endianness {
806        Endianness::Little => rounded.to_le_bytes(),
807        Endianness::Big => rounded.to_be_bytes(),
808    }
809}
810
811fn encode_u32(value: f64, endianness: Endianness) -> [u8; 4] {
812    let rounded = saturating_round(value, 0.0, u32::MAX as f64) as u32;
813    match endianness {
814        Endianness::Little => rounded.to_le_bytes(),
815        Endianness::Big => rounded.to_be_bytes(),
816    }
817}
818
819fn encode_i32(value: f64, endianness: Endianness) -> [u8; 4] {
820    let rounded = saturating_round(value, i32::MIN as f64, i32::MAX as f64) as i32;
821    match endianness {
822        Endianness::Little => rounded.to_le_bytes(),
823        Endianness::Big => rounded.to_be_bytes(),
824    }
825}
826
827fn encode_u64(value: f64, endianness: Endianness) -> [u8; 8] {
828    let rounded = saturating_round(value, 0.0, u64::MAX as f64);
829    let as_u64 = if rounded.is_finite() {
830        rounded as u64
831    } else if rounded.is_sign_negative() {
832        0
833    } else {
834        u64::MAX
835    };
836    match endianness {
837        Endianness::Little => as_u64.to_le_bytes(),
838        Endianness::Big => as_u64.to_be_bytes(),
839    }
840}
841
842fn encode_i64(value: f64, endianness: Endianness) -> [u8; 8] {
843    let rounded = saturating_round(value, i64::MIN as f64, i64::MAX as f64);
844    let as_i64 = if rounded.is_finite() {
845        rounded as i64
846    } else if rounded.is_sign_negative() {
847        i64::MIN
848    } else {
849        i64::MAX
850    };
851    match endianness {
852        Endianness::Little => as_i64.to_le_bytes(),
853        Endianness::Big => as_i64.to_be_bytes(),
854    }
855}
856
857fn encode_f32(value: f64, endianness: Endianness) -> [u8; 4] {
858    let as_f32 = value as f32;
859    let bits = as_f32.to_bits();
860    match endianness {
861        Endianness::Little => bits.to_le_bytes(),
862        Endianness::Big => bits.to_be_bytes(),
863    }
864}
865
866fn encode_f64(value: f64, endianness: Endianness) -> [u8; 8] {
867    let bits = value.to_bits();
868    match endianness {
869        Endianness::Little => bits.to_le_bytes(),
870        Endianness::Big => bits.to_be_bytes(),
871    }
872}
873
874fn saturating_round(value: f64, min: f64, max: f64) -> f64 {
875    if !value.is_finite() {
876        return if value.is_sign_negative() { min } else { max };
877    }
878    let mut rounded = value.round();
879    if rounded.is_nan() {
880        return 0.0;
881    }
882    if rounded < min {
883        rounded = min;
884    }
885    if rounded > max {
886        rounded = max;
887    }
888    rounded
889}
890
891#[derive(Clone, Copy, Debug, PartialEq, Eq)]
892enum InputType {
893    UInt8,
894    Int8,
895    UInt16,
896    Int16,
897    UInt32,
898    Int32,
899    UInt64,
900    Int64,
901    Float32,
902    Float64,
903}
904
905fn parse_input_label(label: &str) -> Result<InputType, String> {
906    match label {
907        "double" | "float64" | "real*8" => Ok(InputType::Float64),
908        "single" | "float32" | "real*4" => Ok(InputType::Float32),
909        "int8" | "schar" | "integer*1" => Ok(InputType::Int8),
910        "uint8" | "uchar" | "unsignedchar" | "char" | "byte" => Ok(InputType::UInt8),
911        "int16" | "short" | "integer*2" => Ok(InputType::Int16),
912        "uint16" | "ushort" | "unsignedshort" => Ok(InputType::UInt16),
913        "int32" | "integer*4" | "long" => Ok(InputType::Int32),
914        "uint32" | "unsignedint" | "unsignedlong" => Ok(InputType::UInt32),
915        "int64" | "integer*8" | "longlong" => Ok(InputType::Int64),
916        "uint64" | "unsignedlonglong" => Ok(InputType::UInt64),
917        other => Err(format!("fwrite: unsupported precision '{other}'")),
918    }
919}
920
921#[cfg(test)]
922pub(crate) mod tests {
923    use super::*;
924    use crate::builtins::common::test_support;
925    use crate::builtins::io::filetext::registry;
926    use crate::builtins::io::filetext::{fclose, fopen};
927    use crate::RuntimeError;
928    #[cfg(feature = "wgpu")]
929    use runmat_accelerate::backend::wgpu::provider;
930    #[cfg(feature = "wgpu")]
931    use runmat_accelerate_api::AccelProvider;
932    use runmat_accelerate_api::HostTensorView;
933    use runmat_builtins::Tensor;
934    use runmat_filesystem::File;
935    use runmat_time::system_time_now;
936    use std::io::Read;
937    use std::path::PathBuf;
938    use std::time::UNIX_EPOCH;
939
940    fn unwrap_error_message(err: RuntimeError) -> String {
941        err.message().to_string()
942    }
943
944    fn run_evaluate(
945        fid_value: &Value,
946        data_value: &Value,
947        rest: &[Value],
948    ) -> BuiltinResult<FwriteEval> {
949        futures::executor::block_on(evaluate(fid_value, data_value, rest))
950    }
951
952    fn run_fopen(args: &[Value]) -> BuiltinResult<fopen::FopenEval> {
953        futures::executor::block_on(fopen::evaluate(args))
954    }
955
956    fn run_fclose(args: &[Value]) -> BuiltinResult<fclose::FcloseEval> {
957        futures::executor::block_on(fclose::evaluate(args))
958    }
959
960    fn registry_guard() -> std::sync::MutexGuard<'static, ()> {
961        registry::test_guard()
962    }
963
964    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
965    #[test]
966    fn fwrite_descriptor_signatures_cover_core_forms() {
967        let labels: Vec<&str> = FWRITE_DESCRIPTOR
968            .signatures
969            .iter()
970            .map(|sig| sig.label)
971            .collect();
972        assert!(labels.contains(&"count = fwrite(fid, data)"));
973        assert!(labels.contains(&"count = fwrite(fid, data, precision, skip)"));
974        assert!(labels.contains(&"count = fwrite(fid, data, precision, machinefmt)"));
975        assert!(labels.contains(&"count = fwrite(fid, data, precision, skip, machinefmt)"));
976    }
977
978    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
979    #[test]
980    fn fwrite_default_uint8_bytes() {
981        let _guard = registry_guard();
982        registry::reset_for_tests();
983        let path = unique_path("fwrite_uint8");
984        let open = run_fopen(&[
985            Value::from(path.to_string_lossy().to_string()),
986            Value::from("w+b"),
987        ])
988        .expect("fopen");
989        let fid = open.as_open().unwrap().fid as i32;
990
991        let tensor = Tensor::new(vec![1.0, 2.0, 255.0], vec![3, 1]).unwrap();
992        let eval = run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &Vec::new())
993            .expect("fwrite");
994        assert_eq!(eval.count(), 3);
995
996        run_fclose(&[Value::Num(fid as f64)]).unwrap();
997
998        let bytes = test_support::fs::read(&path).expect("read");
999        assert_eq!(bytes, vec![1u8, 2, 255]);
1000        test_support::fs::remove_file(path).unwrap();
1001    }
1002
1003    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1004    #[test]
1005    fn fwrite_double_precision_writes_native_endian() {
1006        let _guard = registry_guard();
1007        registry::reset_for_tests();
1008        let path = unique_path("fwrite_double");
1009        let open = run_fopen(&[
1010            Value::from(path.to_string_lossy().to_string()),
1011            Value::from("w+b"),
1012        ])
1013        .expect("fopen");
1014        let fid = open.as_open().unwrap().fid as i32;
1015
1016        let tensor = Tensor::new(vec![1.5, -2.25], vec![2, 1]).unwrap();
1017        let args = vec![Value::from("double")];
1018        let eval =
1019            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
1020        assert_eq!(eval.count(), 2);
1021
1022        run_fclose(&[Value::Num(fid as f64)]).unwrap();
1023
1024        let bytes = test_support::fs::read(&path).expect("read");
1025        let expected: Vec<u8> = if cfg!(target_endian = "little") {
1026            [1.5f64.to_le_bytes(), (-2.25f64).to_le_bytes()].concat()
1027        } else {
1028            [1.5f64.to_be_bytes(), (-2.25f64).to_be_bytes()].concat()
1029        };
1030        assert_eq!(bytes, expected);
1031        test_support::fs::remove_file(path).unwrap();
1032    }
1033
1034    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1035    #[test]
1036    fn fwrite_big_endian_uint16() {
1037        let _guard = registry_guard();
1038        registry::reset_for_tests();
1039        let path = unique_path("fwrite_be");
1040        let open = run_fopen(&[
1041            Value::from(path.to_string_lossy().to_string()),
1042            Value::from("w+b"),
1043            Value::from("ieee-be"),
1044        ])
1045        .expect("fopen");
1046        let fid = open.as_open().unwrap().fid as i32;
1047
1048        let tensor = Tensor::new(vec![258.0, 772.0], vec![2, 1]).unwrap();
1049        let args = vec![Value::from("uint16")];
1050        let eval =
1051            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
1052        assert_eq!(eval.count(), 2);
1053
1054        run_fclose(&[Value::Num(fid as f64)]).unwrap();
1055
1056        let bytes = test_support::fs::read(&path).expect("read");
1057        assert_eq!(bytes, vec![0x01, 0x02, 0x03, 0x04]);
1058        test_support::fs::remove_file(path).unwrap();
1059    }
1060
1061    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1062    #[test]
1063    fn fwrite_skip_inserts_padding() {
1064        let _guard = registry_guard();
1065        registry::reset_for_tests();
1066        let path = unique_path("fwrite_skip");
1067        let open = run_fopen(&[
1068            Value::from(path.to_string_lossy().to_string()),
1069            Value::from("w+b"),
1070        ])
1071        .expect("fopen");
1072        let fid = open.as_open().unwrap().fid as i32;
1073
1074        let tensor = Tensor::new(vec![10.0, 20.0, 30.0], vec![3, 1]).unwrap();
1075        let args = vec![Value::from("uint8"), Value::Num(1.0)];
1076        let eval =
1077            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
1078        assert_eq!(eval.count(), 3);
1079
1080        run_fclose(&[Value::Num(fid as f64)]).unwrap();
1081
1082        let bytes = test_support::fs::read(&path).expect("read");
1083        assert_eq!(bytes, vec![10u8, 0, 20, 0, 30]);
1084        test_support::fs::remove_file(path).unwrap();
1085    }
1086
1087    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1088    #[test]
1089    fn fwrite_gpu_tensor_gathers_before_write() {
1090        let _guard = registry_guard();
1091        registry::reset_for_tests();
1092        let path = unique_path("fwrite_gpu");
1093
1094        test_support::with_test_provider(|provider| {
1095            registry::reset_for_tests();
1096            let open = run_fopen(&[
1097                Value::from(path.to_string_lossy().to_string()),
1098                Value::from("w+b"),
1099            ])
1100            .expect("fopen");
1101            let fid = open.as_open().unwrap().fid as i32;
1102
1103            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
1104            let view = HostTensorView {
1105                data: &tensor.data,
1106                shape: &tensor.shape,
1107            };
1108            let handle = provider.upload(&view).expect("upload");
1109            let args = vec![Value::from("uint16")];
1110            let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
1111                .expect("fwrite");
1112            assert_eq!(eval.count(), 4);
1113
1114            run_fclose(&[Value::Num(fid as f64)]).unwrap();
1115        });
1116
1117        let mut file = File::open(&path).expect("open");
1118        let mut bytes = Vec::new();
1119        file.read_to_end(&mut bytes).expect("read");
1120        assert_eq!(bytes.len(), 8);
1121        let mut decoded = Vec::new();
1122        for chunk in bytes.chunks_exact(2) {
1123            let value = if cfg!(target_endian = "little") {
1124                u16::from_le_bytes([chunk[0], chunk[1]])
1125            } else {
1126                u16::from_be_bytes([chunk[0], chunk[1]])
1127            };
1128            decoded.push(value);
1129        }
1130        assert_eq!(decoded, vec![1u16, 2, 3, 4]);
1131        test_support::fs::remove_file(path).unwrap();
1132    }
1133
1134    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1135    #[test]
1136    fn fwrite_invalid_precision_errors() {
1137        let _guard = registry_guard();
1138        registry::reset_for_tests();
1139        let path = unique_path("fwrite_invalid_precision");
1140        let open = run_fopen(&[
1141            Value::from(path.to_string_lossy().to_string()),
1142            Value::from("w+b"),
1143        ])
1144        .expect("fopen");
1145        let fid = open.as_open().unwrap().fid as i32;
1146
1147        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1148        let args = vec![Value::from("bogus-class")];
1149        let err = unwrap_error_message(
1150            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
1151        );
1152        assert!(err.contains("unsupported precision"));
1153        let _ = run_fclose(&[Value::Num(fid as f64)]);
1154        test_support::fs::remove_file(path).unwrap();
1155    }
1156
1157    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1158    #[test]
1159    fn fwrite_negative_skip_errors() {
1160        let _guard = registry_guard();
1161        registry::reset_for_tests();
1162        let path = unique_path("fwrite_negative_skip");
1163        let open = run_fopen(&[
1164            Value::from(path.to_string_lossy().to_string()),
1165            Value::from("w+b"),
1166        ])
1167        .expect("fopen");
1168        let fid = open.as_open().unwrap().fid as i32;
1169
1170        let tensor = Tensor::new(vec![10.0], vec![1, 1]).unwrap();
1171        let args = vec![Value::from("uint8"), Value::Num(-1.0)];
1172        let err = unwrap_error_message(
1173            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
1174        );
1175        assert!(err.contains("skip value must be non-negative"));
1176        let _ = run_fclose(&[Value::Num(fid as f64)]);
1177        test_support::fs::remove_file(path).unwrap();
1178    }
1179
1180    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1181    #[test]
1182    #[cfg(feature = "wgpu")]
1183    fn fwrite_wgpu_tensor_roundtrip() {
1184        let _guard = registry_guard();
1185        registry::reset_for_tests();
1186        let path = unique_path("fwrite_wgpu_roundtrip");
1187        let open = run_fopen(&[
1188            Value::from(path.to_string_lossy().to_string()),
1189            Value::from("w+b"),
1190        ])
1191        .expect("fopen");
1192        let fid = open.as_open().unwrap().fid as i32;
1193
1194        let provider = provider::register_wgpu_provider(provider::WgpuProviderOptions::default())
1195            .expect("wgpu provider");
1196
1197        let tensor = Tensor::new(vec![0.5, -1.25, 3.75], vec![3, 1]).unwrap();
1198        let expected = tensor.data.clone();
1199        let view = HostTensorView {
1200            data: &tensor.data,
1201            shape: &tensor.shape,
1202        };
1203        let handle = provider.upload(&view).expect("upload to gpu");
1204        let args = vec![Value::from("double")];
1205        let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
1206            .expect("fwrite");
1207        assert_eq!(eval.count(), 3);
1208
1209        run_fclose(&[Value::Num(fid as f64)]).unwrap();
1210
1211        let mut file = File::open(&path).expect("open");
1212        let mut bytes = Vec::new();
1213        file.read_to_end(&mut bytes).expect("read");
1214        assert_eq!(bytes.len(), 24);
1215        for (chunk, expected_value) in bytes.chunks_exact(8).zip(expected.iter()) {
1216            let mut buf = [0u8; 8];
1217            buf.copy_from_slice(chunk);
1218            let value = if cfg!(target_endian = "little") {
1219                f64::from_le_bytes(buf)
1220            } else {
1221                f64::from_be_bytes(buf)
1222            };
1223            assert!(
1224                (value - expected_value).abs() < 1e-12,
1225                "mismatch: {} vs {}",
1226                value,
1227                expected_value
1228            );
1229        }
1230        test_support::fs::remove_file(path).unwrap();
1231    }
1232
1233    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1234    #[test]
1235    fn fwrite_invalid_identifier_errors() {
1236        let _guard = registry_guard();
1237        registry::reset_for_tests();
1238        let err = unwrap_error_message(
1239            run_evaluate(&Value::Num(-1.0), &Value::Num(1.0), &Vec::new()).unwrap_err(),
1240        );
1241        assert!(err.contains("file identifier must be non-negative"));
1242    }
1243
1244    fn unique_path(prefix: &str) -> PathBuf {
1245        let now = system_time_now()
1246            .duration_since(UNIX_EPOCH)
1247            .expect("time went backwards");
1248        let filename = format!(
1249            "runmat_{prefix}_{}_{}.tmp",
1250            now.as_secs(),
1251            now.subsec_nanos()
1252        );
1253        std::env::temp_dir().join(filename)
1254    }
1255}