runmat_runtime/builtins/io/mat/
load.rs

1//! MATLAB-compatible `load` builtin for RunMat.
2
3use std::collections::HashMap;
4use std::fs::File;
5use std::io::{BufReader, Cursor, Read};
6use std::path::{Path, PathBuf};
7
8use regex::Regex;
9use runmat_builtins::{
10    CharArray, ComplexTensor, LogicalArray, StringArray, StructValue, Tensor, Value,
11};
12use runmat_macros::runtime_builtin;
13
14use super::format::{
15    MatArray, MatClass, MatData, FLAG_COMPLEX, FLAG_LOGICAL, MAT_HEADER_LEN, MI_DOUBLE, MI_INT32,
16    MI_INT8, MI_MATRIX, MI_UINT16, MI_UINT32, MI_UINT8,
17};
18use crate::builtins::common::spec::{
19    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
20    ReductionNaN, ResidencyPolicy, ShapeRequirements,
21};
22use crate::{gather_if_needed, make_cell, register_builtin_fusion_spec, register_builtin_gpu_spec};
23
24#[cfg(feature = "doc_export")]
25use crate::register_builtin_doc_text;
26
27#[cfg(feature = "doc_export")]
28pub const DOC_MD: &str = r#"---
29title: "load"
30category: "io/mat"
31keywords: ["load", "mat", "workspace", "io", "matlab load", "regex load"]
32summary: "Load variables from a MATLAB-compatible MAT-file into the workspace or return them as a struct."
33references:
34  - https://www.mathworks.com/help/matlab/ref/load.html
35gpu_support:
36  elementwise: false
37  reduction: false
38  precisions: []
39  broadcasting: "none"
40  notes: "Files are read on the host. When auto-offload is enabled, planners may later promote tensors to the GPU; no provider hooks are required at load time."
41fusion:
42  elementwise: false
43  reduction: false
44  max_inputs: 1
45  constants: "inline"
46requires_feature: null
47tested:
48  unit: "builtins::io::mat::load::tests"
49  integration:
50    - "builtins::io::mat::load::tests::load_selected_variables"
51    - "builtins::io::mat::load::tests::load_regex_selection"
52---
53
54# What does the `load` function do in MATLAB / RunMat?
55`load` reads variables from a MAT-file (Level-5 layout) and brings them into the current workspace. Like MATLAB, it can either populate variables directly or return a struct containing the loaded data.
56
57## How does the `load` function behave in MATLAB / RunMat?
58- `load filename` reads every variable stored in `filename.mat` and assigns them into the caller's workspace. When no extension is supplied, `.mat` is appended automatically. Set `RUNMAT_LOAD_DEFAULT_PATH` to override the default `matlab.mat` target when no filename argument is provided.
59- `S = load(filename)` loads the file but returns a struct instead of modifying the workspace. The struct fields mirror the variables stored in the MAT-file.
60- `load(filename, 'A', 'B')` restricts the operation to the listed variable names. String scalars, char vectors, string arrays, or cell arrays of character vectors are accepted.
61- `load(filename, '-regexp', '^foo', 'bar$')` selects variables whose names match any of the supplied regular expressions.
62- Repeated names are deduplicated so that the last occurrence wins, mirroring MATLAB's behavior.
63- Unsupported data classes trigger descriptive errors. RunMat currently supports double and complex numeric arrays, logical arrays, character arrays, string arrays (stored as cell-of-char data), structs, and cells whose elements are composed of the supported types.
64- Files saved on platforms that produce little-endian Level-5 MAT-files (MATLAB's default) are supported. Big-endian and compressed (`miCOMPRESSED`) files currently report an error.
65
66## `load` Function GPU Execution Behaviour
67`load` always reads data on the host. The resulting values start on the CPU. When RunMat Accelerate is active, auto-offload heuristics may later decide to promote tensors to the GPU if they participate in accelerated expressions, but no provider hooks are required during the `load` operation itself. GPU-resident variables that were saved earlier are gathered back to host memory as part of file serialisation, so loading them produces standard host values.
68
69## Examples of using the `load` function in MATLAB / RunMat
70
71### Load the entire file into the workspace
72```matlab
73load('results.mat');
74disp(norm(weights));
75```
76Expected outcome: every variable contained in `results.mat` becomes available in the caller's workspace.
77
78### Load a subset of variables by name
79```matlab
80load('sim_state.mat', 'state', 'time');
81plot(time, state);
82```
83Only `state` and `time` are created; other variables in the file are ignored.
84
85### Load variables using regular expressions
86```matlab
87load('checkpoint.mat', '-regexp', '^layer_\\d+$');
88```
89All variables whose names look like `layer_0`, `layer_1`, … are loaded.
90
91### Capture loaded variables in a struct without altering the workspace
92```matlab
93S = load('snapshot.mat');
94disp(fieldnames(S));
95```
96`S` contains one field per variable stored in `snapshot.mat`, leaving the workspace untouched.
97
98### Combine explicit names and regex filters
99```matlab
100model = load('model.mat', 'config', '-regexp', '^weights_(conv|fc)');
101```
102The returned struct includes the `config` variable and every weight matrix whose name matches either `weights_conv` or `weights_fc`.
103
104### Honour a custom default filename
105```matlab
106setenv('RUNMAT_LOAD_DEFAULT_PATH', fullfile(tempdir, 'autosave.mat'));
107load();
108```
109With no arguments, `load` falls back to the file specified by `RUNMAT_LOAD_DEFAULT_PATH`.
110
111### Load character and string data
112```matlab
113values = load('strings.mat', 'labels');
114disp(values.labels(1));
115```
116String arrays saved by RunMat are reconstructed faithfully from the underlying MAT-file representation.
117
118## GPU residency in RunMat (Do I need `gpuArray`?)
119No manual action is required. `load` always creates host values. When the auto-offload planner decides that downstream computations benefit from GPU execution, it will promote tensors automatically. You can still call `gpuArray` on loaded variables explicitly if you want to pin them to the device immediately.
120
121## FAQ
122
123### Does `load` support ASCII text files?
124No. RunMat (like MATLAB) restricts the `load` builtin in modern releases to MAT-files. Text and delimited files should be read using `readmatrix`, `readtable`, or other file I/O utilities such as `fileread`.
125
126### How are structures handled?
127Structure scalars are reconstructed as `struct` values whose fields match the MAT-file content. Nested structs, cells, logical arrays, and numeric data are all supported.
128
129### Will `load` overwrite existing variables?
130Yes. When you call `load` without capturing the output struct, any variables with matching names in the caller's workspace are overwritten with the values from the MAT-file.
131
132### What happens if a requested variable is missing?
133RunMat raises a descriptive error: `load: variable 'foo' was not found in the file`. This mirrors MATLAB's behavior.
134
135### Can I load into a different workspace?
136Use MATLAB-compatible functions such as `assignin` (when available) if you need to populate a different scope explicitly. The `load` builtin itself targets the caller workspace by default.
137
138### How are GPU arrays handled?
139GPU-resident values are serialised to host data when saved. Loading the resulting MAT-file produces standard host arrays. Downstream acceleration is handled automatically by RunMat Accelerate.
140
141### How do I detect which variables were loaded?
142Use the struct form: `info = load(filename);` and then inspect `fieldnames(info)` or `isfield` to programmatically check what was present in the MAT-file.
143
144## See Also
145[save](./save), [who](../../introspection/who), [fileread](../filetext/fileread), [matfile](https://www.mathworks.com/help/matlab/ref/matfile.html)
146
147## Source & Feedback
148- Implementation: [`crates/runmat-runtime/src/builtins/io/mat/load.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/io/mat/load.rs)
149- Found a bug? [Open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with a minimal reproduction.
150"#;
151
152pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
153    name: "load",
154    op_kind: GpuOpKind::Custom("io-load"),
155    supported_precisions: &[],
156    broadcast: BroadcastSemantics::None,
157    provider_hooks: &[],
158    constant_strategy: ConstantStrategy::InlineLiteral,
159    residency: ResidencyPolicy::NewHandle,
160    nan_mode: ReductionNaN::Include,
161    two_pass_threshold: None,
162    workgroup_size: None,
163    accepts_nan_mode: false,
164    notes: "Reads MAT-files on the host and produces CPU-resident values. Providers are not involved until accelerated code later promotes the results.",
165};
166
167register_builtin_gpu_spec!(GPU_SPEC);
168
169pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
170    name: "load",
171    shape: ShapeRequirements::Any,
172    constant_strategy: ConstantStrategy::InlineLiteral,
173    elementwise: None,
174    reduction: None,
175    emits_nan: false,
176    notes: "File I/O is not eligible for fusion. Registration exists for documentation completeness only.",
177};
178
179register_builtin_fusion_spec!(FUSION_SPEC);
180
181#[cfg(feature = "doc_export")]
182register_builtin_doc_text!("load", DOC_MD);
183
184#[runtime_builtin(
185    name = "load",
186    category = "io/mat",
187    summary = "Load variables from a MAT-file.",
188    keywords = "load,mat,workspace",
189    accel = "cpu",
190    sink = true
191)]
192fn load_builtin(args: Vec<Value>) -> Result<Value, String> {
193    let eval = evaluate(&args)?;
194    Ok(eval.first_output())
195}
196
197#[derive(Clone, Debug)]
198pub struct LoadEval {
199    variables: Vec<(String, Value)>,
200}
201
202impl LoadEval {
203    pub fn first_output(&self) -> Value {
204        let mut st = StructValue::new();
205        for (name, value) in &self.variables {
206            st.fields.insert(name.clone(), value.clone());
207        }
208        Value::Struct(st)
209    }
210
211    pub fn variables(&self) -> &[(String, Value)] {
212        &self.variables
213    }
214
215    pub fn into_variables(self) -> Vec<(String, Value)> {
216        self.variables
217    }
218}
219
220struct LoadRequest {
221    variables: Vec<String>,
222    regex_patterns: Vec<Regex>,
223}
224
225pub fn evaluate(args: &[Value]) -> Result<LoadEval, String> {
226    let mut host_args = Vec::with_capacity(args.len());
227    for arg in args {
228        host_args.push(gather_if_needed(arg)?);
229    }
230
231    let invocation = parse_invocation(&host_args)?;
232
233    let mut path_value = if let Some(path) = invocation.path_value {
234        path
235    } else {
236        Value::from("matlab.mat")
237    };
238
239    if invocation.path_was_default {
240        if let Ok(override_path) = std::env::var("RUNMAT_LOAD_DEFAULT_PATH") {
241            path_value = Value::from(override_path);
242        }
243    }
244
245    let mut regex_patterns = Vec::with_capacity(invocation.regex_tokens.len());
246    for pattern in invocation.regex_tokens {
247        let regex = Regex::new(&pattern)
248            .map_err(|err| format!("load: invalid regular expression '{pattern}': {err}"))?;
249        regex_patterns.push(regex);
250    }
251
252    let request = LoadRequest {
253        variables: invocation.variables,
254        regex_patterns,
255    };
256    let path = normalise_path(&path_value)?;
257    let entries = read_mat_file(&path)?;
258
259    let selected = select_variables(&entries, &request)?;
260    Ok(LoadEval {
261        variables: selected,
262    })
263}
264
265struct ParsedInvocation {
266    path_value: Option<Value>,
267    path_was_default: bool,
268    variables: Vec<String>,
269    regex_tokens: Vec<String>,
270}
271
272fn parse_invocation(values: &[Value]) -> Result<ParsedInvocation, String> {
273    let mut path_value = None;
274    let mut path_was_default = false;
275    let mut variables = Vec::new();
276    let mut regex_tokens = Vec::new();
277    let mut idx = 0usize;
278    while idx < values.len() {
279        if let Some(flag) = option_token(&values[idx])? {
280            match flag.as_str() {
281                "-mat" => {
282                    idx += 1;
283                    continue;
284                }
285                "-regexp" => {
286                    idx += 1;
287                    if idx >= values.len() {
288                        return Err("load: '-regexp' requires at least one pattern".to_string());
289                    }
290                    while idx < values.len() {
291                        if option_token(&values[idx])?.is_some() {
292                            break;
293                        }
294                        let names = extract_names(&values[idx])?;
295                        if names.is_empty() {
296                            return Err(
297                                "load: '-regexp' requires non-empty pattern strings".to_string()
298                            );
299                        }
300                        regex_tokens.extend(names);
301                        idx += 1;
302                    }
303                    continue;
304                }
305                other => {
306                    return Err(format!("load: unsupported option '{other}'"));
307                }
308            }
309        } else {
310            if path_value.is_none() {
311                path_value = Some(values[idx].clone());
312                idx += 1;
313                continue;
314            }
315            let names = extract_names(&values[idx])?;
316            variables.extend(names);
317            idx += 1;
318        }
319    }
320
321    if path_value.is_none() {
322        path_was_default = true;
323    }
324
325    Ok(ParsedInvocation {
326        path_value,
327        path_was_default,
328        variables,
329        regex_tokens,
330    })
331}
332
333fn normalise_path(value: &Value) -> Result<PathBuf, String> {
334    let raw = value_to_string_scalar(value)
335        .ok_or_else(|| "load: filename must be a character vector or string scalar".to_string())?;
336    let mut path = PathBuf::from(raw);
337    if path.extension().is_none() {
338        path.set_extension("mat");
339    }
340    Ok(path)
341}
342
343fn select_variables(
344    entries: &[(String, Value)],
345    request: &LoadRequest,
346) -> Result<Vec<(String, Value)>, String> {
347    if request.variables.is_empty() && request.regex_patterns.is_empty() {
348        return Ok(entries.to_vec());
349    }
350
351    let mut by_name: HashMap<&str, &Value> = HashMap::with_capacity(entries.len());
352    for (name, value) in entries {
353        by_name.insert(name, value);
354    }
355
356    let mut selected = Vec::new();
357
358    for name in &request.variables {
359        let value = by_name
360            .get(name.as_str())
361            .ok_or_else(|| format!("load: variable '{name}' was not found in the file"))?;
362        insert_or_replace(&mut selected, name, (*value).clone());
363    }
364
365    if !request.regex_patterns.is_empty() {
366        let mut matched = 0usize;
367        for (name, value) in entries {
368            if request
369                .regex_patterns
370                .iter()
371                .any(|regex| regex.is_match(name))
372            {
373                matched += 1;
374                insert_or_replace(&mut selected, name, value.clone());
375            }
376        }
377        if matched == 0 && request.variables.is_empty() {
378            return Err("load: no variables matched '-regexp' patterns".to_string());
379        }
380    }
381
382    if selected.is_empty() {
383        return Err("load: no variables selected".to_string());
384    }
385
386    Ok(selected)
387}
388
389fn insert_or_replace(selected: &mut Vec<(String, Value)>, name: &str, value: Value) {
390    if let Some(entry) = selected.iter_mut().find(|(existing, _)| existing == name) {
391        entry.1 = value;
392    } else {
393        selected.push((name.to_string(), value));
394    }
395}
396
397pub(crate) fn read_mat_file(path: &Path) -> Result<Vec<(String, Value)>, String> {
398    let file = File::open(path)
399        .map_err(|err| format!("load: failed to open '{}': {err}", path.display()))?;
400    let mut reader = BufReader::new(file);
401
402    let mut header = [0u8; MAT_HEADER_LEN];
403    reader
404        .read_exact(&mut header)
405        .map_err(|err| format!("load: failed to read MAT-file header: {err}"))?;
406    if header[126] != b'I' || header[127] != b'M' {
407        return Err("load: file is not a MATLAB Level-5 MAT-file".to_string());
408    }
409
410    let mut variables = Vec::new();
411    while let Some(tagged) = read_tagged(&mut reader, true)? {
412        if tagged.data_type != MI_MATRIX {
413            continue;
414        }
415        let parsed = parse_matrix(&tagged.data)?;
416        let value = mat_array_to_value(parsed.array)?;
417        variables.push((parsed.name, value));
418    }
419    Ok(variables)
420}
421
422struct ParsedMatrix {
423    name: String,
424    array: MatArray,
425}
426
427fn parse_matrix(buffer: &[u8]) -> Result<ParsedMatrix, String> {
428    let mut cursor = Cursor::new(buffer);
429
430    let flags = read_tagged(&mut cursor, false)?
431        .ok_or_else(|| "load: matrix element missing array flags".to_string())?;
432    if flags.data_type != MI_UINT32 || flags.data.len() < 8 {
433        return Err("load: invalid array flags block".to_string());
434    }
435    let flags0 = u32::from_le_bytes(flags.data[0..4].try_into().unwrap());
436    let class_code = flags0 & 0xFF;
437    let mut class = MatClass::from_class_code(class_code)
438        .ok_or_else(|| "load: unsupported MATLAB class".to_string())?;
439    let is_logical = (flags0 & FLAG_LOGICAL) != 0;
440    let has_imag = (flags0 & FLAG_COMPLEX) != 0;
441    if matches!(class, MatClass::Double) && is_logical {
442        class = MatClass::Logical;
443    }
444
445    let dims_elem = read_tagged(&mut cursor, false)?
446        .ok_or_else(|| "load: matrix element missing dimensions".to_string())?;
447    if dims_elem.data_type != MI_INT32 {
448        return Err("load: dimension block must use MI_INT32".to_string());
449    }
450    if dims_elem.data.is_empty() || dims_elem.data.len() % 4 != 0 {
451        return Err("load: malformed dimension block".to_string());
452    }
453    let mut dims = Vec::with_capacity(dims_elem.data.len() / 4);
454    for chunk in dims_elem.data.chunks_exact(4) {
455        let value = i32::from_le_bytes(chunk.try_into().unwrap());
456        if value < 0 {
457            return Err("load: negative dimensions are not supported".to_string());
458        }
459        dims.push(value as usize);
460    }
461    if dims.is_empty() {
462        dims.push(1);
463        dims.push(1);
464    }
465
466    let name_elem = read_tagged(&mut cursor, false)?
467        .ok_or_else(|| "load: matrix element missing name".to_string())?;
468    let name = match name_elem.data_type {
469        MI_INT8 | MI_UINT8 => bytes_to_string(&name_elem.data),
470        MI_UINT16 => {
471            let mut bytes = Vec::with_capacity(name_elem.data.len());
472            for chunk in name_elem.data.chunks_exact(2) {
473                let code = u16::from_le_bytes(chunk.try_into().unwrap());
474                if code == 0 {
475                    break;
476                }
477                if let Some(ch) = char::from_u32(code as u32) {
478                    bytes.push(ch);
479                }
480            }
481            bytes.into_iter().collect()
482        }
483        _ => {
484            return Err("load: unsupported array name encoding".to_string());
485        }
486    };
487
488    let array = match class {
489        MatClass::Double => parse_double_array(&mut cursor, dims, has_imag)?,
490        MatClass::Logical => parse_logical_array(&mut cursor, dims)?,
491        MatClass::Char => parse_char_array(&mut cursor, dims)?,
492        MatClass::Cell => parse_cell_array(&mut cursor, dims)?,
493        MatClass::Struct => parse_struct(&mut cursor, dims)?,
494    };
495
496    Ok(ParsedMatrix { name, array })
497}
498
499fn parse_double_array(
500    cursor: &mut Cursor<&[u8]>,
501    dims: Vec<usize>,
502    has_imag: bool,
503) -> Result<MatArray, String> {
504    let real_elem = read_tagged(cursor, false)?
505        .ok_or_else(|| "load: numeric array missing real component".to_string())?;
506    if real_elem.data_type != MI_DOUBLE || real_elem.data.len() % 8 != 0 {
507        return Err("load: numeric data must be stored as MI_DOUBLE".to_string());
508    }
509    let mut real = Vec::with_capacity(real_elem.data.len() / 8);
510    for chunk in real_elem.data.chunks_exact(8) {
511        real.push(f64::from_le_bytes(chunk.try_into().unwrap()));
512    }
513
514    let imag = if has_imag {
515        let imag_elem = read_tagged(cursor, false)?
516            .ok_or_else(|| "load: numeric array missing imaginary component".to_string())?;
517        if imag_elem.data_type != MI_DOUBLE || imag_elem.data.len() % 8 != 0 {
518            return Err("load: imaginary component must be MI_DOUBLE".to_string());
519        }
520        let mut imag = Vec::with_capacity(imag_elem.data.len() / 8);
521        for chunk in imag_elem.data.chunks_exact(8) {
522            imag.push(f64::from_le_bytes(chunk.try_into().unwrap()));
523        }
524        Some(imag)
525    } else {
526        None
527    };
528
529    Ok(MatArray {
530        class: MatClass::Double,
531        dims,
532        data: MatData::Double { real, imag },
533    })
534}
535
536fn parse_logical_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> Result<MatArray, String> {
537    let elem = read_tagged(cursor, false)?
538        .ok_or_else(|| "load: logical array missing data block".to_string())?;
539    if elem.data_type != MI_UINT8 {
540        return Err("load: logical arrays must be stored as MI_UINT8".to_string());
541    }
542    Ok(MatArray {
543        class: MatClass::Logical,
544        dims,
545        data: MatData::Logical { data: elem.data },
546    })
547}
548
549fn parse_char_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> Result<MatArray, String> {
550    let elem = read_tagged(cursor, false)?
551        .ok_or_else(|| "load: character array missing data block".to_string())?;
552    if elem.data_type != MI_UINT16 {
553        return Err("load: character data must be stored as MI_UINT16".to_string());
554    }
555    if elem.data.len() % 2 != 0 {
556        return Err("load: malformed character data".to_string());
557    }
558    let mut data = Vec::with_capacity(elem.data.len() / 2);
559    for chunk in elem.data.chunks_exact(2) {
560        data.push(u16::from_le_bytes(chunk.try_into().unwrap()));
561    }
562    Ok(MatArray {
563        class: MatClass::Char,
564        dims,
565        data: MatData::Char { data },
566    })
567}
568
569fn parse_cell_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> Result<MatArray, String> {
570    let total: usize = dims
571        .iter()
572        .copied()
573        .fold(1usize, |acc, d| acc.saturating_mul(d));
574    let mut elements = Vec::with_capacity(total);
575    for _ in 0..total {
576        let elem = read_tagged(cursor, false)?
577            .ok_or_else(|| "load: cell element missing matrix payload".to_string())?;
578        if elem.data_type != MI_MATRIX {
579            return Err("load: cell elements must be matrices".to_string());
580        }
581        let parsed = parse_matrix(&elem.data)?;
582        elements.push(parsed.array);
583    }
584    Ok(MatArray {
585        class: MatClass::Cell,
586        dims,
587        data: MatData::Cell { elements },
588    })
589}
590
591fn parse_struct(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> Result<MatArray, String> {
592    if dims.len() != 2 || dims[0] != 1 || dims[1] != 1 {
593        return Err("load: struct arrays are not supported yet".to_string());
594    }
595    let len_elem = read_tagged(cursor, false)?
596        .ok_or_else(|| "load: struct missing maximum field length specifier".to_string())?;
597    if len_elem.data_type != MI_INT32 || len_elem.data.len() != 4 {
598        return Err("load: struct field length must be MI_INT32".to_string());
599    }
600    let max_len = i32::from_le_bytes(len_elem.data[..4].try_into().unwrap());
601    if max_len <= 0 {
602        return Err("load: struct field length must be positive".to_string());
603    }
604
605    let names_elem = read_tagged(cursor, false)?
606        .ok_or_else(|| "load: struct missing field name table".to_string())?;
607    if names_elem.data_type != MI_INT8 && names_elem.data_type != MI_UINT8 {
608        return Err("load: struct field names must be stored as MI_INT8/MI_UINT8".to_string());
609    }
610    if names_elem.data.len() % (max_len as usize) != 0 {
611        return Err("load: malformed struct field name table".to_string());
612    }
613    let field_count = names_elem.data.len() / (max_len as usize);
614    let mut field_names = Vec::with_capacity(field_count);
615    for i in 0..field_count {
616        let start = i * (max_len as usize);
617        let end = start + (max_len as usize);
618        let slice = &names_elem.data[start..end];
619        field_names.push(bytes_to_string(slice));
620    }
621
622    let mut field_values = Vec::with_capacity(field_count);
623    for _ in 0..field_count {
624        let elem = read_tagged(cursor, false)?
625            .ok_or_else(|| "load: struct field missing matrix payload".to_string())?;
626        if elem.data_type != MI_MATRIX {
627            return Err("load: struct fields must be matrices".to_string());
628        }
629        let parsed = parse_matrix(&elem.data)?;
630        field_values.push(parsed.array);
631    }
632
633    Ok(MatArray {
634        class: MatClass::Struct,
635        dims,
636        data: MatData::Struct {
637            field_names,
638            field_values,
639        },
640    })
641}
642
643fn mat_array_to_value(array: MatArray) -> Result<Value, String> {
644    match array.data {
645        MatData::Double { real, imag } => {
646            let len = real.len();
647            if let Some(imag) = imag {
648                if imag.len() != len {
649                    return Err("load: complex data has mismatched real/imag parts".to_string());
650                }
651                if len == 1 {
652                    Ok(Value::Complex(real[0], imag[0]))
653                } else {
654                    let mut pairs = Vec::with_capacity(len);
655                    for i in 0..len {
656                        pairs.push((real[i], imag[i]));
657                    }
658                    let tensor = ComplexTensor::new(pairs, array.dims.clone())
659                        .map_err(|e| format!("load: {e}"))?;
660                    Ok(Value::ComplexTensor(tensor))
661                }
662            } else if len == 1 {
663                Ok(Value::Num(real[0]))
664            } else {
665                let tensor =
666                    Tensor::new(real, array.dims.clone()).map_err(|e| format!("load: {e}"))?;
667                Ok(Value::Tensor(tensor))
668            }
669        }
670        MatData::Logical { data } => {
671            let total: usize = array
672                .dims
673                .iter()
674                .copied()
675                .fold(1usize, |acc, d| acc.saturating_mul(d));
676            if data.len() != total {
677                return Err("load: logical data length mismatch".to_string());
678            }
679            if total == 1 {
680                Ok(Value::Bool(data.first().copied().unwrap_or(0) != 0))
681            } else {
682                let logical = LogicalArray::new(data, array.dims.clone())
683                    .map_err(|e| format!("load: {e}"))?;
684                Ok(Value::LogicalArray(logical))
685            }
686        }
687        MatData::Char { data } => {
688            let rows = array.dims.first().copied().unwrap_or(1);
689            let cols = array.dims.get(1).copied().unwrap_or(1);
690            let mut chars = Vec::with_capacity(rows.saturating_mul(cols));
691            for code in data {
692                let ch = char::from_u32(code as u32).unwrap_or('\u{FFFD}');
693                chars.push(ch);
694            }
695            let char_array = CharArray::new(chars, rows, cols).map_err(|e| format!("load: {e}"))?;
696            Ok(Value::CharArray(char_array))
697        }
698        MatData::Cell { elements } => {
699            if let Some(strings) = cell_elements_to_strings(&elements) {
700                let string_array = StringArray::new(strings, array.dims.clone())
701                    .map_err(|e| format!("load: {e}"))?;
702                return Ok(Value::StringArray(string_array));
703            }
704            if array.dims.len() != 2 {
705                return Err(
706                    "load: cell arrays with more than two dimensions are not supported yet"
707                        .to_string(),
708                );
709            }
710            let rows = array.dims[0];
711            let cols = array.dims[1];
712            let expected = rows.saturating_mul(cols);
713            if elements.len() != expected {
714                return Err("load: cell array element count mismatch".to_string());
715            }
716            let mut converted = Vec::with_capacity(elements.len());
717            for elem in elements {
718                converted.push(mat_array_to_value(elem)?);
719            }
720            let mut row_major = vec![Value::Num(0.0); expected];
721            for col in 0..cols {
722                for row in 0..rows {
723                    let cm_idx = col * rows + row;
724                    let rm_idx = row * cols + col;
725                    row_major[rm_idx] = converted[cm_idx].clone();
726                }
727            }
728            make_cell(row_major, rows, cols)
729        }
730        MatData::Struct {
731            field_names,
732            field_values,
733        } => {
734            if field_names.len() != field_values.len() {
735                return Err("load: struct field metadata is inconsistent".to_string());
736            }
737            let mut st = StructValue::new();
738            for (name, value) in field_names.into_iter().zip(field_values.into_iter()) {
739                let converted = mat_array_to_value(value)?;
740                st.fields.insert(name, converted);
741            }
742            Ok(Value::Struct(st))
743        }
744    }
745}
746
747fn cell_elements_to_strings(elements: &[MatArray]) -> Option<Vec<String>> {
748    let mut strings = Vec::with_capacity(elements.len());
749    for element in elements {
750        if element.class != MatClass::Char {
751            return None;
752        }
753        let rows = element.dims.first().copied().unwrap_or(1);
754        if rows > 1 {
755            return None;
756        }
757        match &element.data {
758            MatData::Char { data } => strings.push(utf16_codes_to_string(data)),
759            _ => return None,
760        }
761    }
762    Some(strings)
763}
764
765fn utf16_codes_to_string(data: &[u16]) -> String {
766    let mut chars: Vec<char> = data
767        .iter()
768        .map(|code| char::from_u32(*code as u32).unwrap_or('\u{FFFD}'))
769        .collect();
770    while matches!(chars.last(), Some(&'\0')) {
771        chars.pop();
772    }
773    chars.into_iter().collect()
774}
775
776fn option_token(value: &Value) -> Result<Option<String>, String> {
777    if let Some(token) = value_to_string_scalar(value) {
778        if token.starts_with('-') {
779            return Ok(Some(token.to_ascii_lowercase()));
780        }
781    }
782    Ok(None)
783}
784
785fn extract_names(value: &Value) -> Result<Vec<String>, String> {
786    match value {
787        Value::String(s) => Ok(vec![s.clone()]),
788        Value::CharArray(ca) => Ok(char_array_rows_as_strings(ca)),
789        Value::StringArray(sa) => Ok(sa.data.clone()),
790        Value::Cell(ca) => {
791            let mut names = Vec::with_capacity(ca.data.len());
792            for handle in &ca.data {
793                let inner = unsafe { &*handle.as_raw() };
794                let text = value_to_string_scalar(inner).ok_or_else(|| {
795                    "load: cell arrays used for variable selection must contain string scalars"
796                        .to_string()
797                })?;
798                names.push(text);
799            }
800            Ok(names)
801        }
802        other => {
803            let gathered = gather_if_needed(other)?;
804            extract_names(&gathered)
805        }
806    }
807}
808
809fn value_to_string_scalar(value: &Value) -> Option<String> {
810    match value {
811        Value::String(s) => Some(s.clone()),
812        Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
813        Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
814        _ => None,
815    }
816}
817
818fn char_array_rows_as_strings(ca: &CharArray) -> Vec<String> {
819    let mut rows = Vec::with_capacity(ca.rows);
820    for r in 0..ca.rows {
821        let mut row = String::with_capacity(ca.cols);
822        for c in 0..ca.cols {
823            let idx = r * ca.cols + c;
824            row.push(ca.data[idx]);
825        }
826        let trimmed = row.trim_end_matches([' ', '\0']).to_string();
827        rows.push(trimmed);
828    }
829    rows
830}
831
832fn bytes_to_string(bytes: &[u8]) -> String {
833    let trimmed = bytes
834        .iter()
835        .copied()
836        .take_while(|b| *b != 0)
837        .collect::<Vec<u8>>();
838    String::from_utf8(trimmed).unwrap_or_default()
839}
840
841struct TaggedData {
842    data_type: u32,
843    data: Vec<u8>,
844}
845
846fn read_tagged<R: Read>(reader: &mut R, allow_eof: bool) -> Result<Option<TaggedData>, String> {
847    let mut type_bytes = [0u8; 4];
848    match reader.read_exact(&mut type_bytes) {
849        Ok(()) => {}
850        Err(err) => {
851            if allow_eof && err.kind() == std::io::ErrorKind::UnexpectedEof {
852                return Ok(None);
853            }
854            return Err(format!("load: failed to read MAT element header: {err}"));
855        }
856    }
857
858    let type_field = u32::from_le_bytes(type_bytes);
859    if (type_field & 0xFFFF0000) != 0 {
860        let data_type = type_field & 0x0000FFFF;
861        let num_bytes = ((type_field & 0xFFFF0000) >> 16) as usize;
862        let mut inline = [0u8; 4];
863        reader
864            .read_exact(&mut inline)
865            .map_err(|err| format!("load: failed to read compact MAT element: {err}"))?;
866        let mut data = inline[..num_bytes.min(4)].to_vec();
867        data.truncate(num_bytes.min(4));
868        Ok(Some(TaggedData { data_type, data }))
869    } else {
870        let mut len_bytes = [0u8; 4];
871        reader
872            .read_exact(&mut len_bytes)
873            .map_err(|err| format!("load: failed to read MAT element length: {err}"))?;
874        let length = u32::from_le_bytes(len_bytes) as usize;
875        let mut data = vec![0u8; length];
876        reader
877            .read_exact(&mut data)
878            .map_err(|err| format!("load: failed to read MAT element body: {err}"))?;
879        let padding = (8 - (length % 8)) % 8;
880        if padding != 0 {
881            let mut pad = vec![0u8; padding];
882            reader
883                .read_exact(&mut pad)
884                .map_err(|err| format!("load: failed to read MAT padding: {err}"))?;
885        }
886        Ok(Some(TaggedData {
887            data_type: type_field,
888            data,
889        }))
890    }
891}
892
893#[cfg(test)]
894mod tests {
895    use super::*;
896    use crate::workspace::WorkspaceResolver;
897    use once_cell::sync::OnceCell;
898    use runmat_builtins::StringArray;
899    use std::cell::RefCell;
900    use std::collections::HashMap;
901    use tempfile::tempdir;
902
903    thread_local! {
904        static TEST_WORKSPACE: RefCell<HashMap<String, Value>> = RefCell::new(HashMap::new());
905    }
906
907    fn ensure_test_resolver() {
908        static INIT: OnceCell<()> = OnceCell::new();
909        INIT.get_or_init(|| {
910            crate::workspace::register_workspace_resolver(WorkspaceResolver {
911                lookup: |name| TEST_WORKSPACE.with(|slot| slot.borrow().get(name).cloned()),
912                snapshot: || {
913                    let mut entries: Vec<(String, Value)> =
914                        TEST_WORKSPACE.with(|slot| slot.borrow().clone().into_iter().collect());
915                    entries.sort_by(|a, b| a.0.cmp(&b.0));
916                    entries
917                },
918                globals: || Vec::new(),
919            });
920        });
921    }
922
923    fn set_workspace(entries: &[(&str, Value)]) {
924        TEST_WORKSPACE.with(|slot| {
925            let mut map = slot.borrow_mut();
926            map.clear();
927            for (name, value) in entries {
928                map.insert((*name).to_string(), value.clone());
929            }
930        });
931    }
932
933    #[test]
934    fn load_roundtrip_numeric() {
935        ensure_test_resolver();
936        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0], vec![2, 2]).unwrap();
937        set_workspace(&[("A", Value::Tensor(tensor))]);
938
939        let dir = tempdir().unwrap();
940        let path = dir.path().join("numeric.mat");
941        let save_arg = Value::from(path.to_string_lossy().to_string());
942        crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
943
944        let eval =
945            evaluate(&[Value::from(path.to_string_lossy().to_string())]).expect("load numeric");
946        let struct_value = eval.first_output();
947        match struct_value {
948            Value::Struct(sv) => {
949                assert!(sv.fields.contains_key("A"));
950                match sv.fields.get("A").unwrap() {
951                    Value::Tensor(t) => {
952                        assert_eq!(t.shape, vec![2, 2]);
953                        assert_eq!(t.data, vec![1.0, 4.0, 2.0, 5.0]);
954                    }
955                    other => panic!("expected tensor, got {other:?}"),
956                }
957            }
958            other => panic!("expected struct, got {other:?}"),
959        }
960    }
961
962    #[test]
963    fn load_selected_variables() {
964        ensure_test_resolver();
965        set_workspace(&[("signal", Value::Num(42.0)), ("noise", Value::Num(5.0))]);
966        let dir = tempdir().unwrap();
967        let path = dir.path().join("selection.mat");
968        let save_arg = Value::from(path.to_string_lossy().to_string());
969        crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
970
971        let eval = evaluate(&[
972            Value::from(path.to_string_lossy().to_string()),
973            Value::from("signal"),
974        ])
975        .expect("load selection");
976        let vars = eval.variables();
977        assert_eq!(vars.len(), 1);
978        assert_eq!(vars[0].0, "signal");
979        assert!(matches!(vars[0].1, Value::Num(42.0)));
980    }
981
982    #[test]
983    fn load_regex_selection() {
984        ensure_test_resolver();
985        set_workspace(&[
986            ("w1", Value::Num(1.0)),
987            ("w2", Value::Num(2.0)),
988            ("bias", Value::Num(3.0)),
989        ]);
990        let dir = tempdir().unwrap();
991        let path = dir.path().join("regex.mat");
992        let save_arg = Value::from(path.to_string_lossy().to_string());
993        crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
994
995        let eval = evaluate(&[
996            Value::from(path.to_string_lossy().to_string()),
997            Value::from("-regexp"),
998            Value::from("^w\\d$"),
999        ])
1000        .expect("load regex");
1001        let mut names: Vec<_> = eval.variables().iter().map(|(n, _)| n.clone()).collect();
1002        names.sort();
1003        assert_eq!(names, vec!["w1".to_string(), "w2".to_string()]);
1004    }
1005
1006    #[test]
1007    fn load_missing_variable_errors() {
1008        ensure_test_resolver();
1009        set_workspace(&[("existing", Value::Num(7.0))]);
1010        let dir = tempdir().unwrap();
1011        let path = dir.path().join("missing.mat");
1012        let save_arg = Value::from(path.to_string_lossy().to_string());
1013        crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
1014
1015        let err = evaluate(&[
1016            Value::from(path.to_string_lossy().to_string()),
1017            Value::from("missing"),
1018        ])
1019        .expect_err("expect missing variable error");
1020        assert!(err.contains("variable 'missing' was not found"));
1021    }
1022
1023    #[test]
1024    fn load_string_array_roundtrip() {
1025        ensure_test_resolver();
1026        let strings = StringArray::new(vec!["foo".into(), "bar".into()], vec![1, 2]).unwrap();
1027        set_workspace(&[("labels", Value::StringArray(strings))]);
1028        let dir = tempdir().unwrap();
1029        let path = dir.path().join("strings.mat");
1030        let save_arg = Value::from(path.to_string_lossy().to_string());
1031        crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
1032
1033        let eval =
1034            evaluate(&[Value::from(path.to_string_lossy().to_string())]).expect("load strings");
1035        let struct_value = eval.first_output();
1036        match struct_value {
1037            Value::Struct(sv) => {
1038                let value = sv
1039                    .fields
1040                    .get("labels")
1041                    .expect("labels field missing in struct");
1042                match value {
1043                    Value::StringArray(sa) => {
1044                        assert_eq!(sa.shape, vec![1, 2]);
1045                        assert_eq!(sa.data, vec![String::from("foo"), String::from("bar")]);
1046                    }
1047                    other => panic!("expected string array, got {other:?}"),
1048                }
1049            }
1050            other => panic!("expected struct, got {other:?}"),
1051        }
1052    }
1053
1054    #[test]
1055    fn load_option_before_filename() {
1056        ensure_test_resolver();
1057        set_workspace(&[("alpha", Value::Num(1.0)), ("beta", Value::Num(2.0))]);
1058        let dir = tempdir().unwrap();
1059        let path = dir.path().join("option_first.mat");
1060        let save_arg = Value::from(path.to_string_lossy().to_string());
1061        crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
1062
1063        let eval = evaluate(&[
1064            Value::from("-mat"),
1065            Value::from(path.to_string_lossy().to_string()),
1066            Value::from("beta"),
1067        ])
1068        .expect("load with option first");
1069        let vars = eval.variables();
1070        assert_eq!(vars.len(), 1);
1071        assert_eq!(vars[0].0, "beta");
1072        assert!(matches!(vars[0].1, Value::Num(2.0)));
1073    }
1074
1075    #[test]
1076    fn load_char_array_names_trimmed() {
1077        ensure_test_resolver();
1078        set_workspace(&[("short", Value::Num(5.0)), ("longer", Value::Num(9.0))]);
1079        let dir = tempdir().unwrap();
1080        let path = dir.path().join("char_names.mat");
1081        let save_arg = Value::from(path.to_string_lossy().to_string());
1082        crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
1083
1084        let cols = 6;
1085        let mut data = Vec::new();
1086        for name in ["short", "longer"] {
1087            let mut chars: Vec<char> = name.chars().collect();
1088            while chars.len() < cols {
1089                chars.push(' ');
1090            }
1091            data.extend(chars);
1092        }
1093        let name_array = CharArray::new(data, 2, cols).unwrap();
1094
1095        let eval = evaluate(&[
1096            Value::from(path.to_string_lossy().to_string()),
1097            Value::CharArray(name_array),
1098        ])
1099        .expect("load with char array names");
1100        let vars = eval.variables();
1101        assert_eq!(vars.len(), 2);
1102        assert_eq!(vars[0].0, "short");
1103        assert!(matches!(vars[0].1, Value::Num(5.0)));
1104        assert_eq!(vars[1].0, "longer");
1105        assert!(matches!(vars[1].1, Value::Num(9.0)));
1106    }
1107
1108    #[test]
1109    fn load_duplicate_names_last_wins() {
1110        ensure_test_resolver();
1111        set_workspace(&[("dup", Value::Num(11.0))]);
1112        let dir = tempdir().unwrap();
1113        let path = dir.path().join("duplicates.mat");
1114        let save_arg = Value::from(path.to_string_lossy().to_string());
1115        crate::call_builtin("save", std::slice::from_ref(&save_arg)).unwrap();
1116
1117        let eval = evaluate(&[
1118            Value::from(path.to_string_lossy().to_string()),
1119            Value::from("dup"),
1120            Value::from("dup"),
1121        ])
1122        .expect("load with duplicate names");
1123        let vars = eval.variables();
1124        assert_eq!(vars.len(), 1);
1125        assert_eq!(vars[0].0, "dup");
1126        assert!(matches!(vars[0].1, Value::Num(11.0)));
1127    }
1128
1129    #[test]
1130    #[cfg(feature = "wgpu")]
1131    fn load_wgpu_tensor_roundtrip() {
1132        ensure_test_resolver();
1133        if runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1134            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1135        )
1136        .is_err()
1137        {
1138            return;
1139        }
1140        let Some(provider) = runmat_accelerate_api::provider() else {
1141            return;
1142        };
1143
1144        use runmat_accelerate_api::HostTensorView;
1145
1146        let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap();
1147        let view = HostTensorView {
1148            data: &tensor.data,
1149            shape: &tensor.shape,
1150        };
1151        let handle = provider.upload(&view).expect("upload tensor");
1152        set_workspace(&[("gpu_var", Value::GpuTensor(handle))]);
1153
1154        let dir = tempdir().unwrap();
1155        let path = dir.path().join("wgpu_load.mat");
1156        let save_args = vec![
1157            Value::from(path.to_string_lossy().to_string()),
1158            Value::from("gpu_var"),
1159        ];
1160        crate::call_builtin("save", &save_args).unwrap();
1161
1162        let eval =
1163            evaluate(&[Value::from(path.to_string_lossy().to_string())]).expect("load wgpu file");
1164        let struct_value = eval.first_output();
1165        match struct_value {
1166            Value::Struct(sv) => match sv.fields.get("gpu_var") {
1167                Some(Value::Tensor(t)) => {
1168                    assert_eq!(t.shape, vec![2, 2]);
1169                    assert_eq!(t.data, tensor.data);
1170                }
1171                other => panic!("expected tensor, got {other:?}"),
1172            },
1173            other => panic!("expected struct, got {other:?}"),
1174        }
1175    }
1176
1177    #[test]
1178    #[cfg(feature = "doc_export")]
1179    fn doc_examples_present() {
1180        let blocks = crate::builtins::common::test_support::doc_examples(DOC_MD);
1181        assert!(!blocks.is_empty());
1182    }
1183}