Skip to main content

runmat_runtime/builtins/io/mat/
load.rs

1//! MATLAB-compatible `load` builtin for RunMat.
2
3use std::collections::HashMap;
4use std::io::{BufReader, Cursor, Read};
5use std::path::{Path, PathBuf};
6
7use regex::Regex;
8use runmat_builtins::{
9    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
10    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
11    CharArray, ComplexTensor, LogicalArray, StringArray, StructValue, Tensor, Value,
12};
13use runmat_filesystem::File;
14use runmat_macros::runtime_builtin;
15
16use super::format::{
17    MatArray, MatClass, MatData, FLAG_COMPLEX, FLAG_LOGICAL, MAT_HEADER_LEN, MI_DOUBLE, MI_INT32,
18    MI_INT8, MI_MATRIX, MI_UINT16, MI_UINT32, MI_UINT8,
19};
20use crate::builtins::common::spec::{
21    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
22    ReductionNaN, ResidencyPolicy, ShapeRequirements,
23};
24use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
25
26const LOAD_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27    name: "S",
28    ty: BuiltinParamType::Any,
29    arity: BuiltinParamArity::Required,
30    default: None,
31    description: "Struct containing the loaded variables.",
32}];
33const LOAD_INPUTS_NONE: [BuiltinParamDescriptor; 0] = [];
34const LOAD_INPUTS_FILENAME: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
35    name: "filename",
36    ty: BuiltinParamType::StringScalar,
37    arity: BuiltinParamArity::Required,
38    default: Some("\"matlab.mat\""),
39    description: "MAT-file path.",
40}];
41const LOAD_INPUTS_FILENAME_VARS: [BuiltinParamDescriptor; 2] = [
42    BuiltinParamDescriptor {
43        name: "filename",
44        ty: BuiltinParamType::StringScalar,
45        arity: BuiltinParamArity::Required,
46        default: Some("\"matlab.mat\""),
47        description: "MAT-file path.",
48    },
49    BuiltinParamDescriptor {
50        name: "varName",
51        ty: BuiltinParamType::StringScalar,
52        arity: BuiltinParamArity::Variadic,
53        default: None,
54        description: "Variable names to load.",
55    },
56];
57const LOAD_INPUTS_FILENAME_REGEXP: [BuiltinParamDescriptor; 3] = [
58    BuiltinParamDescriptor {
59        name: "filename",
60        ty: BuiltinParamType::StringScalar,
61        arity: BuiltinParamArity::Required,
62        default: Some("\"matlab.mat\""),
63        description: "MAT-file path.",
64    },
65    BuiltinParamDescriptor {
66        name: "option",
67        ty: BuiltinParamType::StringScalar,
68        arity: BuiltinParamArity::Required,
69        default: Some("\"-regexp\""),
70        description: "Regular-expression selection option.",
71    },
72    BuiltinParamDescriptor {
73        name: "pattern",
74        ty: BuiltinParamType::StringScalar,
75        arity: BuiltinParamArity::Variadic,
76        default: None,
77        description: "Regex patterns matched against variable names.",
78    },
79];
80const LOAD_INPUTS_OPTIONS: [BuiltinParamDescriptor; 2] = [
81    BuiltinParamDescriptor {
82        name: "option",
83        ty: BuiltinParamType::StringScalar,
84        arity: BuiltinParamArity::Variadic,
85        default: None,
86        description: "Compatibility options such as '-mat' and '-regexp'.",
87    },
88    BuiltinParamDescriptor {
89        name: "value",
90        ty: BuiltinParamType::Any,
91        arity: BuiltinParamArity::Variadic,
92        default: None,
93        description: "Option arguments and variable selectors.",
94    },
95];
96const LOAD_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
97    BuiltinSignatureDescriptor {
98        label: "S = load()",
99        inputs: &LOAD_INPUTS_NONE,
100        outputs: &LOAD_OUTPUT,
101    },
102    BuiltinSignatureDescriptor {
103        label: "S = load(filename)",
104        inputs: &LOAD_INPUTS_FILENAME,
105        outputs: &LOAD_OUTPUT,
106    },
107    BuiltinSignatureDescriptor {
108        label: "S = load(filename, varName1, varName2, ...)",
109        inputs: &LOAD_INPUTS_FILENAME_VARS,
110        outputs: &LOAD_OUTPUT,
111    },
112    BuiltinSignatureDescriptor {
113        label: "S = load(filename, \"-regexp\", pattern1, ...)",
114        inputs: &LOAD_INPUTS_FILENAME_REGEXP,
115        outputs: &LOAD_OUTPUT,
116    },
117    BuiltinSignatureDescriptor {
118        label: "S = load(option, value, ...)",
119        inputs: &LOAD_INPUTS_OPTIONS,
120        outputs: &LOAD_OUTPUT,
121    },
122];
123const LOAD_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
124    code: "RM.LOAD.INVALID_ARGUMENT",
125    identifier: Some("RunMat:load:InvalidArgument"),
126    when: "Arguments do not match a supported load invocation form.",
127    message: "load: invalid argument",
128};
129const LOAD_ERROR_INVALID_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
130    code: "RM.LOAD.INVALID_OPTION",
131    identifier: Some("RunMat:load:InvalidOption"),
132    when: "An option token or option argument is invalid.",
133    message: "load: invalid option",
134};
135const LOAD_ERROR_FILENAME: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
136    code: "RM.LOAD.FILENAME",
137    identifier: Some("RunMat:load:Filename"),
138    when: "Filename is invalid or cannot be normalized.",
139    message: "load: invalid filename",
140};
141const LOAD_ERROR_SELECTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
142    code: "RM.LOAD.SELECTION",
143    identifier: Some("RunMat:load:Selection"),
144    when: "Requested variables are missing or no variables are selected.",
145    message: "load: variable selection failed",
146};
147const LOAD_ERROR_IO: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
148    code: "RM.LOAD.IO",
149    identifier: Some("RunMat:load:Io"),
150    when: "MAT-file data cannot be read or decoded.",
151    message: "load: MAT-file I/O failure",
152};
153const LOAD_ERROR_OUTPUT_COUNT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
154    code: "RM.LOAD.OUTPUT_COUNT",
155    identifier: Some("RunMat:load:OutputCount"),
156    when: "Caller requests more outputs than supported by load.",
157    message: "load: unsupported output count",
158};
159const LOAD_ERROR_WORKSPACE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
160    code: "RM.LOAD.WORKSPACE",
161    identifier: Some("RunMat:load:Workspace"),
162    when: "Statement-form load cannot assign values into workspace.",
163    message: "load: workspace assignment failed",
164};
165const LOAD_ERRORS: [BuiltinErrorDescriptor; 7] = [
166    LOAD_ERROR_INVALID_ARGUMENT,
167    LOAD_ERROR_INVALID_OPTION,
168    LOAD_ERROR_FILENAME,
169    LOAD_ERROR_SELECTION,
170    LOAD_ERROR_IO,
171    LOAD_ERROR_OUTPUT_COUNT,
172    LOAD_ERROR_WORKSPACE,
173];
174pub const LOAD_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
175    signatures: &LOAD_SIGNATURES,
176    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
177    completion_policy: BuiltinCompletionPolicy::Public,
178    errors: &LOAD_ERRORS,
179};
180
181#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::mat::load")]
182pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
183    name: "load",
184    op_kind: GpuOpKind::Custom("io-load"),
185    supported_precisions: &[],
186    broadcast: BroadcastSemantics::None,
187    provider_hooks: &[],
188    constant_strategy: ConstantStrategy::InlineLiteral,
189    residency: ResidencyPolicy::NewHandle,
190    nan_mode: ReductionNaN::Include,
191    two_pass_threshold: None,
192    workgroup_size: None,
193    accepts_nan_mode: false,
194    notes: "Reads MAT-files on the host and produces CPU-resident values. Providers are not involved until accelerated code later promotes the results.",
195};
196
197#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::mat::load")]
198pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
199    name: "load",
200    shape: ShapeRequirements::Any,
201    constant_strategy: ConstantStrategy::InlineLiteral,
202    elementwise: None,
203    reduction: None,
204    emits_nan: false,
205    notes: "File I/O is not eligible for fusion. Registration exists for documentation completeness only.",
206};
207
208#[runtime_builtin(
209    name = "load",
210    category = "io/mat",
211    summary = "Load variables from a MAT-file.",
212    keywords = "load,mat,workspace",
213    accel = "cpu",
214    sink = true,
215    type_resolver(crate::builtins::io::type_resolvers::load_type),
216    descriptor(crate::builtins::io::mat::load::LOAD_DESCRIPTOR),
217    builtin_path = "crate::builtins::io::mat::load"
218)]
219async fn load_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
220    let eval = evaluate(&args).await?;
221
222    // current_output_count() is set by the dispatcher only for multi-output Unpack patterns
223    // like `[a, b] = load(...)`. Guard against requesting more than one struct output.
224    if let Some(n) = crate::output_count::current_output_count() {
225        if n > 1 {
226            return Err(load_error_with(
227                &LOAD_ERROR_OUTPUT_COUNT,
228                "load supports at most one output argument",
229            ));
230        }
231    }
232
233    // The VM sets output_context::requested_output_count() at every call site before
234    // dispatching:
235    //   Some(0) → statement-level call (result is discarded or printed without capture)
236    //             → assign loaded variables directly into the caller's workspace.
237    //   Some(1) → single-output assignment `S = load(...)` → return a struct.
238    //   None    → called outside the VM (e.g. directly from Rust) → return a struct.
239    if crate::output_context::requested_output_count() == Some(0) {
240        for (name, value) in eval.variables() {
241            crate::workspace::assign(name, value.clone())
242                .map_err(|err| load_error_with(&LOAD_ERROR_WORKSPACE, err))?;
243        }
244        return Ok(Value::OutputList(Vec::new()));
245    }
246
247    Ok(eval.first_output())
248}
249
250#[derive(Clone, Debug)]
251pub struct LoadEval {
252    variables: Vec<(String, Value)>,
253}
254
255impl LoadEval {
256    pub fn first_output(&self) -> Value {
257        let mut st = StructValue::new();
258        for (name, value) in &self.variables {
259            st.fields.insert(name.clone(), value.clone());
260        }
261        Value::Struct(st)
262    }
263
264    pub fn variables(&self) -> &[(String, Value)] {
265        &self.variables
266    }
267
268    pub fn into_variables(self) -> Vec<(String, Value)> {
269        self.variables
270    }
271}
272
273struct LoadRequest {
274    variables: Vec<String>,
275    regex_patterns: Vec<Regex>,
276}
277
278const BUILTIN_NAME: &str = "load";
279
280fn load_error(message: impl Into<String>) -> RuntimeError {
281    load_error_with(&LOAD_ERROR_INVALID_ARGUMENT, message)
282}
283
284fn load_error_with(
285    error: &'static BuiltinErrorDescriptor,
286    message: impl Into<String>,
287) -> RuntimeError {
288    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
289    if let Some(identifier) = error.identifier {
290        builder = builder.with_identifier(identifier);
291    }
292    builder.build()
293}
294
295fn load_error_with_source(
296    error: &'static BuiltinErrorDescriptor,
297    message: impl Into<String>,
298    source: impl std::error::Error + Send + Sync + 'static,
299) -> RuntimeError {
300    let mut builder = build_runtime_error(message)
301        .with_builtin(BUILTIN_NAME)
302        .with_source(source);
303    if let Some(identifier) = error.identifier {
304        builder = builder.with_identifier(identifier);
305    }
306    builder.build()
307}
308
309pub async fn evaluate(args: &[Value]) -> BuiltinResult<LoadEval> {
310    let mut host_args = Vec::with_capacity(args.len());
311    for arg in args {
312        host_args.push(gather_if_needed_async(arg).await?);
313    }
314
315    let invocation = parse_invocation(&host_args).await?;
316
317    let mut path_value = if let Some(path) = invocation.path_value {
318        path
319    } else {
320        Value::from("matlab.mat")
321    };
322
323    if invocation.path_was_default {
324        if let Ok(override_path) = std::env::var("RUNMAT_LOAD_DEFAULT_PATH") {
325            path_value = Value::from(override_path);
326        }
327    }
328
329    let mut regex_patterns = Vec::with_capacity(invocation.regex_tokens.len());
330    for pattern in invocation.regex_tokens {
331        let regex = Regex::new(&pattern).map_err(|err| {
332            load_error_with_source(
333                &LOAD_ERROR_INVALID_OPTION,
334                format!("load: invalid regular expression '{pattern}': {err}"),
335                err,
336            )
337        })?;
338        regex_patterns.push(regex);
339    }
340
341    let request = LoadRequest {
342        variables: invocation.variables,
343        regex_patterns,
344    };
345    let path = normalise_path(&path_value)?;
346    let entries = read_mat_file(&path).await?;
347
348    let selected = select_variables(&entries, &request)?;
349    Ok(LoadEval {
350        variables: selected,
351    })
352}
353
354struct ParsedInvocation {
355    path_value: Option<Value>,
356    path_was_default: bool,
357    variables: Vec<String>,
358    regex_tokens: Vec<String>,
359}
360
361async fn parse_invocation(values: &[Value]) -> BuiltinResult<ParsedInvocation> {
362    let mut path_value = None;
363    let mut path_was_default = false;
364    let mut variables = Vec::new();
365    let mut regex_tokens = Vec::new();
366    let mut idx = 0usize;
367    while idx < values.len() {
368        if let Some(flag) = option_token(&values[idx])? {
369            match flag.as_str() {
370                "-mat" => {
371                    idx += 1;
372                    continue;
373                }
374                "-regexp" => {
375                    idx += 1;
376                    if idx >= values.len() {
377                        return Err(load_error_with(
378                            &LOAD_ERROR_INVALID_OPTION,
379                            "load: '-regexp' requires at least one pattern",
380                        ));
381                    }
382                    while idx < values.len() {
383                        if option_token(&values[idx])?.is_some() {
384                            break;
385                        }
386                        let names = extract_names(&values[idx]).await?;
387                        if names.is_empty() {
388                            return Err(load_error_with(
389                                &LOAD_ERROR_INVALID_OPTION,
390                                "load: '-regexp' requires non-empty pattern strings",
391                            ));
392                        }
393                        regex_tokens.extend(names);
394                        idx += 1;
395                    }
396                    continue;
397                }
398                other => {
399                    return Err(load_error_with(
400                        &LOAD_ERROR_INVALID_OPTION,
401                        format!("load: unsupported option '{other}'"),
402                    ));
403                }
404            }
405        } else {
406            if path_value.is_none() {
407                path_value = Some(values[idx].clone());
408                idx += 1;
409                continue;
410            }
411            let names = extract_names(&values[idx]).await?;
412            variables.extend(names);
413            idx += 1;
414        }
415    }
416
417    if path_value.is_none() {
418        path_was_default = true;
419    }
420
421    Ok(ParsedInvocation {
422        path_value,
423        path_was_default,
424        variables,
425        regex_tokens,
426    })
427}
428
429fn normalise_path(value: &Value) -> BuiltinResult<PathBuf> {
430    let raw = value_to_string_scalar(value).ok_or_else(|| {
431        load_error_with(
432            &LOAD_ERROR_FILENAME,
433            "load: filename must be a character vector or string scalar",
434        )
435    })?;
436    let mut path = PathBuf::from(raw);
437    if path.extension().is_none() {
438        path.set_extension("mat");
439    }
440    Ok(path)
441}
442
443fn select_variables(
444    entries: &[(String, Value)],
445    request: &LoadRequest,
446) -> BuiltinResult<Vec<(String, Value)>> {
447    if request.variables.is_empty() && request.regex_patterns.is_empty() {
448        return Ok(entries.to_vec());
449    }
450
451    let mut by_name: HashMap<&str, &Value> = HashMap::with_capacity(entries.len());
452    for (name, value) in entries {
453        by_name.insert(name, value);
454    }
455
456    let mut selected = Vec::new();
457
458    for name in &request.variables {
459        let value = by_name.get(name.as_str()).ok_or_else(|| {
460            load_error_with(
461                &LOAD_ERROR_SELECTION,
462                format!("load: variable '{name}' was not found in the file"),
463            )
464        })?;
465        insert_or_replace(&mut selected, name, (*value).clone());
466    }
467
468    if !request.regex_patterns.is_empty() {
469        let mut matched = 0usize;
470        for (name, value) in entries {
471            if request
472                .regex_patterns
473                .iter()
474                .any(|regex| regex.is_match(name))
475            {
476                matched += 1;
477                insert_or_replace(&mut selected, name, value.clone());
478            }
479        }
480        if matched == 0 && request.variables.is_empty() {
481            return Err(load_error_with(
482                &LOAD_ERROR_SELECTION,
483                "load: no variables matched '-regexp' patterns",
484            ));
485        }
486    }
487
488    if selected.is_empty() {
489        return Err(load_error_with(
490            &LOAD_ERROR_SELECTION,
491            "load: no variables selected",
492        ));
493    }
494
495    Ok(selected)
496}
497
498fn insert_or_replace(selected: &mut Vec<(String, Value)>, name: &str, value: Value) {
499    if let Some(entry) = selected.iter_mut().find(|(existing, _)| existing == name) {
500        entry.1 = value;
501    } else {
502        selected.push((name.to_string(), value));
503    }
504}
505
506pub(crate) async fn read_mat_file_for_builtin(
507    path: &Path,
508    builtin: &str,
509) -> crate::BuiltinResult<Vec<(String, Value)>> {
510    match read_mat_file(path).await {
511        Ok(entries) => Ok(entries),
512        Err(err) => {
513            let message = err.message().replacen("load:", &format!("{builtin}:"), 1);
514            let mut builder = build_runtime_error(message).with_builtin(builtin);
515            if let Some(identifier) = err.identifier() {
516                builder = builder.with_identifier(identifier);
517            }
518            Err(builder.with_source(err).build())
519        }
520    }
521}
522
523pub(crate) async fn read_mat_file(path: &Path) -> BuiltinResult<Vec<(String, Value)>> {
524    let file = File::open_async(path).await.map_err(|err| {
525        load_error_with_source(
526            &LOAD_ERROR_IO,
527            format!("load: failed to open '{}': {err}", path.display()),
528            err,
529        )
530    })?;
531    let mut reader = BufReader::new(file);
532    read_mat_reader(&mut reader)
533}
534
535pub fn decode_workspace_from_mat_bytes(bytes: &[u8]) -> BuiltinResult<Vec<(String, Value)>> {
536    let mut cursor = Cursor::new(bytes);
537    read_mat_reader(&mut cursor)
538}
539
540fn read_mat_reader<R: Read>(reader: &mut R) -> BuiltinResult<Vec<(String, Value)>> {
541    let mut header = [0u8; MAT_HEADER_LEN];
542    reader.read_exact(&mut header).map_err(|err| {
543        load_error_with_source(
544            &LOAD_ERROR_IO,
545            format!("load: failed to read MAT-file header: {err}"),
546            err,
547        )
548    })?;
549    if header[126] != b'I' || header[127] != b'M' {
550        return Err(load_error("load: file is not a MATLAB Level-5 MAT-file"));
551    }
552
553    let mut variables = Vec::new();
554    while let Some(tagged) = read_tagged(reader, true)? {
555        if tagged.data_type != MI_MATRIX {
556            continue;
557        }
558        let parsed = parse_matrix(&tagged.data)?;
559        let value = mat_array_to_value(parsed.array)?;
560        variables.push((parsed.name, value));
561    }
562    Ok(variables)
563}
564
565struct ParsedMatrix {
566    name: String,
567    array: MatArray,
568}
569
570fn parse_matrix(buffer: &[u8]) -> BuiltinResult<ParsedMatrix> {
571    let mut cursor = Cursor::new(buffer);
572
573    let flags = read_tagged(&mut cursor, false)?
574        .ok_or_else(|| load_error("load: matrix element missing array flags"))?;
575    if flags.data_type != MI_UINT32 || flags.data.len() < 8 {
576        return Err(load_error("load: invalid array flags block"));
577    }
578    let flags0 = u32::from_le_bytes(flags.data[0..4].try_into().unwrap());
579    let class_code = flags0 & 0xFF;
580    let mut class = MatClass::from_class_code(class_code)
581        .ok_or_else(|| load_error("load: unsupported MATLAB class"))?;
582    let is_logical = (flags0 & FLAG_LOGICAL) != 0;
583    let has_imag = (flags0 & FLAG_COMPLEX) != 0;
584    if matches!(class, MatClass::Double) && is_logical {
585        class = MatClass::Logical;
586    }
587
588    let dims_elem = read_tagged(&mut cursor, false)?
589        .ok_or_else(|| load_error("load: matrix element missing dimensions"))?;
590    if dims_elem.data_type != MI_INT32 {
591        return Err(load_error("load: dimension block must use MI_INT32"));
592    }
593    if dims_elem.data.is_empty() || dims_elem.data.len() % 4 != 0 {
594        return Err(load_error("load: malformed dimension block"));
595    }
596    let mut dims = Vec::with_capacity(dims_elem.data.len() / 4);
597    for chunk in dims_elem.data.chunks_exact(4) {
598        let value = i32::from_le_bytes(chunk.try_into().unwrap());
599        if value < 0 {
600            return Err(load_error("load: negative dimensions are not supported"));
601        }
602        dims.push(value as usize);
603    }
604    if dims.is_empty() {
605        dims.push(1);
606        dims.push(1);
607    }
608
609    let name_elem = read_tagged(&mut cursor, false)?
610        .ok_or_else(|| load_error("load: matrix element missing name"))?;
611    let name = match name_elem.data_type {
612        MI_INT8 | MI_UINT8 => bytes_to_string(&name_elem.data),
613        MI_UINT16 => {
614            let mut bytes = Vec::with_capacity(name_elem.data.len());
615            for chunk in name_elem.data.chunks_exact(2) {
616                let code = u16::from_le_bytes(chunk.try_into().unwrap());
617                if code == 0 {
618                    break;
619                }
620                if let Some(ch) = char::from_u32(code as u32) {
621                    bytes.push(ch);
622                }
623            }
624            bytes.into_iter().collect()
625        }
626        _ => {
627            return Err(load_error("load: unsupported array name encoding"));
628        }
629    };
630
631    let array = match class {
632        MatClass::Double => parse_double_array(&mut cursor, dims, has_imag)?,
633        MatClass::Logical => parse_logical_array(&mut cursor, dims)?,
634        MatClass::Char => parse_char_array(&mut cursor, dims)?,
635        MatClass::Cell => parse_cell_array(&mut cursor, dims)?,
636        MatClass::Struct => parse_struct(&mut cursor, dims)?,
637    };
638
639    Ok(ParsedMatrix { name, array })
640}
641
642fn parse_double_array(
643    cursor: &mut Cursor<&[u8]>,
644    dims: Vec<usize>,
645    has_imag: bool,
646) -> BuiltinResult<MatArray> {
647    let real_elem = read_tagged(cursor, false)?
648        .ok_or_else(|| load_error("load: numeric array missing real component"))?;
649    if real_elem.data_type != MI_DOUBLE || real_elem.data.len() % 8 != 0 {
650        return Err(load_error("load: numeric data must be stored as MI_DOUBLE"));
651    }
652    let mut real = Vec::with_capacity(real_elem.data.len() / 8);
653    for chunk in real_elem.data.chunks_exact(8) {
654        real.push(f64::from_le_bytes(chunk.try_into().unwrap()));
655    }
656
657    let imag = if has_imag {
658        let imag_elem = read_tagged(cursor, false)?
659            .ok_or_else(|| load_error("load: numeric array missing imaginary component"))?;
660        if imag_elem.data_type != MI_DOUBLE || imag_elem.data.len() % 8 != 0 {
661            return Err(load_error("load: imaginary component must be MI_DOUBLE"));
662        }
663        let mut imag = Vec::with_capacity(imag_elem.data.len() / 8);
664        for chunk in imag_elem.data.chunks_exact(8) {
665            imag.push(f64::from_le_bytes(chunk.try_into().unwrap()));
666        }
667        Some(imag)
668    } else {
669        None
670    };
671
672    Ok(MatArray {
673        class: MatClass::Double,
674        dims,
675        data: MatData::Double { real, imag },
676    })
677}
678
679fn parse_logical_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
680    let elem = read_tagged(cursor, false)?
681        .ok_or_else(|| load_error("load: logical array missing data block"))?;
682    if elem.data_type != MI_UINT8 {
683        return Err(load_error(
684            "load: logical arrays must be stored as MI_UINT8",
685        ));
686    }
687    Ok(MatArray {
688        class: MatClass::Logical,
689        dims,
690        data: MatData::Logical { data: elem.data },
691    })
692}
693
694fn parse_char_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
695    let elem = read_tagged(cursor, false)?
696        .ok_or_else(|| load_error("load: character array missing data block"))?;
697    if elem.data_type != MI_UINT16 {
698        return Err(load_error(
699            "load: character data must be stored as MI_UINT16",
700        ));
701    }
702    if elem.data.len() % 2 != 0 {
703        return Err(load_error("load: malformed character data"));
704    }
705    let mut data = Vec::with_capacity(elem.data.len() / 2);
706    for chunk in elem.data.chunks_exact(2) {
707        data.push(u16::from_le_bytes(chunk.try_into().unwrap()));
708    }
709    Ok(MatArray {
710        class: MatClass::Char,
711        dims,
712        data: MatData::Char { data },
713    })
714}
715
716fn parse_cell_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
717    let total: usize = dims
718        .iter()
719        .copied()
720        .fold(1usize, |acc, d| acc.saturating_mul(d));
721    let mut elements = Vec::with_capacity(total);
722    for _ in 0..total {
723        let elem = read_tagged(cursor, false)?
724            .ok_or_else(|| load_error("load: cell element missing matrix payload"))?;
725        if elem.data_type != MI_MATRIX {
726            return Err(load_error("load: cell elements must be matrices"));
727        }
728        let parsed = parse_matrix(&elem.data)?;
729        elements.push(parsed.array);
730    }
731    Ok(MatArray {
732        class: MatClass::Cell,
733        dims,
734        data: MatData::Cell { elements },
735    })
736}
737
738fn parse_struct(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
739    if dims.len() != 2 || dims[0] != 1 || dims[1] != 1 {
740        return Err(load_error("load: struct arrays are not supported yet"));
741    }
742    let len_elem = read_tagged(cursor, false)?
743        .ok_or_else(|| load_error("load: struct missing maximum field length specifier"))?;
744    if len_elem.data_type != MI_INT32 || len_elem.data.len() != 4 {
745        return Err(load_error("load: struct field length must be MI_INT32"));
746    }
747    let max_len = i32::from_le_bytes(len_elem.data[..4].try_into().unwrap());
748    if max_len <= 0 {
749        return Err(load_error("load: struct field length must be positive"));
750    }
751
752    let names_elem = read_tagged(cursor, false)?
753        .ok_or_else(|| load_error("load: struct missing field name table"))?;
754    if names_elem.data_type != MI_INT8 && names_elem.data_type != MI_UINT8 {
755        return Err(load_error(
756            "load: struct field names must be stored as MI_INT8/MI_UINT8",
757        ));
758    }
759    if names_elem.data.len() % (max_len as usize) != 0 {
760        return Err(load_error("load: malformed struct field name table"));
761    }
762    let field_count = names_elem.data.len() / (max_len as usize);
763    let mut field_names = Vec::with_capacity(field_count);
764    for i in 0..field_count {
765        let start = i * (max_len as usize);
766        let end = start + (max_len as usize);
767        let slice = &names_elem.data[start..end];
768        field_names.push(bytes_to_string(slice));
769    }
770
771    let mut field_values = Vec::with_capacity(field_count);
772    for _ in 0..field_count {
773        let elem = read_tagged(cursor, false)?
774            .ok_or_else(|| load_error("load: struct field missing matrix payload"))?;
775        if elem.data_type != MI_MATRIX {
776            return Err(load_error("load: struct fields must be matrices"));
777        }
778        let parsed = parse_matrix(&elem.data)?;
779        field_values.push(parsed.array);
780    }
781
782    Ok(MatArray {
783        class: MatClass::Struct,
784        dims,
785        data: MatData::Struct {
786            field_names,
787            field_values,
788        },
789    })
790}
791
792fn mat_array_to_value(array: MatArray) -> BuiltinResult<Value> {
793    match array.data {
794        MatData::Double { real, imag } => {
795            let len = real.len();
796            if let Some(imag) = imag {
797                if imag.len() != len {
798                    return Err(load_error(
799                        "load: complex data has mismatched real/imag parts",
800                    ));
801                }
802                if len == 1 {
803                    Ok(Value::Complex(real[0], imag[0]))
804                } else {
805                    let mut pairs = Vec::with_capacity(len);
806                    for i in 0..len {
807                        pairs.push((real[i], imag[i]));
808                    }
809                    let tensor = ComplexTensor::new(pairs, array.dims.clone())
810                        .map_err(|e| load_error(format!("load: {e}")))?;
811                    Ok(Value::ComplexTensor(tensor))
812                }
813            } else if len == 1 {
814                Ok(Value::Num(real[0]))
815            } else {
816                let tensor = Tensor::new(real, array.dims.clone())
817                    .map_err(|e| load_error(format!("load: {e}")))?;
818                Ok(Value::Tensor(tensor))
819            }
820        }
821        MatData::Logical { data } => {
822            let total: usize = array
823                .dims
824                .iter()
825                .copied()
826                .fold(1usize, |acc, d| acc.saturating_mul(d));
827            if data.len() != total {
828                return Err(load_error("load: logical data length mismatch"));
829            }
830            if total == 1 {
831                Ok(Value::Bool(data.first().copied().unwrap_or(0) != 0))
832            } else {
833                let logical = LogicalArray::new(data, array.dims.clone())
834                    .map_err(|e| load_error(format!("load: {e}")))?;
835                Ok(Value::LogicalArray(logical))
836            }
837        }
838        MatData::Char { data } => {
839            let rows = array.dims.first().copied().unwrap_or(1);
840            let cols = array.dims.get(1).copied().unwrap_or(1);
841            let mut chars = Vec::with_capacity(rows.saturating_mul(cols));
842            for code in data {
843                let ch = char::from_u32(code as u32).unwrap_or('\u{FFFD}');
844                chars.push(ch);
845            }
846            let char_array =
847                CharArray::new(chars, rows, cols).map_err(|e| load_error(format!("load: {e}")))?;
848            Ok(Value::CharArray(char_array))
849        }
850        MatData::Cell { elements } => {
851            if let Some(strings) = cell_elements_to_strings(&elements) {
852                let string_array = StringArray::new(strings, array.dims.clone())
853                    .map_err(|e| load_error(format!("load: {e}")))?;
854                return Ok(Value::StringArray(string_array));
855            }
856            if array.dims.len() != 2 {
857                return Err(load_error(
858                    "load: cell arrays with more than two dimensions are not supported yet",
859                ));
860            }
861            let rows = array.dims[0];
862            let cols = array.dims[1];
863            let expected = rows.saturating_mul(cols);
864            if elements.len() != expected {
865                return Err(load_error("load: cell array element count mismatch"));
866            }
867            let mut converted = Vec::with_capacity(elements.len());
868            for elem in elements {
869                converted.push(mat_array_to_value(elem)?);
870            }
871            let mut row_major = vec![Value::Num(0.0); expected];
872            for col in 0..cols {
873                for row in 0..rows {
874                    let cm_idx = col * rows + row;
875                    let rm_idx = row * cols + col;
876                    row_major[rm_idx] = converted[cm_idx].clone();
877                }
878            }
879            make_cell(row_major, rows, cols).map_err(|err| load_error(format!("load: {err}")))
880        }
881        MatData::Struct {
882            field_names,
883            field_values,
884        } => {
885            if field_names.len() != field_values.len() {
886                return Err(load_error("load: struct field metadata is inconsistent"));
887            }
888            let mut st = StructValue::new();
889            for (name, value) in field_names.into_iter().zip(field_values.into_iter()) {
890                let converted = mat_array_to_value(value)?;
891                st.fields.insert(name, converted);
892            }
893            Ok(Value::Struct(st))
894        }
895    }
896}
897
898fn cell_elements_to_strings(elements: &[MatArray]) -> Option<Vec<String>> {
899    let mut strings = Vec::with_capacity(elements.len());
900    for element in elements {
901        if element.class != MatClass::Char {
902            return None;
903        }
904        let rows = element.dims.first().copied().unwrap_or(1);
905        if rows > 1 {
906            return None;
907        }
908        match &element.data {
909            MatData::Char { data } => strings.push(utf16_codes_to_string(data)),
910            _ => return None,
911        }
912    }
913    Some(strings)
914}
915
916fn utf16_codes_to_string(data: &[u16]) -> String {
917    let mut chars: Vec<char> = data
918        .iter()
919        .map(|code| char::from_u32(*code as u32).unwrap_or('\u{FFFD}'))
920        .collect();
921    while matches!(chars.last(), Some(&'\0')) {
922        chars.pop();
923    }
924    chars.into_iter().collect()
925}
926
927fn option_token(value: &Value) -> BuiltinResult<Option<String>> {
928    if let Some(token) = value_to_string_scalar(value) {
929        if token.starts_with('-') {
930            return Ok(Some(token.to_ascii_lowercase()));
931        }
932    }
933    Ok(None)
934}
935
936#[async_recursion::async_recursion(?Send)]
937async fn extract_names(value: &Value) -> BuiltinResult<Vec<String>> {
938    match value {
939        Value::String(s) => Ok(vec![s.clone()]),
940        Value::CharArray(ca) => Ok(char_array_rows_as_strings(ca)),
941        Value::StringArray(sa) => Ok(sa.data.clone()),
942        Value::Cell(ca) => {
943            let mut names = Vec::with_capacity(ca.data.len());
944            for handle in &ca.data {
945                let inner = unsafe { &*handle.as_raw() };
946                let text = value_to_string_scalar(inner).ok_or_else(|| {
947                    load_error(
948                        "load: cell arrays used for variable selection must contain string scalars",
949                    )
950                })?;
951                names.push(text);
952            }
953            Ok(names)
954        }
955        other => {
956            let gathered = gather_if_needed_async(other).await?;
957            extract_names(&gathered).await
958        }
959    }
960}
961
962fn value_to_string_scalar(value: &Value) -> Option<String> {
963    match value {
964        Value::String(s) => Some(s.clone()),
965        Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
966        Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
967        _ => None,
968    }
969}
970
971fn char_array_rows_as_strings(ca: &CharArray) -> Vec<String> {
972    let mut rows = Vec::with_capacity(ca.rows);
973    for r in 0..ca.rows {
974        let mut row = String::with_capacity(ca.cols);
975        for c in 0..ca.cols {
976            let idx = r * ca.cols + c;
977            row.push(ca.data[idx]);
978        }
979        let trimmed = row.trim_end_matches([' ', '\0']).to_string();
980        rows.push(trimmed);
981    }
982    rows
983}
984
985fn bytes_to_string(bytes: &[u8]) -> String {
986    let trimmed = bytes
987        .iter()
988        .copied()
989        .take_while(|b| *b != 0)
990        .collect::<Vec<u8>>();
991    String::from_utf8(trimmed).unwrap_or_default()
992}
993
994struct TaggedData {
995    data_type: u32,
996    data: Vec<u8>,
997}
998
999fn read_tagged<R: Read>(reader: &mut R, allow_eof: bool) -> BuiltinResult<Option<TaggedData>> {
1000    let mut type_bytes = [0u8; 4];
1001    match reader.read_exact(&mut type_bytes) {
1002        Ok(()) => {}
1003        Err(err) => {
1004            if allow_eof && err.kind() == std::io::ErrorKind::UnexpectedEof {
1005                return Ok(None);
1006            }
1007            return Err(load_error_with_source(
1008                &LOAD_ERROR_IO,
1009                format!("load: failed to read MAT element header: {err}"),
1010                err,
1011            ));
1012        }
1013    }
1014
1015    let type_field = u32::from_le_bytes(type_bytes);
1016    if (type_field & 0xFFFF0000) != 0 {
1017        let data_type = type_field & 0x0000FFFF;
1018        let num_bytes = ((type_field & 0xFFFF0000) >> 16) as usize;
1019        let mut inline = [0u8; 4];
1020        reader.read_exact(&mut inline).map_err(|err| {
1021            load_error_with_source(
1022                &LOAD_ERROR_IO,
1023                format!("load: failed to read compact MAT element: {err}"),
1024                err,
1025            )
1026        })?;
1027        let mut data = inline[..num_bytes.min(4)].to_vec();
1028        data.truncate(num_bytes.min(4));
1029        Ok(Some(TaggedData { data_type, data }))
1030    } else {
1031        let mut len_bytes = [0u8; 4];
1032        reader.read_exact(&mut len_bytes).map_err(|err| {
1033            load_error_with_source(
1034                &LOAD_ERROR_IO,
1035                format!("load: failed to read MAT element length: {err}"),
1036                err,
1037            )
1038        })?;
1039        let length = u32::from_le_bytes(len_bytes) as usize;
1040        let mut data = vec![0u8; length];
1041        reader.read_exact(&mut data).map_err(|err| {
1042            load_error_with_source(
1043                &LOAD_ERROR_IO,
1044                format!("load: failed to read MAT element body: {err}"),
1045                err,
1046            )
1047        })?;
1048        let padding = (8 - (length % 8)) % 8;
1049        if padding != 0 {
1050            let mut pad = vec![0u8; padding];
1051            reader.read_exact(&mut pad).map_err(|err| {
1052                load_error_with_source(
1053                    &LOAD_ERROR_IO,
1054                    format!("load: failed to read MAT padding: {err}"),
1055                    err,
1056                )
1057            })?;
1058        }
1059        Ok(Some(TaggedData {
1060            data_type: type_field,
1061            data,
1062        }))
1063    }
1064}
1065
1066#[cfg(test)]
1067pub(crate) mod tests {
1068    use super::*;
1069    use crate::workspace::WorkspaceResolver;
1070    use futures::executor::block_on;
1071    use runmat_builtins::StringArray;
1072    use runmat_thread_local::runmat_thread_local;
1073    use std::cell::RefCell;
1074    use std::collections::HashMap;
1075    use tempfile::tempdir;
1076
1077    runmat_thread_local! {
1078        static TEST_WORKSPACE: RefCell<HashMap<String, Value>> = RefCell::new(HashMap::new());
1079    }
1080
1081    fn ensure_test_resolver() {
1082        crate::workspace::register_workspace_resolver(WorkspaceResolver {
1083            lookup: |name| TEST_WORKSPACE.with(|slot| slot.borrow().get(name).cloned()),
1084            snapshot: || {
1085                let mut entries: Vec<(String, Value)> =
1086                    TEST_WORKSPACE.with(|slot| slot.borrow().clone().into_iter().collect());
1087                entries.sort_by(|a, b| a.0.cmp(&b.0));
1088                entries
1089            },
1090            globals: || Vec::new(),
1091            assign: None,
1092            clear: None,
1093            remove: None,
1094        });
1095    }
1096
1097    fn set_workspace(entries: &[(&str, Value)]) {
1098        TEST_WORKSPACE.with(|slot| {
1099            let mut map = slot.borrow_mut();
1100            map.clear();
1101            for (name, value) in entries {
1102                map.insert((*name).to_string(), value.clone());
1103            }
1104        });
1105    }
1106
1107    fn workspace_guard() -> std::sync::MutexGuard<'static, ()> {
1108        crate::workspace::test_guard()
1109    }
1110
1111    fn assert_error_contains<T>(result: crate::BuiltinResult<T>, snippet: &str) {
1112        match result {
1113            Err(err) => {
1114                assert!(
1115                    err.message().contains(snippet),
1116                    "expected error to contain '{snippet}', got '{}'",
1117                    err.message()
1118                );
1119            }
1120            Ok(_) => panic!("expected error containing '{snippet}'"),
1121        }
1122    }
1123
1124    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1125    #[test]
1126    fn load_descriptor_signatures_cover_core_forms() {
1127        let labels: Vec<&str> = LOAD_DESCRIPTOR
1128            .signatures
1129            .iter()
1130            .map(|sig| sig.label)
1131            .collect();
1132        assert!(labels.contains(&"S = load()"));
1133        assert!(labels.contains(&"S = load(filename)"));
1134        assert!(labels.contains(&"S = load(filename, varName1, varName2, ...)"));
1135        assert!(labels.contains(&"S = load(filename, \"-regexp\", pattern1, ...)"));
1136    }
1137
1138    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1139    #[test]
1140    fn load_roundtrip_numeric() {
1141        let _guard = workspace_guard();
1142        ensure_test_resolver();
1143        let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0], vec![2, 2]).unwrap();
1144        set_workspace(&[("A", Value::Tensor(tensor))]);
1145
1146        let dir = tempdir().unwrap();
1147        let path = dir.path().join("numeric.mat");
1148        let save_arg = Value::from(path.to_string_lossy().to_string());
1149        block_on(crate::call_builtin_async(
1150            "save",
1151            std::slice::from_ref(&save_arg),
1152        ))
1153        .unwrap();
1154
1155        let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
1156            .expect("load numeric");
1157        let struct_value = eval.first_output();
1158        match struct_value {
1159            Value::Struct(sv) => {
1160                assert!(sv.fields.contains_key("A"));
1161                match sv.fields.get("A").unwrap() {
1162                    Value::Tensor(t) => {
1163                        assert_eq!(t.shape, vec![2, 2]);
1164                        assert_eq!(t.data, vec![1.0, 4.0, 2.0, 5.0]);
1165                    }
1166                    other => panic!("expected tensor, got {other:?}"),
1167                }
1168            }
1169            other => panic!("expected struct, got {other:?}"),
1170        }
1171    }
1172
1173    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1174    #[test]
1175    fn load_selected_variables() {
1176        let _guard = workspace_guard();
1177        ensure_test_resolver();
1178        set_workspace(&[("signal", Value::Num(42.0)), ("noise", Value::Num(5.0))]);
1179        let dir = tempdir().unwrap();
1180        let path = dir.path().join("selection.mat");
1181        let save_arg = Value::from(path.to_string_lossy().to_string());
1182        block_on(crate::call_builtin_async(
1183            "save",
1184            std::slice::from_ref(&save_arg),
1185        ))
1186        .unwrap();
1187
1188        let eval = block_on(evaluate(&[
1189            Value::from(path.to_string_lossy().to_string()),
1190            Value::from("signal"),
1191        ]))
1192        .expect("load selection");
1193        let vars = eval.variables();
1194        assert_eq!(vars.len(), 1);
1195        assert_eq!(vars[0].0, "signal");
1196        assert!(matches!(vars[0].1, Value::Num(42.0)));
1197    }
1198
1199    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1200    #[test]
1201    fn load_regex_selection() {
1202        let _guard = workspace_guard();
1203        ensure_test_resolver();
1204        set_workspace(&[
1205            ("w1", Value::Num(1.0)),
1206            ("w2", Value::Num(2.0)),
1207            ("bias", Value::Num(3.0)),
1208        ]);
1209        let dir = tempdir().unwrap();
1210        let path = dir.path().join("regex.mat");
1211        let save_arg = Value::from(path.to_string_lossy().to_string());
1212        block_on(crate::call_builtin_async(
1213            "save",
1214            std::slice::from_ref(&save_arg),
1215        ))
1216        .unwrap();
1217
1218        let eval = block_on(evaluate(&[
1219            Value::from(path.to_string_lossy().to_string()),
1220            Value::from("-regexp"),
1221            Value::from("^w\\d$"),
1222        ]))
1223        .expect("load regex");
1224        let mut names: Vec<_> = eval.variables().iter().map(|(n, _)| n.clone()).collect();
1225        names.sort();
1226        assert_eq!(names, vec!["w1".to_string(), "w2".to_string()]);
1227    }
1228
1229    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1230    #[test]
1231    fn load_missing_variable_errors() {
1232        let _guard = workspace_guard();
1233        ensure_test_resolver();
1234        set_workspace(&[("existing", Value::Num(7.0))]);
1235        let dir = tempdir().unwrap();
1236        let path = dir.path().join("missing.mat");
1237        let save_arg = Value::from(path.to_string_lossy().to_string());
1238        block_on(crate::call_builtin_async(
1239            "save",
1240            std::slice::from_ref(&save_arg),
1241        ))
1242        .unwrap();
1243
1244        assert_error_contains(
1245            block_on(evaluate(&[
1246                Value::from(path.to_string_lossy().to_string()),
1247                Value::from("missing"),
1248            ])),
1249            "variable 'missing' was not found",
1250        );
1251    }
1252
1253    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1254    #[test]
1255    fn load_string_array_roundtrip() {
1256        let _guard = workspace_guard();
1257        ensure_test_resolver();
1258        let strings = StringArray::new(vec!["foo".into(), "bar".into()], vec![1, 2]).unwrap();
1259        set_workspace(&[("labels", Value::StringArray(strings))]);
1260        let dir = tempdir().unwrap();
1261        let path = dir.path().join("strings.mat");
1262        let save_arg = Value::from(path.to_string_lossy().to_string());
1263        block_on(crate::call_builtin_async(
1264            "save",
1265            std::slice::from_ref(&save_arg),
1266        ))
1267        .unwrap();
1268
1269        let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
1270            .expect("load strings");
1271        let struct_value = eval.first_output();
1272        match struct_value {
1273            Value::Struct(sv) => {
1274                let value = sv
1275                    .fields
1276                    .get("labels")
1277                    .expect("labels field missing in struct");
1278                match value {
1279                    Value::StringArray(sa) => {
1280                        assert_eq!(sa.shape, vec![1, 2]);
1281                        assert_eq!(sa.data, vec![String::from("foo"), String::from("bar")]);
1282                    }
1283                    other => panic!("expected string array, got {other:?}"),
1284                }
1285            }
1286            other => panic!("expected struct, got {other:?}"),
1287        }
1288    }
1289
1290    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1291    #[test]
1292    fn load_option_before_filename() {
1293        let _guard = workspace_guard();
1294        ensure_test_resolver();
1295        set_workspace(&[("alpha", Value::Num(1.0)), ("beta", Value::Num(2.0))]);
1296        let dir = tempdir().unwrap();
1297        let path = dir.path().join("option_first.mat");
1298        let save_arg = Value::from(path.to_string_lossy().to_string());
1299        block_on(crate::call_builtin_async(
1300            "save",
1301            std::slice::from_ref(&save_arg),
1302        ))
1303        .unwrap();
1304
1305        let eval = block_on(evaluate(&[
1306            Value::from("-mat"),
1307            Value::from(path.to_string_lossy().to_string()),
1308            Value::from("beta"),
1309        ]))
1310        .expect("load with option first");
1311        let vars = eval.variables();
1312        assert_eq!(vars.len(), 1);
1313        assert_eq!(vars[0].0, "beta");
1314        assert!(matches!(vars[0].1, Value::Num(2.0)));
1315    }
1316
1317    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1318    #[test]
1319    fn load_char_array_names_trimmed() {
1320        let _guard = workspace_guard();
1321        ensure_test_resolver();
1322        set_workspace(&[("short", Value::Num(5.0)), ("longer", Value::Num(9.0))]);
1323        let dir = tempdir().unwrap();
1324        let path = dir.path().join("char_names.mat");
1325        let save_arg = Value::from(path.to_string_lossy().to_string());
1326        block_on(crate::call_builtin_async(
1327            "save",
1328            std::slice::from_ref(&save_arg),
1329        ))
1330        .unwrap();
1331
1332        let cols = 6;
1333        let mut data = Vec::new();
1334        for name in ["short", "longer"] {
1335            let mut chars: Vec<char> = name.chars().collect();
1336            while chars.len() < cols {
1337                chars.push(' ');
1338            }
1339            data.extend(chars);
1340        }
1341        let name_array = CharArray::new(data, 2, cols).unwrap();
1342
1343        let eval = block_on(evaluate(&[
1344            Value::from(path.to_string_lossy().to_string()),
1345            Value::CharArray(name_array),
1346        ]))
1347        .expect("load with char array names");
1348        let vars = eval.variables();
1349        assert_eq!(vars.len(), 2);
1350        assert_eq!(vars[0].0, "short");
1351        assert!(matches!(vars[0].1, Value::Num(5.0)));
1352        assert_eq!(vars[1].0, "longer");
1353        assert!(matches!(vars[1].1, Value::Num(9.0)));
1354    }
1355
1356    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1357    #[test]
1358    fn load_duplicate_names_last_wins() {
1359        let _guard = workspace_guard();
1360        ensure_test_resolver();
1361        set_workspace(&[("dup", Value::Num(11.0))]);
1362        let dir = tempdir().unwrap();
1363        let path = dir.path().join("duplicates.mat");
1364        let save_arg = Value::from(path.to_string_lossy().to_string());
1365        block_on(crate::call_builtin_async(
1366            "save",
1367            std::slice::from_ref(&save_arg),
1368        ))
1369        .unwrap();
1370
1371        let eval = block_on(evaluate(&[
1372            Value::from(path.to_string_lossy().to_string()),
1373            Value::from("dup"),
1374            Value::from("dup"),
1375        ]))
1376        .expect("load with duplicate names");
1377        let vars = eval.variables();
1378        assert_eq!(vars.len(), 1);
1379        assert_eq!(vars[0].0, "dup");
1380        assert!(matches!(vars[0].1, Value::Num(11.0)));
1381    }
1382
1383    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1384    #[test]
1385    #[cfg(feature = "wgpu")]
1386    fn load_wgpu_tensor_roundtrip() {
1387        let _guard = workspace_guard();
1388        ensure_test_resolver();
1389        if runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1390            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1391        )
1392        .is_err()
1393        {
1394            return;
1395        }
1396        let Some(provider) = runmat_accelerate_api::provider() else {
1397            return;
1398        };
1399
1400        use runmat_accelerate_api::HostTensorView;
1401
1402        let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap();
1403        let view = HostTensorView {
1404            data: &tensor.data,
1405            shape: &tensor.shape,
1406        };
1407        let handle = provider.upload(&view).expect("upload tensor");
1408        set_workspace(&[("gpu_var", Value::GpuTensor(handle))]);
1409
1410        let dir = tempdir().unwrap();
1411        let path = dir.path().join("wgpu_load.mat");
1412        let save_args = vec![
1413            Value::from(path.to_string_lossy().to_string()),
1414            Value::from("gpu_var"),
1415        ];
1416        block_on(crate::call_builtin_async("save", &save_args)).unwrap();
1417
1418        let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
1419            .expect("load wgpu file");
1420        let struct_value = eval.first_output();
1421        match struct_value {
1422            Value::Struct(sv) => match sv.fields.get("gpu_var") {
1423                Some(Value::Tensor(t)) => {
1424                    assert_eq!(t.shape, vec![2, 2]);
1425                    assert_eq!(t.data, tensor.data);
1426                }
1427                other => panic!("expected tensor, got {other:?}"),
1428            },
1429            other => panic!("expected struct, got {other:?}"),
1430        }
1431    }
1432}