1use std::io::{Seek, SeekFrom, Write};
3
4use runmat_builtins::{CharArray, Value};
5use runmat_macros::runtime_builtin;
6
7use crate::builtins::common::spec::{
8 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9 ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::io::filetext::registry;
12use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
13use runmat_filesystem::File;
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17 name: "fwrite",
18 op_kind: GpuOpKind::Custom("file-io-write"),
19 supported_precisions: &[],
20 broadcast: BroadcastSemantics::None,
21 provider_hooks: &[],
22 constant_strategy: ConstantStrategy::InlineLiteral,
23 residency: ResidencyPolicy::GatherImmediately,
24 nan_mode: ReductionNaN::Include,
25 two_pass_threshold: None,
26 workgroup_size: None,
27 accepts_nan_mode: false,
28 notes: "Host-only binary file I/O; GPU arguments are gathered to the CPU prior to writing.",
29};
30
31#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::filetext::fwrite")]
32pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
33 name: "fwrite",
34 shape: ShapeRequirements::Any,
35 constant_strategy: ConstantStrategy::InlineLiteral,
36 elementwise: None,
37 reduction: None,
38 emits_nan: false,
39 notes: "File I/O is never fused; metadata recorded for completeness.",
40};
41
42const BUILTIN_NAME: &str = "fwrite";
43
44fn fwrite_error(message: impl Into<String>) -> RuntimeError {
45 build_runtime_error(message)
46 .with_builtin(BUILTIN_NAME)
47 .build()
48}
49
50fn map_control_flow(err: RuntimeError) -> RuntimeError {
51 let message = err.message().to_string();
52 let identifier = err.identifier().map(|value| value.to_string());
53 let mut builder = build_runtime_error(format!("{BUILTIN_NAME}: {message}"))
54 .with_builtin(BUILTIN_NAME)
55 .with_source(err);
56 if let Some(identifier) = identifier {
57 builder = builder.with_identifier(identifier);
58 }
59 builder.build()
60}
61
62fn map_string_result<T>(result: Result<T, String>) -> BuiltinResult<T> {
63 result.map_err(fwrite_error)
64}
65
66#[runtime_builtin(
67 name = "fwrite",
68 category = "io/filetext",
69 summary = "Write binary data to a file identifier.",
70 keywords = "fwrite,file,io,binary,precision",
71 accel = "cpu",
72 type_resolver(crate::builtins::io::type_resolvers::fwrite_type),
73 builtin_path = "crate::builtins::io::filetext::fwrite"
74)]
75async fn fwrite_builtin(fid: Value, data: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
76 let eval = evaluate(&fid, &data, &rest).await?;
77 Ok(Value::Num(eval.count as f64))
78}
79
80#[derive(Debug, Clone)]
82pub struct FwriteEval {
83 count: usize,
84}
85
86impl FwriteEval {
87 fn new(count: usize) -> Self {
88 Self { count }
89 }
90
91 pub fn count(&self) -> usize {
93 self.count
94 }
95}
96
97pub async fn evaluate(
99 fid_value: &Value,
100 data_value: &Value,
101 rest: &[Value],
102) -> BuiltinResult<FwriteEval> {
103 let fid_host = gather_value(fid_value).await?;
104 let fid = map_string_result(parse_fid(&fid_host))?;
105 if fid < 0 {
106 return Err(fwrite_error("fwrite: file identifier must be non-negative"));
107 }
108 if fid < 3 {
109 return Err(fwrite_error(
110 "fwrite: standard input/output identifiers are not supported yet",
111 ));
112 }
113
114 let info = registry::info_for(fid).ok_or_else(|| {
115 fwrite_error("fwrite: Invalid file identifier. Use fopen to generate a valid file ID.")
116 })?;
117 let handle = registry::take_handle(fid).ok_or_else(|| {
118 fwrite_error("fwrite: Invalid file identifier. Use fopen to generate a valid file ID.")
119 })?;
120
121 let mut file = handle
122 .lock()
123 .map_err(|_| fwrite_error("fwrite: failed to lock file handle (poisoned mutex)"))?;
124
125 let data_host = gather_value(data_value).await?;
126 let rest_host = gather_args(rest).await?;
127 let (precision_arg, skip_arg, machine_arg) = map_string_result(classify_arguments(&rest_host))?;
128
129 let precision_spec = map_string_result(parse_precision(precision_arg))?;
130 let skip_bytes = map_string_result(parse_skip(skip_arg))?;
131 let machine_format = map_string_result(parse_machine_format(machine_arg, &info.machinefmt))?;
132
133 let elements = map_string_result(flatten_elements(&data_host))?;
134 let count = map_string_result(write_elements(
135 &mut file,
136 &elements,
137 precision_spec,
138 skip_bytes,
139 machine_format,
140 ))?;
141 Ok(FwriteEval::new(count))
142}
143
144async fn gather_value(value: &Value) -> BuiltinResult<Value> {
145 gather_if_needed_async(value)
146 .await
147 .map_err(map_control_flow)
148}
149
150async fn gather_args(args: &[Value]) -> BuiltinResult<Vec<Value>> {
151 let mut gathered = Vec::with_capacity(args.len());
152 for value in args {
153 gathered.push(
154 gather_if_needed_async(value)
155 .await
156 .map_err(map_control_flow)?,
157 );
158 }
159 Ok(gathered)
160}
161
162fn parse_fid(value: &Value) -> Result<i32, String> {
163 let scalar = match value {
164 Value::Num(n) => *n,
165 Value::Int(int) => int.to_f64(),
166 _ => return Err("fwrite: file identifier must be numeric".to_string()),
167 };
168 if !scalar.is_finite() {
169 return Err("fwrite: file identifier must be finite".to_string());
170 }
171 if scalar.fract().abs() > f64::EPSILON {
172 return Err("fwrite: file identifier must be an integer".to_string());
173 }
174 Ok(scalar as i32)
175}
176
177type FwriteArgs<'a> = (Option<&'a Value>, Option<&'a Value>, Option<&'a Value>);
178
179fn classify_arguments(args: &[Value]) -> Result<FwriteArgs<'_>, String> {
180 match args.len() {
181 0 => Ok((None, None, None)),
182 1 => {
183 if is_string_like(&args[0]) {
184 Ok((Some(&args[0]), None, None))
185 } else {
186 Err(
187 "fwrite: precision argument must be a string scalar or character vector"
188 .to_string(),
189 )
190 }
191 }
192 2 => {
193 if !is_string_like(&args[0]) {
194 return Err(
195 "fwrite: precision argument must be a string scalar or character vector"
196 .to_string(),
197 );
198 }
199 if is_numeric_like(&args[1]) {
200 Ok((Some(&args[0]), Some(&args[1]), None))
201 } else if is_string_like(&args[1]) {
202 Ok((Some(&args[0]), None, Some(&args[1])))
203 } else {
204 Err("fwrite: invalid argument combination (expected numeric skip or machine format string)".to_string())
205 }
206 }
207 3 => {
208 if !is_string_like(&args[0]) || !is_numeric_like(&args[1]) || !is_string_like(&args[2])
209 {
210 return Err("fwrite: expected arguments (precision, skip, machinefmt)".to_string());
211 }
212 Ok((Some(&args[0]), Some(&args[1]), Some(&args[2])))
213 }
214 _ => Err("fwrite: too many input arguments".to_string()),
215 }
216}
217
218fn is_string_like(value: &Value) -> bool {
219 match value {
220 Value::String(_) => true,
221 Value::CharArray(ca) => ca.rows == 1,
222 Value::StringArray(sa) => sa.data.len() == 1,
223 _ => false,
224 }
225}
226
227fn is_numeric_like(value: &Value) -> bool {
228 match value {
229 Value::Num(_) | Value::Int(_) | Value::Bool(_) => true,
230 Value::Tensor(t) => t.data.len() == 1,
231 Value::LogicalArray(la) => la.data.len() == 1,
232 _ => false,
233 }
234}
235
236#[derive(Clone, Copy, Debug)]
237struct WriteSpec {
238 input: InputType,
239}
240
241impl WriteSpec {
242 fn default() -> Self {
243 Self {
244 input: InputType::UInt8,
245 }
246 }
247}
248
249fn parse_precision(arg: Option<&Value>) -> Result<WriteSpec, String> {
250 match arg {
251 None => Ok(WriteSpec::default()),
252 Some(value) => {
253 let text = scalar_string(
254 value,
255 "fwrite: precision argument must be a string scalar or character vector",
256 )?;
257 parse_precision_string(&text)
258 }
259 }
260}
261
262fn parse_precision_string(raw: &str) -> Result<WriteSpec, String> {
263 let trimmed = raw.trim();
264 if trimmed.is_empty() {
265 return Err("fwrite: precision argument must not be empty".to_string());
266 }
267 let lower = trimmed.to_ascii_lowercase();
268 if let Some((lhs, rhs)) = lower.split_once("=>") {
269 let lhs = lhs.trim();
270 let rhs = rhs.trim();
271 let input = parse_input_label(lhs)?;
272 let output = parse_input_label(rhs)?;
273 if input != output {
274 return Err(
275 "fwrite: differing input/output precisions are not implemented yet".to_string(),
276 );
277 }
278 Ok(WriteSpec { input })
279 } else {
280 parse_input_label(lower.trim()).map(|input| WriteSpec { input })
281 }
282}
283
284fn parse_skip(arg: Option<&Value>) -> Result<usize, String> {
285 match arg {
286 None => Ok(0),
287 Some(value) => {
288 let scalar = numeric_scalar(value, "fwrite: skip must be numeric")?;
289 if !scalar.is_finite() {
290 return Err("fwrite: skip value must be finite".to_string());
291 }
292 if scalar < 0.0 {
293 return Err("fwrite: skip value must be non-negative".to_string());
294 }
295 let rounded = scalar.round();
296 if (rounded - scalar).abs() > f64::EPSILON {
297 return Err("fwrite: skip value must be an integer".to_string());
298 }
299 if rounded > i64::MAX as f64 {
300 return Err("fwrite: skip value is too large".to_string());
301 }
302 Ok(rounded as usize)
303 }
304 }
305}
306
307#[derive(Clone, Copy, Debug)]
308enum MachineFormat {
309 Native,
310 LittleEndian,
311 BigEndian,
312}
313
314impl MachineFormat {
315 fn to_endianness(self) -> Endianness {
316 match self {
317 MachineFormat::Native => {
318 if cfg!(target_endian = "little") {
319 Endianness::Little
320 } else {
321 Endianness::Big
322 }
323 }
324 MachineFormat::LittleEndian => Endianness::Little,
325 MachineFormat::BigEndian => Endianness::Big,
326 }
327 }
328}
329
330#[derive(Clone, Copy, Debug)]
331enum Endianness {
332 Little,
333 Big,
334}
335
336fn parse_machine_format(arg: Option<&Value>, default_label: &str) -> Result<MachineFormat, String> {
337 match arg {
338 Some(value) => {
339 let text = scalar_string(
340 value,
341 "fwrite: machine format must be a string scalar or character vector",
342 )?;
343 machine_format_from_label(&text)
344 }
345 None => machine_format_from_label(default_label),
346 }
347}
348
349fn machine_format_from_label(label: &str) -> Result<MachineFormat, String> {
350 let trimmed = label.trim();
351 if trimmed.is_empty() {
352 return Err("fwrite: machine format must not be empty".to_string());
353 }
354 let lower = trimmed.to_ascii_lowercase();
355 let collapsed: String = lower
356 .chars()
357 .filter(|c| !matches!(c, '-' | '_' | ' '))
358 .collect();
359 if matches!(collapsed.as_str(), "native" | "n" | "system" | "default") {
360 return Ok(MachineFormat::Native);
361 }
362 if matches!(
363 collapsed.as_str(),
364 "l" | "le" | "littleendian" | "pc" | "intel"
365 ) {
366 return Ok(MachineFormat::LittleEndian);
367 }
368 if matches!(
369 collapsed.as_str(),
370 "b" | "be" | "bigendian" | "mac" | "motorola"
371 ) {
372 return Ok(MachineFormat::BigEndian);
373 }
374 if lower.starts_with("ieee-le") {
375 return Ok(MachineFormat::LittleEndian);
376 }
377 if lower.starts_with("ieee-be") {
378 return Ok(MachineFormat::BigEndian);
379 }
380 Err(format!("fwrite: unsupported machine format '{trimmed}'"))
381}
382
383fn scalar_string(value: &Value, err: &str) -> Result<String, String> {
384 match value {
385 Value::String(s) => Ok(s.clone()),
386 Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
387 Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
388 _ => Err(err.to_string()),
389 }
390}
391
392fn numeric_scalar(value: &Value, err: &str) -> Result<f64, String> {
393 match value {
394 Value::Num(n) => Ok(*n),
395 Value::Int(int) => Ok(int.to_f64()),
396 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
397 Value::Tensor(t) if t.data.len() == 1 => Ok(t.data[0]),
398 Value::LogicalArray(la) if la.data.len() == 1 => {
399 Ok(if la.data[0] != 0 { 1.0 } else { 0.0 })
400 }
401 _ => Err(err.to_string()),
402 }
403}
404
405fn flatten_elements(value: &Value) -> Result<Vec<f64>, String> {
406 match value {
407 Value::Tensor(tensor) => Ok(tensor.data.clone()),
408 Value::Num(n) => Ok(vec![*n]),
409 Value::Int(int) => Ok(vec![int.to_f64()]),
410 Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
411 Value::LogicalArray(array) => Ok(array
412 .data
413 .iter()
414 .map(|bit| if *bit != 0 { 1.0 } else { 0.0 })
415 .collect()),
416 Value::CharArray(ca) => Ok(flatten_char_array(ca)),
417 Value::String(text) => Ok(text.chars().map(|ch| ch as u32 as f64).collect()),
418 Value::StringArray(sa) => Ok(flatten_string_array(sa)),
419 Value::GpuTensor(_) => Err("fwrite: expected host tensor data after gathering".to_string()),
420 Value::Complex(_, _) | Value::ComplexTensor(_) => {
421 Err("fwrite: complex values are not supported yet".to_string())
422 }
423 _ => Err(format!("fwrite: unsupported data type {:?}", value)),
424 }
425}
426
427fn flatten_char_array(ca: &CharArray) -> Vec<f64> {
428 let mut values = Vec::with_capacity(ca.rows.saturating_mul(ca.cols));
429 for c in 0..ca.cols {
430 for r in 0..ca.rows {
431 let idx = r * ca.cols + c;
432 values.push(ca.data[idx] as u32 as f64);
433 }
434 }
435 values
436}
437
438fn flatten_string_array(sa: &runmat_builtins::StringArray) -> Vec<f64> {
439 if sa.data.is_empty() {
440 return Vec::new();
441 }
442 let mut values = Vec::new();
443 for (idx, text) in sa.data.iter().enumerate() {
444 if idx > 0 {
445 values.push('\n' as u32 as f64);
446 }
447 values.extend(text.chars().map(|ch| ch as u32 as f64));
448 }
449 values
450}
451
452fn write_elements(
453 file: &mut File,
454 values: &[f64],
455 spec: WriteSpec,
456 skip: usize,
457 machine: MachineFormat,
458) -> Result<usize, String> {
459 let endianness = machine.to_endianness();
460 let skip_offset = skip as i64;
461 for &value in values {
462 match spec.input {
463 InputType::UInt8 => {
464 let byte = to_u8(value);
465 write_bytes(file, &[byte])?;
466 }
467 InputType::Int8 => {
468 let byte = to_i8(value) as u8;
469 write_bytes(file, &[byte])?;
470 }
471 InputType::UInt16 => {
472 let bytes = encode_u16(value, endianness);
473 write_bytes(file, &bytes)?;
474 }
475 InputType::Int16 => {
476 let bytes = encode_i16(value, endianness);
477 write_bytes(file, &bytes)?;
478 }
479 InputType::UInt32 => {
480 let bytes = encode_u32(value, endianness);
481 write_bytes(file, &bytes)?;
482 }
483 InputType::Int32 => {
484 let bytes = encode_i32(value, endianness);
485 write_bytes(file, &bytes)?;
486 }
487 InputType::UInt64 => {
488 let bytes = encode_u64(value, endianness);
489 write_bytes(file, &bytes)?;
490 }
491 InputType::Int64 => {
492 let bytes = encode_i64(value, endianness);
493 write_bytes(file, &bytes)?;
494 }
495 InputType::Float32 => {
496 let bytes = encode_f32(value, endianness);
497 write_bytes(file, &bytes)?;
498 }
499 InputType::Float64 => {
500 let bytes = encode_f64(value, endianness);
501 write_bytes(file, &bytes)?;
502 }
503 }
504
505 if skip > 0 {
506 file.seek(SeekFrom::Current(skip_offset))
507 .map_err(|err| format!("fwrite: failed to seek while applying skip ({err})"))?;
508 }
509 }
510 Ok(values.len())
511}
512
513fn write_bytes(file: &mut File, bytes: &[u8]) -> Result<(), String> {
514 file.write_all(bytes)
515 .map_err(|err| format!("fwrite: failed to write to file ({err})"))
516}
517
518fn to_u8(value: f64) -> u8 {
519 if !value.is_finite() {
520 return if value.is_sign_negative() { 0 } else { u8::MAX };
521 }
522 let mut rounded = value.round();
523 if rounded.is_nan() {
524 return 0;
525 }
526 if rounded < 0.0 {
527 rounded = 0.0;
528 }
529 if rounded > u8::MAX as f64 {
530 rounded = u8::MAX as f64;
531 }
532 rounded as u8
533}
534
535fn to_i8(value: f64) -> i8 {
536 saturating_round(value, i8::MIN as f64, i8::MAX as f64) as i8
537}
538
539fn encode_u16(value: f64, endianness: Endianness) -> [u8; 2] {
540 let rounded = saturating_round(value, 0.0, u16::MAX as f64) as u16;
541 match endianness {
542 Endianness::Little => rounded.to_le_bytes(),
543 Endianness::Big => rounded.to_be_bytes(),
544 }
545}
546
547fn encode_i16(value: f64, endianness: Endianness) -> [u8; 2] {
548 let rounded = saturating_round(value, i16::MIN as f64, i16::MAX as f64) as i16;
549 match endianness {
550 Endianness::Little => rounded.to_le_bytes(),
551 Endianness::Big => rounded.to_be_bytes(),
552 }
553}
554
555fn encode_u32(value: f64, endianness: Endianness) -> [u8; 4] {
556 let rounded = saturating_round(value, 0.0, u32::MAX as f64) as u32;
557 match endianness {
558 Endianness::Little => rounded.to_le_bytes(),
559 Endianness::Big => rounded.to_be_bytes(),
560 }
561}
562
563fn encode_i32(value: f64, endianness: Endianness) -> [u8; 4] {
564 let rounded = saturating_round(value, i32::MIN as f64, i32::MAX as f64) as i32;
565 match endianness {
566 Endianness::Little => rounded.to_le_bytes(),
567 Endianness::Big => rounded.to_be_bytes(),
568 }
569}
570
571fn encode_u64(value: f64, endianness: Endianness) -> [u8; 8] {
572 let rounded = saturating_round(value, 0.0, u64::MAX as f64);
573 let as_u64 = if rounded.is_finite() {
574 rounded as u64
575 } else if rounded.is_sign_negative() {
576 0
577 } else {
578 u64::MAX
579 };
580 match endianness {
581 Endianness::Little => as_u64.to_le_bytes(),
582 Endianness::Big => as_u64.to_be_bytes(),
583 }
584}
585
586fn encode_i64(value: f64, endianness: Endianness) -> [u8; 8] {
587 let rounded = saturating_round(value, i64::MIN as f64, i64::MAX as f64);
588 let as_i64 = if rounded.is_finite() {
589 rounded as i64
590 } else if rounded.is_sign_negative() {
591 i64::MIN
592 } else {
593 i64::MAX
594 };
595 match endianness {
596 Endianness::Little => as_i64.to_le_bytes(),
597 Endianness::Big => as_i64.to_be_bytes(),
598 }
599}
600
601fn encode_f32(value: f64, endianness: Endianness) -> [u8; 4] {
602 let as_f32 = value as f32;
603 let bits = as_f32.to_bits();
604 match endianness {
605 Endianness::Little => bits.to_le_bytes(),
606 Endianness::Big => bits.to_be_bytes(),
607 }
608}
609
610fn encode_f64(value: f64, endianness: Endianness) -> [u8; 8] {
611 let bits = value.to_bits();
612 match endianness {
613 Endianness::Little => bits.to_le_bytes(),
614 Endianness::Big => bits.to_be_bytes(),
615 }
616}
617
618fn saturating_round(value: f64, min: f64, max: f64) -> f64 {
619 if !value.is_finite() {
620 return if value.is_sign_negative() { min } else { max };
621 }
622 let mut rounded = value.round();
623 if rounded.is_nan() {
624 return 0.0;
625 }
626 if rounded < min {
627 rounded = min;
628 }
629 if rounded > max {
630 rounded = max;
631 }
632 rounded
633}
634
635#[derive(Clone, Copy, Debug, PartialEq, Eq)]
636enum InputType {
637 UInt8,
638 Int8,
639 UInt16,
640 Int16,
641 UInt32,
642 Int32,
643 UInt64,
644 Int64,
645 Float32,
646 Float64,
647}
648
649fn parse_input_label(label: &str) -> Result<InputType, String> {
650 match label {
651 "double" | "float64" | "real*8" => Ok(InputType::Float64),
652 "single" | "float32" | "real*4" => Ok(InputType::Float32),
653 "int8" | "schar" | "integer*1" => Ok(InputType::Int8),
654 "uint8" | "uchar" | "unsignedchar" | "char" | "byte" => Ok(InputType::UInt8),
655 "int16" | "short" | "integer*2" => Ok(InputType::Int16),
656 "uint16" | "ushort" | "unsignedshort" => Ok(InputType::UInt16),
657 "int32" | "integer*4" | "long" => Ok(InputType::Int32),
658 "uint32" | "unsignedint" | "unsignedlong" => Ok(InputType::UInt32),
659 "int64" | "integer*8" | "longlong" => Ok(InputType::Int64),
660 "uint64" | "unsignedlonglong" => Ok(InputType::UInt64),
661 other => Err(format!("fwrite: unsupported precision '{other}'")),
662 }
663}
664
665#[cfg(test)]
666pub(crate) mod tests {
667 use super::*;
668 use crate::builtins::common::test_support;
669 use crate::builtins::io::filetext::registry;
670 use crate::builtins::io::filetext::{fclose, fopen};
671 use crate::RuntimeError;
672 #[cfg(feature = "wgpu")]
673 use runmat_accelerate::backend::wgpu::provider;
674 #[cfg(feature = "wgpu")]
675 use runmat_accelerate_api::AccelProvider;
676 use runmat_accelerate_api::HostTensorView;
677 use runmat_builtins::Tensor;
678 use runmat_filesystem::File;
679 use runmat_time::system_time_now;
680 use std::io::Read;
681 use std::path::PathBuf;
682 use std::time::UNIX_EPOCH;
683
684 fn unwrap_error_message(err: RuntimeError) -> String {
685 err.message().to_string()
686 }
687
688 fn run_evaluate(
689 fid_value: &Value,
690 data_value: &Value,
691 rest: &[Value],
692 ) -> BuiltinResult<FwriteEval> {
693 futures::executor::block_on(evaluate(fid_value, data_value, rest))
694 }
695
696 fn run_fopen(args: &[Value]) -> BuiltinResult<fopen::FopenEval> {
697 futures::executor::block_on(fopen::evaluate(args))
698 }
699
700 fn run_fclose(args: &[Value]) -> BuiltinResult<fclose::FcloseEval> {
701 futures::executor::block_on(fclose::evaluate(args))
702 }
703
704 fn registry_guard() -> std::sync::MutexGuard<'static, ()> {
705 registry::test_guard()
706 }
707
708 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
709 #[test]
710 fn fwrite_default_uint8_bytes() {
711 let _guard = registry_guard();
712 registry::reset_for_tests();
713 let path = unique_path("fwrite_uint8");
714 let open = run_fopen(&[
715 Value::from(path.to_string_lossy().to_string()),
716 Value::from("w+b"),
717 ])
718 .expect("fopen");
719 let fid = open.as_open().unwrap().fid as i32;
720
721 let tensor = Tensor::new(vec![1.0, 2.0, 255.0], vec![3, 1]).unwrap();
722 let eval = run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &Vec::new())
723 .expect("fwrite");
724 assert_eq!(eval.count(), 3);
725
726 run_fclose(&[Value::Num(fid as f64)]).unwrap();
727
728 let bytes = test_support::fs::read(&path).expect("read");
729 assert_eq!(bytes, vec![1u8, 2, 255]);
730 test_support::fs::remove_file(path).unwrap();
731 }
732
733 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
734 #[test]
735 fn fwrite_double_precision_writes_native_endian() {
736 let _guard = registry_guard();
737 registry::reset_for_tests();
738 let path = unique_path("fwrite_double");
739 let open = run_fopen(&[
740 Value::from(path.to_string_lossy().to_string()),
741 Value::from("w+b"),
742 ])
743 .expect("fopen");
744 let fid = open.as_open().unwrap().fid as i32;
745
746 let tensor = Tensor::new(vec![1.5, -2.25], vec![2, 1]).unwrap();
747 let args = vec![Value::from("double")];
748 let eval =
749 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
750 assert_eq!(eval.count(), 2);
751
752 run_fclose(&[Value::Num(fid as f64)]).unwrap();
753
754 let bytes = test_support::fs::read(&path).expect("read");
755 let expected: Vec<u8> = if cfg!(target_endian = "little") {
756 [1.5f64.to_le_bytes(), (-2.25f64).to_le_bytes()].concat()
757 } else {
758 [1.5f64.to_be_bytes(), (-2.25f64).to_be_bytes()].concat()
759 };
760 assert_eq!(bytes, expected);
761 test_support::fs::remove_file(path).unwrap();
762 }
763
764 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
765 #[test]
766 fn fwrite_big_endian_uint16() {
767 let _guard = registry_guard();
768 registry::reset_for_tests();
769 let path = unique_path("fwrite_be");
770 let open = run_fopen(&[
771 Value::from(path.to_string_lossy().to_string()),
772 Value::from("w+b"),
773 Value::from("ieee-be"),
774 ])
775 .expect("fopen");
776 let fid = open.as_open().unwrap().fid as i32;
777
778 let tensor = Tensor::new(vec![258.0, 772.0], vec![2, 1]).unwrap();
779 let args = vec![Value::from("uint16")];
780 let eval =
781 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
782 assert_eq!(eval.count(), 2);
783
784 run_fclose(&[Value::Num(fid as f64)]).unwrap();
785
786 let bytes = test_support::fs::read(&path).expect("read");
787 assert_eq!(bytes, vec![0x01, 0x02, 0x03, 0x04]);
788 test_support::fs::remove_file(path).unwrap();
789 }
790
791 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
792 #[test]
793 fn fwrite_skip_inserts_padding() {
794 let _guard = registry_guard();
795 registry::reset_for_tests();
796 let path = unique_path("fwrite_skip");
797 let open = run_fopen(&[
798 Value::from(path.to_string_lossy().to_string()),
799 Value::from("w+b"),
800 ])
801 .expect("fopen");
802 let fid = open.as_open().unwrap().fid as i32;
803
804 let tensor = Tensor::new(vec![10.0, 20.0, 30.0], vec![3, 1]).unwrap();
805 let args = vec![Value::from("uint8"), Value::Num(1.0)];
806 let eval =
807 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).expect("fwrite");
808 assert_eq!(eval.count(), 3);
809
810 run_fclose(&[Value::Num(fid as f64)]).unwrap();
811
812 let bytes = test_support::fs::read(&path).expect("read");
813 assert_eq!(bytes, vec![10u8, 0, 20, 0, 30]);
814 test_support::fs::remove_file(path).unwrap();
815 }
816
817 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
818 #[test]
819 fn fwrite_gpu_tensor_gathers_before_write() {
820 let _guard = registry_guard();
821 registry::reset_for_tests();
822 let path = unique_path("fwrite_gpu");
823
824 test_support::with_test_provider(|provider| {
825 registry::reset_for_tests();
826 let open = run_fopen(&[
827 Value::from(path.to_string_lossy().to_string()),
828 Value::from("w+b"),
829 ])
830 .expect("fopen");
831 let fid = open.as_open().unwrap().fid as i32;
832
833 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
834 let view = HostTensorView {
835 data: &tensor.data,
836 shape: &tensor.shape,
837 };
838 let handle = provider.upload(&view).expect("upload");
839 let args = vec![Value::from("uint16")];
840 let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
841 .expect("fwrite");
842 assert_eq!(eval.count(), 4);
843
844 run_fclose(&[Value::Num(fid as f64)]).unwrap();
845 });
846
847 let mut file = File::open(&path).expect("open");
848 let mut bytes = Vec::new();
849 file.read_to_end(&mut bytes).expect("read");
850 assert_eq!(bytes.len(), 8);
851 let mut decoded = Vec::new();
852 for chunk in bytes.chunks_exact(2) {
853 let value = if cfg!(target_endian = "little") {
854 u16::from_le_bytes([chunk[0], chunk[1]])
855 } else {
856 u16::from_be_bytes([chunk[0], chunk[1]])
857 };
858 decoded.push(value);
859 }
860 assert_eq!(decoded, vec![1u16, 2, 3, 4]);
861 test_support::fs::remove_file(path).unwrap();
862 }
863
864 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
865 #[test]
866 fn fwrite_invalid_precision_errors() {
867 let _guard = registry_guard();
868 registry::reset_for_tests();
869 let path = unique_path("fwrite_invalid_precision");
870 let open = run_fopen(&[
871 Value::from(path.to_string_lossy().to_string()),
872 Value::from("w+b"),
873 ])
874 .expect("fopen");
875 let fid = open.as_open().unwrap().fid as i32;
876
877 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
878 let args = vec![Value::from("bogus-class")];
879 let err = unwrap_error_message(
880 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
881 );
882 assert!(err.contains("unsupported precision"));
883 let _ = run_fclose(&[Value::Num(fid as f64)]);
884 test_support::fs::remove_file(path).unwrap();
885 }
886
887 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
888 #[test]
889 fn fwrite_negative_skip_errors() {
890 let _guard = registry_guard();
891 registry::reset_for_tests();
892 let path = unique_path("fwrite_negative_skip");
893 let open = run_fopen(&[
894 Value::from(path.to_string_lossy().to_string()),
895 Value::from("w+b"),
896 ])
897 .expect("fopen");
898 let fid = open.as_open().unwrap().fid as i32;
899
900 let tensor = Tensor::new(vec![10.0], vec![1, 1]).unwrap();
901 let args = vec![Value::from("uint8"), Value::Num(-1.0)];
902 let err = unwrap_error_message(
903 run_evaluate(&Value::Num(fid as f64), &Value::Tensor(tensor), &args).unwrap_err(),
904 );
905 assert!(err.contains("skip value must be non-negative"));
906 let _ = run_fclose(&[Value::Num(fid as f64)]);
907 test_support::fs::remove_file(path).unwrap();
908 }
909
910 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
911 #[test]
912 #[cfg(feature = "wgpu")]
913 fn fwrite_wgpu_tensor_roundtrip() {
914 let _guard = registry_guard();
915 registry::reset_for_tests();
916 let path = unique_path("fwrite_wgpu_roundtrip");
917 let open = run_fopen(&[
918 Value::from(path.to_string_lossy().to_string()),
919 Value::from("w+b"),
920 ])
921 .expect("fopen");
922 let fid = open.as_open().unwrap().fid as i32;
923
924 let provider = provider::register_wgpu_provider(provider::WgpuProviderOptions::default())
925 .expect("wgpu provider");
926
927 let tensor = Tensor::new(vec![0.5, -1.25, 3.75], vec![3, 1]).unwrap();
928 let expected = tensor.data.clone();
929 let view = HostTensorView {
930 data: &tensor.data,
931 shape: &tensor.shape,
932 };
933 let handle = provider.upload(&view).expect("upload to gpu");
934 let args = vec![Value::from("double")];
935 let eval = run_evaluate(&Value::Num(fid as f64), &Value::GpuTensor(handle), &args)
936 .expect("fwrite");
937 assert_eq!(eval.count(), 3);
938
939 run_fclose(&[Value::Num(fid as f64)]).unwrap();
940
941 let mut file = File::open(&path).expect("open");
942 let mut bytes = Vec::new();
943 file.read_to_end(&mut bytes).expect("read");
944 assert_eq!(bytes.len(), 24);
945 for (chunk, expected_value) in bytes.chunks_exact(8).zip(expected.iter()) {
946 let mut buf = [0u8; 8];
947 buf.copy_from_slice(chunk);
948 let value = if cfg!(target_endian = "little") {
949 f64::from_le_bytes(buf)
950 } else {
951 f64::from_be_bytes(buf)
952 };
953 assert!(
954 (value - expected_value).abs() < 1e-12,
955 "mismatch: {} vs {}",
956 value,
957 expected_value
958 );
959 }
960 test_support::fs::remove_file(path).unwrap();
961 }
962
963 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
964 #[test]
965 fn fwrite_invalid_identifier_errors() {
966 let _guard = registry_guard();
967 registry::reset_for_tests();
968 let err = unwrap_error_message(
969 run_evaluate(&Value::Num(-1.0), &Value::Num(1.0), &Vec::new()).unwrap_err(),
970 );
971 assert!(err.contains("file identifier must be non-negative"));
972 }
973
974 fn unique_path(prefix: &str) -> PathBuf {
975 let now = system_time_now()
976 .duration_since(UNIX_EPOCH)
977 .expect("time went backwards");
978 let filename = format!(
979 "runmat_{prefix}_{}_{}.tmp",
980 now.as_secs(),
981 now.subsec_nanos()
982 );
983 std::env::temp_dir().join(filename)
984 }
985}