1use std::collections::HashMap;
4use std::fs::File;
5use std::io::{BufReader, Cursor, Read};
6use std::path::{Path, PathBuf};
7
8use regex::Regex;
9use runmat_builtins::{
10 CharArray, ComplexTensor, LogicalArray, StringArray, StructValue, Tensor, Value,
11};
12use runmat_macros::runtime_builtin;
13
14use super::format::{
15 MatArray, MatClass, MatData, FLAG_COMPLEX, FLAG_LOGICAL, MAT_HEADER_LEN, MI_DOUBLE, MI_INT32,
16 MI_INT8, MI_MATRIX, MI_UINT16, MI_UINT32, MI_UINT8,
17};
18use crate::builtins::common::spec::{
19 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
20 ReductionNaN, ResidencyPolicy, ShapeRequirements,
21};
22use crate::{gather_if_needed, make_cell, register_builtin_fusion_spec, register_builtin_gpu_spec};
23
24#[cfg(feature = "doc_export")]
25use crate::register_builtin_doc_text;
26
27#[cfg(feature = "doc_export")]
28pub const DOC_MD: &str = r#"---
29title: "load"
30category: "io/mat"
31keywords: ["load", "mat", "workspace", "io", "matlab load", "regex load"]
32summary: "Load variables from a MATLAB-compatible MAT-file into the workspace or return them as a struct."
33references:
34 - https://www.mathworks.com/help/matlab/ref/load.html
35gpu_support:
36 elementwise: false
37 reduction: false
38 precisions: []
39 broadcasting: "none"
40 notes: "Files are read on the host. When auto-offload is enabled, planners may later promote tensors to the GPU; no provider hooks are required at load time."
41fusion:
42 elementwise: false
43 reduction: false
44 max_inputs: 1
45 constants: "inline"
46requires_feature: null
47tested:
48 unit: "builtins::io::mat::load::tests"
49 integration:
50 - "builtins::io::mat::load::tests::load_selected_variables"
51 - "builtins::io::mat::load::tests::load_regex_selection"
52---
53
54# What does the `load` function do in MATLAB / RunMat?
55`load` reads variables from a MAT-file (Level-5 layout) and brings them into the current workspace. Like MATLAB, it can either populate variables directly or return a struct containing the loaded data.
56
57## How does the `load` function behave in MATLAB / RunMat?
58- `load filename` reads every variable stored in `filename.mat` and assigns them into the caller's workspace. When no extension is supplied, `.mat` is appended automatically. Set `RUNMAT_LOAD_DEFAULT_PATH` to override the default `matlab.mat` target when no filename argument is provided.
59- `S = load(filename)` loads the file but returns a struct instead of modifying the workspace. The struct fields mirror the variables stored in the MAT-file.
60- `load(filename, 'A', 'B')` restricts the operation to the listed variable names. String scalars, char vectors, string arrays, or cell arrays of character vectors are accepted.
61- `load(filename, '-regexp', '^foo', 'bar$')` selects variables whose names match any of the supplied regular expressions.
62- Repeated names are deduplicated so that the last occurrence wins, mirroring MATLAB's behavior.
63- Unsupported data classes trigger descriptive errors. RunMat currently supports double and complex numeric arrays, logical arrays, character arrays, string arrays (stored as cell-of-char data), structs, and cells whose elements are composed of the supported types.
64- Files saved on platforms that produce little-endian Level-5 MAT-files (MATLAB's default) are supported. Big-endian and compressed (`miCOMPRESSED`) files currently report an error.
65
66## `load` Function GPU Execution Behaviour
67`load` always reads data on the host. The resulting values start on the CPU. When RunMat Accelerate is active, auto-offload heuristics may later decide to promote tensors to the GPU if they participate in accelerated expressions, but no provider hooks are required during the `load` operation itself. GPU-resident variables that were saved earlier are gathered back to host memory as part of file serialisation, so loading them produces standard host values.
68
69## Examples of using the `load` function in MATLAB / RunMat
70
71### Load the entire file into the workspace
72```matlab
73load('results.mat');
74disp(norm(weights));
75```
76Expected outcome: every variable contained in `results.mat` becomes available in the caller's workspace.
77
78### Load a subset of variables by name
79```matlab
80load('sim_state.mat', 'state', 'time');
81plot(time, state);
82```
83Only `state` and `time` are created; other variables in the file are ignored.
84
85### Load variables using regular expressions
86```matlab
87load('checkpoint.mat', '-regexp', '^layer_\\d+$');
88```
89All variables whose names look like `layer_0`, `layer_1`, … are loaded.
90
91### Capture loaded variables in a struct without altering the workspace
92```matlab
93S = load('snapshot.mat');
94disp(fieldnames(S));
95```
96`S` contains one field per variable stored in `snapshot.mat`, leaving the workspace untouched.
97
98### Combine explicit names and regex filters
99```matlab
100model = load('model.mat', 'config', '-regexp', '^weights_(conv|fc)');
101```
102The returned struct includes the `config` variable and every weight matrix whose name matches either `weights_conv` or `weights_fc`.
103
104### Honour a custom default filename
105```matlab
106setenv('RUNMAT_LOAD_DEFAULT_PATH', fullfile(tempdir, 'autosave.mat'));
107load();
108```
109With no arguments, `load` falls back to the file specified by `RUNMAT_LOAD_DEFAULT_PATH`.
110
111### Load character and string data
112```matlab
113values = load('strings.mat', 'labels');
114disp(values.labels(1));
115```
116String arrays saved by RunMat are reconstructed faithfully from the underlying MAT-file representation.
117
118## GPU residency in RunMat (Do I need `gpuArray`?)
119No manual action is required. `load` always creates host values. When the auto-offload planner decides that downstream computations benefit from GPU execution, it will promote tensors automatically. You can still call `gpuArray` on loaded variables explicitly if you want to pin them to the device immediately.
120
121## FAQ
122
123### Does `load` support ASCII text files?
124No. RunMat (like MATLAB) restricts the `load` builtin in modern releases to MAT-files. Text and delimited files should be read using `readmatrix`, `readtable`, or other file I/O utilities such as `fileread`.
125
126### How are structures handled?
127Structure scalars are reconstructed as `struct` values whose fields match the MAT-file content. Nested structs, cells, logical arrays, and numeric data are all supported.
128
129### Will `load` overwrite existing variables?
130Yes. When you call `load` without capturing the output struct, any variables with matching names in the caller's workspace are overwritten with the values from the MAT-file.
131
132### What happens if a requested variable is missing?
133RunMat raises a descriptive error: `load: variable 'foo' was not found in the file`. This mirrors MATLAB's behavior.
134
135### Can I load into a different workspace?
136Use MATLAB-compatible functions such as `assignin` (when available) if you need to populate a different scope explicitly. The `load` builtin itself targets the caller workspace by default.
137
138### How are GPU arrays handled?
139GPU-resident values are serialised to host data when saved. Loading the resulting MAT-file produces standard host arrays. Downstream acceleration is handled automatically by RunMat Accelerate.
140
141### How do I detect which variables were loaded?
142Use the struct form: `info = load(filename);` and then inspect `fieldnames(info)` or `isfield` to programmatically check what was present in the MAT-file.
143
144## See Also
145[save](./save), [who](../../introspection/who), [fileread](../filetext/fileread), [matfile](https://www.mathworks.com/help/matlab/ref/matfile.html)
146
147## Source & Feedback
148- Implementation: [`crates/runmat-runtime/src/builtins/io/mat/load.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/io/mat/load.rs)
149- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with a minimal reproduction.
150"#;
151
152pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
153 name: "load",
154 op_kind: GpuOpKind::Custom("io-load"),
155 supported_precisions: &[],
156 broadcast: BroadcastSemantics::None,
157 provider_hooks: &[],
158 constant_strategy: ConstantStrategy::InlineLiteral,
159 residency: ResidencyPolicy::NewHandle,
160 nan_mode: ReductionNaN::Include,
161 two_pass_threshold: None,
162 workgroup_size: None,
163 accepts_nan_mode: false,
164 notes: "Reads MAT-files on the host and produces CPU-resident values. Providers are not involved until accelerated code later promotes the results.",
165};
166
167register_builtin_gpu_spec!(GPU_SPEC);
168
169pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
170 name: "load",
171 shape: ShapeRequirements::Any,
172 constant_strategy: ConstantStrategy::InlineLiteral,
173 elementwise: None,
174 reduction: None,
175 emits_nan: false,
176 notes: "File I/O is not eligible for fusion. Registration exists for documentation completeness only.",
177};
178
179register_builtin_fusion_spec!(FUSION_SPEC);
180
181#[cfg(feature = "doc_export")]
182register_builtin_doc_text!("load", DOC_MD);
183
184#[runtime_builtin(
185 name = "load",
186 category = "io/mat",
187 summary = "Load variables from a MAT-file.",
188 keywords = "load,mat,workspace",
189 accel = "cpu",
190 sink = true
191)]
192fn load_builtin(args: Vec<Value>) -> Result<Value, String> {
193 let eval = evaluate(&args)?;
194 Ok(eval.first_output())
195}
196
197#[derive(Clone, Debug)]
198pub struct LoadEval {
199 variables: Vec<(String, Value)>,
200}
201
202impl LoadEval {
203 pub fn first_output(&self) -> Value {
204 let mut st = StructValue::new();
205 for (name, value) in &self.variables {
206 st.fields.insert(name.clone(), value.clone());
207 }
208 Value::Struct(st)
209 }
210
211 pub fn variables(&self) -> &[(String, Value)] {
212 &self.variables
213 }
214
215 pub fn into_variables(self) -> Vec<(String, Value)> {
216 self.variables
217 }
218}
219
220struct LoadRequest {
221 variables: Vec<String>,
222 regex_patterns: Vec<Regex>,
223}
224
225pub fn evaluate(args: &[Value]) -> Result<LoadEval, String> {
226 let mut host_args = Vec::with_capacity(args.len());
227 for arg in args {
228 host_args.push(gather_if_needed(arg)?);
229 }
230
231 let invocation = parse_invocation(&host_args)?;
232
233 let mut path_value = if let Some(path) = invocation.path_value {
234 path
235 } else {
236 Value::from("matlab.mat")
237 };
238
239 if invocation.path_was_default {
240 if let Ok(override_path) = std::env::var("RUNMAT_LOAD_DEFAULT_PATH") {
241 path_value = Value::from(override_path);
242 }
243 }
244
245 let mut regex_patterns = Vec::with_capacity(invocation.regex_tokens.len());
246 for pattern in invocation.regex_tokens {
247 let regex = Regex::new(&pattern)
248 .map_err(|err| format!("load: invalid regular expression '{pattern}': {err}"))?;
249 regex_patterns.push(regex);
250 }
251
252 let request = LoadRequest {
253 variables: invocation.variables,
254 regex_patterns,
255 };
256 let path = normalise_path(&path_value)?;
257 let entries = read_mat_file(&path)?;
258
259 let selected = select_variables(&entries, &request)?;
260 Ok(LoadEval {
261 variables: selected,
262 })
263}
264
265struct ParsedInvocation {
266 path_value: Option<Value>,
267 path_was_default: bool,
268 variables: Vec<String>,
269 regex_tokens: Vec<String>,
270}
271
272fn parse_invocation(values: &[Value]) -> Result<ParsedInvocation, String> {
273 let mut path_value = None;
274 let mut path_was_default = false;
275 let mut variables = Vec::new();
276 let mut regex_tokens = Vec::new();
277 let mut idx = 0usize;
278 while idx < values.len() {
279 if let Some(flag) = option_token(&values[idx])? {
280 match flag.as_str() {
281 "-mat" => {
282 idx += 1;
283 continue;
284 }
285 "-regexp" => {
286 idx += 1;
287 if idx >= values.len() {
288 return Err("load: '-regexp' requires at least one pattern".to_string());
289 }
290 while idx < values.len() {
291 if option_token(&values[idx])?.is_some() {
292 break;
293 }
294 let names = extract_names(&values[idx])?;
295 if names.is_empty() {
296 return Err(
297 "load: '-regexp' requires non-empty pattern strings".to_string()
298 );
299 }
300 regex_tokens.extend(names);
301 idx += 1;
302 }
303 continue;
304 }
305 other => {
306 return Err(format!("load: unsupported option '{other}'"));
307 }
308 }
309 } else {
310 if path_value.is_none() {
311 path_value = Some(values[idx].clone());
312 idx += 1;
313 continue;
314 }
315 let names = extract_names(&values[idx])?;
316 variables.extend(names);
317 idx += 1;
318 }
319 }
320
321 if path_value.is_none() {
322 path_was_default = true;
323 }
324
325 Ok(ParsedInvocation {
326 path_value,
327 path_was_default,
328 variables,
329 regex_tokens,
330 })
331}
332
333fn normalise_path(value: &Value) -> Result<PathBuf, String> {
334 let raw = value_to_string_scalar(value)
335 .ok_or_else(|| "load: filename must be a character vector or string scalar".to_string())?;
336 let mut path = PathBuf::from(raw);
337 if path.extension().is_none() {
338 path.set_extension("mat");
339 }
340 Ok(path)
341}
342
343fn select_variables(
344 entries: &[(String, Value)],
345 request: &LoadRequest,
346) -> Result<Vec<(String, Value)>, String> {
347 if request.variables.is_empty() && request.regex_patterns.is_empty() {
348 return Ok(entries.to_vec());
349 }
350
351 let mut by_name: HashMap<&str, &Value> = HashMap::with_capacity(entries.len());
352 for (name, value) in entries {
353 by_name.insert(name, value);
354 }
355
356 let mut selected = Vec::new();
357
358 for name in &request.variables {
359 let value = by_name
360 .get(name.as_str())
361 .ok_or_else(|| format!("load: variable '{name}' was not found in the file"))?;
362 insert_or_replace(&mut selected, name, (*value).clone());
363 }
364
365 if !request.regex_patterns.is_empty() {
366 let mut matched = 0usize;
367 for (name, value) in entries {
368 if request
369 .regex_patterns
370 .iter()
371 .any(|regex| regex.is_match(name))
372 {
373 matched += 1;
374 insert_or_replace(&mut selected, name, value.clone());
375 }
376 }
377 if matched == 0 && request.variables.is_empty() {
378 return Err("load: no variables matched '-regexp' patterns".to_string());
379 }
380 }
381
382 if selected.is_empty() {
383 return Err("load: no variables selected".to_string());
384 }
385
386 Ok(selected)
387}
388
389fn insert_or_replace(selected: &mut Vec<(String, Value)>, name: &str, value: Value) {
390 if let Some(entry) = selected.iter_mut().find(|(existing, _)| existing == name) {
391 entry.1 = value;
392 } else {
393 selected.push((name.to_string(), value));
394 }
395}
396
397pub(crate) fn read_mat_file(path: &Path) -> Result<Vec<(String, Value)>, String> {
398 let file = File::open(path)
399 .map_err(|err| format!("load: failed to open '{}': {err}", path.display()))?;
400 let mut reader = BufReader::new(file);
401
402 let mut header = [0u8; MAT_HEADER_LEN];
403 reader
404 .read_exact(&mut header)
405 .map_err(|err| format!("load: failed to read MAT-file header: {err}"))?;
406 if header[126] != b'I' || header[127] != b'M' {
407 return Err("load: file is not a MATLAB Level-5 MAT-file".to_string());
408 }
409
410 let mut variables = Vec::new();
411 while let Some(tagged) = read_tagged(&mut reader, true)? {
412 if tagged.data_type != MI_MATRIX {
413 continue;
414 }
415 let parsed = parse_matrix(&tagged.data)?;
416 let value = mat_array_to_value(parsed.array)?;
417 variables.push((parsed.name, value));
418 }
419 Ok(variables)
420}
421
422struct ParsedMatrix {
423 name: String,
424 array: MatArray,
425}
426
427fn parse_matrix(buffer: &[u8]) -> Result<ParsedMatrix, String> {
428 let mut cursor = Cursor::new(buffer);
429
430 let flags = read_tagged(&mut cursor, false)?
431 .ok_or_else(|| "load: matrix element missing array flags".to_string())?;
432 if flags.data_type != MI_UINT32 || flags.data.len() < 8 {
433 return Err("load: invalid array flags block".to_string());
434 }
435 let flags0 = u32::from_le_bytes(flags.data[0..4].try_into().unwrap());
436 let class_code = flags0 & 0xFF;
437 let mut class = MatClass::from_class_code(class_code)
438 .ok_or_else(|| "load: unsupported MATLAB class".to_string())?;
439 let is_logical = (flags0 & FLAG_LOGICAL) != 0;
440 let has_imag = (flags0 & FLAG_COMPLEX) != 0;
441 if matches!(class, MatClass::Double) && is_logical {
442 class = MatClass::Logical;
443 }
444
445 let dims_elem = read_tagged(&mut cursor, false)?
446 .ok_or_else(|| "load: matrix element missing dimensions".to_string())?;
447 if dims_elem.data_type != MI_INT32 {
448 return Err("load: dimension block must use MI_INT32".to_string());
449 }
450 if dims_elem.data.is_empty() || dims_elem.data.len() % 4 != 0 {
451 return Err("load: malformed dimension block".to_string());
452 }
453 let mut dims = Vec::with_capacity(dims_elem.data.len() / 4);
454 for chunk in dims_elem.data.chunks_exact(4) {
455 let value = i32::from_le_bytes(chunk.try_into().unwrap());
456 if value < 0 {
457 return Err("load: negative dimensions are not supported".to_string());
458 }
459 dims.push(value as usize);
460 }
461 if dims.is_empty() {
462 dims.push(1);
463 dims.push(1);
464 }
465
466 let name_elem = read_tagged(&mut cursor, false)?
467 .ok_or_else(|| "load: matrix element missing name".to_string())?;
468 let name = match name_elem.data_type {
469 MI_INT8 | MI_UINT8 => bytes_to_string(&name_elem.data),
470 MI_UINT16 => {
471 let mut bytes = Vec::with_capacity(name_elem.data.len());
472 for chunk in name_elem.data.chunks_exact(2) {
473 let code = u16::from_le_bytes(chunk.try_into().unwrap());
474 if code == 0 {
475 break;
476 }
477 if let Some(ch) = char::from_u32(code as u32) {
478 bytes.push(ch);
479 }
480 }
481 bytes.into_iter().collect()
482 }
483 _ => {
484 return Err("load: unsupported array name encoding".to_string());
485 }
486 };
487
488 let array = match class {
489 MatClass::Double => parse_double_array(&mut cursor, dims, has_imag)?,
490 MatClass::Logical => parse_logical_array(&mut cursor, dims)?,
491 MatClass::Char => parse_char_array(&mut cursor, dims)?,
492 MatClass::Cell => parse_cell_array(&mut cursor, dims)?,
493 MatClass::Struct => parse_struct(&mut cursor, dims)?,
494 };
495
496 Ok(ParsedMatrix { name, array })
497}
498
499fn parse_double_array(
500 cursor: &mut Cursor<&[u8]>,
501 dims: Vec<usize>,
502 has_imag: bool,
503) -> Result<MatArray, String> {
504 let real_elem = read_tagged(cursor, false)?
505 .ok_or_else(|| "load: numeric array missing real component".to_string())?;
506 if real_elem.data_type != MI_DOUBLE || real_elem.data.len() % 8 != 0 {
507 return Err("load: numeric data must be stored as MI_DOUBLE".to_string());
508 }
509 let mut real = Vec::with_capacity(real_elem.data.len() / 8);
510 for chunk in real_elem.data.chunks_exact(8) {
511 real.push(f64::from_le_bytes(chunk.try_into().unwrap()));
512 }
513
514 let imag = if has_imag {
515 let imag_elem = read_tagged(cursor, false)?
516 .ok_or_else(|| "load: numeric array missing imaginary component".to_string())?;
517 if imag_elem.data_type != MI_DOUBLE || imag_elem.data.len() % 8 != 0 {
518 return Err("load: imaginary component must be MI_DOUBLE".to_string());
519 }
520 let mut imag = Vec::with_capacity(imag_elem.data.len() / 8);
521 for chunk in imag_elem.data.chunks_exact(8) {
522 imag.push(f64::from_le_bytes(chunk.try_into().unwrap()));
523 }
524 Some(imag)
525 } else {
526 None
527 };
528
529 Ok(MatArray {
530 class: MatClass::Double,
531 dims,
532 data: MatData::Double { real, imag },
533 })
534}
535
536fn parse_logical_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> Result<MatArray, String> {
537 let elem = read_tagged(cursor, false)?
538 .ok_or_else(|| "load: logical array missing data block".to_string())?;
539 if elem.data_type != MI_UINT8 {
540 return Err("load: logical arrays must be stored as MI_UINT8".to_string());
541 }
542 Ok(MatArray {
543 class: MatClass::Logical,
544 dims,
545 data: MatData::Logical { data: elem.data },
546 })
547}
548
549fn parse_char_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> Result<MatArray, String> {
550 let elem = read_tagged(cursor, false)?
551 .ok_or_else(|| "load: character array missing data block".to_string())?;
552 if elem.data_type != MI_UINT16 {
553 return Err("load: character data must be stored as MI_UINT16".to_string());
554 }
555 if elem.data.len() % 2 != 0 {
556 return Err("load: malformed character data".to_string());
557 }
558 let mut data = Vec::with_capacity(elem.data.len() / 2);
559 for chunk in elem.data.chunks_exact(2) {
560 data.push(u16::from_le_bytes(chunk.try_into().unwrap()));
561 }
562 Ok(MatArray {
563 class: MatClass::Char,
564 dims,
565 data: MatData::Char { data },
566 })
567}
568
569fn parse_cell_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> Result<MatArray, String> {
570 let total: usize = dims
571 .iter()
572 .copied()
573 .fold(1usize, |acc, d| acc.saturating_mul(d));
574 let mut elements = Vec::with_capacity(total);
575 for _ in 0..total {
576 let elem = read_tagged(cursor, false)?
577 .ok_or_else(|| "load: cell element missing matrix payload".to_string())?;
578 if elem.data_type != MI_MATRIX {
579 return Err("load: cell elements must be matrices".to_string());
580 }
581 let parsed = parse_matrix(&elem.data)?;
582 elements.push(parsed.array);
583 }
584 Ok(MatArray {
585 class: MatClass::Cell,
586 dims,
587 data: MatData::Cell { elements },
588 })
589}
590
591fn parse_struct(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> Result<MatArray, String> {
592 if dims.len() != 2 || dims[0] != 1 || dims[1] != 1 {
593 return Err("load: struct arrays are not supported yet".to_string());
594 }
595 let len_elem = read_tagged(cursor, false)?
596 .ok_or_else(|| "load: struct missing maximum field length specifier".to_string())?;
597 if len_elem.data_type != MI_INT32 || len_elem.data.len() != 4 {
598 return Err("load: struct field length must be MI_INT32".to_string());
599 }
600 let max_len = i32::from_le_bytes(len_elem.data[..4].try_into().unwrap());
601 if max_len <= 0 {
602 return Err("load: struct field length must be positive".to_string());
603 }
604
605 let names_elem = read_tagged(cursor, false)?
606 .ok_or_else(|| "load: struct missing field name table".to_string())?;
607 if names_elem.data_type != MI_INT8 && names_elem.data_type != MI_UINT8 {
608 return Err("load: struct field names must be stored as MI_INT8/MI_UINT8".to_string());
609 }
610 if names_elem.data.len() % (max_len as usize) != 0 {
611 return Err("load: malformed struct field name table".to_string());
612 }
613 let field_count = names_elem.data.len() / (max_len as usize);
614 let mut field_names = Vec::with_capacity(field_count);
615 for i in 0..field_count {
616 let start = i * (max_len as usize);
617 let end = start + (max_len as usize);
618 let slice = &names_elem.data[start..end];
619 field_names.push(bytes_to_string(slice));
620 }
621
622 let mut field_values = Vec::with_capacity(field_count);
623 for _ in 0..field_count {
624 let elem = read_tagged(cursor, false)?
625 .ok_or_else(|| "load: struct field missing matrix payload".to_string())?;
626 if elem.data_type != MI_MATRIX {
627 return Err("load: struct fields must be matrices".to_string());
628 }
629 let parsed = parse_matrix(&elem.data)?;
630 field_values.push(parsed.array);
631 }
632
633 Ok(MatArray {
634 class: MatClass::Struct,
635 dims,
636 data: MatData::Struct {
637 field_names,
638 field_values,
639 },
640 })
641}
642
643fn mat_array_to_value(array: MatArray) -> Result<Value, String> {
644 match array.data {
645 MatData::Double { real, imag } => {
646 let len = real.len();
647 if let Some(imag) = imag {
648 if imag.len() != len {
649 return Err("load: complex data has mismatched real/imag parts".to_string());
650 }
651 if len == 1 {
652 Ok(Value::Complex(real[0], imag[0]))
653 } else {
654 let mut pairs = Vec::with_capacity(len);
655 for i in 0..len {
656 pairs.push((real[i], imag[i]));
657 }
658 let tensor = ComplexTensor::new(pairs, array.dims.clone())
659 .map_err(|e| format!("load: {e}"))?;
660 Ok(Value::ComplexTensor(tensor))
661 }
662 } else if len == 1 {
663 Ok(Value::Num(real[0]))
664 } else {
665 let tensor =
666 Tensor::new(real, array.dims.clone()).map_err(|e| format!("load: {e}"))?;
667 Ok(Value::Tensor(tensor))
668 }
669 }
670 MatData::Logical { data } => {
671 let total: usize = array
672 .dims
673 .iter()
674 .copied()
675 .fold(1usize, |acc, d| acc.saturating_mul(d));
676 if data.len() != total {
677 return Err("load: logical data length mismatch".to_string());
678 }
679 if total == 1 {
680 Ok(Value::Bool(data.first().copied().unwrap_or(0) != 0))
681 } else {
682 let logical = LogicalArray::new(data, array.dims.clone())
683 .map_err(|e| format!("load: {e}"))?;
684 Ok(Value::LogicalArray(logical))
685 }
686 }
687 MatData::Char { data } => {
688 let rows = array.dims.first().copied().unwrap_or(1);
689 let cols = array.dims.get(1).copied().unwrap_or(1);
690 let mut chars = Vec::with_capacity(rows.saturating_mul(cols));
691 for code in data {
692 let ch = char::from_u32(code as u32).unwrap_or('\u{FFFD}');
693 chars.push(ch);
694 }
695 let char_array = CharArray::new(chars, rows, cols).map_err(|e| format!("load: {e}"))?;
696 Ok(Value::CharArray(char_array))
697 }
698 MatData::Cell { elements } => {
699 if let Some(strings) = cell_elements_to_strings(&elements) {
700 let string_array = StringArray::new(strings, array.dims.clone())
701 .map_err(|e| format!("load: {e}"))?;
702 return Ok(Value::StringArray(string_array));
703 }
704 if array.dims.len() != 2 {
705 return Err(
706 "load: cell arrays with more than two dimensions are not supported yet"
707 .to_string(),
708 );
709 }
710 let rows = array.dims[0];
711 let cols = array.dims[1];
712 let expected = rows.saturating_mul(cols);
713 if elements.len() != expected {
714 return Err("load: cell array element count mismatch".to_string());
715 }
716 let mut converted = Vec::with_capacity(elements.len());
717 for elem in elements {
718 converted.push(mat_array_to_value(elem)?);
719 }
720 let mut row_major = vec![Value::Num(0.0); expected];
721 for col in 0..cols {
722 for row in 0..rows {
723 let cm_idx = col * rows + row;
724 let rm_idx = row * cols + col;
725 row_major[rm_idx] = converted[cm_idx].clone();
726 }
727 }
728 make_cell(row_major, rows, cols)
729 }
730 MatData::Struct {
731 field_names,
732 field_values,
733 } => {
734 if field_names.len() != field_values.len() {
735 return Err("load: struct field metadata is inconsistent".to_string());
736 }
737 let mut st = StructValue::new();
738 for (name, value) in field_names.into_iter().zip(field_values.into_iter()) {
739 let converted = mat_array_to_value(value)?;
740 st.fields.insert(name, converted);
741 }
742 Ok(Value::Struct(st))
743 }
744 }
745}
746
747fn cell_elements_to_strings(elements: &[MatArray]) -> Option<Vec<String>> {
748 let mut strings = Vec::with_capacity(elements.len());
749 for element in elements {
750 if element.class != MatClass::Char {
751 return None;
752 }
753 let rows = element.dims.first().copied().unwrap_or(1);
754 if rows > 1 {
755 return None;
756 }
757 match &element.data {
758 MatData::Char { data } => strings.push(utf16_codes_to_string(data)),
759 _ => return None,
760 }
761 }
762 Some(strings)
763}
764
765fn utf16_codes_to_string(data: &[u16]) -> String {
766 let mut chars: Vec<char> = data
767 .iter()
768 .map(|code| char::from_u32(*code as u32).unwrap_or('\u{FFFD}'))
769 .collect();
770 while matches!(chars.last(), Some(&'\0')) {
771 chars.pop();
772 }
773 chars.into_iter().collect()
774}
775
776fn option_token(value: &Value) -> Result<Option<String>, String> {
777 if let Some(token) = value_to_string_scalar(value) {
778 if token.starts_with('-') {
779 return Ok(Some(token.to_ascii_lowercase()));
780 }
781 }
782 Ok(None)
783}
784
785fn extract_names(value: &Value) -> Result<Vec<String>, String> {
786 match value {
787 Value::String(s) => Ok(vec![s.clone()]),
788 Value::CharArray(ca) => Ok(char_array_rows_as_strings(ca)),
789 Value::StringArray(sa) => Ok(sa.data.clone()),
790 Value::Cell(ca) => {
791 let mut names = Vec::with_capacity(ca.data.len());
792 for handle in &ca.data {
793 let inner = unsafe { &*handle.as_raw() };
794 let text = value_to_string_scalar(inner).ok_or_else(|| {
795 "load: cell arrays used for variable selection must contain string scalars"
796 .to_string()
797 })?;
798 names.push(text);
799 }
800 Ok(names)
801 }
802 other => {
803 let gathered = gather_if_needed(other)?;
804 extract_names(&gathered)
805 }
806 }
807}
808
809fn value_to_string_scalar(value: &Value) -> Option<String> {
810 match value {
811 Value::String(s) => Some(s.clone()),
812 Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
813 Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
814 _ => None,
815 }
816}
817
818fn char_array_rows_as_strings(ca: &CharArray) -> Vec<String> {
819 let mut rows = Vec::with_capacity(ca.rows);
820 for r in 0..ca.rows {
821 let mut row = String::with_capacity(ca.cols);
822 for c in 0..ca.cols {
823 let idx = r * ca.cols + c;
824 row.push(ca.data[idx]);
825 }
826 let trimmed = row.trim_end_matches([' ', '\0']).to_string();
827 rows.push(trimmed);
828 }
829 rows
830}
831
832fn bytes_to_string(bytes: &[u8]) -> String {
833 let trimmed = bytes
834 .iter()
835 .copied()
836 .take_while(|b| *b != 0)
837 .collect::<Vec<u8>>();
838 String::from_utf8(trimmed).unwrap_or_default()
839}
840
841struct TaggedData {
842 data_type: u32,
843 data: Vec<u8>,
844}
845
846fn read_tagged<R: Read>(reader: &mut R, allow_eof: bool) -> Result<Option<TaggedData>, String> {
847 let mut type_bytes = [0u8; 4];
848 match reader.read_exact(&mut type_bytes) {
849 Ok(()) => {}
850 Err(err) => {
851 if allow_eof && err.kind() == std::io::ErrorKind::UnexpectedEof {
852 return Ok(None);
853 }
854 return Err(format!("load: failed to read MAT element header: {err}"));
855 }
856 }
857
858 let type_field = u32::from_le_bytes(type_bytes);
859 if (type_field & 0xFFFF0000) != 0 {
860 let data_type = type_field & 0x0000FFFF;
861 let num_bytes = ((type_field & 0xFFFF0000) >> 16) as usize;
862 let mut inline = [0u8; 4];
863 reader
864 .read_exact(&mut inline)
865 .map_err(|err| format!("load: failed to read compact MAT element: {err}"))?;
866 let mut data = inline[..num_bytes.min(4)].to_vec();
867 data.truncate(num_bytes.min(4));
868 Ok(Some(TaggedData { data_type, data }))
869 } else {
870 let mut len_bytes = [0u8; 4];
871 reader
872 .read_exact(&mut len_bytes)
873 .map_err(|err| format!("load: failed to read MAT element length: {err}"))?;
874 let length = u32::from_le_bytes(len_bytes) as usize;
875 let mut data = vec![0u8; length];
876 reader
877 .read_exact(&mut data)
878 .map_err(|err| format!("load: failed to read MAT element body: {err}"))?;
879 let padding = (8 - (length % 8)) % 8;
880 if padding != 0 {
881 let mut pad = vec![0u8; padding];
882 reader
883 .read_exact(&mut pad)
884 .map_err(|err| format!("load: failed to read MAT padding: {err}"))?;
885 }
886 Ok(Some(TaggedData {
887 data_type: type_field,
888 data,
889 }))
890 }
891}
892
893#[cfg(test)]
894mod tests {
895 use super::*;
896 use crate::workspace::WorkspaceResolver;
897 use once_cell::sync::OnceCell;
898 use runmat_builtins::StringArray;
899 use std::cell::RefCell;
900 use std::collections::HashMap;
901 use tempfile::tempdir;
902
903 thread_local! {
904 static TEST_WORKSPACE: RefCell<HashMap<String, Value>> = RefCell::new(HashMap::new());
905 }
906
907 fn ensure_test_resolver() {
908 static INIT: OnceCell<()> = OnceCell::new();
909 INIT.get_or_init(|| {
910 crate::workspace::register_workspace_resolver(WorkspaceResolver {
911 lookup: |name| TEST_WORKSPACE.with(|slot| slot.borrow().get(name).cloned()),
912 snapshot: || {
913 let mut entries: Vec<(String, Value)> =
914 TEST_WORKSPACE.with(|slot| slot.borrow().clone().into_iter().collect());
915 entries.sort_by(|a, b| a.0.cmp(&b.0));
916 entries
917 },
918 globals: || Vec::new(),
919 });
920 });
921 }
922
923 fn set_workspace(entries: &[(&str, Value)]) {
924 TEST_WORKSPACE.with(|slot| {
925 let mut map = slot.borrow_mut();
926 map.clear();
927 for (name, value) in entries {
928 map.insert((*name).to_string(), value.clone());
929 }
930 });
931 }
932
933 #[test]
934 fn load_roundtrip_numeric() {
935 ensure_test_resolver();
936 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0], vec![2, 2]).unwrap();
937 set_workspace(&[("A", Value::Tensor(tensor))]);
938
939 let dir = tempdir().unwrap();
940 let path = dir.path().join("numeric.mat");
941 let save_arg = Value::from(path.to_string_lossy().to_string());
942 crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
943
944 let eval =
945 evaluate(&[Value::from(path.to_string_lossy().to_string())]).expect("load numeric");
946 let struct_value = eval.first_output();
947 match struct_value {
948 Value::Struct(sv) => {
949 assert!(sv.fields.contains_key("A"));
950 match sv.fields.get("A").unwrap() {
951 Value::Tensor(t) => {
952 assert_eq!(t.shape, vec![2, 2]);
953 assert_eq!(t.data, vec![1.0, 4.0, 2.0, 5.0]);
954 }
955 other => panic!("expected tensor, got {other:?}"),
956 }
957 }
958 other => panic!("expected struct, got {other:?}"),
959 }
960 }
961
962 #[test]
963 fn load_selected_variables() {
964 ensure_test_resolver();
965 set_workspace(&[("signal", Value::Num(42.0)), ("noise", Value::Num(5.0))]);
966 let dir = tempdir().unwrap();
967 let path = dir.path().join("selection.mat");
968 let save_arg = Value::from(path.to_string_lossy().to_string());
969 crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
970
971 let eval = evaluate(&[
972 Value::from(path.to_string_lossy().to_string()),
973 Value::from("signal"),
974 ])
975 .expect("load selection");
976 let vars = eval.variables();
977 assert_eq!(vars.len(), 1);
978 assert_eq!(vars[0].0, "signal");
979 assert!(matches!(vars[0].1, Value::Num(42.0)));
980 }
981
982 #[test]
983 fn load_regex_selection() {
984 ensure_test_resolver();
985 set_workspace(&[
986 ("w1", Value::Num(1.0)),
987 ("w2", Value::Num(2.0)),
988 ("bias", Value::Num(3.0)),
989 ]);
990 let dir = tempdir().unwrap();
991 let path = dir.path().join("regex.mat");
992 let save_arg = Value::from(path.to_string_lossy().to_string());
993 crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
994
995 let eval = evaluate(&[
996 Value::from(path.to_string_lossy().to_string()),
997 Value::from("-regexp"),
998 Value::from("^w\\d$"),
999 ])
1000 .expect("load regex");
1001 let mut names: Vec<_> = eval.variables().iter().map(|(n, _)| n.clone()).collect();
1002 names.sort();
1003 assert_eq!(names, vec!["w1".to_string(), "w2".to_string()]);
1004 }
1005
1006 #[test]
1007 fn load_missing_variable_errors() {
1008 ensure_test_resolver();
1009 set_workspace(&[("existing", Value::Num(7.0))]);
1010 let dir = tempdir().unwrap();
1011 let path = dir.path().join("missing.mat");
1012 let save_arg = Value::from(path.to_string_lossy().to_string());
1013 crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
1014
1015 let err = evaluate(&[
1016 Value::from(path.to_string_lossy().to_string()),
1017 Value::from("missing"),
1018 ])
1019 .expect_err("expect missing variable error");
1020 assert!(err.contains("variable 'missing' was not found"));
1021 }
1022
1023 #[test]
1024 fn load_string_array_roundtrip() {
1025 ensure_test_resolver();
1026 let strings = StringArray::new(vec!["foo".into(), "bar".into()], vec![1, 2]).unwrap();
1027 set_workspace(&[("labels", Value::StringArray(strings))]);
1028 let dir = tempdir().unwrap();
1029 let path = dir.path().join("strings.mat");
1030 let save_arg = Value::from(path.to_string_lossy().to_string());
1031 crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
1032
1033 let eval =
1034 evaluate(&[Value::from(path.to_string_lossy().to_string())]).expect("load strings");
1035 let struct_value = eval.first_output();
1036 match struct_value {
1037 Value::Struct(sv) => {
1038 let value = sv
1039 .fields
1040 .get("labels")
1041 .expect("labels field missing in struct");
1042 match value {
1043 Value::StringArray(sa) => {
1044 assert_eq!(sa.shape, vec![1, 2]);
1045 assert_eq!(sa.data, vec![String::from("foo"), String::from("bar")]);
1046 }
1047 other => panic!("expected string array, got {other:?}"),
1048 }
1049 }
1050 other => panic!("expected struct, got {other:?}"),
1051 }
1052 }
1053
1054 #[test]
1055 fn load_option_before_filename() {
1056 ensure_test_resolver();
1057 set_workspace(&[("alpha", Value::Num(1.0)), ("beta", Value::Num(2.0))]);
1058 let dir = tempdir().unwrap();
1059 let path = dir.path().join("option_first.mat");
1060 let save_arg = Value::from(path.to_string_lossy().to_string());
1061 crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
1062
1063 let eval = evaluate(&[
1064 Value::from("-mat"),
1065 Value::from(path.to_string_lossy().to_string()),
1066 Value::from("beta"),
1067 ])
1068 .expect("load with option first");
1069 let vars = eval.variables();
1070 assert_eq!(vars.len(), 1);
1071 assert_eq!(vars[0].0, "beta");
1072 assert!(matches!(vars[0].1, Value::Num(2.0)));
1073 }
1074
1075 #[test]
1076 fn load_char_array_names_trimmed() {
1077 ensure_test_resolver();
1078 set_workspace(&[("short", Value::Num(5.0)), ("longer", Value::Num(9.0))]);
1079 let dir = tempdir().unwrap();
1080 let path = dir.path().join("char_names.mat");
1081 let save_arg = Value::from(path.to_string_lossy().to_string());
1082 crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
1083
1084 let cols = 6;
1085 let mut data = Vec::new();
1086 for name in ["short", "longer"] {
1087 let mut chars: Vec<char> = name.chars().collect();
1088 while chars.len() < cols {
1089 chars.push(' ');
1090 }
1091 data.extend(chars);
1092 }
1093 let name_array = CharArray::new(data, 2, cols).unwrap();
1094
1095 let eval = evaluate(&[
1096 Value::from(path.to_string_lossy().to_string()),
1097 Value::CharArray(name_array),
1098 ])
1099 .expect("load with char array names");
1100 let vars = eval.variables();
1101 assert_eq!(vars.len(), 2);
1102 assert_eq!(vars[0].0, "short");
1103 assert!(matches!(vars[0].1, Value::Num(5.0)));
1104 assert_eq!(vars[1].0, "longer");
1105 assert!(matches!(vars[1].1, Value::Num(9.0)));
1106 }
1107
1108 #[test]
1109 fn load_duplicate_names_last_wins() {
1110 ensure_test_resolver();
1111 set_workspace(&[("dup", Value::Num(11.0))]);
1112 let dir = tempdir().unwrap();
1113 let path = dir.path().join("duplicates.mat");
1114 let save_arg = Value::from(path.to_string_lossy().to_string());
1115 crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
1116
1117 let eval = evaluate(&[
1118 Value::from(path.to_string_lossy().to_string()),
1119 Value::from("dup"),
1120 Value::from("dup"),
1121 ])
1122 .expect("load with duplicate names");
1123 let vars = eval.variables();
1124 assert_eq!(vars.len(), 1);
1125 assert_eq!(vars[0].0, "dup");
1126 assert!(matches!(vars[0].1, Value::Num(11.0)));
1127 }
1128
1129 #[test]
1130 #[cfg(feature = "wgpu")]
1131 fn load_wgpu_tensor_roundtrip() {
1132 ensure_test_resolver();
1133 if runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1134 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1135 )
1136 .is_err()
1137 {
1138 return;
1139 }
1140 let Some(provider) = runmat_accelerate_api::provider() else {
1141 return;
1142 };
1143
1144 use runmat_accelerate_api::HostTensorView;
1145
1146 let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap();
1147 let view = HostTensorView {
1148 data: &tensor.data,
1149 shape: &tensor.shape,
1150 };
1151 let handle = provider.upload(&view).expect("upload tensor");
1152 set_workspace(&[("gpu_var", Value::GpuTensor(handle))]);
1153
1154 let dir = tempdir().unwrap();
1155 let path = dir.path().join("wgpu_load.mat");
1156 let save_args = vec![
1157 Value::from(path.to_string_lossy().to_string()),
1158 Value::from("gpu_var"),
1159 ];
1160 crate::call_builtin("save", &save_args).unwrap();
1161
1162 let eval =
1163 evaluate(&[Value::from(path.to_string_lossy().to_string())]).expect("load wgpu file");
1164 let struct_value = eval.first_output();
1165 match struct_value {
1166 Value::Struct(sv) => match sv.fields.get("gpu_var") {
1167 Some(Value::Tensor(t)) => {
1168 assert_eq!(t.shape, vec![2, 2]);
1169 assert_eq!(t.data, tensor.data);
1170 }
1171 other => panic!("expected tensor, got {other:?}"),
1172 },
1173 other => panic!("expected struct, got {other:?}"),
1174 }
1175 }
1176
1177 #[test]
1178 #[cfg(feature = "doc_export")]
1179 fn doc_examples_present() {
1180 let blocks = crate::builtins::common::test_support::doc_examples(DOC_MD);
1181 assert!(!blocks.is_empty());
1182 }
1183}