1use std::collections::HashMap;
4use std::io::{BufReader, Cursor, Read};
5use std::path::{Path, PathBuf};
6
7use regex::Regex;
8use runmat_builtins::{
9 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
10 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
11 CharArray, ComplexTensor, LogicalArray, StringArray, StructValue, Tensor, Value,
12};
13use runmat_filesystem::File;
14use runmat_macros::runtime_builtin;
15
16use super::format::{
17 MatArray, MatClass, MatData, FLAG_COMPLEX, FLAG_LOGICAL, MAT_HEADER_LEN, MI_DOUBLE, MI_INT32,
18 MI_INT8, MI_MATRIX, MI_UINT16, MI_UINT32, MI_UINT8,
19};
20use crate::builtins::common::spec::{
21 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
22 ReductionNaN, ResidencyPolicy, ShapeRequirements,
23};
24use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
25
26const LOAD_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27 name: "S",
28 ty: BuiltinParamType::Any,
29 arity: BuiltinParamArity::Required,
30 default: None,
31 description: "Struct containing the loaded variables.",
32}];
33const LOAD_INPUTS_NONE: [BuiltinParamDescriptor; 0] = [];
34const LOAD_INPUTS_FILENAME: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
35 name: "filename",
36 ty: BuiltinParamType::StringScalar,
37 arity: BuiltinParamArity::Required,
38 default: Some("\"matlab.mat\""),
39 description: "MAT-file path.",
40}];
41const LOAD_INPUTS_FILENAME_VARS: [BuiltinParamDescriptor; 2] = [
42 BuiltinParamDescriptor {
43 name: "filename",
44 ty: BuiltinParamType::StringScalar,
45 arity: BuiltinParamArity::Required,
46 default: Some("\"matlab.mat\""),
47 description: "MAT-file path.",
48 },
49 BuiltinParamDescriptor {
50 name: "varName",
51 ty: BuiltinParamType::StringScalar,
52 arity: BuiltinParamArity::Variadic,
53 default: None,
54 description: "Variable names to load.",
55 },
56];
57const LOAD_INPUTS_FILENAME_REGEXP: [BuiltinParamDescriptor; 3] = [
58 BuiltinParamDescriptor {
59 name: "filename",
60 ty: BuiltinParamType::StringScalar,
61 arity: BuiltinParamArity::Required,
62 default: Some("\"matlab.mat\""),
63 description: "MAT-file path.",
64 },
65 BuiltinParamDescriptor {
66 name: "option",
67 ty: BuiltinParamType::StringScalar,
68 arity: BuiltinParamArity::Required,
69 default: Some("\"-regexp\""),
70 description: "Regular-expression selection option.",
71 },
72 BuiltinParamDescriptor {
73 name: "pattern",
74 ty: BuiltinParamType::StringScalar,
75 arity: BuiltinParamArity::Variadic,
76 default: None,
77 description: "Regex patterns matched against variable names.",
78 },
79];
80const LOAD_INPUTS_OPTIONS: [BuiltinParamDescriptor; 2] = [
81 BuiltinParamDescriptor {
82 name: "option",
83 ty: BuiltinParamType::StringScalar,
84 arity: BuiltinParamArity::Variadic,
85 default: None,
86 description: "Compatibility options such as '-mat' and '-regexp'.",
87 },
88 BuiltinParamDescriptor {
89 name: "value",
90 ty: BuiltinParamType::Any,
91 arity: BuiltinParamArity::Variadic,
92 default: None,
93 description: "Option arguments and variable selectors.",
94 },
95];
96const LOAD_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
97 BuiltinSignatureDescriptor {
98 label: "S = load()",
99 inputs: &LOAD_INPUTS_NONE,
100 outputs: &LOAD_OUTPUT,
101 },
102 BuiltinSignatureDescriptor {
103 label: "S = load(filename)",
104 inputs: &LOAD_INPUTS_FILENAME,
105 outputs: &LOAD_OUTPUT,
106 },
107 BuiltinSignatureDescriptor {
108 label: "S = load(filename, varName1, varName2, ...)",
109 inputs: &LOAD_INPUTS_FILENAME_VARS,
110 outputs: &LOAD_OUTPUT,
111 },
112 BuiltinSignatureDescriptor {
113 label: "S = load(filename, \"-regexp\", pattern1, ...)",
114 inputs: &LOAD_INPUTS_FILENAME_REGEXP,
115 outputs: &LOAD_OUTPUT,
116 },
117 BuiltinSignatureDescriptor {
118 label: "S = load(option, value, ...)",
119 inputs: &LOAD_INPUTS_OPTIONS,
120 outputs: &LOAD_OUTPUT,
121 },
122];
123const LOAD_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
124 code: "RM.LOAD.INVALID_ARGUMENT",
125 identifier: Some("RunMat:load:InvalidArgument"),
126 when: "Arguments do not match a supported load invocation form.",
127 message: "load: invalid argument",
128};
129const LOAD_ERROR_INVALID_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
130 code: "RM.LOAD.INVALID_OPTION",
131 identifier: Some("RunMat:load:InvalidOption"),
132 when: "An option token or option argument is invalid.",
133 message: "load: invalid option",
134};
135const LOAD_ERROR_FILENAME: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
136 code: "RM.LOAD.FILENAME",
137 identifier: Some("RunMat:load:Filename"),
138 when: "Filename is invalid or cannot be normalized.",
139 message: "load: invalid filename",
140};
141const LOAD_ERROR_SELECTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
142 code: "RM.LOAD.SELECTION",
143 identifier: Some("RunMat:load:Selection"),
144 when: "Requested variables are missing or no variables are selected.",
145 message: "load: variable selection failed",
146};
147const LOAD_ERROR_IO: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
148 code: "RM.LOAD.IO",
149 identifier: Some("RunMat:load:Io"),
150 when: "MAT-file data cannot be read or decoded.",
151 message: "load: MAT-file I/O failure",
152};
153const LOAD_ERROR_OUTPUT_COUNT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
154 code: "RM.LOAD.OUTPUT_COUNT",
155 identifier: Some("RunMat:load:OutputCount"),
156 when: "Caller requests more outputs than supported by load.",
157 message: "load: unsupported output count",
158};
159const LOAD_ERROR_WORKSPACE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
160 code: "RM.LOAD.WORKSPACE",
161 identifier: Some("RunMat:load:Workspace"),
162 when: "Statement-form load cannot assign values into workspace.",
163 message: "load: workspace assignment failed",
164};
165const LOAD_ERRORS: [BuiltinErrorDescriptor; 7] = [
166 LOAD_ERROR_INVALID_ARGUMENT,
167 LOAD_ERROR_INVALID_OPTION,
168 LOAD_ERROR_FILENAME,
169 LOAD_ERROR_SELECTION,
170 LOAD_ERROR_IO,
171 LOAD_ERROR_OUTPUT_COUNT,
172 LOAD_ERROR_WORKSPACE,
173];
174pub const LOAD_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
175 signatures: &LOAD_SIGNATURES,
176 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
177 completion_policy: BuiltinCompletionPolicy::Public,
178 errors: &LOAD_ERRORS,
179};
180
181#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::mat::load")]
182pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
183 name: "load",
184 op_kind: GpuOpKind::Custom("io-load"),
185 supported_precisions: &[],
186 broadcast: BroadcastSemantics::None,
187 provider_hooks: &[],
188 constant_strategy: ConstantStrategy::InlineLiteral,
189 residency: ResidencyPolicy::NewHandle,
190 nan_mode: ReductionNaN::Include,
191 two_pass_threshold: None,
192 workgroup_size: None,
193 accepts_nan_mode: false,
194 notes: "Reads MAT-files on the host and produces CPU-resident values. Providers are not involved until accelerated code later promotes the results.",
195};
196
197#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::mat::load")]
198pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
199 name: "load",
200 shape: ShapeRequirements::Any,
201 constant_strategy: ConstantStrategy::InlineLiteral,
202 elementwise: None,
203 reduction: None,
204 emits_nan: false,
205 notes: "File I/O is not eligible for fusion. Registration exists for documentation completeness only.",
206};
207
208#[runtime_builtin(
209 name = "load",
210 category = "io/mat",
211 summary = "Load variables from a MAT-file.",
212 keywords = "load,mat,workspace",
213 accel = "cpu",
214 sink = true,
215 type_resolver(crate::builtins::io::type_resolvers::load_type),
216 descriptor(crate::builtins::io::mat::load::LOAD_DESCRIPTOR),
217 builtin_path = "crate::builtins::io::mat::load"
218)]
219async fn load_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
220 let eval = evaluate(&args).await?;
221
222 if let Some(n) = crate::output_count::current_output_count() {
225 if n > 1 {
226 return Err(load_error_with(
227 &LOAD_ERROR_OUTPUT_COUNT,
228 "load supports at most one output argument",
229 ));
230 }
231 }
232
233 if crate::output_context::requested_output_count() == Some(0) {
240 for (name, value) in eval.variables() {
241 crate::workspace::assign(name, value.clone())
242 .map_err(|err| load_error_with(&LOAD_ERROR_WORKSPACE, err))?;
243 }
244 return Ok(Value::OutputList(Vec::new()));
245 }
246
247 Ok(eval.first_output())
248}
249
250#[derive(Clone, Debug)]
251pub struct LoadEval {
252 variables: Vec<(String, Value)>,
253}
254
255impl LoadEval {
256 pub fn first_output(&self) -> Value {
257 let mut st = StructValue::new();
258 for (name, value) in &self.variables {
259 st.fields.insert(name.clone(), value.clone());
260 }
261 Value::Struct(st)
262 }
263
264 pub fn variables(&self) -> &[(String, Value)] {
265 &self.variables
266 }
267
268 pub fn into_variables(self) -> Vec<(String, Value)> {
269 self.variables
270 }
271}
272
273struct LoadRequest {
274 variables: Vec<String>,
275 regex_patterns: Vec<Regex>,
276}
277
278const BUILTIN_NAME: &str = "load";
279
280fn load_error(message: impl Into<String>) -> RuntimeError {
281 load_error_with(&LOAD_ERROR_INVALID_ARGUMENT, message)
282}
283
284fn load_error_with(
285 error: &'static BuiltinErrorDescriptor,
286 message: impl Into<String>,
287) -> RuntimeError {
288 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
289 if let Some(identifier) = error.identifier {
290 builder = builder.with_identifier(identifier);
291 }
292 builder.build()
293}
294
295fn load_error_with_source(
296 error: &'static BuiltinErrorDescriptor,
297 message: impl Into<String>,
298 source: impl std::error::Error + Send + Sync + 'static,
299) -> RuntimeError {
300 let mut builder = build_runtime_error(message)
301 .with_builtin(BUILTIN_NAME)
302 .with_source(source);
303 if let Some(identifier) = error.identifier {
304 builder = builder.with_identifier(identifier);
305 }
306 builder.build()
307}
308
309pub async fn evaluate(args: &[Value]) -> BuiltinResult<LoadEval> {
310 let mut host_args = Vec::with_capacity(args.len());
311 for arg in args {
312 host_args.push(gather_if_needed_async(arg).await?);
313 }
314
315 let invocation = parse_invocation(&host_args).await?;
316
317 let mut path_value = if let Some(path) = invocation.path_value {
318 path
319 } else {
320 Value::from("matlab.mat")
321 };
322
323 if invocation.path_was_default {
324 if let Ok(override_path) = std::env::var("RUNMAT_LOAD_DEFAULT_PATH") {
325 path_value = Value::from(override_path);
326 }
327 }
328
329 let mut regex_patterns = Vec::with_capacity(invocation.regex_tokens.len());
330 for pattern in invocation.regex_tokens {
331 let regex = Regex::new(&pattern).map_err(|err| {
332 load_error_with_source(
333 &LOAD_ERROR_INVALID_OPTION,
334 format!("load: invalid regular expression '{pattern}': {err}"),
335 err,
336 )
337 })?;
338 regex_patterns.push(regex);
339 }
340
341 let request = LoadRequest {
342 variables: invocation.variables,
343 regex_patterns,
344 };
345 let path = normalise_path(&path_value)?;
346 let entries = read_mat_file(&path).await?;
347
348 let selected = select_variables(&entries, &request)?;
349 Ok(LoadEval {
350 variables: selected,
351 })
352}
353
354struct ParsedInvocation {
355 path_value: Option<Value>,
356 path_was_default: bool,
357 variables: Vec<String>,
358 regex_tokens: Vec<String>,
359}
360
361async fn parse_invocation(values: &[Value]) -> BuiltinResult<ParsedInvocation> {
362 let mut path_value = None;
363 let mut path_was_default = false;
364 let mut variables = Vec::new();
365 let mut regex_tokens = Vec::new();
366 let mut idx = 0usize;
367 while idx < values.len() {
368 if let Some(flag) = option_token(&values[idx])? {
369 match flag.as_str() {
370 "-mat" => {
371 idx += 1;
372 continue;
373 }
374 "-regexp" => {
375 idx += 1;
376 if idx >= values.len() {
377 return Err(load_error_with(
378 &LOAD_ERROR_INVALID_OPTION,
379 "load: '-regexp' requires at least one pattern",
380 ));
381 }
382 while idx < values.len() {
383 if option_token(&values[idx])?.is_some() {
384 break;
385 }
386 let names = extract_names(&values[idx]).await?;
387 if names.is_empty() {
388 return Err(load_error_with(
389 &LOAD_ERROR_INVALID_OPTION,
390 "load: '-regexp' requires non-empty pattern strings",
391 ));
392 }
393 regex_tokens.extend(names);
394 idx += 1;
395 }
396 continue;
397 }
398 other => {
399 return Err(load_error_with(
400 &LOAD_ERROR_INVALID_OPTION,
401 format!("load: unsupported option '{other}'"),
402 ));
403 }
404 }
405 } else {
406 if path_value.is_none() {
407 path_value = Some(values[idx].clone());
408 idx += 1;
409 continue;
410 }
411 let names = extract_names(&values[idx]).await?;
412 variables.extend(names);
413 idx += 1;
414 }
415 }
416
417 if path_value.is_none() {
418 path_was_default = true;
419 }
420
421 Ok(ParsedInvocation {
422 path_value,
423 path_was_default,
424 variables,
425 regex_tokens,
426 })
427}
428
429fn normalise_path(value: &Value) -> BuiltinResult<PathBuf> {
430 let raw = value_to_string_scalar(value).ok_or_else(|| {
431 load_error_with(
432 &LOAD_ERROR_FILENAME,
433 "load: filename must be a character vector or string scalar",
434 )
435 })?;
436 let mut path = PathBuf::from(raw);
437 if path.extension().is_none() {
438 path.set_extension("mat");
439 }
440 Ok(path)
441}
442
443fn select_variables(
444 entries: &[(String, Value)],
445 request: &LoadRequest,
446) -> BuiltinResult<Vec<(String, Value)>> {
447 if request.variables.is_empty() && request.regex_patterns.is_empty() {
448 return Ok(entries.to_vec());
449 }
450
451 let mut by_name: HashMap<&str, &Value> = HashMap::with_capacity(entries.len());
452 for (name, value) in entries {
453 by_name.insert(name, value);
454 }
455
456 let mut selected = Vec::new();
457
458 for name in &request.variables {
459 let value = by_name.get(name.as_str()).ok_or_else(|| {
460 load_error_with(
461 &LOAD_ERROR_SELECTION,
462 format!("load: variable '{name}' was not found in the file"),
463 )
464 })?;
465 insert_or_replace(&mut selected, name, (*value).clone());
466 }
467
468 if !request.regex_patterns.is_empty() {
469 let mut matched = 0usize;
470 for (name, value) in entries {
471 if request
472 .regex_patterns
473 .iter()
474 .any(|regex| regex.is_match(name))
475 {
476 matched += 1;
477 insert_or_replace(&mut selected, name, value.clone());
478 }
479 }
480 if matched == 0 && request.variables.is_empty() {
481 return Err(load_error_with(
482 &LOAD_ERROR_SELECTION,
483 "load: no variables matched '-regexp' patterns",
484 ));
485 }
486 }
487
488 if selected.is_empty() {
489 return Err(load_error_with(
490 &LOAD_ERROR_SELECTION,
491 "load: no variables selected",
492 ));
493 }
494
495 Ok(selected)
496}
497
498fn insert_or_replace(selected: &mut Vec<(String, Value)>, name: &str, value: Value) {
499 if let Some(entry) = selected.iter_mut().find(|(existing, _)| existing == name) {
500 entry.1 = value;
501 } else {
502 selected.push((name.to_string(), value));
503 }
504}
505
506pub(crate) async fn read_mat_file_for_builtin(
507 path: &Path,
508 builtin: &str,
509) -> crate::BuiltinResult<Vec<(String, Value)>> {
510 match read_mat_file(path).await {
511 Ok(entries) => Ok(entries),
512 Err(err) => {
513 let message = err.message().replacen("load:", &format!("{builtin}:"), 1);
514 let mut builder = build_runtime_error(message).with_builtin(builtin);
515 if let Some(identifier) = err.identifier() {
516 builder = builder.with_identifier(identifier);
517 }
518 Err(builder.with_source(err).build())
519 }
520 }
521}
522
523pub(crate) async fn read_mat_file(path: &Path) -> BuiltinResult<Vec<(String, Value)>> {
524 let file = File::open_async(path).await.map_err(|err| {
525 load_error_with_source(
526 &LOAD_ERROR_IO,
527 format!("load: failed to open '{}': {err}", path.display()),
528 err,
529 )
530 })?;
531 let mut reader = BufReader::new(file);
532 read_mat_reader(&mut reader)
533}
534
535pub fn decode_workspace_from_mat_bytes(bytes: &[u8]) -> BuiltinResult<Vec<(String, Value)>> {
536 let mut cursor = Cursor::new(bytes);
537 read_mat_reader(&mut cursor)
538}
539
540fn read_mat_reader<R: Read>(reader: &mut R) -> BuiltinResult<Vec<(String, Value)>> {
541 let mut header = [0u8; MAT_HEADER_LEN];
542 reader.read_exact(&mut header).map_err(|err| {
543 load_error_with_source(
544 &LOAD_ERROR_IO,
545 format!("load: failed to read MAT-file header: {err}"),
546 err,
547 )
548 })?;
549 if header[126] != b'I' || header[127] != b'M' {
550 return Err(load_error("load: file is not a MATLAB Level-5 MAT-file"));
551 }
552
553 let mut variables = Vec::new();
554 while let Some(tagged) = read_tagged(reader, true)? {
555 if tagged.data_type != MI_MATRIX {
556 continue;
557 }
558 let parsed = parse_matrix(&tagged.data)?;
559 let value = mat_array_to_value(parsed.array)?;
560 variables.push((parsed.name, value));
561 }
562 Ok(variables)
563}
564
565struct ParsedMatrix {
566 name: String,
567 array: MatArray,
568}
569
570fn parse_matrix(buffer: &[u8]) -> BuiltinResult<ParsedMatrix> {
571 let mut cursor = Cursor::new(buffer);
572
573 let flags = read_tagged(&mut cursor, false)?
574 .ok_or_else(|| load_error("load: matrix element missing array flags"))?;
575 if flags.data_type != MI_UINT32 || flags.data.len() < 8 {
576 return Err(load_error("load: invalid array flags block"));
577 }
578 let flags0 = u32::from_le_bytes(flags.data[0..4].try_into().unwrap());
579 let class_code = flags0 & 0xFF;
580 let mut class = MatClass::from_class_code(class_code)
581 .ok_or_else(|| load_error("load: unsupported MATLAB class"))?;
582 let is_logical = (flags0 & FLAG_LOGICAL) != 0;
583 let has_imag = (flags0 & FLAG_COMPLEX) != 0;
584 if matches!(class, MatClass::Double) && is_logical {
585 class = MatClass::Logical;
586 }
587
588 let dims_elem = read_tagged(&mut cursor, false)?
589 .ok_or_else(|| load_error("load: matrix element missing dimensions"))?;
590 if dims_elem.data_type != MI_INT32 {
591 return Err(load_error("load: dimension block must use MI_INT32"));
592 }
593 if dims_elem.data.is_empty() || dims_elem.data.len() % 4 != 0 {
594 return Err(load_error("load: malformed dimension block"));
595 }
596 let mut dims = Vec::with_capacity(dims_elem.data.len() / 4);
597 for chunk in dims_elem.data.chunks_exact(4) {
598 let value = i32::from_le_bytes(chunk.try_into().unwrap());
599 if value < 0 {
600 return Err(load_error("load: negative dimensions are not supported"));
601 }
602 dims.push(value as usize);
603 }
604 if dims.is_empty() {
605 dims.push(1);
606 dims.push(1);
607 }
608
609 let name_elem = read_tagged(&mut cursor, false)?
610 .ok_or_else(|| load_error("load: matrix element missing name"))?;
611 let name = match name_elem.data_type {
612 MI_INT8 | MI_UINT8 => bytes_to_string(&name_elem.data),
613 MI_UINT16 => {
614 let mut bytes = Vec::with_capacity(name_elem.data.len());
615 for chunk in name_elem.data.chunks_exact(2) {
616 let code = u16::from_le_bytes(chunk.try_into().unwrap());
617 if code == 0 {
618 break;
619 }
620 if let Some(ch) = char::from_u32(code as u32) {
621 bytes.push(ch);
622 }
623 }
624 bytes.into_iter().collect()
625 }
626 _ => {
627 return Err(load_error("load: unsupported array name encoding"));
628 }
629 };
630
631 let array = match class {
632 MatClass::Double => parse_double_array(&mut cursor, dims, has_imag)?,
633 MatClass::Logical => parse_logical_array(&mut cursor, dims)?,
634 MatClass::Char => parse_char_array(&mut cursor, dims)?,
635 MatClass::Cell => parse_cell_array(&mut cursor, dims)?,
636 MatClass::Struct => parse_struct(&mut cursor, dims)?,
637 };
638
639 Ok(ParsedMatrix { name, array })
640}
641
642fn parse_double_array(
643 cursor: &mut Cursor<&[u8]>,
644 dims: Vec<usize>,
645 has_imag: bool,
646) -> BuiltinResult<MatArray> {
647 let real_elem = read_tagged(cursor, false)?
648 .ok_or_else(|| load_error("load: numeric array missing real component"))?;
649 if real_elem.data_type != MI_DOUBLE || real_elem.data.len() % 8 != 0 {
650 return Err(load_error("load: numeric data must be stored as MI_DOUBLE"));
651 }
652 let mut real = Vec::with_capacity(real_elem.data.len() / 8);
653 for chunk in real_elem.data.chunks_exact(8) {
654 real.push(f64::from_le_bytes(chunk.try_into().unwrap()));
655 }
656
657 let imag = if has_imag {
658 let imag_elem = read_tagged(cursor, false)?
659 .ok_or_else(|| load_error("load: numeric array missing imaginary component"))?;
660 if imag_elem.data_type != MI_DOUBLE || imag_elem.data.len() % 8 != 0 {
661 return Err(load_error("load: imaginary component must be MI_DOUBLE"));
662 }
663 let mut imag = Vec::with_capacity(imag_elem.data.len() / 8);
664 for chunk in imag_elem.data.chunks_exact(8) {
665 imag.push(f64::from_le_bytes(chunk.try_into().unwrap()));
666 }
667 Some(imag)
668 } else {
669 None
670 };
671
672 Ok(MatArray {
673 class: MatClass::Double,
674 dims,
675 data: MatData::Double { real, imag },
676 })
677}
678
679fn parse_logical_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
680 let elem = read_tagged(cursor, false)?
681 .ok_or_else(|| load_error("load: logical array missing data block"))?;
682 if elem.data_type != MI_UINT8 {
683 return Err(load_error(
684 "load: logical arrays must be stored as MI_UINT8",
685 ));
686 }
687 Ok(MatArray {
688 class: MatClass::Logical,
689 dims,
690 data: MatData::Logical { data: elem.data },
691 })
692}
693
694fn parse_char_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
695 let elem = read_tagged(cursor, false)?
696 .ok_or_else(|| load_error("load: character array missing data block"))?;
697 if elem.data_type != MI_UINT16 {
698 return Err(load_error(
699 "load: character data must be stored as MI_UINT16",
700 ));
701 }
702 if elem.data.len() % 2 != 0 {
703 return Err(load_error("load: malformed character data"));
704 }
705 let mut data = Vec::with_capacity(elem.data.len() / 2);
706 for chunk in elem.data.chunks_exact(2) {
707 data.push(u16::from_le_bytes(chunk.try_into().unwrap()));
708 }
709 Ok(MatArray {
710 class: MatClass::Char,
711 dims,
712 data: MatData::Char { data },
713 })
714}
715
716fn parse_cell_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
717 let total: usize = dims
718 .iter()
719 .copied()
720 .fold(1usize, |acc, d| acc.saturating_mul(d));
721 let mut elements = Vec::with_capacity(total);
722 for _ in 0..total {
723 let elem = read_tagged(cursor, false)?
724 .ok_or_else(|| load_error("load: cell element missing matrix payload"))?;
725 if elem.data_type != MI_MATRIX {
726 return Err(load_error("load: cell elements must be matrices"));
727 }
728 let parsed = parse_matrix(&elem.data)?;
729 elements.push(parsed.array);
730 }
731 Ok(MatArray {
732 class: MatClass::Cell,
733 dims,
734 data: MatData::Cell { elements },
735 })
736}
737
738fn parse_struct(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
739 if dims.len() != 2 || dims[0] != 1 || dims[1] != 1 {
740 return Err(load_error("load: struct arrays are not supported yet"));
741 }
742 let len_elem = read_tagged(cursor, false)?
743 .ok_or_else(|| load_error("load: struct missing maximum field length specifier"))?;
744 if len_elem.data_type != MI_INT32 || len_elem.data.len() != 4 {
745 return Err(load_error("load: struct field length must be MI_INT32"));
746 }
747 let max_len = i32::from_le_bytes(len_elem.data[..4].try_into().unwrap());
748 if max_len <= 0 {
749 return Err(load_error("load: struct field length must be positive"));
750 }
751
752 let names_elem = read_tagged(cursor, false)?
753 .ok_or_else(|| load_error("load: struct missing field name table"))?;
754 if names_elem.data_type != MI_INT8 && names_elem.data_type != MI_UINT8 {
755 return Err(load_error(
756 "load: struct field names must be stored as MI_INT8/MI_UINT8",
757 ));
758 }
759 if names_elem.data.len() % (max_len as usize) != 0 {
760 return Err(load_error("load: malformed struct field name table"));
761 }
762 let field_count = names_elem.data.len() / (max_len as usize);
763 let mut field_names = Vec::with_capacity(field_count);
764 for i in 0..field_count {
765 let start = i * (max_len as usize);
766 let end = start + (max_len as usize);
767 let slice = &names_elem.data[start..end];
768 field_names.push(bytes_to_string(slice));
769 }
770
771 let mut field_values = Vec::with_capacity(field_count);
772 for _ in 0..field_count {
773 let elem = read_tagged(cursor, false)?
774 .ok_or_else(|| load_error("load: struct field missing matrix payload"))?;
775 if elem.data_type != MI_MATRIX {
776 return Err(load_error("load: struct fields must be matrices"));
777 }
778 let parsed = parse_matrix(&elem.data)?;
779 field_values.push(parsed.array);
780 }
781
782 Ok(MatArray {
783 class: MatClass::Struct,
784 dims,
785 data: MatData::Struct {
786 field_names,
787 field_values,
788 },
789 })
790}
791
792fn mat_array_to_value(array: MatArray) -> BuiltinResult<Value> {
793 match array.data {
794 MatData::Double { real, imag } => {
795 let len = real.len();
796 if let Some(imag) = imag {
797 if imag.len() != len {
798 return Err(load_error(
799 "load: complex data has mismatched real/imag parts",
800 ));
801 }
802 if len == 1 {
803 Ok(Value::Complex(real[0], imag[0]))
804 } else {
805 let mut pairs = Vec::with_capacity(len);
806 for i in 0..len {
807 pairs.push((real[i], imag[i]));
808 }
809 let tensor = ComplexTensor::new(pairs, array.dims.clone())
810 .map_err(|e| load_error(format!("load: {e}")))?;
811 Ok(Value::ComplexTensor(tensor))
812 }
813 } else if len == 1 {
814 Ok(Value::Num(real[0]))
815 } else {
816 let tensor = Tensor::new(real, array.dims.clone())
817 .map_err(|e| load_error(format!("load: {e}")))?;
818 Ok(Value::Tensor(tensor))
819 }
820 }
821 MatData::Logical { data } => {
822 let total: usize = array
823 .dims
824 .iter()
825 .copied()
826 .fold(1usize, |acc, d| acc.saturating_mul(d));
827 if data.len() != total {
828 return Err(load_error("load: logical data length mismatch"));
829 }
830 if total == 1 {
831 Ok(Value::Bool(data.first().copied().unwrap_or(0) != 0))
832 } else {
833 let logical = LogicalArray::new(data, array.dims.clone())
834 .map_err(|e| load_error(format!("load: {e}")))?;
835 Ok(Value::LogicalArray(logical))
836 }
837 }
838 MatData::Char { data } => {
839 let rows = array.dims.first().copied().unwrap_or(1);
840 let cols = array.dims.get(1).copied().unwrap_or(1);
841 let mut chars = Vec::with_capacity(rows.saturating_mul(cols));
842 for code in data {
843 let ch = char::from_u32(code as u32).unwrap_or('\u{FFFD}');
844 chars.push(ch);
845 }
846 let char_array =
847 CharArray::new(chars, rows, cols).map_err(|e| load_error(format!("load: {e}")))?;
848 Ok(Value::CharArray(char_array))
849 }
850 MatData::Cell { elements } => {
851 if let Some(strings) = cell_elements_to_strings(&elements) {
852 let string_array = StringArray::new(strings, array.dims.clone())
853 .map_err(|e| load_error(format!("load: {e}")))?;
854 return Ok(Value::StringArray(string_array));
855 }
856 if array.dims.len() != 2 {
857 return Err(load_error(
858 "load: cell arrays with more than two dimensions are not supported yet",
859 ));
860 }
861 let rows = array.dims[0];
862 let cols = array.dims[1];
863 let expected = rows.saturating_mul(cols);
864 if elements.len() != expected {
865 return Err(load_error("load: cell array element count mismatch"));
866 }
867 let mut converted = Vec::with_capacity(elements.len());
868 for elem in elements {
869 converted.push(mat_array_to_value(elem)?);
870 }
871 let mut row_major = vec![Value::Num(0.0); expected];
872 for col in 0..cols {
873 for row in 0..rows {
874 let cm_idx = col * rows + row;
875 let rm_idx = row * cols + col;
876 row_major[rm_idx] = converted[cm_idx].clone();
877 }
878 }
879 make_cell(row_major, rows, cols).map_err(|err| load_error(format!("load: {err}")))
880 }
881 MatData::Struct {
882 field_names,
883 field_values,
884 } => {
885 if field_names.len() != field_values.len() {
886 return Err(load_error("load: struct field metadata is inconsistent"));
887 }
888 let mut st = StructValue::new();
889 for (name, value) in field_names.into_iter().zip(field_values.into_iter()) {
890 let converted = mat_array_to_value(value)?;
891 st.fields.insert(name, converted);
892 }
893 Ok(Value::Struct(st))
894 }
895 }
896}
897
898fn cell_elements_to_strings(elements: &[MatArray]) -> Option<Vec<String>> {
899 let mut strings = Vec::with_capacity(elements.len());
900 for element in elements {
901 if element.class != MatClass::Char {
902 return None;
903 }
904 let rows = element.dims.first().copied().unwrap_or(1);
905 if rows > 1 {
906 return None;
907 }
908 match &element.data {
909 MatData::Char { data } => strings.push(utf16_codes_to_string(data)),
910 _ => return None,
911 }
912 }
913 Some(strings)
914}
915
916fn utf16_codes_to_string(data: &[u16]) -> String {
917 let mut chars: Vec<char> = data
918 .iter()
919 .map(|code| char::from_u32(*code as u32).unwrap_or('\u{FFFD}'))
920 .collect();
921 while matches!(chars.last(), Some(&'\0')) {
922 chars.pop();
923 }
924 chars.into_iter().collect()
925}
926
927fn option_token(value: &Value) -> BuiltinResult<Option<String>> {
928 if let Some(token) = value_to_string_scalar(value) {
929 if token.starts_with('-') {
930 return Ok(Some(token.to_ascii_lowercase()));
931 }
932 }
933 Ok(None)
934}
935
936#[async_recursion::async_recursion(?Send)]
937async fn extract_names(value: &Value) -> BuiltinResult<Vec<String>> {
938 match value {
939 Value::String(s) => Ok(vec![s.clone()]),
940 Value::CharArray(ca) => Ok(char_array_rows_as_strings(ca)),
941 Value::StringArray(sa) => Ok(sa.data.clone()),
942 Value::Cell(ca) => {
943 let mut names = Vec::with_capacity(ca.data.len());
944 for handle in &ca.data {
945 let inner = unsafe { &*handle.as_raw() };
946 let text = value_to_string_scalar(inner).ok_or_else(|| {
947 load_error(
948 "load: cell arrays used for variable selection must contain string scalars",
949 )
950 })?;
951 names.push(text);
952 }
953 Ok(names)
954 }
955 other => {
956 let gathered = gather_if_needed_async(other).await?;
957 extract_names(&gathered).await
958 }
959 }
960}
961
962fn value_to_string_scalar(value: &Value) -> Option<String> {
963 match value {
964 Value::String(s) => Some(s.clone()),
965 Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
966 Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
967 _ => None,
968 }
969}
970
971fn char_array_rows_as_strings(ca: &CharArray) -> Vec<String> {
972 let mut rows = Vec::with_capacity(ca.rows);
973 for r in 0..ca.rows {
974 let mut row = String::with_capacity(ca.cols);
975 for c in 0..ca.cols {
976 let idx = r * ca.cols + c;
977 row.push(ca.data[idx]);
978 }
979 let trimmed = row.trim_end_matches([' ', '\0']).to_string();
980 rows.push(trimmed);
981 }
982 rows
983}
984
985fn bytes_to_string(bytes: &[u8]) -> String {
986 let trimmed = bytes
987 .iter()
988 .copied()
989 .take_while(|b| *b != 0)
990 .collect::<Vec<u8>>();
991 String::from_utf8(trimmed).unwrap_or_default()
992}
993
994struct TaggedData {
995 data_type: u32,
996 data: Vec<u8>,
997}
998
999fn read_tagged<R: Read>(reader: &mut R, allow_eof: bool) -> BuiltinResult<Option<TaggedData>> {
1000 let mut type_bytes = [0u8; 4];
1001 match reader.read_exact(&mut type_bytes) {
1002 Ok(()) => {}
1003 Err(err) => {
1004 if allow_eof && err.kind() == std::io::ErrorKind::UnexpectedEof {
1005 return Ok(None);
1006 }
1007 return Err(load_error_with_source(
1008 &LOAD_ERROR_IO,
1009 format!("load: failed to read MAT element header: {err}"),
1010 err,
1011 ));
1012 }
1013 }
1014
1015 let type_field = u32::from_le_bytes(type_bytes);
1016 if (type_field & 0xFFFF0000) != 0 {
1017 let data_type = type_field & 0x0000FFFF;
1018 let num_bytes = ((type_field & 0xFFFF0000) >> 16) as usize;
1019 let mut inline = [0u8; 4];
1020 reader.read_exact(&mut inline).map_err(|err| {
1021 load_error_with_source(
1022 &LOAD_ERROR_IO,
1023 format!("load: failed to read compact MAT element: {err}"),
1024 err,
1025 )
1026 })?;
1027 let mut data = inline[..num_bytes.min(4)].to_vec();
1028 data.truncate(num_bytes.min(4));
1029 Ok(Some(TaggedData { data_type, data }))
1030 } else {
1031 let mut len_bytes = [0u8; 4];
1032 reader.read_exact(&mut len_bytes).map_err(|err| {
1033 load_error_with_source(
1034 &LOAD_ERROR_IO,
1035 format!("load: failed to read MAT element length: {err}"),
1036 err,
1037 )
1038 })?;
1039 let length = u32::from_le_bytes(len_bytes) as usize;
1040 let mut data = vec![0u8; length];
1041 reader.read_exact(&mut data).map_err(|err| {
1042 load_error_with_source(
1043 &LOAD_ERROR_IO,
1044 format!("load: failed to read MAT element body: {err}"),
1045 err,
1046 )
1047 })?;
1048 let padding = (8 - (length % 8)) % 8;
1049 if padding != 0 {
1050 let mut pad = vec![0u8; padding];
1051 reader.read_exact(&mut pad).map_err(|err| {
1052 load_error_with_source(
1053 &LOAD_ERROR_IO,
1054 format!("load: failed to read MAT padding: {err}"),
1055 err,
1056 )
1057 })?;
1058 }
1059 Ok(Some(TaggedData {
1060 data_type: type_field,
1061 data,
1062 }))
1063 }
1064}
1065
1066#[cfg(test)]
1067pub(crate) mod tests {
1068 use super::*;
1069 use crate::workspace::WorkspaceResolver;
1070 use futures::executor::block_on;
1071 use runmat_builtins::StringArray;
1072 use runmat_thread_local::runmat_thread_local;
1073 use std::cell::RefCell;
1074 use std::collections::HashMap;
1075 use tempfile::tempdir;
1076
1077 runmat_thread_local! {
1078 static TEST_WORKSPACE: RefCell<HashMap<String, Value>> = RefCell::new(HashMap::new());
1079 }
1080
1081 fn ensure_test_resolver() {
1082 crate::workspace::register_workspace_resolver(WorkspaceResolver {
1083 lookup: |name| TEST_WORKSPACE.with(|slot| slot.borrow().get(name).cloned()),
1084 snapshot: || {
1085 let mut entries: Vec<(String, Value)> =
1086 TEST_WORKSPACE.with(|slot| slot.borrow().clone().into_iter().collect());
1087 entries.sort_by(|a, b| a.0.cmp(&b.0));
1088 entries
1089 },
1090 globals: || Vec::new(),
1091 assign: None,
1092 clear: None,
1093 remove: None,
1094 });
1095 }
1096
1097 fn set_workspace(entries: &[(&str, Value)]) {
1098 TEST_WORKSPACE.with(|slot| {
1099 let mut map = slot.borrow_mut();
1100 map.clear();
1101 for (name, value) in entries {
1102 map.insert((*name).to_string(), value.clone());
1103 }
1104 });
1105 }
1106
1107 fn workspace_guard() -> std::sync::MutexGuard<'static, ()> {
1108 crate::workspace::test_guard()
1109 }
1110
1111 fn assert_error_contains<T>(result: crate::BuiltinResult<T>, snippet: &str) {
1112 match result {
1113 Err(err) => {
1114 assert!(
1115 err.message().contains(snippet),
1116 "expected error to contain '{snippet}', got '{}'",
1117 err.message()
1118 );
1119 }
1120 Ok(_) => panic!("expected error containing '{snippet}'"),
1121 }
1122 }
1123
1124 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1125 #[test]
1126 fn load_descriptor_signatures_cover_core_forms() {
1127 let labels: Vec<&str> = LOAD_DESCRIPTOR
1128 .signatures
1129 .iter()
1130 .map(|sig| sig.label)
1131 .collect();
1132 assert!(labels.contains(&"S = load()"));
1133 assert!(labels.contains(&"S = load(filename)"));
1134 assert!(labels.contains(&"S = load(filename, varName1, varName2, ...)"));
1135 assert!(labels.contains(&"S = load(filename, \"-regexp\", pattern1, ...)"));
1136 }
1137
1138 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1139 #[test]
1140 fn load_roundtrip_numeric() {
1141 let _guard = workspace_guard();
1142 ensure_test_resolver();
1143 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0], vec![2, 2]).unwrap();
1144 set_workspace(&[("A", Value::Tensor(tensor))]);
1145
1146 let dir = tempdir().unwrap();
1147 let path = dir.path().join("numeric.mat");
1148 let save_arg = Value::from(path.to_string_lossy().to_string());
1149 block_on(crate::call_builtin_async(
1150 "save",
1151 std::slice::from_ref(&save_arg),
1152 ))
1153 .unwrap();
1154
1155 let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
1156 .expect("load numeric");
1157 let struct_value = eval.first_output();
1158 match struct_value {
1159 Value::Struct(sv) => {
1160 assert!(sv.fields.contains_key("A"));
1161 match sv.fields.get("A").unwrap() {
1162 Value::Tensor(t) => {
1163 assert_eq!(t.shape, vec![2, 2]);
1164 assert_eq!(t.data, vec![1.0, 4.0, 2.0, 5.0]);
1165 }
1166 other => panic!("expected tensor, got {other:?}"),
1167 }
1168 }
1169 other => panic!("expected struct, got {other:?}"),
1170 }
1171 }
1172
1173 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1174 #[test]
1175 fn load_selected_variables() {
1176 let _guard = workspace_guard();
1177 ensure_test_resolver();
1178 set_workspace(&[("signal", Value::Num(42.0)), ("noise", Value::Num(5.0))]);
1179 let dir = tempdir().unwrap();
1180 let path = dir.path().join("selection.mat");
1181 let save_arg = Value::from(path.to_string_lossy().to_string());
1182 block_on(crate::call_builtin_async(
1183 "save",
1184 std::slice::from_ref(&save_arg),
1185 ))
1186 .unwrap();
1187
1188 let eval = block_on(evaluate(&[
1189 Value::from(path.to_string_lossy().to_string()),
1190 Value::from("signal"),
1191 ]))
1192 .expect("load selection");
1193 let vars = eval.variables();
1194 assert_eq!(vars.len(), 1);
1195 assert_eq!(vars[0].0, "signal");
1196 assert!(matches!(vars[0].1, Value::Num(42.0)));
1197 }
1198
1199 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1200 #[test]
1201 fn load_regex_selection() {
1202 let _guard = workspace_guard();
1203 ensure_test_resolver();
1204 set_workspace(&[
1205 ("w1", Value::Num(1.0)),
1206 ("w2", Value::Num(2.0)),
1207 ("bias", Value::Num(3.0)),
1208 ]);
1209 let dir = tempdir().unwrap();
1210 let path = dir.path().join("regex.mat");
1211 let save_arg = Value::from(path.to_string_lossy().to_string());
1212 block_on(crate::call_builtin_async(
1213 "save",
1214 std::slice::from_ref(&save_arg),
1215 ))
1216 .unwrap();
1217
1218 let eval = block_on(evaluate(&[
1219 Value::from(path.to_string_lossy().to_string()),
1220 Value::from("-regexp"),
1221 Value::from("^w\\d$"),
1222 ]))
1223 .expect("load regex");
1224 let mut names: Vec<_> = eval.variables().iter().map(|(n, _)| n.clone()).collect();
1225 names.sort();
1226 assert_eq!(names, vec!["w1".to_string(), "w2".to_string()]);
1227 }
1228
1229 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1230 #[test]
1231 fn load_missing_variable_errors() {
1232 let _guard = workspace_guard();
1233 ensure_test_resolver();
1234 set_workspace(&[("existing", Value::Num(7.0))]);
1235 let dir = tempdir().unwrap();
1236 let path = dir.path().join("missing.mat");
1237 let save_arg = Value::from(path.to_string_lossy().to_string());
1238 block_on(crate::call_builtin_async(
1239 "save",
1240 std::slice::from_ref(&save_arg),
1241 ))
1242 .unwrap();
1243
1244 assert_error_contains(
1245 block_on(evaluate(&[
1246 Value::from(path.to_string_lossy().to_string()),
1247 Value::from("missing"),
1248 ])),
1249 "variable 'missing' was not found",
1250 );
1251 }
1252
1253 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1254 #[test]
1255 fn load_string_array_roundtrip() {
1256 let _guard = workspace_guard();
1257 ensure_test_resolver();
1258 let strings = StringArray::new(vec!["foo".into(), "bar".into()], vec![1, 2]).unwrap();
1259 set_workspace(&[("labels", Value::StringArray(strings))]);
1260 let dir = tempdir().unwrap();
1261 let path = dir.path().join("strings.mat");
1262 let save_arg = Value::from(path.to_string_lossy().to_string());
1263 block_on(crate::call_builtin_async(
1264 "save",
1265 std::slice::from_ref(&save_arg),
1266 ))
1267 .unwrap();
1268
1269 let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
1270 .expect("load strings");
1271 let struct_value = eval.first_output();
1272 match struct_value {
1273 Value::Struct(sv) => {
1274 let value = sv
1275 .fields
1276 .get("labels")
1277 .expect("labels field missing in struct");
1278 match value {
1279 Value::StringArray(sa) => {
1280 assert_eq!(sa.shape, vec![1, 2]);
1281 assert_eq!(sa.data, vec![String::from("foo"), String::from("bar")]);
1282 }
1283 other => panic!("expected string array, got {other:?}"),
1284 }
1285 }
1286 other => panic!("expected struct, got {other:?}"),
1287 }
1288 }
1289
1290 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1291 #[test]
1292 fn load_option_before_filename() {
1293 let _guard = workspace_guard();
1294 ensure_test_resolver();
1295 set_workspace(&[("alpha", Value::Num(1.0)), ("beta", Value::Num(2.0))]);
1296 let dir = tempdir().unwrap();
1297 let path = dir.path().join("option_first.mat");
1298 let save_arg = Value::from(path.to_string_lossy().to_string());
1299 block_on(crate::call_builtin_async(
1300 "save",
1301 std::slice::from_ref(&save_arg),
1302 ))
1303 .unwrap();
1304
1305 let eval = block_on(evaluate(&[
1306 Value::from("-mat"),
1307 Value::from(path.to_string_lossy().to_string()),
1308 Value::from("beta"),
1309 ]))
1310 .expect("load with option first");
1311 let vars = eval.variables();
1312 assert_eq!(vars.len(), 1);
1313 assert_eq!(vars[0].0, "beta");
1314 assert!(matches!(vars[0].1, Value::Num(2.0)));
1315 }
1316
1317 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1318 #[test]
1319 fn load_char_array_names_trimmed() {
1320 let _guard = workspace_guard();
1321 ensure_test_resolver();
1322 set_workspace(&[("short", Value::Num(5.0)), ("longer", Value::Num(9.0))]);
1323 let dir = tempdir().unwrap();
1324 let path = dir.path().join("char_names.mat");
1325 let save_arg = Value::from(path.to_string_lossy().to_string());
1326 block_on(crate::call_builtin_async(
1327 "save",
1328 std::slice::from_ref(&save_arg),
1329 ))
1330 .unwrap();
1331
1332 let cols = 6;
1333 let mut data = Vec::new();
1334 for name in ["short", "longer"] {
1335 let mut chars: Vec<char> = name.chars().collect();
1336 while chars.len() < cols {
1337 chars.push(' ');
1338 }
1339 data.extend(chars);
1340 }
1341 let name_array = CharArray::new(data, 2, cols).unwrap();
1342
1343 let eval = block_on(evaluate(&[
1344 Value::from(path.to_string_lossy().to_string()),
1345 Value::CharArray(name_array),
1346 ]))
1347 .expect("load with char array names");
1348 let vars = eval.variables();
1349 assert_eq!(vars.len(), 2);
1350 assert_eq!(vars[0].0, "short");
1351 assert!(matches!(vars[0].1, Value::Num(5.0)));
1352 assert_eq!(vars[1].0, "longer");
1353 assert!(matches!(vars[1].1, Value::Num(9.0)));
1354 }
1355
1356 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1357 #[test]
1358 fn load_duplicate_names_last_wins() {
1359 let _guard = workspace_guard();
1360 ensure_test_resolver();
1361 set_workspace(&[("dup", Value::Num(11.0))]);
1362 let dir = tempdir().unwrap();
1363 let path = dir.path().join("duplicates.mat");
1364 let save_arg = Value::from(path.to_string_lossy().to_string());
1365 block_on(crate::call_builtin_async(
1366 "save",
1367 std::slice::from_ref(&save_arg),
1368 ))
1369 .unwrap();
1370
1371 let eval = block_on(evaluate(&[
1372 Value::from(path.to_string_lossy().to_string()),
1373 Value::from("dup"),
1374 Value::from("dup"),
1375 ]))
1376 .expect("load with duplicate names");
1377 let vars = eval.variables();
1378 assert_eq!(vars.len(), 1);
1379 assert_eq!(vars[0].0, "dup");
1380 assert!(matches!(vars[0].1, Value::Num(11.0)));
1381 }
1382
1383 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1384 #[test]
1385 #[cfg(feature = "wgpu")]
1386 fn load_wgpu_tensor_roundtrip() {
1387 let _guard = workspace_guard();
1388 ensure_test_resolver();
1389 if runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1390 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1391 )
1392 .is_err()
1393 {
1394 return;
1395 }
1396 let Some(provider) = runmat_accelerate_api::provider() else {
1397 return;
1398 };
1399
1400 use runmat_accelerate_api::HostTensorView;
1401
1402 let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap();
1403 let view = HostTensorView {
1404 data: &tensor.data,
1405 shape: &tensor.shape,
1406 };
1407 let handle = provider.upload(&view).expect("upload tensor");
1408 set_workspace(&[("gpu_var", Value::GpuTensor(handle))]);
1409
1410 let dir = tempdir().unwrap();
1411 let path = dir.path().join("wgpu_load.mat");
1412 let save_args = vec![
1413 Value::from(path.to_string_lossy().to_string()),
1414 Value::from("gpu_var"),
1415 ];
1416 block_on(crate::call_builtin_async("save", &save_args)).unwrap();
1417
1418 let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
1419 .expect("load wgpu file");
1420 let struct_value = eval.first_output();
1421 match struct_value {
1422 Value::Struct(sv) => match sv.fields.get("gpu_var") {
1423 Some(Value::Tensor(t)) => {
1424 assert_eq!(t.shape, vec![2, 2]);
1425 assert_eq!(t.data, tensor.data);
1426 }
1427 other => panic!("expected tensor, got {other:?}"),
1428 },
1429 other => panic!("expected struct, got {other:?}"),
1430 }
1431 }
1432}