1use std::io::{Seek, SeekFrom, Write};
3
4use runmat_builtins::{CharArray, Value};
5use runmat_macros::runtime_builtin;
6
7use crate::builtins::common::spec::{
8 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9 ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::io::filetext::registry;
12use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
13use runmat_filesystem::File;
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17 name: "fwrite",
18 op_kind: GpuOpKind::Custom("file-io-write"),
19 supported_precisions: &[],
20 broadcast: BroadcastSemantics::None,
21 provider_hooks: &[],
22 constant_strategy: ConstantStrategy::InlineLiteral,
23 residency: ResidencyPolicy::GatherImmediately,
24 nan_mode: ReductionNaN::Include,
25 two_pass_threshold: None,
26 workgroup_size: None,
27 accepts_nan_mode: false,
28 notes: "Host-only binary file I/O; GPU arguments are gathered to the CPU prior to writing.",
29};
30
31#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
32pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
33 name: "fwrite",
34 shape: ShapeRequirements::Any,
35 constant_strategy: ConstantStrategy::InlineLiteral,
36 elementwise: None,
37 reduction: None,
38 emits_nan: false,
39 notes: "File I/O is never fused; metadata recorded for completeness.",
40};
41
42const BUILTIN_NAME: &str = "fwrite";
43
44fn fwrite_error(message: impl Into<String>) -> RuntimeError {
45 build_runtime_error(message)
46 .with_builtin(BUILTIN_NAME)
47 .build()
48}
49
50fn map_control_flow(err: RuntimeError) -> RuntimeError {
51 let message = err.message().to_string();
52 let identifier = err.identifier().map(|value| value.to_string());
53 let mut builder = build_runtime_error(format!("{BUILTIN_NAME}: {message}"))
54 .with_builtin(BUILTIN_NAME)
55 .with_source(err);
56 if let Some(identifier) = identifier {
57 builder = builder.with_identifier(identifier);
58 }
59 builder.build()
60}
61
62fn map_string_result<T>(result: Result<T, String>) -> BuiltinResult<T> {
63 result.map_err(fwrite_error)
64}
65
66#[runtime_builtin(
67 name = "fwrite",
68 category = "io/filetext",
69 summary = "Write binary data to a file identifier.",
70 keywords = "fwrite,file,io,binary,precision",
71 accel = "cpu",
72 type_resolver(crate::builtins::io::type_resolvers::fwrite_type),
73 builtin_path = "crate::builtins::io::filetext::fwrite"
74)]
75async fn fwrite_builtin(fid: Value, data: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
76 let eval = evaluate(&fid, &data, &rest).await?;
77 Ok(Value::Num(eval.count as f64))
78}
79
80#[derive(Debug, Clone)]
82pub struct FwriteEval {
83 count: usize,
84}
85
86impl FwriteEval {
87 fn new(count: usize) -> Self {
88 Self { count }
89 }
90
91 pub fn count(&self) -> usize {
93 self.count
94 }
95}
96
97pub async fn evaluate(
99 fid_value: &Value,
100 data_value: &Value,
101 rest: &[Value],
102) -> BuiltinResult<FwriteEval> {
103 let fid_host = gather_value(fid_value).await?;
104 let fid = map_string_result(parse_fid(&fid_host))?;
105 if fid < 0 {
106 return Err(fwrite_error("fwrite: file identifier must be non-negative"));
107 }
108 if fid < 3 {
109 return Err(fwrite_error(
110 "fwrite: standard input/output identifiers are not supported yet",
111 ));
112 }
113
114 let info = registry::info_for(fid).ok_or_else(|| {
115 fwrite_error("fwrite: Invalid file identifier. Use fopen to generate a valid file ID.")
116 })?;
117 let handle = registry::take_handle(fid).ok_or_else(|| {
118 fwrite_error("fwrite: Invalid file identifier. Use fopen to generate a valid file ID.")
119 })?;
120
121 let data_host = gather_value(data_value).await?;
122 let rest_host = gather_args(rest).await?;
123 let (precision_arg, skip_arg, machine_arg) = map_string_result(classify_arguments(&rest_host))?;
124
125 let precision_spec = map_string_result(parse_precision(precision_arg))?;
126 let skip_bytes = map_string_result(parse_skip(skip_arg))?;
127 let machine_format = map_string_result(parse_machine_format(machine_arg, &info.machinefmt))?;
128
129 let mut guard = handle
130 .lock()
131 .map_err(|_| fwrite_error("fwrite: failed to lock file handle (poisoned mutex)"))?;
132 let file = guard.as_mut().ok_or_else(|| {
133 fwrite_error("fwrite: Invalid file identifier. Use fopen to generate a valid file ID.")
134 })?;
135
136 let elements = map_string_result(flatten_elements(&data_host))?;
137 let count = map_string_result(write_elements(
138 file,
139 &elements,
140 precision_spec,
141 skip_bytes,
142 machine_format,
143 ))?;
144 Ok(FwriteEval::new(count))
145}
146
147async fn gather_value(value: &Value) -> BuiltinResult<Value> {
148 gather_if_needed_async(value)
149 .await
150 .map_err(map_control_flow)
151}
152
153async fn gather_args(args: &[Value]) -> BuiltinResult<Vec<Value>> {
154 let mut gathered = Vec::with_capacity(args.len());
155 for value in args {
156 gathered.push(
157 gather_if_needed_async(value)
158 .await
159 .map_err(map_control_flow)?,
160 );
161 }
162 Ok(gathered)
163}
164
165fn parse_fid(value: &Value) -> Result<i32, String> {
166 let scalar = match value {
167 Value::Num(n) => *n,
168 Value::Int(int) => int.to_f64(),
169 _ => return Err("fwrite: file identifier must be numeric".to_string()),
170 };
171 if !scalar.is_finite() {
172 return Err("fwrite: file identifier must be finite".to_string());
173 }
174 if scalar.fract().abs() > f64::EPSILON {
175 return Err("fwrite: file identifier must be an integer".to_string());
176 }
177 Ok(scalar as i32)
178}
179
180type FwriteArgs<'a> = (Option<&'a Value>, Option<&'a Value>, Option<&'a Value>);
181
182fn classify_arguments(args: &[Value]) -> Result<FwriteArgs<'_>, String> {
183 match args.len() {
184 0 => Ok((None, None, None)),
185 1 => {
186 if is_string_like(&args[0]) {
187 Ok((Some(&args[0]), None, None))
188 } else {
189 Err(
190 "fwrite: precision argument must be a string scalar or character vector"
191 .to_string(),
192 )
193 }
194 }
195 2 => {
196 if !is_string_like(&args[0]) {
197 return Err(
198 "fwrite: precision argument must be a string scalar or character vector"
199 .to_string(),
200 );
201 }
202 if is_numeric_like(&args[1]) {
203 Ok((Some(&args[0]), Some(&args[1]), None))
204 } else if is_string_like(&args[1]) {
205 Ok((Some(&args[0]), None, Some(&args[1])))
206 } else {
207 Err("fwrite: invalid argument combination (expected numeric skip or machine format string)".to_string())
208 }
209 }
210 3 => {
211 if !is_string_like(&args[0]) || !is_numeric_like(&args[1]) || !is_string_like(&args[2])
212 {
213 return Err("fwrite: expected arguments (precision, skip, machinefmt)".to_string());
214 }
215 Ok((Some(&args[0]), Some(&args[1]), Some(&args[2])))
216 }
217 _ => Err("fwrite: too many input arguments".to_string()),
218 }
219}
220
221fn is_string_like(value: &Value) -> bool {
222 match value {
223 Value::String(_) => true,
224 Value::CharArray(ca) => ca.rows == 1,
225 Value::StringArray(sa) => sa.data.len() == 1,
226 _ => false,
227 }
228}
229
230fn is_numeric_like(value: &Value) -> bool {
231 match value {
232 Value::Num(_) | Value::Int(_) | Value::Bool(_) => true,
233 Value::Tensor(t) => t.data.len() == 1,
234 Value::LogicalArray(la) => la.data.len() == 1,
235 _ => false,
236 }
237}
238
239#[derive(Clone, Copy, Debug)]
240struct WriteSpec {
241 input: InputType,
242}
243
244impl WriteSpec {
245 fn default() -> Self {
246 Self {
247 input: InputType::UInt8,
248 }
249 }
250}
251
252fn parse_precision(arg: Option<&Value>) -> Result<WriteSpec, String> {
253 match arg {
254 None => Ok(WriteSpec::default()),
255 Some(value) => {
256 let text = scalar_string(
257 value,
258 "fwrite: precision argument must be a string scalar or character vector",
259 )?;
260 parse_precision_string(&text)
261 }
262 }
263}
264
265fn parse_precision_string(raw: &str) -> Result<WriteSpec, String> {
266 let trimmed = raw.trim();
267 if trimmed.is_empty() {
268 return Err("fwrite: precision argument must not be empty".to_string());
269 }
270 let lower = trimmed.to_ascii_lowercase();
271 if let Some((lhs, rhs)) = lower.split_once("=>") {
272 let lhs = lhs.trim();
273 let rhs = rhs.trim();
274 let input = parse_input_label(lhs)?;
275 let output = parse_input_label(rhs)?;
276 if input != output {
277 return Err(
278 "fwrite: differing input/output precisions are not implemented yet".to_string(),
279 );
280 }
281 Ok(WriteSpec { input })
282 } else {
283 parse_input_label(lower.trim()).map(|input| WriteSpec { input })
284 }
285}
286
287fn parse_skip(arg: Option<&Value>) -> Result<usize, String> {
288 match arg {
289 None => Ok(0),
290 Some(value) => {
291 let scalar = numeric_scalar(value, "fwrite: skip must be numeric")?;
292 if !scalar.is_finite() {
293 return Err("fwrite: skip value must be finite".to_string());
294 }
295 if scalar < 0.0 {
296 return Err("fwrite: skip value must be non-negative".to_string());
297 }
298 let rounded = scalar.round();
299 if (rounded - scalar).abs() > f64::EPSILON {
300 return Err("fwrite: skip value must be an integer".to_string());
301 }
302 if rounded > i64::MAX as f64 {
303 return Err("fwrite: skip value is too large".to_string());
304 }
305 Ok(rounded as usize)
306 }
307 }
308}
309
310#[derive(Clone, Copy, Debug)]
311enum MachineFormat {
312 Native,
313 LittleEndian,
314 BigEndian,
315}
316
317impl MachineFormat {
318 fn to_endianness(self) -> Endianness {
319 match self {
320 MachineFormat::Native => {
321 if cfg!(target_endian = "little") {
322 Endianness::Little
323 } else {
324 Endianness::Big
325 }
326 }
327 MachineFormat::LittleEndian => Endianness::Little,
328 MachineFormat::BigEndian => Endianness::Big,
329 }
330 }
331}
332
333#[derive(Clone, Copy, Debug)]
334enum Endianness {
335 Little,
336 Big,
337}
338
339fn parse_machine_format(arg: Option<&Value>, default_label: &str) -> Result<MachineFormat, String> {
340 match arg {
341 Some(value) => {
342 let text = scalar_string(
343 value,
344 "fwrite: machine format must be a string scalar or character vector",
345 )?;
346 machine_format_from_label(&text)
347 }
348 None => machine_format_from_label(default_label),
349 }
350}
351
352fn machine_format_from_label(label: &str) -> Result<MachineFormat, String> {
353 let trimmed = label.trim();
354 if trimmed.is_empty() {
355 return Err("fwrite: machine format must not be empty".to_string());
356 }
357 let lower = trimmed.to_ascii_lowercase();
358 let collapsed: String = lower
359 .chars()
360 .filter(|c| !matches!(c, '-' | '_' | ' '))
361 .collect();
362 if matches!(collapsed.as_str(), "native" | "n" | "system" | "default") {
363 return Ok(MachineFormat::Native);
364 }
365 if matches!(
366 collapsed.as_str(),
367 "l" | "le" | "littleendian" | "pc" | "intel"
368 ) {
369 return Ok(MachineFormat::LittleEndian);
370 }
371 if matches!(
372 collapsed.as_str(),
373 "b" | "be" | "bigendian" | "mac" | "motorola"
374 ) {
375 return Ok(MachineFormat::BigEndian);
376 }
377 if lower.starts_with("ieee-le") {
378 return Ok(MachineFormat::LittleEndian);
379 }
380 if lower.starts_with("ieee-be") {
381 return Ok(MachineFormat::BigEndian);
382 }
383 Err(format!("fwrite: unsupported machine format '{trimmed}'"))
384}
385
386fn scalar_string(value: &Value, err: &str) -> Result<String, String> {
387 match value {
388 Value::String(s) => Ok(s.clone()),
389 Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
390 Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
391 _ => Err(err.to_string()),
392 }
393}
394
395fn numeric_scalar(value: &Value, err: &str) -> Result<f64, String> {
396 match value {
397 Value::Num(n) => Ok(*n),
398 Value::Int(int) => Ok(int.to_f64()),
399 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
400 Value::Tensor(t) if t.data.len() == 1 => Ok(t.data[0]),
401 Value::LogicalArray(la) if la.data.len() == 1 => {
402 Ok(if la.data[0] != 0 { 1.0 } else { 0.0 })
403 }
404 _ => Err(err.to_string()),
405 }
406}
407
408fn flatten_elements(value: &Value) -> Result<Vec<f64>, String> {
409 match value {
410 Value::Tensor(tensor) => Ok(tensor.data.clone()),
411 Value::Num(n) => Ok(vec![*n]),
412 Value::Int(int) => Ok(vec![int.to_f64()]),
413 Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
414 Value::LogicalArray(array) => Ok(array
415 .data
416 .iter()
417 .map(|bit| if *bit != 0 { 1.0 } else { 0.0 })
418 .collect()),
419 Value::CharArray(ca) => Ok(flatten_char_array(ca)),
420 Value::String(text) => Ok(text.chars().map(|ch| ch as u32 as f64).collect()),
421 Value::StringArray(sa) => Ok(flatten_string_array(sa)),
422 Value::GpuTensor(_) => Err("fwrite: expected host tensor data after gathering".to_string()),
423 Value::Complex(_, _) | Value::ComplexTensor(_) => {
424 Err("fwrite: complex values are not supported yet".to_string())
425 }
426 _ => Err(format!("fwrite: unsupported data type {:?}", value)),
427 }
428}
429
430fn flatten_char_array(ca: &CharArray) -> Vec<f64> {
431 let mut values = Vec::with_capacity(ca.rows.saturating_mul(ca.cols));
432 for c in 0..ca.cols {
433 for r in 0..ca.rows {
434 let idx = r * ca.cols + c;
435 values.push(ca.data[idx] as u32 as f64);
436 }
437 }
438 values
439}
440
441fn flatten_string_array(sa: &runmat_builtins::StringArray) -> Vec<f64> {
442 if sa.data.is_empty() {
443 return Vec::new();
444 }
445 let mut values = Vec::new();
446 for (idx, text) in sa.data.iter().enumerate() {
447 if idx > 0 {
448 values.push('\n' as u32 as f64);
449 }
450 values.extend(text.chars().map(|ch| ch as u32 as f64));
451 }
452 values
453}
454
455fn write_elements(
456 file: &mut File,
457 values: &[f64],
458 spec: WriteSpec,
459 skip: usize,
460 machine: MachineFormat,
461) -> Result<usize, String> {
462 let endianness = machine.to_endianness();
463 let skip_offset = skip as i64;
464 for &value in values {
465 match spec.input {
466 InputType::UInt8 => {
467 let byte = to_u8(value);
468 write_bytes(file, &[byte])?;
469 }
470 InputType::Int8 => {
471 let byte = to_i8(value) as u8;
472 write_bytes(file, &[byte])?;
473 }
474 InputType::UInt16 => {
475 let bytes = encode_u16(value, endianness);
476 write_bytes(file, &bytes)?;
477 }
478 InputType::Int16 => {
479 let bytes = encode_i16(value, endianness);
480 write_bytes(file, &bytes)?;
481 }
482 InputType::UInt32 => {
483 let bytes = encode_u32(value, endianness);
484 write_bytes(file, &bytes)?;
485 }
486 InputType::Int32 => {
487 let bytes = encode_i32(value, endianness);
488 write_bytes(file, &bytes)?;
489 }
490 InputType::UInt64 => {
491 let bytes = encode_u64(value, endianness);
492 write_bytes(file, &bytes)?;
493 }
494 InputType::Int64 => {
495 let bytes = encode_i64(value, endianness);
496 write_bytes(file, &bytes)?;
497 }
498 InputType::Float32 => {
499 let bytes = encode_f32(value, endianness);
500 write_bytes(file, &bytes)?;
501 }
502 InputType::Float64 => {
503 let bytes = encode_f64(value, endianness);
504 write_bytes(file, &bytes)?;
505 }
506 }
507
508 if skip > 0 {
509 file.seek(SeekFrom::Current(skip_offset))
510 .map_err(|err| format!("fwrite: failed to seek while applying skip ({err})"))?;
511 }
512 }
513 Ok(values.len())
514}
515
516fn write_bytes(file: &mut File, bytes: &[u8]) -> Result<(), String> {
517 file.write_all(bytes)
518 .map_err(|err| format!("fwrite: failed to write to file ({err})"))
519}
520
521fn to_u8(value: f64) -> u8 {
522 if !value.is_finite() {
523 return if value.is_sign_negative() { 0 } else { u8::MAX };
524 }
525 let mut rounded = value.round();
526 if rounded.is_nan() {
527 return 0;
528 }
529 if rounded < 0.0 {
530 rounded = 0.0;
531 }
532 if rounded > u8::MAX as f64 {
533 rounded = u8::MAX as f64;
534 }
535 rounded as u8
536}
537
538fn to_i8(value: f64) -> i8 {
539 saturating_round(value, i8::MIN as f64, i8::MAX as f64) as i8
540}
541
542fn encode_u16(value: f64, endianness: Endianness) -> [u8; 2] {
543 let rounded = saturating_round(value, 0.0, u16::MAX as f64) as u16;
544 match endianness {
545 Endianness::Little => rounded.to_le_bytes(),
546 Endianness::Big => rounded.to_be_bytes(),
547 }
548}
549
550fn encode_i16(value: f64, endianness: Endianness) -> [u8; 2] {
551 let rounded = saturating_round(value, i16::MIN as f64, i16::MAX as f64) as i16;
552 match endianness {
553 Endianness::Little => rounded.to_le_bytes(),
554 Endianness::Big => rounded.to_be_bytes(),
555 }
556}
557
558fn encode_u32(value: f64, endianness: Endianness) -> [u8; 4] {
559 let rounded = saturating_round(value, 0.0, u32::MAX as f64) as u32;
560 match endianness {
561 Endianness::Little => rounded.to_le_bytes(),
562 Endianness::Big => rounded.to_be_bytes(),
563 }
564}
565
566fn encode_i32(value: f64, endianness: Endianness) -> [u8; 4] {
567 let rounded = saturating_round(value, i32::MIN as f64, i32::MAX as f64) as i32;
568 match endianness {
569 Endianness::Little => rounded.to_le_bytes(),
570 Endianness::Big => rounded.to_be_bytes(),
571 }
572}
573
574fn encode_u64(value: f64, endianness: Endianness) -> [u8; 8] {
575 let rounded = saturating_round(value, 0.0, u64::MAX as f64);
576 let as_u64 = if rounded.is_finite() {
577 rounded as u64
578 } else if rounded.is_sign_negative() {
579 0
580 } else {
581 u64::MAX
582 };
583 match endianness {
584 Endianness::Little => as_u64.to_le_bytes(),
585 Endianness::Big => as_u64.to_be_bytes(),
586 }
587}
588
589fn encode_i64(value: f64, endianness: Endianness) -> [u8; 8] {
590 let rounded = saturating_round(value, i64::MIN as f64, i64::MAX as f64);
591 let as_i64 = if rounded.is_finite() {
592 rounded as i64
593 } else if rounded.is_sign_negative() {
594 i64::MIN
595 } else {
596 i64::MAX
597 };
598 match endianness {
599 Endianness::Little => as_i64.to_le_bytes(),
600 Endianness::Big => as_i64.to_be_bytes(),
601 }
602}
603
604fn encode_f32(value: f64, endianness: Endianness) -> [u8; 4] {
605 let as_f32 = value as f32;
606 let bits = as_f32.to_bits();
607 match endianness {
608 Endianness::Little => bits.to_le_bytes(),
609 Endianness::Big => bits.to_be_bytes(),
610 }
611}
612
613fn encode_f64(value: f64, endianness: Endianness) -> [u8; 8] {
614 let bits = value.to_bits();
615 match endianness {
616 Endianness::Little => bits.to_le_bytes(),
617 Endianness::Big => bits.to_be_bytes(),
618 }
619}
620
621fn saturating_round(value: f64, min: f64, max: f64) -> f64 {
622 if !value.is_finite() {
623 return if value.is_sign_negative() { min } else { max };
624 }
625 let mut rounded = value.round();
626 if rounded.is_nan() {
627 return 0.0;
628 }
629 if rounded < min {
630 rounded = min;
631 }
632 if rounded > max {
633 rounded = max;
634 }
635 rounded
636}
637
638#[derive(Clone, Copy, Debug, PartialEq, Eq)]
639enum InputType {
640 UInt8,
641 Int8,
642 UInt16,
643 Int16,
644 UInt32,
645 Int32,
646 UInt64,
647 Int64,
648 Float32,
649 Float64,
650}
651
652fn parse_input_label(label: &str) -> Result<InputType, String> {
653 match label {
654 "double" | "float64" | "real*8" => Ok(InputType::Float64),
655 "single" | "float32" | "real*4" => Ok(InputType::Float32),
656 "int8" | "schar" | "integer*1" => Ok(InputType::Int8),
657 "uint8" | "uchar" | "unsignedchar" | "char" | "byte" => Ok(InputType::UInt8),
658 "int16" | "short" | "integer*2" => Ok(InputType::Int16),
659 "uint16" | "ushort" | "unsignedshort" => Ok(InputType::UInt16),
660 "int32" | "integer*4" | "long" => Ok(InputType::Int32),
661 "uint32" | "unsignedint" | "unsignedlong" => Ok(InputType::UInt32),
662 "int64" | "integer*8" | "longlong" => Ok(InputType::Int64),
663 "uint64" | "unsignedlonglong" => Ok(InputType::UInt64),
664 other => Err(format!("fwrite: unsupported precision '{other}'")),
665 }
666}
667
668#[cfg(test)]
669pub(crate) mod tests {
670 use super::*;
671 use crate::builtins::common::test_support;
672 use crate::builtins::io::filetext::registry;
673 use crate::builtins::io::filetext::{fclose, fopen};
674 use crate::RuntimeError;
675 #[cfg(feature = "wgpu")]
676 use runmat_accelerate::backend::wgpu::provider;
677 #[cfg(feature = "wgpu")]
678 use runmat_accelerate_api::AccelProvider;
679 use runmat_accelerate_api::HostTensorView;
680 use runmat_builtins::Tensor;
681 use runmat_filesystem::File;
682 use runmat_time::system_time_now;
683 use std::io::Read;
684 use std::path::PathBuf;
685 use std::time::UNIX_EPOCH;
686
687 fn unwrap_error_message(err: RuntimeError) -> String {
688 err.message().to_string()
689 }
690
691 fn run_evaluate(
692 fid_value: &Value,
693 data_value: &Value,
694 rest: &[Value],
695 ) -> BuiltinResult<FwriteEval> {
696 futures::executor::block_on(evaluate(fid_value, data_value, rest))
697 }
698
699 fn run_fopen(args: &[Value]) -> BuiltinResult<fopen::FopenEval> {
700 futures::executor::block_on(fopen::evaluate(args))
701 }
702
703 fn run_fclose(args: &[Value]) -> BuiltinResult<fclose::FcloseEval> {
704 futures::executor::block_on(fclose::evaluate(args))
705 }
706
707 fn registry_guard() -> std::sync::MutexGuard<'static, ()> {
708 registry::test_guard()
709 }
710
711 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
712 #[test]
713 fn fwrite_default_uint8_bytes() {
714 let _guard = registry_guard();
715 registry::reset_for_tests();
716 let path = unique_path("fwrite_uint8");
717 let open = run_fopen(&[
718 Value::from(path.to_string_lossy().to_string()),
719 Value::from("w+b"),
720 ])
721 .expect("fopen");
722 let fid = open.as_open().unwrap().fid as i32;
723
724 let tensor = Tensor::new(vec![1.0, 2.0, 255.0], vec![3, 1]).unwrap();
725 let eval = run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &Vec::new())
726 .expect("fwrite");
727 assert_eq!(eval.count(), 3);
728
729 run_fclose(&[Value::Num(fid as f64)]).unwrap();
730
731 let bytes = test_support::fs::read(&path).expect("read");
732 assert_eq!(bytes, vec![1u8, 2, 255]);
733 test_support::fs::remove_file(path).unwrap();
734 }
735
736 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
737 #[test]
738 fn fwrite_double_precision_writes_native_endian() {
739 let _guard = registry_guard();
740 registry::reset_for_tests();
741 let path = unique_path("fwrite_double");
742 let open = run_fopen(&[
743 Value::from(path.to_string_lossy().to_string()),
744 Value::from("w+b"),
745 ])
746 .expect("fopen");
747 let fid = open.as_open().unwrap().fid as i32;
748
749 let tensor = Tensor::new(vec![1.5, -2.25], vec![2, 1]).unwrap();
750 let args = vec![Value::from("double")];
751 let eval =
752 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
753 assert_eq!(eval.count(), 2);
754
755 run_fclose(&[Value::Num(fid as f64)]).unwrap();
756
757 let bytes = test_support::fs::read(&path).expect("read");
758 let expected: Vec<u8> = if cfg!(target_endian = "little") {
759 [1.5f64.to_le_bytes(), (-2.25f64).to_le_bytes()].concat()
760 } else {
761 [1.5f64.to_be_bytes(), (-2.25f64).to_be_bytes()].concat()
762 };
763 assert_eq!(bytes, expected);
764 test_support::fs::remove_file(path).unwrap();
765 }
766
767 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
768 #[test]
769 fn fwrite_big_endian_uint16() {
770 let _guard = registry_guard();
771 registry::reset_for_tests();
772 let path = unique_path("fwrite_be");
773 let open = run_fopen(&[
774 Value::from(path.to_string_lossy().to_string()),
775 Value::from("w+b"),
776 Value::from("ieee-be"),
777 ])
778 .expect("fopen");
779 let fid = open.as_open().unwrap().fid as i32;
780
781 let tensor = Tensor::new(vec![258.0, 772.0], vec![2, 1]).unwrap();
782 let args = vec![Value::from("uint16")];
783 let eval =
784 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
785 assert_eq!(eval.count(), 2);
786
787 run_fclose(&[Value::Num(fid as f64)]).unwrap();
788
789 let bytes = test_support::fs::read(&path).expect("read");
790 assert_eq!(bytes, vec![0x01, 0x02, 0x03, 0x04]);
791 test_support::fs::remove_file(path).unwrap();
792 }
793
794 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
795 #[test]
796 fn fwrite_skip_inserts_padding() {
797 let _guard = registry_guard();
798 registry::reset_for_tests();
799 let path = unique_path("fwrite_skip");
800 let open = run_fopen(&[
801 Value::from(path.to_string_lossy().to_string()),
802 Value::from("w+b"),
803 ])
804 .expect("fopen");
805 let fid = open.as_open().unwrap().fid as i32;
806
807 let tensor = Tensor::new(vec![10.0, 20.0, 30.0], vec![3, 1]).unwrap();
808 let args = vec![Value::from("uint8"), Value::Num(1.0)];
809 let eval =
810 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
811 assert_eq!(eval.count(), 3);
812
813 run_fclose(&[Value::Num(fid as f64)]).unwrap();
814
815 let bytes = test_support::fs::read(&path).expect("read");
816 assert_eq!(bytes, vec![10u8, 0, 20, 0, 30]);
817 test_support::fs::remove_file(path).unwrap();
818 }
819
820 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
821 #[test]
822 fn fwrite_gpu_tensor_gathers_before_write() {
823 let _guard = registry_guard();
824 registry::reset_for_tests();
825 let path = unique_path("fwrite_gpu");
826
827 test_support::with_test_provider(|provider| {
828 registry::reset_for_tests();
829 let open = run_fopen(&[
830 Value::from(path.to_string_lossy().to_string()),
831 Value::from("w+b"),
832 ])
833 .expect("fopen");
834 let fid = open.as_open().unwrap().fid as i32;
835
836 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
837 let view = HostTensorView {
838 data: &tensor.data,
839 shape: &tensor.shape,
840 };
841 let handle = provider.upload(&view).expect("upload");
842 let args = vec![Value::from("uint16")];
843 let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
844 .expect("fwrite");
845 assert_eq!(eval.count(), 4);
846
847 run_fclose(&[Value::Num(fid as f64)]).unwrap();
848 });
849
850 let mut file = File::open(&path).expect("open");
851 let mut bytes = Vec::new();
852 file.read_to_end(&mut bytes).expect("read");
853 assert_eq!(bytes.len(), 8);
854 let mut decoded = Vec::new();
855 for chunk in bytes.chunks_exact(2) {
856 let value = if cfg!(target_endian = "little") {
857 u16::from_le_bytes([chunk[0], chunk[1]])
858 } else {
859 u16::from_be_bytes([chunk[0], chunk[1]])
860 };
861 decoded.push(value);
862 }
863 assert_eq!(decoded, vec![1u16, 2, 3, 4]);
864 test_support::fs::remove_file(path).unwrap();
865 }
866
867 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
868 #[test]
869 fn fwrite_invalid_precision_errors() {
870 let _guard = registry_guard();
871 registry::reset_for_tests();
872 let path = unique_path("fwrite_invalid_precision");
873 let open = run_fopen(&[
874 Value::from(path.to_string_lossy().to_string()),
875 Value::from("w+b"),
876 ])
877 .expect("fopen");
878 let fid = open.as_open().unwrap().fid as i32;
879
880 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
881 let args = vec![Value::from("bogus-class")];
882 let err = unwrap_error_message(
883 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
884 );
885 assert!(err.contains("unsupported precision"));
886 let _ = run_fclose(&[Value::Num(fid as f64)]);
887 test_support::fs::remove_file(path).unwrap();
888 }
889
890 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
891 #[test]
892 fn fwrite_negative_skip_errors() {
893 let _guard = registry_guard();
894 registry::reset_for_tests();
895 let path = unique_path("fwrite_negative_skip");
896 let open = run_fopen(&[
897 Value::from(path.to_string_lossy().to_string()),
898 Value::from("w+b"),
899 ])
900 .expect("fopen");
901 let fid = open.as_open().unwrap().fid as i32;
902
903 let tensor = Tensor::new(vec![10.0], vec![1, 1]).unwrap();
904 let args = vec![Value::from("uint8"), Value::Num(-1.0)];
905 let err = unwrap_error_message(
906 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
907 );
908 assert!(err.contains("skip value must be non-negative"));
909 let _ = run_fclose(&[Value::Num(fid as f64)]);
910 test_support::fs::remove_file(path).unwrap();
911 }
912
913 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
914 #[test]
915 #[cfg(feature = "wgpu")]
916 fn fwrite_wgpu_tensor_roundtrip() {
917 let _guard = registry_guard();
918 registry::reset_for_tests();
919 let path = unique_path("fwrite_wgpu_roundtrip");
920 let open = run_fopen(&[
921 Value::from(path.to_string_lossy().to_string()),
922 Value::from("w+b"),
923 ])
924 .expect("fopen");
925 let fid = open.as_open().unwrap().fid as i32;
926
927 let provider = provider::register_wgpu_provider(provider::WgpuProviderOptions::default())
928 .expect("wgpu provider");
929
930 let tensor = Tensor::new(vec![0.5, -1.25, 3.75], vec![3, 1]).unwrap();
931 let expected = tensor.data.clone();
932 let view = HostTensorView {
933 data: &tensor.data,
934 shape: &tensor.shape,
935 };
936 let handle = provider.upload(&view).expect("upload to gpu");
937 let args = vec![Value::from("double")];
938 let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
939 .expect("fwrite");
940 assert_eq!(eval.count(), 3);
941
942 run_fclose(&[Value::Num(fid as f64)]).unwrap();
943
944 let mut file = File::open(&path).expect("open");
945 let mut bytes = Vec::new();
946 file.read_to_end(&mut bytes).expect("read");
947 assert_eq!(bytes.len(), 24);
948 for (chunk, expected_value) in bytes.chunks_exact(8).zip(expected.iter()) {
949 let mut buf = [0u8; 8];
950 buf.copy_from_slice(chunk);
951 let value = if cfg!(target_endian = "little") {
952 f64::from_le_bytes(buf)
953 } else {
954 f64::from_be_bytes(buf)
955 };
956 assert!(
957 (value - expected_value).abs() < 1e-12,
958 "mismatch: {} vs {}",
959 value,
960 expected_value
961 );
962 }
963 test_support::fs::remove_file(path).unwrap();
964 }
965
966 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
967 #[test]
968 fn fwrite_invalid_identifier_errors() {
969 let _guard = registry_guard();
970 registry::reset_for_tests();
971 let err = unwrap_error_message(
972 run_evaluate(&Value::Num(-1.0), &Value::Num(1.0), &Vec::new()).unwrap_err(),
973 );
974 assert!(err.contains("file identifier must be non-negative"));
975 }
976
977 fn unique_path(prefix: &str) -> PathBuf {
978 let now = system_time_now()
979 .duration_since(UNIX_EPOCH)
980 .expect("time went backwards");
981 let filename = format!(
982 "runmat_{prefix}_{}_{}.tmp",
983 now.as_secs(),
984 now.subsec_nanos()
985 );
986 std::env::temp_dir().join(filename)
987 }
988}