Skip to main content

runmat_runtime/builtins/io/filetext/
fwrite.rs

1//! MATLAB-compatible `fwrite` builtin for RunMat.
2use std::io::{Seek, SeekFrom, Write};
3
4use runmat_builtins::{CharArray, Value};
5use runmat_macros::runtime_builtin;
6
7use crate::builtins::common::spec::{
8    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9    ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::io::filetext::registry;
12use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
13use runmat_filesystem::File;
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17    name: "fwrite",
18    op_kind: GpuOpKind::Custom("file-io-write"),
19    supported_precisions: &[],
20    broadcast: BroadcastSemantics::None,
21    provider_hooks: &[],
22    constant_strategy: ConstantStrategy::InlineLiteral,
23    residency: ResidencyPolicy::GatherImmediately,
24    nan_mode: ReductionNaN::Include,
25    two_pass_threshold: None,
26    workgroup_size: None,
27    accepts_nan_mode: false,
28    notes: "Host-only binary file I/O; GPU arguments are gathered to the CPU prior to writing.",
29};
30
31#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
32pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
33    name: "fwrite",
34    shape: ShapeRequirements::Any,
35    constant_strategy: ConstantStrategy::InlineLiteral,
36    elementwise: None,
37    reduction: None,
38    emits_nan: false,
39    notes: "File I/O is never fused; metadata recorded for completeness.",
40};
41
42const BUILTIN_NAME: &str = "fwrite";
43
44fn fwrite_error(message: impl Into<String>) -> RuntimeError {
45    build_runtime_error(message)
46        .with_builtin(BUILTIN_NAME)
47        .build()
48}
49
50fn map_control_flow(err: RuntimeError) -> RuntimeError {
51    let message = err.message().to_string();
52    let identifier = err.identifier().map(|value| value.to_string());
53    let mut builder = build_runtime_error(format!("{BUILTIN_NAME}: {message}"))
54        .with_builtin(BUILTIN_NAME)
55        .with_source(err);
56    if let Some(identifier) = identifier {
57        builder = builder.with_identifier(identifier);
58    }
59    builder.build()
60}
61
62fn map_string_result<T>(result: Result<T, String>) -> BuiltinResult<T> {
63    result.map_err(fwrite_error)
64}
65
66#[runtime_builtin(
67    name = "fwrite",
68    category = "io/filetext",
69    summary = "Write binary data to a file identifier.",
70    keywords = "fwrite,file,io,binary,precision",
71    accel = "cpu",
72    type_resolver(crate::builtins::io::type_resolvers::fwrite_type),
73    builtin_path = "crate::builtins::io::filetext::fwrite"
74)]
75async fn fwrite_builtin(fid: Value, data: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
76    let eval = evaluate(&fid, &data, &rest).await?;
77    Ok(Value::Num(eval.count as f64))
78}
79
80/// Result of an `fwrite` evaluation.
81#[derive(Debug, Clone)]
82pub struct FwriteEval {
83    count: usize,
84}
85
86impl FwriteEval {
87    fn new(count: usize) -> Self {
88        Self { count }
89    }
90
91    /// Number of elements successfully written.
92    pub fn count(&self) -> usize {
93        self.count
94    }
95}
96
97/// Evaluate the `fwrite` builtin without invoking the runtime dispatcher.
98pub async fn evaluate(
99    fid_value: &Value,
100    data_value: &Value,
101    rest: &[Value],
102) -> BuiltinResult<FwriteEval> {
103    let fid_host = gather_value(fid_value).await?;
104    let fid = map_string_result(parse_fid(&fid_host))?;
105    if fid < 0 {
106        return Err(fwrite_error("fwrite: file identifier must be non-negative"));
107    }
108    if fid < 3 {
109        return Err(fwrite_error(
110            "fwrite: standard input/output identifiers are not supported yet",
111        ));
112    }
113
114    let info = registry::info_for(fid).ok_or_else(|| {
115        fwrite_error("fwrite: Invalid file identifier. Use fopen to generate a valid file ID.")
116    })?;
117    let handle = registry::take_handle(fid).ok_or_else(|| {
118        fwrite_error("fwrite: Invalid file identifier. Use fopen to generate a valid file ID.")
119    })?;
120
121    let data_host = gather_value(data_value).await?;
122    let rest_host = gather_args(rest).await?;
123    let (precision_arg, skip_arg, machine_arg) = map_string_result(classify_arguments(&rest_host))?;
124
125    let precision_spec = map_string_result(parse_precision(precision_arg))?;
126    let skip_bytes = map_string_result(parse_skip(skip_arg))?;
127    let machine_format = map_string_result(parse_machine_format(machine_arg, &info.machinefmt))?;
128
129    let mut guard = handle
130        .lock()
131        .map_err(|_| fwrite_error("fwrite: failed to lock file handle (poisoned mutex)"))?;
132    let file = guard.as_mut().ok_or_else(|| {
133        fwrite_error("fwrite: Invalid file identifier. Use fopen to generate a valid file ID.")
134    })?;
135
136    let elements = map_string_result(flatten_elements(&data_host))?;
137    let count = map_string_result(write_elements(
138        file,
139        &elements,
140        precision_spec,
141        skip_bytes,
142        machine_format,
143    ))?;
144    Ok(FwriteEval::new(count))
145}
146
147async fn gather_value(value: &Value) -> BuiltinResult<Value> {
148    gather_if_needed_async(value)
149        .await
150        .map_err(map_control_flow)
151}
152
153async fn gather_args(args: &[Value]) -> BuiltinResult<Vec<Value>> {
154    let mut gathered = Vec::with_capacity(args.len());
155    for value in args {
156        gathered.push(
157            gather_if_needed_async(value)
158                .await
159                .map_err(map_control_flow)?,
160        );
161    }
162    Ok(gathered)
163}
164
165fn parse_fid(value: &Value) -> Result<i32, String> {
166    let scalar = match value {
167        Value::Num(n) => *n,
168        Value::Int(int) => int.to_f64(),
169        _ => return Err("fwrite: file identifier must be numeric".to_string()),
170    };
171    if !scalar.is_finite() {
172        return Err("fwrite: file identifier must be finite".to_string());
173    }
174    if scalar.fract().abs() > f64::EPSILON {
175        return Err("fwrite: file identifier must be an integer".to_string());
176    }
177    Ok(scalar as i32)
178}
179
180type FwriteArgs<'a> = (Option<&'a Value>, Option<&'a Value>, Option<&'a Value>);
181
182fn classify_arguments(args: &[Value]) -> Result<FwriteArgs<'_>, String> {
183    match args.len() {
184        0 => Ok((None, None, None)),
185        1 => {
186            if is_string_like(&args[0]) {
187                Ok((Some(&args[0]), None, None))
188            } else {
189                Err(
190                    "fwrite: precision argument must be a string scalar or character vector"
191                        .to_string(),
192                )
193            }
194        }
195        2 => {
196            if !is_string_like(&args[0]) {
197                return Err(
198                    "fwrite: precision argument must be a string scalar or character vector"
199                        .to_string(),
200                );
201            }
202            if is_numeric_like(&args[1]) {
203                Ok((Some(&args[0]), Some(&args[1]), None))
204            } else if is_string_like(&args[1]) {
205                Ok((Some(&args[0]), None, Some(&args[1])))
206            } else {
207                Err("fwrite: invalid argument combination (expected numeric skip or machine format string)".to_string())
208            }
209        }
210        3 => {
211            if !is_string_like(&args[0]) || !is_numeric_like(&args[1]) || !is_string_like(&args[2])
212            {
213                return Err("fwrite: expected arguments (precision, skip, machinefmt)".to_string());
214            }
215            Ok((Some(&args[0]), Some(&args[1]), Some(&args[2])))
216        }
217        _ => Err("fwrite: too many input arguments".to_string()),
218    }
219}
220
221fn is_string_like(value: &Value) -> bool {
222    match value {
223        Value::String(_) => true,
224        Value::CharArray(ca) => ca.rows == 1,
225        Value::StringArray(sa) => sa.data.len() == 1,
226        _ => false,
227    }
228}
229
230fn is_numeric_like(value: &Value) -> bool {
231    match value {
232        Value::Num(_) | Value::Int(_) | Value::Bool(_) => true,
233        Value::Tensor(t) => t.data.len() == 1,
234        Value::LogicalArray(la) => la.data.len() == 1,
235        _ => false,
236    }
237}
238
239#[derive(Clone, Copy, Debug)]
240struct WriteSpec {
241    input: InputType,
242}
243
244impl WriteSpec {
245    fn default() -> Self {
246        Self {
247            input: InputType::UInt8,
248        }
249    }
250}
251
252fn parse_precision(arg: Option<&Value>) -> Result<WriteSpec, String> {
253    match arg {
254        None => Ok(WriteSpec::default()),
255        Some(value) => {
256            let text = scalar_string(
257                value,
258                "fwrite: precision argument must be a string scalar or character vector",
259            )?;
260            parse_precision_string(&text)
261        }
262    }
263}
264
265fn parse_precision_string(raw: &str) -> Result<WriteSpec, String> {
266    let trimmed = raw.trim();
267    if trimmed.is_empty() {
268        return Err("fwrite: precision argument must not be empty".to_string());
269    }
270    let lower = trimmed.to_ascii_lowercase();
271    if let Some((lhs, rhs)) = lower.split_once("=>") {
272        let lhs = lhs.trim();
273        let rhs = rhs.trim();
274        let input = parse_input_label(lhs)?;
275        let output = parse_input_label(rhs)?;
276        if input != output {
277            return Err(
278                "fwrite: differing input/output precisions are not implemented yet".to_string(),
279            );
280        }
281        Ok(WriteSpec { input })
282    } else {
283        parse_input_label(lower.trim()).map(|input| WriteSpec { input })
284    }
285}
286
287fn parse_skip(arg: Option<&Value>) -> Result<usize, String> {
288    match arg {
289        None => Ok(0),
290        Some(value) => {
291            let scalar = numeric_scalar(value, "fwrite: skip must be numeric")?;
292            if !scalar.is_finite() {
293                return Err("fwrite: skip value must be finite".to_string());
294            }
295            if scalar < 0.0 {
296                return Err("fwrite: skip value must be non-negative".to_string());
297            }
298            let rounded = scalar.round();
299            if (rounded - scalar).abs() > f64::EPSILON {
300                return Err("fwrite: skip value must be an integer".to_string());
301            }
302            if rounded > i64::MAX as f64 {
303                return Err("fwrite: skip value is too large".to_string());
304            }
305            Ok(rounded as usize)
306        }
307    }
308}
309
310#[derive(Clone, Copy, Debug)]
311enum MachineFormat {
312    Native,
313    LittleEndian,
314    BigEndian,
315}
316
317impl MachineFormat {
318    fn to_endianness(self) -> Endianness {
319        match self {
320            MachineFormat::Native => {
321                if cfg!(target_endian = "little") {
322                    Endianness::Little
323                } else {
324                    Endianness::Big
325                }
326            }
327            MachineFormat::LittleEndian => Endianness::Little,
328            MachineFormat::BigEndian => Endianness::Big,
329        }
330    }
331}
332
333#[derive(Clone, Copy, Debug)]
334enum Endianness {
335    Little,
336    Big,
337}
338
339fn parse_machine_format(arg: Option<&Value>, default_label: &str) -> Result<MachineFormat, String> {
340    match arg {
341        Some(value) => {
342            let text = scalar_string(
343                value,
344                "fwrite: machine format must be a string scalar or character vector",
345            )?;
346            machine_format_from_label(&text)
347        }
348        None => machine_format_from_label(default_label),
349    }
350}
351
352fn machine_format_from_label(label: &str) -> Result<MachineFormat, String> {
353    let trimmed = label.trim();
354    if trimmed.is_empty() {
355        return Err("fwrite: machine format must not be empty".to_string());
356    }
357    let lower = trimmed.to_ascii_lowercase();
358    let collapsed: String = lower
359        .chars()
360        .filter(|c| !matches!(c, '-' | '_' | ' '))
361        .collect();
362    if matches!(collapsed.as_str(), "native" | "n" | "system" | "default") {
363        return Ok(MachineFormat::Native);
364    }
365    if matches!(
366        collapsed.as_str(),
367        "l" | "le" | "littleendian" | "pc" | "intel"
368    ) {
369        return Ok(MachineFormat::LittleEndian);
370    }
371    if matches!(
372        collapsed.as_str(),
373        "b" | "be" | "bigendian" | "mac" | "motorola"
374    ) {
375        return Ok(MachineFormat::BigEndian);
376    }
377    if lower.starts_with("ieee-le") {
378        return Ok(MachineFormat::LittleEndian);
379    }
380    if lower.starts_with("ieee-be") {
381        return Ok(MachineFormat::BigEndian);
382    }
383    Err(format!("fwrite: unsupported machine format '{trimmed}'"))
384}
385
386fn scalar_string(value: &Value, err: &str) -> Result<String, String> {
387    match value {
388        Value::String(s) => Ok(s.clone()),
389        Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
390        Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
391        _ => Err(err.to_string()),
392    }
393}
394
395fn numeric_scalar(value: &Value, err: &str) -> Result<f64, String> {
396    match value {
397        Value::Num(n) => Ok(*n),
398        Value::Int(int) => Ok(int.to_f64()),
399        Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
400        Value::Tensor(t) if t.data.len() == 1 => Ok(t.data[0]),
401        Value::LogicalArray(la) if la.data.len() == 1 => {
402            Ok(if la.data[0] != 0 { 1.0 } else { 0.0 })
403        }
404        _ => Err(err.to_string()),
405    }
406}
407
408fn flatten_elements(value: &Value) -> Result<Vec<f64>, String> {
409    match value {
410        Value::Tensor(tensor) => Ok(tensor.data.clone()),
411        Value::Num(n) => Ok(vec![*n]),
412        Value::Int(int) => Ok(vec![int.to_f64()]),
413        Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
414        Value::LogicalArray(array) => Ok(array
415            .data
416            .iter()
417            .map(|bit| if *bit != 0 { 1.0 } else { 0.0 })
418            .collect()),
419        Value::CharArray(ca) => Ok(flatten_char_array(ca)),
420        Value::String(text) => Ok(text.chars().map(|ch| ch as u32 as f64).collect()),
421        Value::StringArray(sa) => Ok(flatten_string_array(sa)),
422        Value::GpuTensor(_) => Err("fwrite: expected host tensor data after gathering".to_string()),
423        Value::Complex(_, _) | Value::ComplexTensor(_) => {
424            Err("fwrite: complex values are not supported yet".to_string())
425        }
426        _ => Err(format!("fwrite: unsupported data type {:?}", value)),
427    }
428}
429
430fn flatten_char_array(ca: &CharArray) -> Vec<f64> {
431    let mut values = Vec::with_capacity(ca.rows.saturating_mul(ca.cols));
432    for c in 0..ca.cols {
433        for r in 0..ca.rows {
434            let idx = r * ca.cols + c;
435            values.push(ca.data[idx] as u32 as f64);
436        }
437    }
438    values
439}
440
441fn flatten_string_array(sa: &runmat_builtins::StringArray) -> Vec<f64> {
442    if sa.data.is_empty() {
443        return Vec::new();
444    }
445    let mut values = Vec::new();
446    for (idx, text) in sa.data.iter().enumerate() {
447        if idx > 0 {
448            values.push('\n' as u32 as f64);
449        }
450        values.extend(text.chars().map(|ch| ch as u32 as f64));
451    }
452    values
453}
454
455fn write_elements(
456    file: &mut File,
457    values: &[f64],
458    spec: WriteSpec,
459    skip: usize,
460    machine: MachineFormat,
461) -> Result<usize, String> {
462    let endianness = machine.to_endianness();
463    let skip_offset = skip as i64;
464    for &value in values {
465        match spec.input {
466            InputType::UInt8 => {
467                let byte = to_u8(value);
468                write_bytes(file, &[byte])?;
469            }
470            InputType::Int8 => {
471                let byte = to_i8(value) as u8;
472                write_bytes(file, &[byte])?;
473            }
474            InputType::UInt16 => {
475                let bytes = encode_u16(value, endianness);
476                write_bytes(file, &bytes)?;
477            }
478            InputType::Int16 => {
479                let bytes = encode_i16(value, endianness);
480                write_bytes(file, &bytes)?;
481            }
482            InputType::UInt32 => {
483                let bytes = encode_u32(value, endianness);
484                write_bytes(file, &bytes)?;
485            }
486            InputType::Int32 => {
487                let bytes = encode_i32(value, endianness);
488                write_bytes(file, &bytes)?;
489            }
490            InputType::UInt64 => {
491                let bytes = encode_u64(value, endianness);
492                write_bytes(file, &bytes)?;
493            }
494            InputType::Int64 => {
495                let bytes = encode_i64(value, endianness);
496                write_bytes(file, &bytes)?;
497            }
498            InputType::Float32 => {
499                let bytes = encode_f32(value, endianness);
500                write_bytes(file, &bytes)?;
501            }
502            InputType::Float64 => {
503                let bytes = encode_f64(value, endianness);
504                write_bytes(file, &bytes)?;
505            }
506        }
507
508        if skip > 0 {
509            file.seek(SeekFrom::Current(skip_offset))
510                .map_err(|err| format!("fwrite: failed to seek while applying skip ({err})"))?;
511        }
512    }
513    Ok(values.len())
514}
515
516fn write_bytes(file: &mut File, bytes: &[u8]) -> Result<(), String> {
517    file.write_all(bytes)
518        .map_err(|err| format!("fwrite: failed to write to file ({err})"))
519}
520
521fn to_u8(value: f64) -> u8 {
522    if !value.is_finite() {
523        return if value.is_sign_negative() { 0 } else { u8::MAX };
524    }
525    let mut rounded = value.round();
526    if rounded.is_nan() {
527        return 0;
528    }
529    if rounded < 0.0 {
530        rounded = 0.0;
531    }
532    if rounded > u8::MAX as f64 {
533        rounded = u8::MAX as f64;
534    }
535    rounded as u8
536}
537
538fn to_i8(value: f64) -> i8 {
539    saturating_round(value, i8::MIN as f64, i8::MAX as f64) as i8
540}
541
542fn encode_u16(value: f64, endianness: Endianness) -> [u8; 2] {
543    let rounded = saturating_round(value, 0.0, u16::MAX as f64) as u16;
544    match endianness {
545        Endianness::Little => rounded.to_le_bytes(),
546        Endianness::Big => rounded.to_be_bytes(),
547    }
548}
549
550fn encode_i16(value: f64, endianness: Endianness) -> [u8; 2] {
551    let rounded = saturating_round(value, i16::MIN as f64, i16::MAX as f64) as i16;
552    match endianness {
553        Endianness::Little => rounded.to_le_bytes(),
554        Endianness::Big => rounded.to_be_bytes(),
555    }
556}
557
558fn encode_u32(value: f64, endianness: Endianness) -> [u8; 4] {
559    let rounded = saturating_round(value, 0.0, u32::MAX as f64) as u32;
560    match endianness {
561        Endianness::Little => rounded.to_le_bytes(),
562        Endianness::Big => rounded.to_be_bytes(),
563    }
564}
565
566fn encode_i32(value: f64, endianness: Endianness) -> [u8; 4] {
567    let rounded = saturating_round(value, i32::MIN as f64, i32::MAX as f64) as i32;
568    match endianness {
569        Endianness::Little => rounded.to_le_bytes(),
570        Endianness::Big => rounded.to_be_bytes(),
571    }
572}
573
574fn encode_u64(value: f64, endianness: Endianness) -> [u8; 8] {
575    let rounded = saturating_round(value, 0.0, u64::MAX as f64);
576    let as_u64 = if rounded.is_finite() {
577        rounded as u64
578    } else if rounded.is_sign_negative() {
579        0
580    } else {
581        u64::MAX
582    };
583    match endianness {
584        Endianness::Little => as_u64.to_le_bytes(),
585        Endianness::Big => as_u64.to_be_bytes(),
586    }
587}
588
589fn encode_i64(value: f64, endianness: Endianness) -> [u8; 8] {
590    let rounded = saturating_round(value, i64::MIN as f64, i64::MAX as f64);
591    let as_i64 = if rounded.is_finite() {
592        rounded as i64
593    } else if rounded.is_sign_negative() {
594        i64::MIN
595    } else {
596        i64::MAX
597    };
598    match endianness {
599        Endianness::Little => as_i64.to_le_bytes(),
600        Endianness::Big => as_i64.to_be_bytes(),
601    }
602}
603
604fn encode_f32(value: f64, endianness: Endianness) -> [u8; 4] {
605    let as_f32 = value as f32;
606    let bits = as_f32.to_bits();
607    match endianness {
608        Endianness::Little => bits.to_le_bytes(),
609        Endianness::Big => bits.to_be_bytes(),
610    }
611}
612
613fn encode_f64(value: f64, endianness: Endianness) -> [u8; 8] {
614    let bits = value.to_bits();
615    match endianness {
616        Endianness::Little => bits.to_le_bytes(),
617        Endianness::Big => bits.to_be_bytes(),
618    }
619}
620
621fn saturating_round(value: f64, min: f64, max: f64) -> f64 {
622    if !value.is_finite() {
623        return if value.is_sign_negative() { min } else { max };
624    }
625    let mut rounded = value.round();
626    if rounded.is_nan() {
627        return 0.0;
628    }
629    if rounded < min {
630        rounded = min;
631    }
632    if rounded > max {
633        rounded = max;
634    }
635    rounded
636}
637
638#[derive(Clone, Copy, Debug, PartialEq, Eq)]
639enum InputType {
640    UInt8,
641    Int8,
642    UInt16,
643    Int16,
644    UInt32,
645    Int32,
646    UInt64,
647    Int64,
648    Float32,
649    Float64,
650}
651
652fn parse_input_label(label: &str) -> Result<InputType, String> {
653    match label {
654        "double" | "float64" | "real*8" => Ok(InputType::Float64),
655        "single" | "float32" | "real*4" => Ok(InputType::Float32),
656        "int8" | "schar" | "integer*1" => Ok(InputType::Int8),
657        "uint8" | "uchar" | "unsignedchar" | "char" | "byte" => Ok(InputType::UInt8),
658        "int16" | "short" | "integer*2" => Ok(InputType::Int16),
659        "uint16" | "ushort" | "unsignedshort" => Ok(InputType::UInt16),
660        "int32" | "integer*4" | "long" => Ok(InputType::Int32),
661        "uint32" | "unsignedint" | "unsignedlong" => Ok(InputType::UInt32),
662        "int64" | "integer*8" | "longlong" => Ok(InputType::Int64),
663        "uint64" | "unsignedlonglong" => Ok(InputType::UInt64),
664        other => Err(format!("fwrite: unsupported precision '{other}'")),
665    }
666}
667
668#[cfg(test)]
669pub(crate) mod tests {
670    use super::*;
671    use crate::builtins::common::test_support;
672    use crate::builtins::io::filetext::registry;
673    use crate::builtins::io::filetext::{fclose, fopen};
674    use crate::RuntimeError;
675    #[cfg(feature = "wgpu")]
676    use runmat_accelerate::backend::wgpu::provider;
677    #[cfg(feature = "wgpu")]
678    use runmat_accelerate_api::AccelProvider;
679    use runmat_accelerate_api::HostTensorView;
680    use runmat_builtins::Tensor;
681    use runmat_filesystem::File;
682    use runmat_time::system_time_now;
683    use std::io::Read;
684    use std::path::PathBuf;
685    use std::time::UNIX_EPOCH;
686
687    fn unwrap_error_message(err: RuntimeError) -> String {
688        err.message().to_string()
689    }
690
691    fn run_evaluate(
692        fid_value: &Value,
693        data_value: &Value,
694        rest: &[Value],
695    ) -> BuiltinResult<FwriteEval> {
696        futures::executor::block_on(evaluate(fid_value, data_value, rest))
697    }
698
699    fn run_fopen(args: &[Value]) -> BuiltinResult<fopen::FopenEval> {
700        futures::executor::block_on(fopen::evaluate(args))
701    }
702
703    fn run_fclose(args: &[Value]) -> BuiltinResult<fclose::FcloseEval> {
704        futures::executor::block_on(fclose::evaluate(args))
705    }
706
707    fn registry_guard() -> std::sync::MutexGuard<'static, ()> {
708        registry::test_guard()
709    }
710
711    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
712    #[test]
713    fn fwrite_default_uint8_bytes() {
714        let _guard = registry_guard();
715        registry::reset_for_tests();
716        let path = unique_path("fwrite_uint8");
717        let open = run_fopen(&[
718            Value::from(path.to_string_lossy().to_string()),
719            Value::from("w+b"),
720        ])
721        .expect("fopen");
722        let fid = open.as_open().unwrap().fid as i32;
723
724        let tensor = Tensor::new(vec![1.0, 2.0, 255.0], vec![3, 1]).unwrap();
725        let eval = run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &Vec::new())
726            .expect("fwrite");
727        assert_eq!(eval.count(), 3);
728
729        run_fclose(&[Value::Num(fid as f64)]).unwrap();
730
731        let bytes = test_support::fs::read(&path).expect("read");
732        assert_eq!(bytes, vec![1u8, 2, 255]);
733        test_support::fs::remove_file(path).unwrap();
734    }
735
736    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
737    #[test]
738    fn fwrite_double_precision_writes_native_endian() {
739        let _guard = registry_guard();
740        registry::reset_for_tests();
741        let path = unique_path("fwrite_double");
742        let open = run_fopen(&[
743            Value::from(path.to_string_lossy().to_string()),
744            Value::from("w+b"),
745        ])
746        .expect("fopen");
747        let fid = open.as_open().unwrap().fid as i32;
748
749        let tensor = Tensor::new(vec![1.5, -2.25], vec![2, 1]).unwrap();
750        let args = vec![Value::from("double")];
751        let eval =
752            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
753        assert_eq!(eval.count(), 2);
754
755        run_fclose(&[Value::Num(fid as f64)]).unwrap();
756
757        let bytes = test_support::fs::read(&path).expect("read");
758        let expected: Vec<u8> = if cfg!(target_endian = "little") {
759            [1.5f64.to_le_bytes(), (-2.25f64).to_le_bytes()].concat()
760        } else {
761            [1.5f64.to_be_bytes(), (-2.25f64).to_be_bytes()].concat()
762        };
763        assert_eq!(bytes, expected);
764        test_support::fs::remove_file(path).unwrap();
765    }
766
767    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
768    #[test]
769    fn fwrite_big_endian_uint16() {
770        let _guard = registry_guard();
771        registry::reset_for_tests();
772        let path = unique_path("fwrite_be");
773        let open = run_fopen(&[
774            Value::from(path.to_string_lossy().to_string()),
775            Value::from("w+b"),
776            Value::from("ieee-be"),
777        ])
778        .expect("fopen");
779        let fid = open.as_open().unwrap().fid as i32;
780
781        let tensor = Tensor::new(vec![258.0, 772.0], vec![2, 1]).unwrap();
782        let args = vec![Value::from("uint16")];
783        let eval =
784            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
785        assert_eq!(eval.count(), 2);
786
787        run_fclose(&[Value::Num(fid as f64)]).unwrap();
788
789        let bytes = test_support::fs::read(&path).expect("read");
790        assert_eq!(bytes, vec![0x01, 0x02, 0x03, 0x04]);
791        test_support::fs::remove_file(path).unwrap();
792    }
793
794    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
795    #[test]
796    fn fwrite_skip_inserts_padding() {
797        let _guard = registry_guard();
798        registry::reset_for_tests();
799        let path = unique_path("fwrite_skip");
800        let open = run_fopen(&[
801            Value::from(path.to_string_lossy().to_string()),
802            Value::from("w+b"),
803        ])
804        .expect("fopen");
805        let fid = open.as_open().unwrap().fid as i32;
806
807        let tensor = Tensor::new(vec![10.0, 20.0, 30.0], vec![3, 1]).unwrap();
808        let args = vec![Value::from("uint8"), Value::Num(1.0)];
809        let eval =
810            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
811        assert_eq!(eval.count(), 3);
812
813        run_fclose(&[Value::Num(fid as f64)]).unwrap();
814
815        let bytes = test_support::fs::read(&path).expect("read");
816        assert_eq!(bytes, vec![10u8, 0, 20, 0, 30]);
817        test_support::fs::remove_file(path).unwrap();
818    }
819
820    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
821    #[test]
822    fn fwrite_gpu_tensor_gathers_before_write() {
823        let _guard = registry_guard();
824        registry::reset_for_tests();
825        let path = unique_path("fwrite_gpu");
826
827        test_support::with_test_provider(|provider| {
828            registry::reset_for_tests();
829            let open = run_fopen(&[
830                Value::from(path.to_string_lossy().to_string()),
831                Value::from("w+b"),
832            ])
833            .expect("fopen");
834            let fid = open.as_open().unwrap().fid as i32;
835
836            let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
837            let view = HostTensorView {
838                data: &tensor.data,
839                shape: &tensor.shape,
840            };
841            let handle = provider.upload(&view).expect("upload");
842            let args = vec![Value::from("uint16")];
843            let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
844                .expect("fwrite");
845            assert_eq!(eval.count(), 4);
846
847            run_fclose(&[Value::Num(fid as f64)]).unwrap();
848        });
849
850        let mut file = File::open(&path).expect("open");
851        let mut bytes = Vec::new();
852        file.read_to_end(&mut bytes).expect("read");
853        assert_eq!(bytes.len(), 8);
854        let mut decoded = Vec::new();
855        for chunk in bytes.chunks_exact(2) {
856            let value = if cfg!(target_endian = "little") {
857                u16::from_le_bytes([chunk[0], chunk[1]])
858            } else {
859                u16::from_be_bytes([chunk[0], chunk[1]])
860            };
861            decoded.push(value);
862        }
863        assert_eq!(decoded, vec![1u16, 2, 3, 4]);
864        test_support::fs::remove_file(path).unwrap();
865    }
866
867    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
868    #[test]
869    fn fwrite_invalid_precision_errors() {
870        let _guard = registry_guard();
871        registry::reset_for_tests();
872        let path = unique_path("fwrite_invalid_precision");
873        let open = run_fopen(&[
874            Value::from(path.to_string_lossy().to_string()),
875            Value::from("w+b"),
876        ])
877        .expect("fopen");
878        let fid = open.as_open().unwrap().fid as i32;
879
880        let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
881        let args = vec![Value::from("bogus-class")];
882        let err = unwrap_error_message(
883            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
884        );
885        assert!(err.contains("unsupported precision"));
886        let _ = run_fclose(&[Value::Num(fid as f64)]);
887        test_support::fs::remove_file(path).unwrap();
888    }
889
890    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
891    #[test]
892    fn fwrite_negative_skip_errors() {
893        let _guard = registry_guard();
894        registry::reset_for_tests();
895        let path = unique_path("fwrite_negative_skip");
896        let open = run_fopen(&[
897            Value::from(path.to_string_lossy().to_string()),
898            Value::from("w+b"),
899        ])
900        .expect("fopen");
901        let fid = open.as_open().unwrap().fid as i32;
902
903        let tensor = Tensor::new(vec![10.0], vec![1, 1]).unwrap();
904        let args = vec![Value::from("uint8"), Value::Num(-1.0)];
905        let err = unwrap_error_message(
906            run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
907        );
908        assert!(err.contains("skip value must be non-negative"));
909        let _ = run_fclose(&[Value::Num(fid as f64)]);
910        test_support::fs::remove_file(path).unwrap();
911    }
912
913    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
914    #[test]
915    #[cfg(feature = "wgpu")]
916    fn fwrite_wgpu_tensor_roundtrip() {
917        let _guard = registry_guard();
918        registry::reset_for_tests();
919        let path = unique_path("fwrite_wgpu_roundtrip");
920        let open = run_fopen(&[
921            Value::from(path.to_string_lossy().to_string()),
922            Value::from("w+b"),
923        ])
924        .expect("fopen");
925        let fid = open.as_open().unwrap().fid as i32;
926
927        let provider = provider::register_wgpu_provider(provider::WgpuProviderOptions::default())
928            .expect("wgpu provider");
929
930        let tensor = Tensor::new(vec![0.5, -1.25, 3.75], vec![3, 1]).unwrap();
931        let expected = tensor.data.clone();
932        let view = HostTensorView {
933            data: &tensor.data,
934            shape: &tensor.shape,
935        };
936        let handle = provider.upload(&view).expect("upload to gpu");
937        let args = vec![Value::from("double")];
938        let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
939            .expect("fwrite");
940        assert_eq!(eval.count(), 3);
941
942        run_fclose(&[Value::Num(fid as f64)]).unwrap();
943
944        let mut file = File::open(&path).expect("open");
945        let mut bytes = Vec::new();
946        file.read_to_end(&mut bytes).expect("read");
947        assert_eq!(bytes.len(), 24);
948        for (chunk, expected_value) in bytes.chunks_exact(8).zip(expected.iter()) {
949            let mut buf = [0u8; 8];
950            buf.copy_from_slice(chunk);
951            let value = if cfg!(target_endian = "little") {
952                f64::from_le_bytes(buf)
953            } else {
954                f64::from_be_bytes(buf)
955            };
956            assert!(
957                (value - expected_value).abs() < 1e-12,
958                "mismatch: {} vs {}",
959                value,
960                expected_value
961            );
962        }
963        test_support::fs::remove_file(path).unwrap();
964    }
965
966    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
967    #[test]
968    fn fwrite_invalid_identifier_errors() {
969        let _guard = registry_guard();
970        registry::reset_for_tests();
971        let err = unwrap_error_message(
972            run_evaluate(&Value::Num(-1.0), &Value::Num(1.0), &Vec::new()).unwrap_err(),
973        );
974        assert!(err.contains("file identifier must be non-negative"));
975    }
976
977    fn unique_path(prefix: &str) -> PathBuf {
978        let now = system_time_now()
979            .duration_since(UNIX_EPOCH)
980            .expect("time went backwards");
981        let filename = format!(
982            "runmat_{prefix}_{}_{}.tmp",
983            now.as_secs(),
984            now.subsec_nanos()
985        );
986        std::env::temp_dir().join(filename)
987    }
988}