Skip to main content

runmat_runtime/builtins/io/mat/
load.rs

1//! MATLAB-compatible `load` builtin for RunMat.
2
3use std::collections::HashMap;
4use std::io::{BufReader, Cursor, Read};
5use std::path::{Path, PathBuf};
6
7use regex::Regex;
8use runmat_builtins::{
9    CharArray, ComplexTensor, LogicalArray, StringArray, StructValue, Tensor, Value,
10};
11use runmat_filesystem::File;
12use runmat_macros::runtime_builtin;
13
14use super::format::{
15    MatArray, MatClass, MatData, FLAG_COMPLEX, FLAG_LOGICAL, MAT_HEADER_LEN, MI_DOUBLE, MI_INT32,
16    MI_INT8, MI_MATRIX, MI_UINT16, MI_UINT32, MI_UINT8,
17};
18use crate::builtins::common::spec::{
19    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
20    ReductionNaN, ResidencyPolicy, ShapeRequirements,
21};
22use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::mat::load")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26    name: "load",
27    op_kind: GpuOpKind::Custom("io-load"),
28    supported_precisions: &[],
29    broadcast: BroadcastSemantics::None,
30    provider_hooks: &[],
31    constant_strategy: ConstantStrategy::InlineLiteral,
32    residency: ResidencyPolicy::NewHandle,
33    nan_mode: ReductionNaN::Include,
34    two_pass_threshold: None,
35    workgroup_size: None,
36    accepts_nan_mode: false,
37    notes: "Reads MAT-files on the host and produces CPU-resident values. Providers are not involved until accelerated code later promotes the results.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::mat::load")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42    name: "load",
43    shape: ShapeRequirements::Any,
44    constant_strategy: ConstantStrategy::InlineLiteral,
45    elementwise: None,
46    reduction: None,
47    emits_nan: false,
48    notes: "File I/O is not eligible for fusion. Registration exists for documentation completeness only.",
49};
50
51#[runtime_builtin(
52    name = "load",
53    category = "io/mat",
54    summary = "Load variables from a MAT-file.",
55    keywords = "load,mat,workspace",
56    accel = "cpu",
57    sink = true,
58    type_resolver(crate::builtins::io::type_resolvers::load_type),
59    builtin_path = "crate::builtins::io::mat::load"
60)]
61async fn load_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
62    let eval = evaluate(&args).await?;
63
64    // current_output_count() is set by the dispatcher only for multi-output Unpack patterns
65    // like `[a, b] = load(...)`. Guard against requesting more than one struct output.
66    if let Some(n) = crate::output_count::current_output_count() {
67        if n > 1 {
68            return Err(load_error("load supports at most one output argument"));
69        }
70    }
71
72    // The VM sets output_context::requested_output_count() at every call site before
73    // dispatching:
74    //   Some(0) → statement-level call (result is discarded or printed without capture)
75    //             → assign loaded variables directly into the caller's workspace.
76    //   Some(1) → single-output assignment `S = load(...)` → return a struct.
77    //   None    → called outside the VM (e.g. directly from Rust) → return a struct.
78    if crate::output_context::requested_output_count() == Some(0) {
79        for (name, value) in eval.variables() {
80            crate::workspace::assign(name, value.clone()).map_err(|err| load_error(err))?;
81        }
82        return Ok(Value::OutputList(Vec::new()));
83    }
84
85    Ok(eval.first_output())
86}
87
88#[derive(Clone, Debug)]
89pub struct LoadEval {
90    variables: Vec<(String, Value)>,
91}
92
93impl LoadEval {
94    pub fn first_output(&self) -> Value {
95        let mut st = StructValue::new();
96        for (name, value) in &self.variables {
97            st.fields.insert(name.clone(), value.clone());
98        }
99        Value::Struct(st)
100    }
101
102    pub fn variables(&self) -> &[(String, Value)] {
103        &self.variables
104    }
105
106    pub fn into_variables(self) -> Vec<(String, Value)> {
107        self.variables
108    }
109}
110
111struct LoadRequest {
112    variables: Vec<String>,
113    regex_patterns: Vec<Regex>,
114}
115
116const BUILTIN_NAME: &str = "load";
117
118fn load_error(message: impl Into<String>) -> RuntimeError {
119    build_runtime_error(message)
120        .with_builtin(BUILTIN_NAME)
121        .build()
122}
123
124fn load_error_with_source(
125    message: impl Into<String>,
126    source: impl std::error::Error + Send + Sync + 'static,
127) -> RuntimeError {
128    build_runtime_error(message)
129        .with_builtin(BUILTIN_NAME)
130        .with_source(source)
131        .build()
132}
133
134pub async fn evaluate(args: &[Value]) -> BuiltinResult<LoadEval> {
135    let mut host_args = Vec::with_capacity(args.len());
136    for arg in args {
137        host_args.push(gather_if_needed_async(arg).await?);
138    }
139
140    let invocation = parse_invocation(&host_args).await?;
141
142    let mut path_value = if let Some(path) = invocation.path_value {
143        path
144    } else {
145        Value::from("matlab.mat")
146    };
147
148    if invocation.path_was_default {
149        if let Ok(override_path) = std::env::var("RUNMAT_LOAD_DEFAULT_PATH") {
150            path_value = Value::from(override_path);
151        }
152    }
153
154    let mut regex_patterns = Vec::with_capacity(invocation.regex_tokens.len());
155    for pattern in invocation.regex_tokens {
156        let regex = Regex::new(&pattern).map_err(|err| {
157            load_error_with_source(
158                format!("load: invalid regular expression '{pattern}': {err}"),
159                err,
160            )
161        })?;
162        regex_patterns.push(regex);
163    }
164
165    let request = LoadRequest {
166        variables: invocation.variables,
167        regex_patterns,
168    };
169    let path = normalise_path(&path_value)?;
170    let entries = read_mat_file(&path)?;
171
172    let selected = select_variables(&entries, &request)?;
173    Ok(LoadEval {
174        variables: selected,
175    })
176}
177
178struct ParsedInvocation {
179    path_value: Option<Value>,
180    path_was_default: bool,
181    variables: Vec<String>,
182    regex_tokens: Vec<String>,
183}
184
185async fn parse_invocation(values: &[Value]) -> BuiltinResult<ParsedInvocation> {
186    let mut path_value = None;
187    let mut path_was_default = false;
188    let mut variables = Vec::new();
189    let mut regex_tokens = Vec::new();
190    let mut idx = 0usize;
191    while idx < values.len() {
192        if let Some(flag) = option_token(&values[idx])? {
193            match flag.as_str() {
194                "-mat" => {
195                    idx += 1;
196                    continue;
197                }
198                "-regexp" => {
199                    idx += 1;
200                    if idx >= values.len() {
201                        return Err(load_error("load: '-regexp' requires at least one pattern"));
202                    }
203                    while idx < values.len() {
204                        if option_token(&values[idx])?.is_some() {
205                            break;
206                        }
207                        let names = extract_names(&values[idx]).await?;
208                        if names.is_empty() {
209                            return Err(load_error(
210                                "load: '-regexp' requires non-empty pattern strings",
211                            ));
212                        }
213                        regex_tokens.extend(names);
214                        idx += 1;
215                    }
216                    continue;
217                }
218                other => {
219                    return Err(load_error(format!("load: unsupported option '{other}'")));
220                }
221            }
222        } else {
223            if path_value.is_none() {
224                path_value = Some(values[idx].clone());
225                idx += 1;
226                continue;
227            }
228            let names = extract_names(&values[idx]).await?;
229            variables.extend(names);
230            idx += 1;
231        }
232    }
233
234    if path_value.is_none() {
235        path_was_default = true;
236    }
237
238    Ok(ParsedInvocation {
239        path_value,
240        path_was_default,
241        variables,
242        regex_tokens,
243    })
244}
245
246fn normalise_path(value: &Value) -> BuiltinResult<PathBuf> {
247    let raw = value_to_string_scalar(value)
248        .ok_or_else(|| load_error("load: filename must be a character vector or string scalar"))?;
249    let mut path = PathBuf::from(raw);
250    if path.extension().is_none() {
251        path.set_extension("mat");
252    }
253    Ok(path)
254}
255
256fn select_variables(
257    entries: &[(String, Value)],
258    request: &LoadRequest,
259) -> BuiltinResult<Vec<(String, Value)>> {
260    if request.variables.is_empty() && request.regex_patterns.is_empty() {
261        return Ok(entries.to_vec());
262    }
263
264    let mut by_name: HashMap<&str, &Value> = HashMap::with_capacity(entries.len());
265    for (name, value) in entries {
266        by_name.insert(name, value);
267    }
268
269    let mut selected = Vec::new();
270
271    for name in &request.variables {
272        let value = by_name.get(name.as_str()).ok_or_else(|| {
273            load_error(format!("load: variable '{name}' was not found in the file"))
274        })?;
275        insert_or_replace(&mut selected, name, (*value).clone());
276    }
277
278    if !request.regex_patterns.is_empty() {
279        let mut matched = 0usize;
280        for (name, value) in entries {
281            if request
282                .regex_patterns
283                .iter()
284                .any(|regex| regex.is_match(name))
285            {
286                matched += 1;
287                insert_or_replace(&mut selected, name, value.clone());
288            }
289        }
290        if matched == 0 && request.variables.is_empty() {
291            return Err(load_error("load: no variables matched '-regexp' patterns"));
292        }
293    }
294
295    if selected.is_empty() {
296        return Err(load_error("load: no variables selected"));
297    }
298
299    Ok(selected)
300}
301
302fn insert_or_replace(selected: &mut Vec<(String, Value)>, name: &str, value: Value) {
303    if let Some(entry) = selected.iter_mut().find(|(existing, _)| existing == name) {
304        entry.1 = value;
305    } else {
306        selected.push((name.to_string(), value));
307    }
308}
309
310pub(crate) fn read_mat_file_for_builtin(
311    path: &Path,
312    builtin: &str,
313) -> crate::BuiltinResult<Vec<(String, Value)>> {
314    match read_mat_file(path) {
315        Ok(entries) => Ok(entries),
316        Err(err) => {
317            let message = err.message().replacen("load:", &format!("{builtin}:"), 1);
318            let mut builder = build_runtime_error(message).with_builtin(builtin);
319            if let Some(identifier) = err.identifier() {
320                builder = builder.with_identifier(identifier);
321            }
322            Err(builder.with_source(err).build())
323        }
324    }
325}
326
327pub(crate) fn read_mat_file(path: &Path) -> BuiltinResult<Vec<(String, Value)>> {
328    let file = File::open(path).map_err(|err| {
329        load_error_with_source(
330            format!("load: failed to open '{}': {err}", path.display()),
331            err,
332        )
333    })?;
334    let mut reader = BufReader::new(file);
335    read_mat_reader(&mut reader)
336}
337
338pub fn decode_workspace_from_mat_bytes(bytes: &[u8]) -> BuiltinResult<Vec<(String, Value)>> {
339    let mut cursor = Cursor::new(bytes);
340    read_mat_reader(&mut cursor)
341}
342
343fn read_mat_reader<R: Read>(reader: &mut R) -> BuiltinResult<Vec<(String, Value)>> {
344    let mut header = [0u8; MAT_HEADER_LEN];
345    reader.read_exact(&mut header).map_err(|err| {
346        load_error_with_source(format!("load: failed to read MAT-file header: {err}"), err)
347    })?;
348    if header[126] != b'I' || header[127] != b'M' {
349        return Err(load_error("load: file is not a MATLAB Level-5 MAT-file"));
350    }
351
352    let mut variables = Vec::new();
353    while let Some(tagged) = read_tagged(reader, true)? {
354        if tagged.data_type != MI_MATRIX {
355            continue;
356        }
357        let parsed = parse_matrix(&tagged.data)?;
358        let value = mat_array_to_value(parsed.array)?;
359        variables.push((parsed.name, value));
360    }
361    Ok(variables)
362}
363
364struct ParsedMatrix {
365    name: String,
366    array: MatArray,
367}
368
369fn parse_matrix(buffer: &[u8]) -> BuiltinResult<ParsedMatrix> {
370    let mut cursor = Cursor::new(buffer);
371
372    let flags = read_tagged(&mut cursor, false)?
373        .ok_or_else(|| load_error("load: matrix element missing array flags"))?;
374    if flags.data_type != MI_UINT32 || flags.data.len() < 8 {
375        return Err(load_error("load: invalid array flags block"));
376    }
377    let flags0 = u32::from_le_bytes(flags.data[0..4].try_into().unwrap());
378    let class_code = flags0 & 0xFF;
379    let mut class = MatClass::from_class_code(class_code)
380        .ok_or_else(|| load_error("load: unsupported MATLAB class"))?;
381    let is_logical = (flags0 & FLAG_LOGICAL) != 0;
382    let has_imag = (flags0 & FLAG_COMPLEX) != 0;
383    if matches!(class, MatClass::Double) && is_logical {
384        class = MatClass::Logical;
385    }
386
387    let dims_elem = read_tagged(&mut cursor, false)?
388        .ok_or_else(|| load_error("load: matrix element missing dimensions"))?;
389    if dims_elem.data_type != MI_INT32 {
390        return Err(load_error("load: dimension block must use MI_INT32"));
391    }
392    if dims_elem.data.is_empty() || dims_elem.data.len() % 4 != 0 {
393        return Err(load_error("load: malformed dimension block"));
394    }
395    let mut dims = Vec::with_capacity(dims_elem.data.len() / 4);
396    for chunk in dims_elem.data.chunks_exact(4) {
397        let value = i32::from_le_bytes(chunk.try_into().unwrap());
398        if value < 0 {
399            return Err(load_error("load: negative dimensions are not supported"));
400        }
401        dims.push(value as usize);
402    }
403    if dims.is_empty() {
404        dims.push(1);
405        dims.push(1);
406    }
407
408    let name_elem = read_tagged(&mut cursor, false)?
409        .ok_or_else(|| load_error("load: matrix element missing name"))?;
410    let name = match name_elem.data_type {
411        MI_INT8 | MI_UINT8 => bytes_to_string(&name_elem.data),
412        MI_UINT16 => {
413            let mut bytes = Vec::with_capacity(name_elem.data.len());
414            for chunk in name_elem.data.chunks_exact(2) {
415                let code = u16::from_le_bytes(chunk.try_into().unwrap());
416                if code == 0 {
417                    break;
418                }
419                if let Some(ch) = char::from_u32(code as u32) {
420                    bytes.push(ch);
421                }
422            }
423            bytes.into_iter().collect()
424        }
425        _ => {
426            return Err(load_error("load: unsupported array name encoding"));
427        }
428    };
429
430    let array = match class {
431        MatClass::Double => parse_double_array(&mut cursor, dims, has_imag)?,
432        MatClass::Logical => parse_logical_array(&mut cursor, dims)?,
433        MatClass::Char => parse_char_array(&mut cursor, dims)?,
434        MatClass::Cell => parse_cell_array(&mut cursor, dims)?,
435        MatClass::Struct => parse_struct(&mut cursor, dims)?,
436    };
437
438    Ok(ParsedMatrix { name, array })
439}
440
441fn parse_double_array(
442    cursor: &mut Cursor<&[u8]>,
443    dims: Vec<usize>,
444    has_imag: bool,
445) -> BuiltinResult<MatArray> {
446    let real_elem = read_tagged(cursor, false)?
447        .ok_or_else(|| load_error("load: numeric array missing real component"))?;
448    if real_elem.data_type != MI_DOUBLE || real_elem.data.len() % 8 != 0 {
449        return Err(load_error("load: numeric data must be stored as MI_DOUBLE"));
450    }
451    let mut real = Vec::with_capacity(real_elem.data.len() / 8);
452    for chunk in real_elem.data.chunks_exact(8) {
453        real.push(f64::from_le_bytes(chunk.try_into().unwrap()));
454    }
455
456    let imag = if has_imag {
457        let imag_elem = read_tagged(cursor, false)?
458            .ok_or_else(|| load_error("load: numeric array missing imaginary component"))?;
459        if imag_elem.data_type != MI_DOUBLE || imag_elem.data.len() % 8 != 0 {
460            return Err(load_error("load: imaginary component must be MI_DOUBLE"));
461        }
462        let mut imag = Vec::with_capacity(imag_elem.data.len() / 8);
463        for chunk in imag_elem.data.chunks_exact(8) {
464            imag.push(f64::from_le_bytes(chunk.try_into().unwrap()));
465        }
466        Some(imag)
467    } else {
468        None
469    };
470
471    Ok(MatArray {
472        class: MatClass::Double,
473        dims,
474        data: MatData::Double { real, imag },
475    })
476}
477
478fn parse_logical_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
479    let elem = read_tagged(cursor, false)?
480        .ok_or_else(|| load_error("load: logical array missing data block"))?;
481    if elem.data_type != MI_UINT8 {
482        return Err(load_error(
483            "load: logical arrays must be stored as MI_UINT8",
484        ));
485    }
486    Ok(MatArray {
487        class: MatClass::Logical,
488        dims,
489        data: MatData::Logical { data: elem.data },
490    })
491}
492
493fn parse_char_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
494    let elem = read_tagged(cursor, false)?
495        .ok_or_else(|| load_error("load: character array missing data block"))?;
496    if elem.data_type != MI_UINT16 {
497        return Err(load_error(
498            "load: character data must be stored as MI_UINT16",
499        ));
500    }
501    if elem.data.len() % 2 != 0 {
502        return Err(load_error("load: malformed character data"));
503    }
504    let mut data = Vec::with_capacity(elem.data.len() / 2);
505    for chunk in elem.data.chunks_exact(2) {
506        data.push(u16::from_le_bytes(chunk.try_into().unwrap()));
507    }
508    Ok(MatArray {
509        class: MatClass::Char,
510        dims,
511        data: MatData::Char { data },
512    })
513}
514
515fn parse_cell_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
516    let total: usize = dims
517        .iter()
518        .copied()
519        .fold(1usize, |acc, d| acc.saturating_mul(d));
520    let mut elements = Vec::with_capacity(total);
521    for _ in 0..total {
522        let elem = read_tagged(cursor, false)?
523            .ok_or_else(|| load_error("load: cell element missing matrix payload"))?;
524        if elem.data_type != MI_MATRIX {
525            return Err(load_error("load: cell elements must be matrices"));
526        }
527        let parsed = parse_matrix(&elem.data)?;
528        elements.push(parsed.array);
529    }
530    Ok(MatArray {
531        class: MatClass::Cell,
532        dims,
533        data: MatData::Cell { elements },
534    })
535}
536
537fn parse_struct(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
538    if dims.len() != 2 || dims[0] != 1 || dims[1] != 1 {
539        return Err(load_error("load: struct arrays are not supported yet"));
540    }
541    let len_elem = read_tagged(cursor, false)?
542        .ok_or_else(|| load_error("load: struct missing maximum field length specifier"))?;
543    if len_elem.data_type != MI_INT32 || len_elem.data.len() != 4 {
544        return Err(load_error("load: struct field length must be MI_INT32"));
545    }
546    let max_len = i32::from_le_bytes(len_elem.data[..4].try_into().unwrap());
547    if max_len <= 0 {
548        return Err(load_error("load: struct field length must be positive"));
549    }
550
551    let names_elem = read_tagged(cursor, false)?
552        .ok_or_else(|| load_error("load: struct missing field name table"))?;
553    if names_elem.data_type != MI_INT8 && names_elem.data_type != MI_UINT8 {
554        return Err(load_error(
555            "load: struct field names must be stored as MI_INT8/MI_UINT8",
556        ));
557    }
558    if names_elem.data.len() % (max_len as usize) != 0 {
559        return Err(load_error("load: malformed struct field name table"));
560    }
561    let field_count = names_elem.data.len() / (max_len as usize);
562    let mut field_names = Vec::with_capacity(field_count);
563    for i in 0..field_count {
564        let start = i * (max_len as usize);
565        let end = start + (max_len as usize);
566        let slice = &names_elem.data[start..end];
567        field_names.push(bytes_to_string(slice));
568    }
569
570    let mut field_values = Vec::with_capacity(field_count);
571    for _ in 0..field_count {
572        let elem = read_tagged(cursor, false)?
573            .ok_or_else(|| load_error("load: struct field missing matrix payload"))?;
574        if elem.data_type != MI_MATRIX {
575            return Err(load_error("load: struct fields must be matrices"));
576        }
577        let parsed = parse_matrix(&elem.data)?;
578        field_values.push(parsed.array);
579    }
580
581    Ok(MatArray {
582        class: MatClass::Struct,
583        dims,
584        data: MatData::Struct {
585            field_names,
586            field_values,
587        },
588    })
589}
590
591fn mat_array_to_value(array: MatArray) -> BuiltinResult<Value> {
592    match array.data {
593        MatData::Double { real, imag } => {
594            let len = real.len();
595            if let Some(imag) = imag {
596                if imag.len() != len {
597                    return Err(load_error(
598                        "load: complex data has mismatched real/imag parts",
599                    ));
600                }
601                if len == 1 {
602                    Ok(Value::Complex(real[0], imag[0]))
603                } else {
604                    let mut pairs = Vec::with_capacity(len);
605                    for i in 0..len {
606                        pairs.push((real[i], imag[i]));
607                    }
608                    let tensor = ComplexTensor::new(pairs, array.dims.clone())
609                        .map_err(|e| load_error(format!("load: {e}")))?;
610                    Ok(Value::ComplexTensor(tensor))
611                }
612            } else if len == 1 {
613                Ok(Value::Num(real[0]))
614            } else {
615                let tensor = Tensor::new(real, array.dims.clone())
616                    .map_err(|e| load_error(format!("load: {e}")))?;
617                Ok(Value::Tensor(tensor))
618            }
619        }
620        MatData::Logical { data } => {
621            let total: usize = array
622                .dims
623                .iter()
624                .copied()
625                .fold(1usize, |acc, d| acc.saturating_mul(d));
626            if data.len() != total {
627                return Err(load_error("load: logical data length mismatch"));
628            }
629            if total == 1 {
630                Ok(Value::Bool(data.first().copied().unwrap_or(0) != 0))
631            } else {
632                let logical = LogicalArray::new(data, array.dims.clone())
633                    .map_err(|e| load_error(format!("load: {e}")))?;
634                Ok(Value::LogicalArray(logical))
635            }
636        }
637        MatData::Char { data } => {
638            let rows = array.dims.first().copied().unwrap_or(1);
639            let cols = array.dims.get(1).copied().unwrap_or(1);
640            let mut chars = Vec::with_capacity(rows.saturating_mul(cols));
641            for code in data {
642                let ch = char::from_u32(code as u32).unwrap_or('\u{FFFD}');
643                chars.push(ch);
644            }
645            let char_array =
646                CharArray::new(chars, rows, cols).map_err(|e| load_error(format!("load: {e}")))?;
647            Ok(Value::CharArray(char_array))
648        }
649        MatData::Cell { elements } => {
650            if let Some(strings) = cell_elements_to_strings(&elements) {
651                let string_array = StringArray::new(strings, array.dims.clone())
652                    .map_err(|e| load_error(format!("load: {e}")))?;
653                return Ok(Value::StringArray(string_array));
654            }
655            if array.dims.len() != 2 {
656                return Err(load_error(
657                    "load: cell arrays with more than two dimensions are not supported yet",
658                ));
659            }
660            let rows = array.dims[0];
661            let cols = array.dims[1];
662            let expected = rows.saturating_mul(cols);
663            if elements.len() != expected {
664                return Err(load_error("load: cell array element count mismatch"));
665            }
666            let mut converted = Vec::with_capacity(elements.len());
667            for elem in elements {
668                converted.push(mat_array_to_value(elem)?);
669            }
670            let mut row_major = vec![Value::Num(0.0); expected];
671            for col in 0..cols {
672                for row in 0..rows {
673                    let cm_idx = col * rows + row;
674                    let rm_idx = row * cols + col;
675                    row_major[rm_idx] = converted[cm_idx].clone();
676                }
677            }
678            make_cell(row_major, rows, cols).map_err(|err| load_error(format!("load: {err}")))
679        }
680        MatData::Struct {
681            field_names,
682            field_values,
683        } => {
684            if field_names.len() != field_values.len() {
685                return Err(load_error("load: struct field metadata is inconsistent"));
686            }
687            let mut st = StructValue::new();
688            for (name, value) in field_names.into_iter().zip(field_values.into_iter()) {
689                let converted = mat_array_to_value(value)?;
690                st.fields.insert(name, converted);
691            }
692            Ok(Value::Struct(st))
693        }
694    }
695}
696
697fn cell_elements_to_strings(elements: &[MatArray]) -> Option<Vec<String>> {
698    let mut strings = Vec::with_capacity(elements.len());
699    for element in elements {
700        if element.class != MatClass::Char {
701            return None;
702        }
703        let rows = element.dims.first().copied().unwrap_or(1);
704        if rows > 1 {
705            return None;
706        }
707        match &element.data {
708            MatData::Char { data } => strings.push(utf16_codes_to_string(data)),
709            _ => return None,
710        }
711    }
712    Some(strings)
713}
714
715fn utf16_codes_to_string(data: &[u16]) -> String {
716    let mut chars: Vec<char> = data
717        .iter()
718        .map(|code| char::from_u32(*code as u32).unwrap_or('\u{FFFD}'))
719        .collect();
720    while matches!(chars.last(), Some(&'\0')) {
721        chars.pop();
722    }
723    chars.into_iter().collect()
724}
725
726fn option_token(value: &Value) -> BuiltinResult<Option<String>> {
727    if let Some(token) = value_to_string_scalar(value) {
728        if token.starts_with('-') {
729            return Ok(Some(token.to_ascii_lowercase()));
730        }
731    }
732    Ok(None)
733}
734
735#[async_recursion::async_recursion(?Send)]
736async fn extract_names(value: &Value) -> BuiltinResult<Vec<String>> {
737    match value {
738        Value::String(s) => Ok(vec![s.clone()]),
739        Value::CharArray(ca) => Ok(char_array_rows_as_strings(ca)),
740        Value::StringArray(sa) => Ok(sa.data.clone()),
741        Value::Cell(ca) => {
742            let mut names = Vec::with_capacity(ca.data.len());
743            for handle in &ca.data {
744                let inner = unsafe { &*handle.as_raw() };
745                let text = value_to_string_scalar(inner).ok_or_else(|| {
746                    load_error(
747                        "load: cell arrays used for variable selection must contain string scalars",
748                    )
749                })?;
750                names.push(text);
751            }
752            Ok(names)
753        }
754        other => {
755            let gathered = gather_if_needed_async(other).await?;
756            extract_names(&gathered).await
757        }
758    }
759}
760
761fn value_to_string_scalar(value: &Value) -> Option<String> {
762    match value {
763        Value::String(s) => Some(s.clone()),
764        Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
765        Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
766        _ => None,
767    }
768}
769
770fn char_array_rows_as_strings(ca: &CharArray) -> Vec<String> {
771    let mut rows = Vec::with_capacity(ca.rows);
772    for r in 0..ca.rows {
773        let mut row = String::with_capacity(ca.cols);
774        for c in 0..ca.cols {
775            let idx = r * ca.cols + c;
776            row.push(ca.data[idx]);
777        }
778        let trimmed = row.trim_end_matches([' ', '\0']).to_string();
779        rows.push(trimmed);
780    }
781    rows
782}
783
784fn bytes_to_string(bytes: &[u8]) -> String {
785    let trimmed = bytes
786        .iter()
787        .copied()
788        .take_while(|b| *b != 0)
789        .collect::<Vec<u8>>();
790    String::from_utf8(trimmed).unwrap_or_default()
791}
792
793struct TaggedData {
794    data_type: u32,
795    data: Vec<u8>,
796}
797
798fn read_tagged<R: Read>(reader: &mut R, allow_eof: bool) -> BuiltinResult<Option<TaggedData>> {
799    let mut type_bytes = [0u8; 4];
800    match reader.read_exact(&mut type_bytes) {
801        Ok(()) => {}
802        Err(err) => {
803            if allow_eof && err.kind() == std::io::ErrorKind::UnexpectedEof {
804                return Ok(None);
805            }
806            return Err(load_error_with_source(
807                format!("load: failed to read MAT element header: {err}"),
808                err,
809            ));
810        }
811    }
812
813    let type_field = u32::from_le_bytes(type_bytes);
814    if (type_field & 0xFFFF0000) != 0 {
815        let data_type = type_field & 0x0000FFFF;
816        let num_bytes = ((type_field & 0xFFFF0000) >> 16) as usize;
817        let mut inline = [0u8; 4];
818        reader.read_exact(&mut inline).map_err(|err| {
819            load_error_with_source(
820                format!("load: failed to read compact MAT element: {err}"),
821                err,
822            )
823        })?;
824        let mut data = inline[..num_bytes.min(4)].to_vec();
825        data.truncate(num_bytes.min(4));
826        Ok(Some(TaggedData { data_type, data }))
827    } else {
828        let mut len_bytes = [0u8; 4];
829        reader.read_exact(&mut len_bytes).map_err(|err| {
830            load_error_with_source(
831                format!("load: failed to read MAT element length: {err}"),
832                err,
833            )
834        })?;
835        let length = u32::from_le_bytes(len_bytes) as usize;
836        let mut data = vec![0u8; length];
837        reader.read_exact(&mut data).map_err(|err| {
838            load_error_with_source(format!("load: failed to read MAT element body: {err}"), err)
839        })?;
840        let padding = (8 - (length % 8)) % 8;
841        if padding != 0 {
842            let mut pad = vec![0u8; padding];
843            reader.read_exact(&mut pad).map_err(|err| {
844                load_error_with_source(format!("load: failed to read MAT padding: {err}"), err)
845            })?;
846        }
847        Ok(Some(TaggedData {
848            data_type: type_field,
849            data,
850        }))
851    }
852}
853
854#[cfg(test)]
855pub(crate) mod tests {
856    use super::*;
857    use crate::workspace::WorkspaceResolver;
858    use futures::executor::block_on;
859    use runmat_builtins::StringArray;
860    use runmat_thread_local::runmat_thread_local;
861    use std::cell::RefCell;
862    use std::collections::HashMap;
863    use tempfile::tempdir;
864
865    runmat_thread_local! {
866        static TEST_WORKSPACE: RefCell<HashMap<String, Value>> = RefCell::new(HashMap::new());
867    }
868
869    fn ensure_test_resolver() {
870        crate::workspace::register_workspace_resolver(WorkspaceResolver {
871            lookup: |name| TEST_WORKSPACE.with(|slot| slot.borrow().get(name).cloned()),
872            snapshot: || {
873                let mut entries: Vec<(String, Value)> =
874                    TEST_WORKSPACE.with(|slot| slot.borrow().clone().into_iter().collect());
875                entries.sort_by(|a, b| a.0.cmp(&b.0));
876                entries
877            },
878            globals: || Vec::new(),
879            assign: None,
880            clear: None,
881            remove: None,
882        });
883    }
884
885    fn set_workspace(entries: &[(&str, Value)]) {
886        TEST_WORKSPACE.with(|slot| {
887            let mut map = slot.borrow_mut();
888            map.clear();
889            for (name, value) in entries {
890                map.insert((*name).to_string(), value.clone());
891            }
892        });
893    }
894
895    fn workspace_guard() -> std::sync::MutexGuard<'static, ()> {
896        crate::workspace::test_guard()
897    }
898
899    fn assert_error_contains<T>(result: crate::BuiltinResult<T>, snippet: &str) {
900        match result {
901            Err(err) => {
902                assert!(
903                    err.message().contains(snippet),
904                    "expected error to contain '{snippet}', got '{}'",
905                    err.message()
906                );
907            }
908            Ok(_) => panic!("expected error containing '{snippet}'"),
909        }
910    }
911
912    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
913    #[test]
914    fn load_roundtrip_numeric() {
915        let _guard = workspace_guard();
916        ensure_test_resolver();
917        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0], vec![2, 2]).unwrap();
918        set_workspace(&[("A", Value::Tensor(tensor))]);
919
920        let dir = tempdir().unwrap();
921        let path = dir.path().join("numeric.mat");
922        let save_arg = Value::from(path.to_string_lossy().to_string());
923        block_on(crate::call_builtin_async(
924            "save",
925            std::slice::from_ref(&save_arg),
926        ))
927        .unwrap();
928
929        let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
930            .expect("load numeric");
931        let struct_value = eval.first_output();
932        match struct_value {
933            Value::Struct(sv) => {
934                assert!(sv.fields.contains_key("A"));
935                match sv.fields.get("A").unwrap() {
936                    Value::Tensor(t) => {
937                        assert_eq!(t.shape, vec![2, 2]);
938                        assert_eq!(t.data, vec![1.0, 4.0, 2.0, 5.0]);
939                    }
940                    other => panic!("expected tensor, got {other:?}"),
941                }
942            }
943            other => panic!("expected struct, got {other:?}"),
944        }
945    }
946
947    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
948    #[test]
949    fn load_selected_variables() {
950        let _guard = workspace_guard();
951        ensure_test_resolver();
952        set_workspace(&[("signal", Value::Num(42.0)), ("noise", Value::Num(5.0))]);
953        let dir = tempdir().unwrap();
954        let path = dir.path().join("selection.mat");
955        let save_arg = Value::from(path.to_string_lossy().to_string());
956        block_on(crate::call_builtin_async(
957            "save",
958            std::slice::from_ref(&save_arg),
959        ))
960        .unwrap();
961
962        let eval = block_on(evaluate(&[
963            Value::from(path.to_string_lossy().to_string()),
964            Value::from("signal"),
965        ]))
966        .expect("load selection");
967        let vars = eval.variables();
968        assert_eq!(vars.len(), 1);
969        assert_eq!(vars[0].0, "signal");
970        assert!(matches!(vars[0].1, Value::Num(42.0)));
971    }
972
973    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
974    #[test]
975    fn load_regex_selection() {
976        let _guard = workspace_guard();
977        ensure_test_resolver();
978        set_workspace(&[
979            ("w1", Value::Num(1.0)),
980            ("w2", Value::Num(2.0)),
981            ("bias", Value::Num(3.0)),
982        ]);
983        let dir = tempdir().unwrap();
984        let path = dir.path().join("regex.mat");
985        let save_arg = Value::from(path.to_string_lossy().to_string());
986        block_on(crate::call_builtin_async(
987            "save",
988            std::slice::from_ref(&save_arg),
989        ))
990        .unwrap();
991
992        let eval = block_on(evaluate(&[
993            Value::from(path.to_string_lossy().to_string()),
994            Value::from("-regexp"),
995            Value::from("^w\\d$"),
996        ]))
997        .expect("load regex");
998        let mut names: Vec<_> = eval.variables().iter().map(|(n, _)| n.clone()).collect();
999        names.sort();
1000        assert_eq!(names, vec!["w1".to_string(), "w2".to_string()]);
1001    }
1002
1003    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1004    #[test]
1005    fn load_missing_variable_errors() {
1006        let _guard = workspace_guard();
1007        ensure_test_resolver();
1008        set_workspace(&[("existing", Value::Num(7.0))]);
1009        let dir = tempdir().unwrap();
1010        let path = dir.path().join("missing.mat");
1011        let save_arg = Value::from(path.to_string_lossy().to_string());
1012        block_on(crate::call_builtin_async(
1013            "save",
1014            std::slice::from_ref(&save_arg),
1015        ))
1016        .unwrap();
1017
1018        assert_error_contains(
1019            block_on(evaluate(&[
1020                Value::from(path.to_string_lossy().to_string()),
1021                Value::from("missing"),
1022            ])),
1023            "variable 'missing' was not found",
1024        );
1025    }
1026
1027    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1028    #[test]
1029    fn load_string_array_roundtrip() {
1030        let _guard = workspace_guard();
1031        ensure_test_resolver();
1032        let strings = StringArray::new(vec!["foo".into(), "bar".into()], vec![1, 2]).unwrap();
1033        set_workspace(&[("labels", Value::StringArray(strings))]);
1034        let dir = tempdir().unwrap();
1035        let path = dir.path().join("strings.mat");
1036        let save_arg = Value::from(path.to_string_lossy().to_string());
1037        block_on(crate::call_builtin_async(
1038            "save",
1039            std::slice::from_ref(&save_arg),
1040        ))
1041        .unwrap();
1042
1043        let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
1044            .expect("load strings");
1045        let struct_value = eval.first_output();
1046        match struct_value {
1047            Value::Struct(sv) => {
1048                let value = sv
1049                    .fields
1050                    .get("labels")
1051                    .expect("labels field missing in struct");
1052                match value {
1053                    Value::StringArray(sa) => {
1054                        assert_eq!(sa.shape, vec![1, 2]);
1055                        assert_eq!(sa.data, vec![String::from("foo"), String::from("bar")]);
1056                    }
1057                    other => panic!("expected string array, got {other:?}"),
1058                }
1059            }
1060            other => panic!("expected struct, got {other:?}"),
1061        }
1062    }
1063
1064    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1065    #[test]
1066    fn load_option_before_filename() {
1067        let _guard = workspace_guard();
1068        ensure_test_resolver();
1069        set_workspace(&[("alpha", Value::Num(1.0)), ("beta", Value::Num(2.0))]);
1070        let dir = tempdir().unwrap();
1071        let path = dir.path().join("option_first.mat");
1072        let save_arg = Value::from(path.to_string_lossy().to_string());
1073        block_on(crate::call_builtin_async(
1074            "save",
1075            std::slice::from_ref(&save_arg),
1076        ))
1077        .unwrap();
1078
1079        let eval = block_on(evaluate(&[
1080            Value::from("-mat"),
1081            Value::from(path.to_string_lossy().to_string()),
1082            Value::from("beta"),
1083        ]))
1084        .expect("load with option first");
1085        let vars = eval.variables();
1086        assert_eq!(vars.len(), 1);
1087        assert_eq!(vars[0].0, "beta");
1088        assert!(matches!(vars[0].1, Value::Num(2.0)));
1089    }
1090
1091    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1092    #[test]
1093    fn load_char_array_names_trimmed() {
1094        let _guard = workspace_guard();
1095        ensure_test_resolver();
1096        set_workspace(&[("short", Value::Num(5.0)), ("longer", Value::Num(9.0))]);
1097        let dir = tempdir().unwrap();
1098        let path = dir.path().join("char_names.mat");
1099        let save_arg = Value::from(path.to_string_lossy().to_string());
1100        block_on(crate::call_builtin_async(
1101            "save",
1102            std::slice::from_ref(&save_arg),
1103        ))
1104        .unwrap();
1105
1106        let cols = 6;
1107        let mut data = Vec::new();
1108        for name in ["short", "longer"] {
1109            let mut chars: Vec<char> = name.chars().collect();
1110            while chars.len() < cols {
1111                chars.push(' ');
1112            }
1113            data.extend(chars);
1114        }
1115        let name_array = CharArray::new(data, 2, cols).unwrap();
1116
1117        let eval = block_on(evaluate(&[
1118            Value::from(path.to_string_lossy().to_string()),
1119            Value::CharArray(name_array),
1120        ]))
1121        .expect("load with char array names");
1122        let vars = eval.variables();
1123        assert_eq!(vars.len(), 2);
1124        assert_eq!(vars[0].0, "short");
1125        assert!(matches!(vars[0].1, Value::Num(5.0)));
1126        assert_eq!(vars[1].0, "longer");
1127        assert!(matches!(vars[1].1, Value::Num(9.0)));
1128    }
1129
1130    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1131    #[test]
1132    fn load_duplicate_names_last_wins() {
1133        let _guard = workspace_guard();
1134        ensure_test_resolver();
1135        set_workspace(&[("dup", Value::Num(11.0))]);
1136        let dir = tempdir().unwrap();
1137        let path = dir.path().join("duplicates.mat");
1138        let save_arg = Value::from(path.to_string_lossy().to_string());
1139        block_on(crate::call_builtin_async(
1140            "save",
1141            std::slice::from_ref(&save_arg),
1142        ))
1143        .unwrap();
1144
1145        let eval = block_on(evaluate(&[
1146            Value::from(path.to_string_lossy().to_string()),
1147            Value::from("dup"),
1148            Value::from("dup"),
1149        ]))
1150        .expect("load with duplicate names");
1151        let vars = eval.variables();
1152        assert_eq!(vars.len(), 1);
1153        assert_eq!(vars[0].0, "dup");
1154        assert!(matches!(vars[0].1, Value::Num(11.0)));
1155    }
1156
1157    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1158    #[test]
1159    #[cfg(feature = "wgpu")]
1160    fn load_wgpu_tensor_roundtrip() {
1161        let _guard = workspace_guard();
1162        ensure_test_resolver();
1163        if runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1164            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1165        )
1166        .is_err()
1167        {
1168            return;
1169        }
1170        let Some(provider) = runmat_accelerate_api::provider() else {
1171            return;
1172        };
1173
1174        use runmat_accelerate_api::HostTensorView;
1175
1176        let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap();
1177        let view = HostTensorView {
1178            data: &tensor.data,
1179            shape: &tensor.shape,
1180        };
1181        let handle = provider.upload(&view).expect("upload tensor");
1182        set_workspace(&[("gpu_var", Value::GpuTensor(handle))]);
1183
1184        let dir = tempdir().unwrap();
1185        let path = dir.path().join("wgpu_load.mat");
1186        let save_args = vec![
1187            Value::from(path.to_string_lossy().to_string()),
1188            Value::from("gpu_var"),
1189        ];
1190        block_on(crate::call_builtin_async("save", &save_args)).unwrap();
1191
1192        let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
1193            .expect("load wgpu file");
1194        let struct_value = eval.first_output();
1195        match struct_value {
1196            Value::Struct(sv) => match sv.fields.get("gpu_var") {
1197                Some(Value::Tensor(t)) => {
1198                    assert_eq!(t.shape, vec![2, 2]);
1199                    assert_eq!(t.data, tensor.data);
1200                }
1201                other => panic!("expected tensor, got {other:?}"),
1202            },
1203            other => panic!("expected struct, got {other:?}"),
1204        }
1205    }
1206}