Skip to main content

runmat_runtime/builtins/io/filetext/
fwrite.rs

1//! MATLAB-compatible `fwrite` builtin for RunMat.
2use std::io::{Seek, SeekFrom, Write};
3
4use runmat_builtins::{CharArray, Value};
5use runmat_macros::runtime_builtin;
6
7use crate::builtins::common::spec::{
8    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9    ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::io::filetext::registry;
12use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
13use runmat_filesystem::File;
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17    name: "fwrite",
18    op_kind: GpuOpKind::Custom("file-io-write"),
19    supported_precisions: &[],
20    broadcast: BroadcastSemantics::None,
21    provider_hooks: &[],
22    constant_strategy: ConstantStrategy::InlineLiteral,
23    residency: ResidencyPolicy::GatherImmediately,
24    nan_mode: ReductionNaN::Include,
25    two_pass_threshold: None,
26    workgroup_size: None,
27    accepts_nan_mode: false,
28    notes: "Host-only binary file I/O; GPU arguments are gathered to the CPU prior to writing.",
29};
30
31#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
32pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
33    name: "fwrite",
34    shape: ShapeRequirements::Any,
35    constant_strategy: ConstantStrategy::InlineLiteral,
36    elementwise: None,
37    reduction: None,
38    emits_nan: false,
39    notes: "File I/O is never fused; metadata recorded for completeness.",
40};
41
42const BUILTIN_NAME: &str = "fwrite";
43
44fn fwrite_error(message: impl Into<String>) -> RuntimeError {
45    build_runtime_error(message)
46        .with_builtin(BUILTIN_NAME)
47        .build()
48}
49
50fn map_control_flow(err: RuntimeError) -> RuntimeError {
51    let message = err.message().to_string();
52    let identifier = err.identifier().map(|value| value.to_string());
53    let mut builder = build_runtime_error(format!("{BUILTIN_NAME}: {message}"))
54        .with_builtin(BUILTIN_NAME)
55        .with_source(err);
56    if let Some(identifier) = identifier {
57        builder = builder.with_identifier(identifier);
58    }
59    builder.build()
60}
61
62fn map_string_result<T>(result: Result<T, String>) -> BuiltinResult<T> {
63    result.map_err(fwrite_error)
64}
65
66#[runtime_builtin(
67    name = "fwrite",
68    category = "io/filetext",
69    summary = "Write binary data to a file identifier.",
70    keywords = "fwrite,file,io,binary,precision",
71    accel = "cpu",
72    type_resolver(crate::builtins::io::type_resolvers::fwrite_type),
73    builtin_path = "crate::builtins::io::filetext::fwrite"
74)]
75async fn fwrite_builtin(fid: Value, data: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
76    let eval = evaluate(&fid, &data, &rest).await?;
77    Ok(Value::Num(eval.count as f64))
78}
79
80/// Result of an `fwrite` evaluation.
81#[derive(Debug, Clone)]
82pub struct FwriteEval {
83    count: usize,
84}
85
86impl FwriteEval {
87    fn new(count: usize) -> Self {
88        Self { count }
89    }
90
91    /// Number of elements successfully written.
92    pub fn count(&self) -> usize {
93        self.count
94    }
95}
96
97/// Evaluate the `fwrite` builtin without invoking the runtime dispatcher.
98pub async fn evaluate(
99    fid_value: &Value,
100    data_value: &Value,
101    rest: &[Value],
102) -> BuiltinResult<FwriteEval> {
103    let fid_host = gather_value(fid_value).await?;
104    let fid = map_string_result(parse_fid(&fid_host))?;
105    if fid < 0 {
106        return Err(fwrite_error("fwrite: file identifier must be non-negative"));
107    }
108    if fid < 3 {
109        return Err(fwrite_error(
110            "fwrite: standard input/output identifiers are not supported yet",
111        ));
112    }
113
114    let info = registry::info_for(fid).ok_or_else(|| {
115        fwrite_error("fwrite: Invalid file identifier. Use fopen to generate a valid file ID.")
116    })?;
117    let handle = registry::take_handle(fid).ok_or_else(|| {
118        fwrite_error("fwrite: Invalid file identifier. Use fopen to generate a valid file ID.")
119    })?;
120
121    let mut file = handle
122        .lock()
123        .map_err(|_| fwrite_error("fwrite: failed to lock file handle (poisoned mutex)"))?;
124
125    let data_host = gather_value(data_value).await?;
126    let rest_host = gather_args(rest).await?;
127    let (precision_arg, skip_arg, machine_arg) = map_string_result(classify_arguments(&rest_host))?;
128
129    let precision_spec = map_string_result(parse_precision(precision_arg))?;
130    let skip_bytes = map_string_result(parse_skip(skip_arg))?;
131    let machine_format = map_string_result(parse_machine_format(machine_arg, &info.machinefmt))?;
132
133    let elements = map_string_result(flatten_elements(&data_host))?;
134    let count = map_string_result(write_elements(
135        &mut file,
136        &elements,
137        precision_spec,
138        skip_bytes,
139        machine_format,
140    ))?;
141    Ok(FwriteEval::new(count))
142}
143
144async fn gather_value(value: &Value) -> BuiltinResult<Value> {
145    gather_if_needed_async(value)
146        .await
147        .map_err(map_control_flow)
148}
149
150async fn gather_args(args: &[Value]) -> BuiltinResult<Vec<Value>> {
151    let mut gathered = Vec::with_capacity(args.len());
152    for value in args {
153        gathered.push(
154            gather_if_needed_async(value)
155                .await
156                .map_err(map_control_flow)?,
157        );
158    }
159    Ok(gathered)
160}
161
162fn parse_fid(value: &Value) -> Result<i32, String> {
163    let scalar = match value {
164        Value::Num(n) => *n,
165        Value::Int(int) => int.to_f64(),
166        _ => return Err("fwrite: file identifier must be numeric".to_string()),
167    };
168    if !scalar.is_finite() {
169        return Err("fwrite: file identifier must be finite".to_string());
170    }
171    if scalar.fract().abs() > f64::EPSILON {
172        return Err("fwrite: file identifier must be an integer".to_string());
173    }
174    Ok(scalar as i32)
175}
176
177type FwriteArgs<'a> = (Option<&'a Value>, Option<&'a Value>, Option<&'a Value>);
178
179fn classify_arguments(args: &[Value]) -> Result<FwriteArgs<'_>, String> {
180    match args.len() {
181        0 => Ok((None, None, None)),
182        1 => {
183            if is_string_like(&args[0]) {
184                Ok((Some(&args[0]), None, None))
185            } else {
186                Err(
187                    "fwrite: precision argument must be a string scalar or character vector"
188                        .to_string(),
189                )
190            }
191        }
192        2 => {
193            if !is_string_like(&args[0]) {
194                return Err(
195                    "fwrite: precision argument must be a string scalar or character vector"
196                        .to_string(),
197                );
198            }
199            if is_numeric_like(&args[1]) {
200                Ok((Some(&args[0]), Some(&args[1]), None))
201            } else if is_string_like(&args[1]) {
202                Ok((Some(&args[0]), None, Some(&args[1])))
203            } else {
204                Err("fwrite: invalid argument combination (expected numeric skip or machine format string)".to_string())
205            }
206        }
207        3 => {
208            if !is_string_like(&args[0]) || !is_numeric_like(&args[1]) || !is_string_like(&args[2])
209            {
210                return Err("fwrite: expected arguments (precision, skip, machinefmt)".to_string());
211            }
212            Ok((Some(&args[0]), Some(&args[1]), Some(&args[2])))
213        }
214        _ => Err("fwrite: too many input arguments".to_string()),
215    }
216}
217
218fn is_string_like(value: &Value) -> bool {
219    match value {
220        Value::String(_) => true,
221        Value::CharArray(ca) => ca.rows == 1,
222        Value::StringArray(sa) => sa.data.len() == 1,
223        _ => false,
224    }
225}
226
227fn is_numeric_like(value: &Value) -> bool {
228    match value {
229        Value::Num(_) | Value::Int(_) | Value::Bool(_) => true,
230        Value::Tensor(t) => t.data.len() == 1,
231        Value::LogicalArray(la) => la.data.len() == 1,
232        _ => false,
233    }
234}
235
236#[derive(Clone, Copy, Debug)]
237struct WriteSpec {
238    input: InputType,
239}
240
241impl WriteSpec {
242    fn default() -> Self {
243        Self {
244            input: InputType::UInt8,
245        }
246    }
247}
248
249fn parse_precision(arg: Option<&Value>) -> Result<WriteSpec, String> {
250    match arg {
251        None => Ok(WriteSpec::default()),
252        Some(value) => {
253            let text = scalar_string(
254                value,
255                "fwrite: precision argument must be a string scalar or character vector",
256            )?;
257            parse_precision_string(&text)
258        }
259    }
260}
261
262fn parse_precision_string(raw: &str) -> Result<WriteSpec, String> {
263    let trimmed = raw.trim();
264    if trimmed.is_empty() {
265        return Err("fwrite: precision argument must not be empty".to_string());
266    }
267    let lower = trimmed.to_ascii_lowercase();
268    if let Some((lhs, rhs)) = lower.split_once("=>") {
269        let lhs = lhs.trim();
270        let rhs = rhs.trim();
271        let input = parse_input_label(lhs)?;
272        let output = parse_input_label(rhs)?;
273        if input != output {
274            return Err(
275                "fwrite: differing input/output precisions are not implemented yet".to_string(),
276            );
277        }
278        Ok(WriteSpec { input })
279    } else {
280        parse_input_label(lower.trim()).map(|input| WriteSpec { input })
281    }
282}
283
284fn parse_skip(arg: Option<&Value>) -> Result<usize, String> {
285    match arg {
286        None => Ok(0),
287        Some(value) => {
288            let scalar = numeric_scalar(value, "fwrite: skip must be numeric")?;
289            if !scalar.is_finite() {
290                return Err("fwrite: skip value must be finite".to_string());
291            }
292            if scalar < 0.0 {
293                return Err("fwrite: skip value must be non-negative".to_string());
294            }
295            let rounded = scalar.round();
296            if (rounded - scalar).abs() > f64::EPSILON {
297                return Err("fwrite: skip value must be an integer".to_string());
298            }
299            if rounded > i64::MAX as f64 {
300                return Err("fwrite: skip value is too large".to_string());
301            }
302            Ok(rounded as usize)
303        }
304    }
305}
306
307#[derive(Clone, Copy, Debug)]
308enum MachineFormat {
309    Native,
310    LittleEndian,
311    BigEndian,
312}
313
314impl MachineFormat {
315    fn to_endianness(self) -> Endianness {
316        match self {
317            MachineFormat::Native => {
318                if cfg!(target_endian = "little") {
319                    Endianness::Little
320                } else {
321                    Endianness::Big
322                }
323            }
324            MachineFormat::LittleEndian => Endianness::Little,
325            MachineFormat::BigEndian => Endianness::Big,
326        }
327    }
328}
329
330#[derive(Clone, Copy, Debug)]
331enum Endianness {
332    Little,
333    Big,
334}
335
336fn parse_machine_format(arg: Option<&Value>, default_label: &str) -> Result<MachineFormat, String> {
337    match arg {
338        Some(value) => {
339            let text = scalar_string(
340                value,
341                "fwrite: machine format must be a string scalar or character vector",
342            )?;
343            machine_format_from_label(&text)
344        }
345        None => machine_format_from_label(default_label),
346    }
347}
348
349fn machine_format_from_label(label: &str) -> Result<MachineFormat, String> {
350    let trimmed = label.trim();
351    if trimmed.is_empty() {
352        return Err("fwrite: machine format must not be empty".to_string());
353    }
354    let lower = trimmed.to_ascii_lowercase();
355    let collapsed: String = lower
356        .chars()
357        .filter(|c| !matches!(c, '-' | '_' | ' '))
358        .collect();
359    if matches!(collapsed.as_str(), "native" | "n" | "system" | "default") {
360        return Ok(MachineFormat::Native);
361    }
362    if matches!(
363        collapsed.as_str(),
364        "l" | "le" | "littleendian" | "pc" | "intel"
365    ) {
366        return Ok(MachineFormat::LittleEndian);
367    }
368    if matches!(
369        collapsed.as_str(),
370        "b" | "be" | "bigendian" | "mac" | "motorola"
371    ) {
372        return Ok(MachineFormat::BigEndian);
373    }
374    if lower.starts_with("ieee-le") {
375        return Ok(MachineFormat::LittleEndian);
376    }
377    if lower.starts_with("ieee-be") {
378        return Ok(MachineFormat::BigEndian);
379    }
380    Err(format!("fwrite: unsupported machine format '{trimmed}'"))
381}
382
383fn scalar_string(value: &Value, err: &str) -> Result<String, String> {
384    match value {
385        Value::String(s) => Ok(s.clone()),
386        Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
387        Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
388        _ => Err(err.to_string()),
389    }
390}
391
392fn numeric_scalar(value: &Value, err: &str) -> Result<f64, String> {
393    match value {
394        Value::Num(n) => Ok(*n),
395        Value::Int(int) => Ok(int.to_f64()),
396        Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
397        Value::Tensor(t) if t.data.len() == 1 => Ok(t.data[0]),
398        Value::LogicalArray(la) if la.data.len() == 1 => {
399            Ok(if la.data[0] != 0 { 1.0 } else { 0.0 })
400        }
401        _ => Err(err.to_string()),
402    }
403}
404
405fn flatten_elements(value: &Value) -> Result<Vec<f64>, String> {
406    match value {
407        Value::Tensor(tensor) => Ok(tensor.data.clone()),
408        Value::Num(n) => Ok(vec![*n]),
409        Value::Int(int) => Ok(vec![int.to_f64()]),
410        Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
411        Value::LogicalArray(array) => Ok(array
412            .data
413            .iter()
414            .map(|bit| if *bit != 0 { 1.0 } else { 0.0 })
415            .collect()),
416        Value::CharArray(ca) => Ok(flatten_char_array(ca)),
417        Value::String(text) => Ok(text.chars().map(|ch| ch as u32 as f64).collect()),
418        Value::StringArray(sa) => Ok(flatten_string_array(sa)),
419        Value::GpuTensor(_) => Err("fwrite: expected host tensor data after gathering".to_string()),
420        Value::Complex(_, _) | Value::ComplexTensor(_) => {
421            Err("fwrite: complex values are not supported yet".to_string())
422        }
423        _ => Err(format!("fwrite: unsupported data type {:?}", value)),
424    }
425}
426
427fn flatten_char_array(ca: &CharArray) -> Vec<f64> {
428    let mut values = Vec::with_capacity(ca.rows.saturating_mul(ca.cols));
429    for c in 0..ca.cols {
430        for r in 0..ca.rows {
431            let idx = r * ca.cols + c;
432            values.push(ca.data[idx] as u32 as f64);
433        }
434    }
435    values
436}
437
438fn flatten_string_array(sa: &runmat_builtins::StringArray) -> Vec<f64> {
439    if sa.data.is_empty() {
440        return Vec::new();
441    }
442    let mut values = Vec::new();
443    for (idx, text) in sa.data.iter().enumerate() {
444        if idx > 0 {
445            values.push('\n' as u32 as f64);
446        }
447        values.extend(text.chars().map(|ch| ch as u32 as f64));
448    }
449    values
450}
451
452fn write_elements(
453    file: &mut File,
454    values: &[f64],
455    spec: WriteSpec,
456    skip: usize,
457    machine: MachineFormat,
458) -> Result<usize, String> {
459    let endianness = machine.to_endianness();
460    let skip_offset = skip as i64;
461    for &value in values {
462        match spec.input {
463            InputType::UInt8 => {
464                let byte = to_u8(value);
465                write_bytes(file, &[byte])?;
466            }
467            InputType::Int8 => {
468                let byte = to_i8(value) as u8;
469                write_bytes(file, &[byte])?;
470            }
471            InputType::UInt16 => {
472                let bytes = encode_u16(value, endianness);
473                write_bytes(file, &bytes)?;
474            }
475            InputType::Int16 => {
476                let bytes = encode_i16(value, endianness);
477                write_bytes(file, &bytes)?;
478            }
479            InputType::UInt32 => {
480                let bytes = encode_u32(value, endianness);
481                write_bytes(file, &bytes)?;
482            }
483            InputType::Int32 => {
484                let bytes = encode_i32(value, endianness);
485                write_bytes(file, &bytes)?;
486            }
487            InputType::UInt64 => {
488                let bytes = encode_u64(value, endianness);
489                write_bytes(file, &bytes)?;
490            }
491            InputType::Int64 => {
492                let bytes = encode_i64(value, endianness);
493                write_bytes(file, &bytes)?;
494            }
495            InputType::Float32 => {
496                let bytes = encode_f32(value, endianness);
497                write_bytes(file, &bytes)?;
498            }
499            InputType::Float64 => {
500                let bytes = encode_f64(value, endianness);
501                write_bytes(file, &bytes)?;
502            }
503        }
504
505        if skip > 0 {
506            file.seek(SeekFrom::Current(skip_offset))
507                .map_err(|err| format!("fwrite: failed to seek while applying skip ({err})"))?;
508        }
509    }
510    Ok(values.len())
511}
512
513fn write_bytes(file: &mut File, bytes: &[u8]) -> Result<(), String> {
514    file.write_all(bytes)
515        .map_err(|err| format!("fwrite: failed to write to file ({err})"))
516}
517
518fn to_u8(value: f64) -> u8 {
519    if !value.is_finite() {
520        return if value.is_sign_negative() { 0 } else { u8::MAX };
521    }
522    let mut rounded = value.round();
523    if rounded.is_nan() {
524        return 0;
525    }
526    if rounded < 0.0 {
527        rounded = 0.0;
528    }
529    if rounded > u8::MAX as f64 {
530        rounded = u8::MAX as f64;
531    }
532    rounded as u8
533}
534
535fn to_i8(value: f64) -> i8 {
536    saturating_round(value, i8::MIN as f64, i8::MAX as f64) as i8
537}
538
539fn encode_u16(value: f64, endianness: Endianness) -> [u8; 2] {
540    let rounded = saturating_round(value, 0.0, u16::MAX as f64) as u16;
541    match endianness {
542        Endianness::Little => rounded.to_le_bytes(),
543        Endianness::Big => rounded.to_be_bytes(),
544    }
545}
546
547fn encode_i16(value: f64, endianness: Endianness) -> [u8; 2] {
548    let rounded = saturating_round(value, i16::MIN as f64, i16::MAX as f64) as i16;
549    match endianness {
550        Endianness::Little => rounded.to_le_bytes(),
551        Endianness::Big => rounded.to_be_bytes(),
552    }
553}
554
555fn encode_u32(value: f64, endianness: Endianness) -> [u8; 4] {
556    let rounded = saturating_round(value, 0.0, u32::MAX as f64) as u32;
557    match endianness {
558        Endianness::Little => rounded.to_le_bytes(),
559        Endianness::Big => rounded.to_be_bytes(),
560    }
561}
562
563fn encode_i32(value: f64, endianness: Endianness) -> [u8; 4] {
564    let rounded = saturating_round(value, i32::MIN as f64, i32::MAX as f64) as i32;
565    match endianness {
566        Endianness::Little => rounded.to_le_bytes(),
567        Endianness::Big => rounded.to_be_bytes(),
568    }
569}
570
571fn encode_u64(value: f64, endianness: Endianness) -> [u8; 8] {
572    let rounded = saturating_round(value, 0.0, u64::MAX as f64);
573    let as_u64 = if rounded.is_finite() {
574        rounded as u64
575    } else if rounded.is_sign_negative() {
576        0
577    } else {
578        u64::MAX
579    };
580    match endianness {
581        Endianness::Little => as_u64.to_le_bytes(),
582        Endianness::Big => as_u64.to_be_bytes(),
583    }
584}
585
586fn encode_i64(value: f64, endianness: Endianness) -> [u8; 8] {
587    let rounded = saturating_round(value, i64::MIN as f64, i64::MAX as f64);
588    let as_i64 = if rounded.is_finite() {
589        rounded as i64
590    } else if rounded.is_sign_negative() {
591        i64::MIN
592    } else {
593        i64::MAX
594    };
595    match endianness {
596        Endianness::Little => as_i64.to_le_bytes(),
597        Endianness::Big => as_i64.to_be_bytes(),
598    }
599}
600
601fn encode_f32(value: f64, endianness: Endianness) -> [u8; 4] {
602    let as_f32 = value as f32;
603    let bits = as_f32.to_bits();
604    match endianness {
605        Endianness::Little => bits.to_le_bytes(),
606        Endianness::Big => bits.to_be_bytes(),
607    }
608}
609
610fn encode_f64(value: f64, endianness: Endianness) -> [u8; 8] {
611    let bits = value.to_bits();
612    match endianness {
613        Endianness::Little => bits.to_le_bytes(),
614        Endianness::Big => bits.to_be_bytes(),
615    }
616}
617
618fn saturating_round(value: f64, min: f64, max: f64) -> f64 {
619    if !value.is_finite() {
620        return if value.is_sign_negative() { min } else { max };
621    }
622    let mut rounded = value.round();
623    if rounded.is_nan() {
624        return 0.0;
625    }
626    if rounded < min {
627        rounded = min;
628    }
629    if rounded > max {
630        rounded = max;
631    }
632    rounded
633}
634
635#[derive(Clone, Copy, Debug, PartialEq, Eq)]
636enum InputType {
637    UInt8,
638    Int8,
639    UInt16,
640    Int16,
641    UInt32,
642    Int32,
643    UInt64,
644    Int64,
645    Float32,
646    Float64,
647}
648
649fn parse_input_label(label: &str) -> Result<InputType, String> {
650    match label {
651        "double" | "float64" | "real*8" => Ok(InputType::Float64),
652        "single" | "float32" | "real*4" => Ok(InputType::Float32),
653        "int8" | "schar" | "integer*1" => Ok(InputType::Int8),
654        "uint8" | "uchar" | "unsignedchar" | "char" | "byte" => Ok(InputType::UInt8),
655        "int16" | "short" | "integer*2" => Ok(InputType::Int16),
656        "uint16" | "ushort" | "unsignedshort" => Ok(InputType::UInt16),
657        "int32" | "integer*4" | "long" => Ok(InputType::Int32),
658        "uint32" | "unsignedint" | "unsignedlong" => Ok(InputType::UInt32),
659        "int64" | "integer*8" | "longlong" => Ok(InputType::Int64),
660        "uint64" | "unsignedlonglong" => Ok(InputType::UInt64),
661        other => Err(format!("fwrite: unsupported precision '{other}'")),
662    }
663}
664
665#[cfg(test)]
666pub(crate) mod tests {
667    use super::*;
668    use crate::builtins::common::test_support;
669    use crate::builtins::io::filetext::registry;
670    use crate::builtins::io::filetext::{fclose, fopen};
671    use crate::RuntimeError;
672    #[cfg(feature = "wgpu")]
673    use runmat_accelerate::backend::wgpu::provider;
674    #[cfg(feature = "wgpu")]
675    use runmat_accelerate_api::AccelProvider;
676    use runmat_accelerate_api::HostTensorView;
677    use runmat_builtins::Tensor;
678    use runmat_filesystem::File;
679    use runmat_time::system_time_now;
680    use std::io::Read;
681    use std::path::PathBuf;
682    use std::time::UNIX_EPOCH;
683
684    fn unwrap_error_message(err: RuntimeError) -> String {
685        err.message().to_string()
686    }
687
688    fn run_evaluate(
689        fid_value: &Value,
690        data_value: &Value,
691        rest: &[Value],
692    ) -> BuiltinResult<FwriteEval> {
693        futures::executor::block_on(evaluate(fid_value, data_value, rest))
694    }
695
696    fn run_fopen(args: &[Value]) -> BuiltinResult<fopen::FopenEval> {
697        futures::executor::block_on(fopen::evaluate(args))
698    }
699
700    fn run_fclose(args: &[Value]) -> BuiltinResult<fclose::FcloseEval> {
701        futures::executor::block_on(fclose::evaluate(args))
702    }
703
704    fn registry_guard() -> std::sync::MutexGuard<'static, ()> {
705        registry::test_guard()
706    }
707
708    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
709    #[test]
710    fn fwrite_default_uint8_bytes() {
711        let _guard = registry_guard();
712        registry::reset_for_tests();
713        let path = unique_path("fwrite_uint8");
714        let open = run_fopen(&[
715            Value::from(path.to_string_lossy().to_string()),
716            Value::from("w+b"),
717        ])
718        .expect("fopen");
719        let fid = open.as_open().unwrap().fid as i32;
720
721        let tensor = Tensor::new(vec![1.0, 2.0, 255.0], vec![3, 1]).unwrap();
722        let eval = run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &Vec::new())
723            .expect("fwrite");
724        assert_eq!(eval.count(), 3);
725
726        run_fclose(&[Value::Num(fid as f64)]).unwrap();
727
728        let bytes = test_support::fs::read(&path).expect("read");
729        assert_eq!(bytes, vec![1u8, 2, 255]);
730        test_support::fs::remove_file(path).unwrap();
731    }
732
733    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
734    #[test]
735    fn fwrite_double_precision_writes_native_endian() {
736        let _guard = registry_guard();
737        registry::reset_for_tests();
738        let path = unique_path("fwrite_double");
739        let open = run_fopen(&[
740            Value::from(path.to_string_lossy().to_string()),
741            Value::from("w+b"),
742        ])
743        .expect("fopen");
744        let fid = open.as_open().unwrap().fid as i32;
745
746        let tensor = Tensor::new(vec![1.5, -2.25], vec![2, 1]).unwrap();
747        let args = vec![Value::from("double")];
748        let eval =
749            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
750        assert_eq!(eval.count(), 2);
751
752        run_fclose(&[Value::Num(fid as f64)]).unwrap();
753
754        let bytes = test_support::fs::read(&path).expect("read");
755        let expected: Vec<u8> = if cfg!(target_endian = "little") {
756            [1.5f64.to_le_bytes(), (-2.25f64).to_le_bytes()].concat()
757        } else {
758            [1.5f64.to_be_bytes(), (-2.25f64).to_be_bytes()].concat()
759        };
760        assert_eq!(bytes, expected);
761        test_support::fs::remove_file(path).unwrap();
762    }
763
764    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
765    #[test]
766    fn fwrite_big_endian_uint16() {
767        let _guard = registry_guard();
768        registry::reset_for_tests();
769        let path = unique_path("fwrite_be");
770        let open = run_fopen(&[
771            Value::from(path.to_string_lossy().to_string()),
772            Value::from("w+b"),
773            Value::from("ieee-be"),
774        ])
775        .expect("fopen");
776        let fid = open.as_open().unwrap().fid as i32;
777
778        let tensor = Tensor::new(vec![258.0, 772.0], vec![2, 1]).unwrap();
779        let args = vec![Value::from("uint16")];
780        let eval =
781            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
782        assert_eq!(eval.count(), 2);
783
784        run_fclose(&[Value::Num(fid as f64)]).unwrap();
785
786        let bytes = test_support::fs::read(&path).expect("read");
787        assert_eq!(bytes, vec![0x01, 0x02, 0x03, 0x04]);
788        test_support::fs::remove_file(path).unwrap();
789    }
790
791    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
792    #[test]
793    fn fwrite_skip_inserts_padding() {
794        let _guard = registry_guard();
795        registry::reset_for_tests();
796        let path = unique_path("fwrite_skip");
797        let open = run_fopen(&[
798            Value::from(path.to_string_lossy().to_string()),
799            Value::from("w+b"),
800        ])
801        .expect("fopen");
802        let fid = open.as_open().unwrap().fid as i32;
803
804        let tensor = Tensor::new(vec![10.0, 20.0, 30.0], vec![3, 1]).unwrap();
805        let args = vec![Value::from("uint8"), Value::Num(1.0)];
806        let eval =
807            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
808        assert_eq!(eval.count(), 3);
809
810        run_fclose(&[Value::Num(fid as f64)]).unwrap();
811
812        let bytes = test_support::fs::read(&path).expect("read");
813        assert_eq!(bytes, vec![10u8, 0, 20, 0, 30]);
814        test_support::fs::remove_file(path).unwrap();
815    }
816
817    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
818    #[test]
819    fn fwrite_gpu_tensor_gathers_before_write() {
820        let _guard = registry_guard();
821        registry::reset_for_tests();
822        let path = unique_path("fwrite_gpu");
823
824        test_support::with_test_provider(|provider| {
825            registry::reset_for_tests();
826            let open = run_fopen(&[
827                Value::from(path.to_string_lossy().to_string()),
828                Value::from("w+b"),
829            ])
830            .expect("fopen");
831            let fid = open.as_open().unwrap().fid as i32;
832
833            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
834            let view = HostTensorView {
835                data: &tensor.data,
836                shape: &tensor.shape,
837            };
838            let handle = provider.upload(&view).expect("upload");
839            let args = vec![Value::from("uint16")];
840            let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
841                .expect("fwrite");
842            assert_eq!(eval.count(), 4);
843
844            run_fclose(&[Value::Num(fid as f64)]).unwrap();
845        });
846
847        let mut file = File::open(&path).expect("open");
848        let mut bytes = Vec::new();
849        file.read_to_end(&mut bytes).expect("read");
850        assert_eq!(bytes.len(), 8);
851        let mut decoded = Vec::new();
852        for chunk in bytes.chunks_exact(2) {
853            let value = if cfg!(target_endian = "little") {
854                u16::from_le_bytes([chunk[0], chunk[1]])
855            } else {
856                u16::from_be_bytes([chunk[0], chunk[1]])
857            };
858            decoded.push(value);
859        }
860        assert_eq!(decoded, vec![1u16, 2, 3, 4]);
861        test_support::fs::remove_file(path).unwrap();
862    }
863
864    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
865    #[test]
866    fn fwrite_invalid_precision_errors() {
867        let _guard = registry_guard();
868        registry::reset_for_tests();
869        let path = unique_path("fwrite_invalid_precision");
870        let open = run_fopen(&[
871            Value::from(path.to_string_lossy().to_string()),
872            Value::from("w+b"),
873        ])
874        .expect("fopen");
875        let fid = open.as_open().unwrap().fid as i32;
876
877        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
878        let args = vec![Value::from("bogus-class")];
879        let err = unwrap_error_message(
880            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
881        );
882        assert!(err.contains("unsupported precision"));
883        let _ = run_fclose(&[Value::Num(fid as f64)]);
884        test_support::fs::remove_file(path).unwrap();
885    }
886
887    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
888    #[test]
889    fn fwrite_negative_skip_errors() {
890        let _guard = registry_guard();
891        registry::reset_for_tests();
892        let path = unique_path("fwrite_negative_skip");
893        let open = run_fopen(&[
894            Value::from(path.to_string_lossy().to_string()),
895            Value::from("w+b"),
896        ])
897        .expect("fopen");
898        let fid = open.as_open().unwrap().fid as i32;
899
900        let tensor = Tensor::new(vec![10.0], vec![1, 1]).unwrap();
901        let args = vec![Value::from("uint8"), Value::Num(-1.0)];
902        let err = unwrap_error_message(
903            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
904        );
905        assert!(err.contains("skip value must be non-negative"));
906        let _ = run_fclose(&[Value::Num(fid as f64)]);
907        test_support::fs::remove_file(path).unwrap();
908    }
909
910    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
911    #[test]
912    #[cfg(feature = "wgpu")]
913    fn fwrite_wgpu_tensor_roundtrip() {
914        let _guard = registry_guard();
915        registry::reset_for_tests();
916        let path = unique_path("fwrite_wgpu_roundtrip");
917        let open = run_fopen(&[
918            Value::from(path.to_string_lossy().to_string()),
919            Value::from("w+b"),
920        ])
921        .expect("fopen");
922        let fid = open.as_open().unwrap().fid as i32;
923
924        let provider = provider::register_wgpu_provider(provider::WgpuProviderOptions::default())
925            .expect("wgpu provider");
926
927        let tensor = Tensor::new(vec![0.5, -1.25, 3.75], vec![3, 1]).unwrap();
928        let expected = tensor.data.clone();
929        let view = HostTensorView {
930            data: &tensor.data,
931            shape: &tensor.shape,
932        };
933        let handle = provider.upload(&view).expect("upload to gpu");
934        let args = vec![Value::from("double")];
935        let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
936            .expect("fwrite");
937        assert_eq!(eval.count(), 3);
938
939        run_fclose(&[Value::Num(fid as f64)]).unwrap();
940
941        let mut file = File::open(&path).expect("open");
942        let mut bytes = Vec::new();
943        file.read_to_end(&mut bytes).expect("read");
944        assert_eq!(bytes.len(), 24);
945        for (chunk, expected_value) in bytes.chunks_exact(8).zip(expected.iter()) {
946            let mut buf = [0u8; 8];
947            buf.copy_from_slice(chunk);
948            let value = if cfg!(target_endian = "little") {
949                f64::from_le_bytes(buf)
950            } else {
951                f64::from_be_bytes(buf)
952            };
953            assert!(
954                (value - expected_value).abs() < 1e-12,
955                "mismatch: {} vs {}",
956                value,
957                expected_value
958            );
959        }
960        test_support::fs::remove_file(path).unwrap();
961    }
962
963    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
964    #[test]
965    fn fwrite_invalid_identifier_errors() {
966        let _guard = registry_guard();
967        registry::reset_for_tests();
968        let err = unwrap_error_message(
969            run_evaluate(&Value::Num(-1.0), &Value::Num(1.0), &Vec::new()).unwrap_err(),
970        );
971        assert!(err.contains("file identifier must be non-negative"));
972    }
973
974    fn unique_path(prefix: &str) -> PathBuf {
975        let now = system_time_now()
976            .duration_since(UNIX_EPOCH)
977            .expect("time went backwards");
978        let filename = format!(
979            "runmat_{prefix}_{}_{}.tmp",
980            now.as_secs(),
981            now.subsec_nanos()
982        );
983        std::env::temp_dir().join(filename)
984    }
985}