1use std::collections::HashMap;
4use std::io::{BufReader, Cursor, Read};
5use std::path::{Path, PathBuf};
6
7use regex::Regex;
8use runmat_builtins::{
9 CharArray, ComplexTensor, LogicalArray, StringArray, StructValue, Tensor, Value,
10};
11use runmat_filesystem::File;
12use runmat_macros::runtime_builtin;
13
14use super::format::{
15 MatArray, MatClass, MatData, FLAG_COMPLEX, FLAG_LOGICAL, MAT_HEADER_LEN, MI_DOUBLE, MI_INT32,
16 MI_INT8, MI_MATRIX, MI_UINT16, MI_UINT32, MI_UINT8,
17};
18use crate::builtins::common::spec::{
19 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
20 ReductionNaN, ResidencyPolicy, ShapeRequirements,
21};
22use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::mat::load")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26 name: "load",
27 op_kind: GpuOpKind::Custom("io-load"),
28 supported_precisions: &[],
29 broadcast: BroadcastSemantics::None,
30 provider_hooks: &[],
31 constant_strategy: ConstantStrategy::InlineLiteral,
32 residency: ResidencyPolicy::NewHandle,
33 nan_mode: ReductionNaN::Include,
34 two_pass_threshold: None,
35 workgroup_size: None,
36 accepts_nan_mode: false,
37 notes: "Reads MAT-files on the host and produces CPU-resident values. Providers are not involved until accelerated code later promotes the results.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::mat::load")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42 name: "load",
43 shape: ShapeRequirements::Any,
44 constant_strategy: ConstantStrategy::InlineLiteral,
45 elementwise: None,
46 reduction: None,
47 emits_nan: false,
48 notes: "File I/O is not eligible for fusion. Registration exists for documentation completeness only.",
49};
50
51#[runtime_builtin(
52 name = "load",
53 category = "io/mat",
54 summary = "Load variables from a MAT-file.",
55 keywords = "load,mat,workspace",
56 accel = "cpu",
57 sink = true,
58 type_resolver(crate::builtins::io::type_resolvers::load_type),
59 builtin_path = "crate::builtins::io::mat::load"
60)]
61async fn load_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
62 let eval = evaluate(&args).await?;
63
64 if let Some(n) = crate::output_count::current_output_count() {
67 if n > 1 {
68 return Err(load_error("load supports at most one output argument"));
69 }
70 }
71
72 if crate::output_context::requested_output_count() == Some(0) {
79 for (name, value) in eval.variables() {
80 crate::workspace::assign(name, value.clone()).map_err(|err| load_error(err))?;
81 }
82 return Ok(Value::OutputList(Vec::new()));
83 }
84
85 Ok(eval.first_output())
86}
87
88#[derive(Clone, Debug)]
89pub struct LoadEval {
90 variables: Vec<(String, Value)>,
91}
92
93impl LoadEval {
94 pub fn first_output(&self) -> Value {
95 let mut st = StructValue::new();
96 for (name, value) in &self.variables {
97 st.fields.insert(name.clone(), value.clone());
98 }
99 Value::Struct(st)
100 }
101
102 pub fn variables(&self) -> &[(String, Value)] {
103 &self.variables
104 }
105
106 pub fn into_variables(self) -> Vec<(String, Value)> {
107 self.variables
108 }
109}
110
111struct LoadRequest {
112 variables: Vec<String>,
113 regex_patterns: Vec<Regex>,
114}
115
116const BUILTIN_NAME: &str = "load";
117
118fn load_error(message: impl Into<String>) -> RuntimeError {
119 build_runtime_error(message)
120 .with_builtin(BUILTIN_NAME)
121 .build()
122}
123
124fn load_error_with_source(
125 message: impl Into<String>,
126 source: impl std::error::Error + Send + Sync + 'static,
127) -> RuntimeError {
128 build_runtime_error(message)
129 .with_builtin(BUILTIN_NAME)
130 .with_source(source)
131 .build()
132}
133
134pub async fn evaluate(args: &[Value]) -> BuiltinResult<LoadEval> {
135 let mut host_args = Vec::with_capacity(args.len());
136 for arg in args {
137 host_args.push(gather_if_needed_async(arg).await?);
138 }
139
140 let invocation = parse_invocation(&host_args).await?;
141
142 let mut path_value = if let Some(path) = invocation.path_value {
143 path
144 } else {
145 Value::from("matlab.mat")
146 };
147
148 if invocation.path_was_default {
149 if let Ok(override_path) = std::env::var("RUNMAT_LOAD_DEFAULT_PATH") {
150 path_value = Value::from(override_path);
151 }
152 }
153
154 let mut regex_patterns = Vec::with_capacity(invocation.regex_tokens.len());
155 for pattern in invocation.regex_tokens {
156 let regex = Regex::new(&pattern).map_err(|err| {
157 load_error_with_source(
158 format!("load: invalid regular expression '{pattern}': {err}"),
159 err,
160 )
161 })?;
162 regex_patterns.push(regex);
163 }
164
165 let request = LoadRequest {
166 variables: invocation.variables,
167 regex_patterns,
168 };
169 let path = normalise_path(&path_value)?;
170 let entries = read_mat_file(&path)?;
171
172 let selected = select_variables(&entries, &request)?;
173 Ok(LoadEval {
174 variables: selected,
175 })
176}
177
178struct ParsedInvocation {
179 path_value: Option<Value>,
180 path_was_default: bool,
181 variables: Vec<String>,
182 regex_tokens: Vec<String>,
183}
184
185async fn parse_invocation(values: &[Value]) -> BuiltinResult<ParsedInvocation> {
186 let mut path_value = None;
187 let mut path_was_default = false;
188 let mut variables = Vec::new();
189 let mut regex_tokens = Vec::new();
190 let mut idx = 0usize;
191 while idx < values.len() {
192 if let Some(flag) = option_token(&values[idx])? {
193 match flag.as_str() {
194 "-mat" => {
195 idx += 1;
196 continue;
197 }
198 "-regexp" => {
199 idx += 1;
200 if idx >= values.len() {
201 return Err(load_error("load: '-regexp' requires at least one pattern"));
202 }
203 while idx < values.len() {
204 if option_token(&values[idx])?.is_some() {
205 break;
206 }
207 let names = extract_names(&values[idx]).await?;
208 if names.is_empty() {
209 return Err(load_error(
210 "load: '-regexp' requires non-empty pattern strings",
211 ));
212 }
213 regex_tokens.extend(names);
214 idx += 1;
215 }
216 continue;
217 }
218 other => {
219 return Err(load_error(format!("load: unsupported option '{other}'")));
220 }
221 }
222 } else {
223 if path_value.is_none() {
224 path_value = Some(values[idx].clone());
225 idx += 1;
226 continue;
227 }
228 let names = extract_names(&values[idx]).await?;
229 variables.extend(names);
230 idx += 1;
231 }
232 }
233
234 if path_value.is_none() {
235 path_was_default = true;
236 }
237
238 Ok(ParsedInvocation {
239 path_value,
240 path_was_default,
241 variables,
242 regex_tokens,
243 })
244}
245
246fn normalise_path(value: &Value) -> BuiltinResult<PathBuf> {
247 let raw = value_to_string_scalar(value)
248 .ok_or_else(|| load_error("load: filename must be a character vector or string scalar"))?;
249 let mut path = PathBuf::from(raw);
250 if path.extension().is_none() {
251 path.set_extension("mat");
252 }
253 Ok(path)
254}
255
256fn select_variables(
257 entries: &[(String, Value)],
258 request: &LoadRequest,
259) -> BuiltinResult<Vec<(String, Value)>> {
260 if request.variables.is_empty() && request.regex_patterns.is_empty() {
261 return Ok(entries.to_vec());
262 }
263
264 let mut by_name: HashMap<&str, &Value> = HashMap::with_capacity(entries.len());
265 for (name, value) in entries {
266 by_name.insert(name, value);
267 }
268
269 let mut selected = Vec::new();
270
271 for name in &request.variables {
272 let value = by_name.get(name.as_str()).ok_or_else(|| {
273 load_error(format!("load: variable '{name}' was not found in the file"))
274 })?;
275 insert_or_replace(&mut selected, name, (*value).clone());
276 }
277
278 if !request.regex_patterns.is_empty() {
279 let mut matched = 0usize;
280 for (name, value) in entries {
281 if request
282 .regex_patterns
283 .iter()
284 .any(|regex| regex.is_match(name))
285 {
286 matched += 1;
287 insert_or_replace(&mut selected, name, value.clone());
288 }
289 }
290 if matched == 0 && request.variables.is_empty() {
291 return Err(load_error("load: no variables matched '-regexp' patterns"));
292 }
293 }
294
295 if selected.is_empty() {
296 return Err(load_error("load: no variables selected"));
297 }
298
299 Ok(selected)
300}
301
302fn insert_or_replace(selected: &mut Vec<(String, Value)>, name: &str, value: Value) {
303 if let Some(entry) = selected.iter_mut().find(|(existing, _)| existing == name) {
304 entry.1 = value;
305 } else {
306 selected.push((name.to_string(), value));
307 }
308}
309
310pub(crate) fn read_mat_file_for_builtin(
311 path: &Path,
312 builtin: &str,
313) -> crate::BuiltinResult<Vec<(String, Value)>> {
314 match read_mat_file(path) {
315 Ok(entries) => Ok(entries),
316 Err(err) => {
317 let message = err.message().replacen("load:", &format!("{builtin}:"), 1);
318 let mut builder = build_runtime_error(message).with_builtin(builtin);
319 if let Some(identifier) = err.identifier() {
320 builder = builder.with_identifier(identifier);
321 }
322 Err(builder.with_source(err).build())
323 }
324 }
325}
326
327pub(crate) fn read_mat_file(path: &Path) -> BuiltinResult<Vec<(String, Value)>> {
328 let file = File::open(path).map_err(|err| {
329 load_error_with_source(
330 format!("load: failed to open '{}': {err}", path.display()),
331 err,
332 )
333 })?;
334 let mut reader = BufReader::new(file);
335 read_mat_reader(&mut reader)
336}
337
338pub fn decode_workspace_from_mat_bytes(bytes: &[u8]) -> BuiltinResult<Vec<(String, Value)>> {
339 let mut cursor = Cursor::new(bytes);
340 read_mat_reader(&mut cursor)
341}
342
343fn read_mat_reader<R: Read>(reader: &mut R) -> BuiltinResult<Vec<(String, Value)>> {
344 let mut header = [0u8; MAT_HEADER_LEN];
345 reader.read_exact(&mut header).map_err(|err| {
346 load_error_with_source(format!("load: failed to read MAT-file header: {err}"), err)
347 })?;
348 if header[126] != b'I' || header[127] != b'M' {
349 return Err(load_error("load: file is not a MATLAB Level-5 MAT-file"));
350 }
351
352 let mut variables = Vec::new();
353 while let Some(tagged) = read_tagged(reader, true)? {
354 if tagged.data_type != MI_MATRIX {
355 continue;
356 }
357 let parsed = parse_matrix(&tagged.data)?;
358 let value = mat_array_to_value(parsed.array)?;
359 variables.push((parsed.name, value));
360 }
361 Ok(variables)
362}
363
364struct ParsedMatrix {
365 name: String,
366 array: MatArray,
367}
368
369fn parse_matrix(buffer: &[u8]) -> BuiltinResult<ParsedMatrix> {
370 let mut cursor = Cursor::new(buffer);
371
372 let flags = read_tagged(&mut cursor, false)?
373 .ok_or_else(|| load_error("load: matrix element missing array flags"))?;
374 if flags.data_type != MI_UINT32 || flags.data.len() < 8 {
375 return Err(load_error("load: invalid array flags block"));
376 }
377 let flags0 = u32::from_le_bytes(flags.data[0..4].try_into().unwrap());
378 let class_code = flags0 & 0xFF;
379 let mut class = MatClass::from_class_code(class_code)
380 .ok_or_else(|| load_error("load: unsupported MATLAB class"))?;
381 let is_logical = (flags0 & FLAG_LOGICAL) != 0;
382 let has_imag = (flags0 & FLAG_COMPLEX) != 0;
383 if matches!(class, MatClass::Double) && is_logical {
384 class = MatClass::Logical;
385 }
386
387 let dims_elem = read_tagged(&mut cursor, false)?
388 .ok_or_else(|| load_error("load: matrix element missing dimensions"))?;
389 if dims_elem.data_type != MI_INT32 {
390 return Err(load_error("load: dimension block must use MI_INT32"));
391 }
392 if dims_elem.data.is_empty() || dims_elem.data.len() % 4 != 0 {
393 return Err(load_error("load: malformed dimension block"));
394 }
395 let mut dims = Vec::with_capacity(dims_elem.data.len() / 4);
396 for chunk in dims_elem.data.chunks_exact(4) {
397 let value = i32::from_le_bytes(chunk.try_into().unwrap());
398 if value < 0 {
399 return Err(load_error("load: negative dimensions are not supported"));
400 }
401 dims.push(value as usize);
402 }
403 if dims.is_empty() {
404 dims.push(1);
405 dims.push(1);
406 }
407
408 let name_elem = read_tagged(&mut cursor, false)?
409 .ok_or_else(|| load_error("load: matrix element missing name"))?;
410 let name = match name_elem.data_type {
411 MI_INT8 | MI_UINT8 => bytes_to_string(&name_elem.data),
412 MI_UINT16 => {
413 let mut bytes = Vec::with_capacity(name_elem.data.len());
414 for chunk in name_elem.data.chunks_exact(2) {
415 let code = u16::from_le_bytes(chunk.try_into().unwrap());
416 if code == 0 {
417 break;
418 }
419 if let Some(ch) = char::from_u32(code as u32) {
420 bytes.push(ch);
421 }
422 }
423 bytes.into_iter().collect()
424 }
425 _ => {
426 return Err(load_error("load: unsupported array name encoding"));
427 }
428 };
429
430 let array = match class {
431 MatClass::Double => parse_double_array(&mut cursor, dims, has_imag)?,
432 MatClass::Logical => parse_logical_array(&mut cursor, dims)?,
433 MatClass::Char => parse_char_array(&mut cursor, dims)?,
434 MatClass::Cell => parse_cell_array(&mut cursor, dims)?,
435 MatClass::Struct => parse_struct(&mut cursor, dims)?,
436 };
437
438 Ok(ParsedMatrix { name, array })
439}
440
441fn parse_double_array(
442 cursor: &mut Cursor<&[u8]>,
443 dims: Vec<usize>,
444 has_imag: bool,
445) -> BuiltinResult<MatArray> {
446 let real_elem = read_tagged(cursor, false)?
447 .ok_or_else(|| load_error("load: numeric array missing real component"))?;
448 if real_elem.data_type != MI_DOUBLE || real_elem.data.len() % 8 != 0 {
449 return Err(load_error("load: numeric data must be stored as MI_DOUBLE"));
450 }
451 let mut real = Vec::with_capacity(real_elem.data.len() / 8);
452 for chunk in real_elem.data.chunks_exact(8) {
453 real.push(f64::from_le_bytes(chunk.try_into().unwrap()));
454 }
455
456 let imag = if has_imag {
457 let imag_elem = read_tagged(cursor, false)?
458 .ok_or_else(|| load_error("load: numeric array missing imaginary component"))?;
459 if imag_elem.data_type != MI_DOUBLE || imag_elem.data.len() % 8 != 0 {
460 return Err(load_error("load: imaginary component must be MI_DOUBLE"));
461 }
462 let mut imag = Vec::with_capacity(imag_elem.data.len() / 8);
463 for chunk in imag_elem.data.chunks_exact(8) {
464 imag.push(f64::from_le_bytes(chunk.try_into().unwrap()));
465 }
466 Some(imag)
467 } else {
468 None
469 };
470
471 Ok(MatArray {
472 class: MatClass::Double,
473 dims,
474 data: MatData::Double { real, imag },
475 })
476}
477
478fn parse_logical_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
479 let elem = read_tagged(cursor, false)?
480 .ok_or_else(|| load_error("load: logical array missing data block"))?;
481 if elem.data_type != MI_UINT8 {
482 return Err(load_error(
483 "load: logical arrays must be stored as MI_UINT8",
484 ));
485 }
486 Ok(MatArray {
487 class: MatClass::Logical,
488 dims,
489 data: MatData::Logical { data: elem.data },
490 })
491}
492
493fn parse_char_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
494 let elem = read_tagged(cursor, false)?
495 .ok_or_else(|| load_error("load: character array missing data block"))?;
496 if elem.data_type != MI_UINT16 {
497 return Err(load_error(
498 "load: character data must be stored as MI_UINT16",
499 ));
500 }
501 if elem.data.len() % 2 != 0 {
502 return Err(load_error("load: malformed character data"));
503 }
504 let mut data = Vec::with_capacity(elem.data.len() / 2);
505 for chunk in elem.data.chunks_exact(2) {
506 data.push(u16::from_le_bytes(chunk.try_into().unwrap()));
507 }
508 Ok(MatArray {
509 class: MatClass::Char,
510 dims,
511 data: MatData::Char { data },
512 })
513}
514
515fn parse_cell_array(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
516 let total: usize = dims
517 .iter()
518 .copied()
519 .fold(1usize, |acc, d| acc.saturating_mul(d));
520 let mut elements = Vec::with_capacity(total);
521 for _ in 0..total {
522 let elem = read_tagged(cursor, false)?
523 .ok_or_else(|| load_error("load: cell element missing matrix payload"))?;
524 if elem.data_type != MI_MATRIX {
525 return Err(load_error("load: cell elements must be matrices"));
526 }
527 let parsed = parse_matrix(&elem.data)?;
528 elements.push(parsed.array);
529 }
530 Ok(MatArray {
531 class: MatClass::Cell,
532 dims,
533 data: MatData::Cell { elements },
534 })
535}
536
537fn parse_struct(cursor: &mut Cursor<&[u8]>, dims: Vec<usize>) -> BuiltinResult<MatArray> {
538 if dims.len() != 2 || dims[0] != 1 || dims[1] != 1 {
539 return Err(load_error("load: struct arrays are not supported yet"));
540 }
541 let len_elem = read_tagged(cursor, false)?
542 .ok_or_else(|| load_error("load: struct missing maximum field length specifier"))?;
543 if len_elem.data_type != MI_INT32 || len_elem.data.len() != 4 {
544 return Err(load_error("load: struct field length must be MI_INT32"));
545 }
546 let max_len = i32::from_le_bytes(len_elem.data[..4].try_into().unwrap());
547 if max_len <= 0 {
548 return Err(load_error("load: struct field length must be positive"));
549 }
550
551 let names_elem = read_tagged(cursor, false)?
552 .ok_or_else(|| load_error("load: struct missing field name table"))?;
553 if names_elem.data_type != MI_INT8 && names_elem.data_type != MI_UINT8 {
554 return Err(load_error(
555 "load: struct field names must be stored as MI_INT8/MI_UINT8",
556 ));
557 }
558 if names_elem.data.len() % (max_len as usize) != 0 {
559 return Err(load_error("load: malformed struct field name table"));
560 }
561 let field_count = names_elem.data.len() / (max_len as usize);
562 let mut field_names = Vec::with_capacity(field_count);
563 for i in 0..field_count {
564 let start = i * (max_len as usize);
565 let end = start + (max_len as usize);
566 let slice = &names_elem.data[start..end];
567 field_names.push(bytes_to_string(slice));
568 }
569
570 let mut field_values = Vec::with_capacity(field_count);
571 for _ in 0..field_count {
572 let elem = read_tagged(cursor, false)?
573 .ok_or_else(|| load_error("load: struct field missing matrix payload"))?;
574 if elem.data_type != MI_MATRIX {
575 return Err(load_error("load: struct fields must be matrices"));
576 }
577 let parsed = parse_matrix(&elem.data)?;
578 field_values.push(parsed.array);
579 }
580
581 Ok(MatArray {
582 class: MatClass::Struct,
583 dims,
584 data: MatData::Struct {
585 field_names,
586 field_values,
587 },
588 })
589}
590
591fn mat_array_to_value(array: MatArray) -> BuiltinResult<Value> {
592 match array.data {
593 MatData::Double { real, imag } => {
594 let len = real.len();
595 if let Some(imag) = imag {
596 if imag.len() != len {
597 return Err(load_error(
598 "load: complex data has mismatched real/imag parts",
599 ));
600 }
601 if len == 1 {
602 Ok(Value::Complex(real[0], imag[0]))
603 } else {
604 let mut pairs = Vec::with_capacity(len);
605 for i in 0..len {
606 pairs.push((real[i], imag[i]));
607 }
608 let tensor = ComplexTensor::new(pairs, array.dims.clone())
609 .map_err(|e| load_error(format!("load: {e}")))?;
610 Ok(Value::ComplexTensor(tensor))
611 }
612 } else if len == 1 {
613 Ok(Value::Num(real[0]))
614 } else {
615 let tensor = Tensor::new(real, array.dims.clone())
616 .map_err(|e| load_error(format!("load: {e}")))?;
617 Ok(Value::Tensor(tensor))
618 }
619 }
620 MatData::Logical { data } => {
621 let total: usize = array
622 .dims
623 .iter()
624 .copied()
625 .fold(1usize, |acc, d| acc.saturating_mul(d));
626 if data.len() != total {
627 return Err(load_error("load: logical data length mismatch"));
628 }
629 if total == 1 {
630 Ok(Value::Bool(data.first().copied().unwrap_or(0) != 0))
631 } else {
632 let logical = LogicalArray::new(data, array.dims.clone())
633 .map_err(|e| load_error(format!("load: {e}")))?;
634 Ok(Value::LogicalArray(logical))
635 }
636 }
637 MatData::Char { data } => {
638 let rows = array.dims.first().copied().unwrap_or(1);
639 let cols = array.dims.get(1).copied().unwrap_or(1);
640 let mut chars = Vec::with_capacity(rows.saturating_mul(cols));
641 for code in data {
642 let ch = char::from_u32(code as u32).unwrap_or('\u{FFFD}');
643 chars.push(ch);
644 }
645 let char_array =
646 CharArray::new(chars, rows, cols).map_err(|e| load_error(format!("load: {e}")))?;
647 Ok(Value::CharArray(char_array))
648 }
649 MatData::Cell { elements } => {
650 if let Some(strings) = cell_elements_to_strings(&elements) {
651 let string_array = StringArray::new(strings, array.dims.clone())
652 .map_err(|e| load_error(format!("load: {e}")))?;
653 return Ok(Value::StringArray(string_array));
654 }
655 if array.dims.len() != 2 {
656 return Err(load_error(
657 "load: cell arrays with more than two dimensions are not supported yet",
658 ));
659 }
660 let rows = array.dims[0];
661 let cols = array.dims[1];
662 let expected = rows.saturating_mul(cols);
663 if elements.len() != expected {
664 return Err(load_error("load: cell array element count mismatch"));
665 }
666 let mut converted = Vec::with_capacity(elements.len());
667 for elem in elements {
668 converted.push(mat_array_to_value(elem)?);
669 }
670 let mut row_major = vec![Value::Num(0.0); expected];
671 for col in 0..cols {
672 for row in 0..rows {
673 let cm_idx = col * rows + row;
674 let rm_idx = row * cols + col;
675 row_major[rm_idx] = converted[cm_idx].clone();
676 }
677 }
678 make_cell(row_major, rows, cols).map_err(|err| load_error(format!("load: {err}")))
679 }
680 MatData::Struct {
681 field_names,
682 field_values,
683 } => {
684 if field_names.len() != field_values.len() {
685 return Err(load_error("load: struct field metadata is inconsistent"));
686 }
687 let mut st = StructValue::new();
688 for (name, value) in field_names.into_iter().zip(field_values.into_iter()) {
689 let converted = mat_array_to_value(value)?;
690 st.fields.insert(name, converted);
691 }
692 Ok(Value::Struct(st))
693 }
694 }
695}
696
697fn cell_elements_to_strings(elements: &[MatArray]) -> Option<Vec<String>> {
698 let mut strings = Vec::with_capacity(elements.len());
699 for element in elements {
700 if element.class != MatClass::Char {
701 return None;
702 }
703 let rows = element.dims.first().copied().unwrap_or(1);
704 if rows > 1 {
705 return None;
706 }
707 match &element.data {
708 MatData::Char { data } => strings.push(utf16_codes_to_string(data)),
709 _ => return None,
710 }
711 }
712 Some(strings)
713}
714
715fn utf16_codes_to_string(data: &[u16]) -> String {
716 let mut chars: Vec<char> = data
717 .iter()
718 .map(|code| char::from_u32(*code as u32).unwrap_or('\u{FFFD}'))
719 .collect();
720 while matches!(chars.last(), Some(&'\0')) {
721 chars.pop();
722 }
723 chars.into_iter().collect()
724}
725
726fn option_token(value: &Value) -> BuiltinResult<Option<String>> {
727 if let Some(token) = value_to_string_scalar(value) {
728 if token.starts_with('-') {
729 return Ok(Some(token.to_ascii_lowercase()));
730 }
731 }
732 Ok(None)
733}
734
735#[async_recursion::async_recursion(?Send)]
736async fn extract_names(value: &Value) -> BuiltinResult<Vec<String>> {
737 match value {
738 Value::String(s) => Ok(vec![s.clone()]),
739 Value::CharArray(ca) => Ok(char_array_rows_as_strings(ca)),
740 Value::StringArray(sa) => Ok(sa.data.clone()),
741 Value::Cell(ca) => {
742 let mut names = Vec::with_capacity(ca.data.len());
743 for handle in &ca.data {
744 let inner = unsafe { &*handle.as_raw() };
745 let text = value_to_string_scalar(inner).ok_or_else(|| {
746 load_error(
747 "load: cell arrays used for variable selection must contain string scalars",
748 )
749 })?;
750 names.push(text);
751 }
752 Ok(names)
753 }
754 other => {
755 let gathered = gather_if_needed_async(other).await?;
756 extract_names(&gathered).await
757 }
758 }
759}
760
761fn value_to_string_scalar(value: &Value) -> Option<String> {
762 match value {
763 Value::String(s) => Some(s.clone()),
764 Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
765 Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
766 _ => None,
767 }
768}
769
770fn char_array_rows_as_strings(ca: &CharArray) -> Vec<String> {
771 let mut rows = Vec::with_capacity(ca.rows);
772 for r in 0..ca.rows {
773 let mut row = String::with_capacity(ca.cols);
774 for c in 0..ca.cols {
775 let idx = r * ca.cols + c;
776 row.push(ca.data[idx]);
777 }
778 let trimmed = row.trim_end_matches([' ', '\0']).to_string();
779 rows.push(trimmed);
780 }
781 rows
782}
783
784fn bytes_to_string(bytes: &[u8]) -> String {
785 let trimmed = bytes
786 .iter()
787 .copied()
788 .take_while(|b| *b != 0)
789 .collect::<Vec<u8>>();
790 String::from_utf8(trimmed).unwrap_or_default()
791}
792
793struct TaggedData {
794 data_type: u32,
795 data: Vec<u8>,
796}
797
798fn read_tagged<R: Read>(reader: &mut R, allow_eof: bool) -> BuiltinResult<Option<TaggedData>> {
799 let mut type_bytes = [0u8; 4];
800 match reader.read_exact(&mut type_bytes) {
801 Ok(()) => {}
802 Err(err) => {
803 if allow_eof && err.kind() == std::io::ErrorKind::UnexpectedEof {
804 return Ok(None);
805 }
806 return Err(load_error_with_source(
807 format!("load: failed to read MAT element header: {err}"),
808 err,
809 ));
810 }
811 }
812
813 let type_field = u32::from_le_bytes(type_bytes);
814 if (type_field & 0xFFFF0000) != 0 {
815 let data_type = type_field & 0x0000FFFF;
816 let num_bytes = ((type_field & 0xFFFF0000) >> 16) as usize;
817 let mut inline = [0u8; 4];
818 reader.read_exact(&mut inline).map_err(|err| {
819 load_error_with_source(
820 format!("load: failed to read compact MAT element: {err}"),
821 err,
822 )
823 })?;
824 let mut data = inline[..num_bytes.min(4)].to_vec();
825 data.truncate(num_bytes.min(4));
826 Ok(Some(TaggedData { data_type, data }))
827 } else {
828 let mut len_bytes = [0u8; 4];
829 reader.read_exact(&mut len_bytes).map_err(|err| {
830 load_error_with_source(
831 format!("load: failed to read MAT element length: {err}"),
832 err,
833 )
834 })?;
835 let length = u32::from_le_bytes(len_bytes) as usize;
836 let mut data = vec![0u8; length];
837 reader.read_exact(&mut data).map_err(|err| {
838 load_error_with_source(format!("load: failed to read MAT element body: {err}"), err)
839 })?;
840 let padding = (8 - (length % 8)) % 8;
841 if padding != 0 {
842 let mut pad = vec![0u8; padding];
843 reader.read_exact(&mut pad).map_err(|err| {
844 load_error_with_source(format!("load: failed to read MAT padding: {err}"), err)
845 })?;
846 }
847 Ok(Some(TaggedData {
848 data_type: type_field,
849 data,
850 }))
851 }
852}
853
854#[cfg(test)]
855pub(crate) mod tests {
856 use super::*;
857 use crate::workspace::WorkspaceResolver;
858 use futures::executor::block_on;
859 use runmat_builtins::StringArray;
860 use runmat_thread_local::runmat_thread_local;
861 use std::cell::RefCell;
862 use std::collections::HashMap;
863 use tempfile::tempdir;
864
865 runmat_thread_local! {
866 static TEST_WORKSPACE: RefCell<HashMap<String, Value>> = RefCell::new(HashMap::new());
867 }
868
869 fn ensure_test_resolver() {
870 crate::workspace::register_workspace_resolver(WorkspaceResolver {
871 lookup: |name| TEST_WORKSPACE.with(|slot| slot.borrow().get(name).cloned()),
872 snapshot: || {
873 let mut entries: Vec<(String, Value)> =
874 TEST_WORKSPACE.with(|slot| slot.borrow().clone().into_iter().collect());
875 entries.sort_by(|a, b| a.0.cmp(&b.0));
876 entries
877 },
878 globals: || Vec::new(),
879 assign: None,
880 clear: None,
881 remove: None,
882 });
883 }
884
885 fn set_workspace(entries: &[(&str, Value)]) {
886 TEST_WORKSPACE.with(|slot| {
887 let mut map = slot.borrow_mut();
888 map.clear();
889 for (name, value) in entries {
890 map.insert((*name).to_string(), value.clone());
891 }
892 });
893 }
894
895 fn workspace_guard() -> std::sync::MutexGuard<'static, ()> {
896 crate::workspace::test_guard()
897 }
898
899 fn assert_error_contains<T>(result: crate::BuiltinResult<T>, snippet: &str) {
900 match result {
901 Err(err) => {
902 assert!(
903 err.message().contains(snippet),
904 "expected error to contain '{snippet}', got '{}'",
905 err.message()
906 );
907 }
908 Ok(_) => panic!("expected error containing '{snippet}'"),
909 }
910 }
911
912 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
913 #[test]
914 fn load_roundtrip_numeric() {
915 let _guard = workspace_guard();
916 ensure_test_resolver();
917 let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0], vec![2, 2]).unwrap();
918 set_workspace(&[("A", Value::Tensor(tensor))]);
919
920 let dir = tempdir().unwrap();
921 let path = dir.path().join("numeric.mat");
922 let save_arg = Value::from(path.to_string_lossy().to_string());
923 block_on(crate::call_builtin_async(
924 "save",
925 std::slice::from_ref(&save_arg),
926 ))
927 .unwrap();
928
929 let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
930 .expect("load numeric");
931 let struct_value = eval.first_output();
932 match struct_value {
933 Value::Struct(sv) => {
934 assert!(sv.fields.contains_key("A"));
935 match sv.fields.get("A").unwrap() {
936 Value::Tensor(t) => {
937 assert_eq!(t.shape, vec![2, 2]);
938 assert_eq!(t.data, vec![1.0, 4.0, 2.0, 5.0]);
939 }
940 other => panic!("expected tensor, got {other:?}"),
941 }
942 }
943 other => panic!("expected struct, got {other:?}"),
944 }
945 }
946
947 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
948 #[test]
949 fn load_selected_variables() {
950 let _guard = workspace_guard();
951 ensure_test_resolver();
952 set_workspace(&[("signal", Value::Num(42.0)), ("noise", Value::Num(5.0))]);
953 let dir = tempdir().unwrap();
954 let path = dir.path().join("selection.mat");
955 let save_arg = Value::from(path.to_string_lossy().to_string());
956 block_on(crate::call_builtin_async(
957 "save",
958 std::slice::from_ref(&save_arg),
959 ))
960 .unwrap();
961
962 let eval = block_on(evaluate(&[
963 Value::from(path.to_string_lossy().to_string()),
964 Value::from("signal"),
965 ]))
966 .expect("load selection");
967 let vars = eval.variables();
968 assert_eq!(vars.len(), 1);
969 assert_eq!(vars[0].0, "signal");
970 assert!(matches!(vars[0].1, Value::Num(42.0)));
971 }
972
973 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
974 #[test]
975 fn load_regex_selection() {
976 let _guard = workspace_guard();
977 ensure_test_resolver();
978 set_workspace(&[
979 ("w1", Value::Num(1.0)),
980 ("w2", Value::Num(2.0)),
981 ("bias", Value::Num(3.0)),
982 ]);
983 let dir = tempdir().unwrap();
984 let path = dir.path().join("regex.mat");
985 let save_arg = Value::from(path.to_string_lossy().to_string());
986 block_on(crate::call_builtin_async(
987 "save",
988 std::slice::from_ref(&save_arg),
989 ))
990 .unwrap();
991
992 let eval = block_on(evaluate(&[
993 Value::from(path.to_string_lossy().to_string()),
994 Value::from("-regexp"),
995 Value::from("^w\\d$"),
996 ]))
997 .expect("load regex");
998 let mut names: Vec<_> = eval.variables().iter().map(|(n, _)| n.clone()).collect();
999 names.sort();
1000 assert_eq!(names, vec!["w1".to_string(), "w2".to_string()]);
1001 }
1002
1003 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1004 #[test]
1005 fn load_missing_variable_errors() {
1006 let _guard = workspace_guard();
1007 ensure_test_resolver();
1008 set_workspace(&[("existing", Value::Num(7.0))]);
1009 let dir = tempdir().unwrap();
1010 let path = dir.path().join("missing.mat");
1011 let save_arg = Value::from(path.to_string_lossy().to_string());
1012 block_on(crate::call_builtin_async(
1013 "save",
1014 std::slice::from_ref(&save_arg),
1015 ))
1016 .unwrap();
1017
1018 assert_error_contains(
1019 block_on(evaluate(&[
1020 Value::from(path.to_string_lossy().to_string()),
1021 Value::from("missing"),
1022 ])),
1023 "variable 'missing' was not found",
1024 );
1025 }
1026
1027 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1028 #[test]
1029 fn load_string_array_roundtrip() {
1030 let _guard = workspace_guard();
1031 ensure_test_resolver();
1032 let strings = StringArray::new(vec!["foo".into(), "bar".into()], vec![1, 2]).unwrap();
1033 set_workspace(&[("labels", Value::StringArray(strings))]);
1034 let dir = tempdir().unwrap();
1035 let path = dir.path().join("strings.mat");
1036 let save_arg = Value::from(path.to_string_lossy().to_string());
1037 block_on(crate::call_builtin_async(
1038 "save",
1039 std::slice::from_ref(&save_arg),
1040 ))
1041 .unwrap();
1042
1043 let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
1044 .expect("load strings");
1045 let struct_value = eval.first_output();
1046 match struct_value {
1047 Value::Struct(sv) => {
1048 let value = sv
1049 .fields
1050 .get("labels")
1051 .expect("labels field missing in struct");
1052 match value {
1053 Value::StringArray(sa) => {
1054 assert_eq!(sa.shape, vec![1, 2]);
1055 assert_eq!(sa.data, vec![String::from("foo"), String::from("bar")]);
1056 }
1057 other => panic!("expected string array, got {other:?}"),
1058 }
1059 }
1060 other => panic!("expected struct, got {other:?}"),
1061 }
1062 }
1063
1064 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1065 #[test]
1066 fn load_option_before_filename() {
1067 let _guard = workspace_guard();
1068 ensure_test_resolver();
1069 set_workspace(&[("alpha", Value::Num(1.0)), ("beta", Value::Num(2.0))]);
1070 let dir = tempdir().unwrap();
1071 let path = dir.path().join("option_first.mat");
1072 let save_arg = Value::from(path.to_string_lossy().to_string());
1073 block_on(crate::call_builtin_async(
1074 "save",
1075 std::slice::from_ref(&save_arg),
1076 ))
1077 .unwrap();
1078
1079 let eval = block_on(evaluate(&[
1080 Value::from("-mat"),
1081 Value::from(path.to_string_lossy().to_string()),
1082 Value::from("beta"),
1083 ]))
1084 .expect("load with option first");
1085 let vars = eval.variables();
1086 assert_eq!(vars.len(), 1);
1087 assert_eq!(vars[0].0, "beta");
1088 assert!(matches!(vars[0].1, Value::Num(2.0)));
1089 }
1090
1091 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1092 #[test]
1093 fn load_char_array_names_trimmed() {
1094 let _guard = workspace_guard();
1095 ensure_test_resolver();
1096 set_workspace(&[("short", Value::Num(5.0)), ("longer", Value::Num(9.0))]);
1097 let dir = tempdir().unwrap();
1098 let path = dir.path().join("char_names.mat");
1099 let save_arg = Value::from(path.to_string_lossy().to_string());
1100 block_on(crate::call_builtin_async(
1101 "save",
1102 std::slice::from_ref(&save_arg),
1103 ))
1104 .unwrap();
1105
1106 let cols = 6;
1107 let mut data = Vec::new();
1108 for name in ["short", "longer"] {
1109 let mut chars: Vec<char> = name.chars().collect();
1110 while chars.len() < cols {
1111 chars.push(' ');
1112 }
1113 data.extend(chars);
1114 }
1115 let name_array = CharArray::new(data, 2, cols).unwrap();
1116
1117 let eval = block_on(evaluate(&[
1118 Value::from(path.to_string_lossy().to_string()),
1119 Value::CharArray(name_array),
1120 ]))
1121 .expect("load with char array names");
1122 let vars = eval.variables();
1123 assert_eq!(vars.len(), 2);
1124 assert_eq!(vars[0].0, "short");
1125 assert!(matches!(vars[0].1, Value::Num(5.0)));
1126 assert_eq!(vars[1].0, "longer");
1127 assert!(matches!(vars[1].1, Value::Num(9.0)));
1128 }
1129
1130 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1131 #[test]
1132 fn load_duplicate_names_last_wins() {
1133 let _guard = workspace_guard();
1134 ensure_test_resolver();
1135 set_workspace(&[("dup", Value::Num(11.0))]);
1136 let dir = tempdir().unwrap();
1137 let path = dir.path().join("duplicates.mat");
1138 let save_arg = Value::from(path.to_string_lossy().to_string());
1139 block_on(crate::call_builtin_async(
1140 "save",
1141 std::slice::from_ref(&save_arg),
1142 ))
1143 .unwrap();
1144
1145 let eval = block_on(evaluate(&[
1146 Value::from(path.to_string_lossy().to_string()),
1147 Value::from("dup"),
1148 Value::from("dup"),
1149 ]))
1150 .expect("load with duplicate names");
1151 let vars = eval.variables();
1152 assert_eq!(vars.len(), 1);
1153 assert_eq!(vars[0].0, "dup");
1154 assert!(matches!(vars[0].1, Value::Num(11.0)));
1155 }
1156
1157 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1158 #[test]
1159 #[cfg(feature = "wgpu")]
1160 fn load_wgpu_tensor_roundtrip() {
1161 let _guard = workspace_guard();
1162 ensure_test_resolver();
1163 if runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1164 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1165 )
1166 .is_err()
1167 {
1168 return;
1169 }
1170 let Some(provider) = runmat_accelerate_api::provider() else {
1171 return;
1172 };
1173
1174 use runmat_accelerate_api::HostTensorView;
1175
1176 let tensor = Tensor::new(vec![0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap();
1177 let view = HostTensorView {
1178 data: &tensor.data,
1179 shape: &tensor.shape,
1180 };
1181 let handle = provider.upload(&view).expect("upload tensor");
1182 set_workspace(&[("gpu_var", Value::GpuTensor(handle))]);
1183
1184 let dir = tempdir().unwrap();
1185 let path = dir.path().join("wgpu_load.mat");
1186 let save_args = vec![
1187 Value::from(path.to_string_lossy().to_string()),
1188 Value::from("gpu_var"),
1189 ];
1190 block_on(crate::call_builtin_async("save", &save_args)).unwrap();
1191
1192 let eval = block_on(evaluate(&[Value::from(path.to_string_lossy().to_string())]))
1193 .expect("load wgpu file");
1194 let struct_value = eval.first_output();
1195 match struct_value {
1196 Value::Struct(sv) => match sv.fields.get("gpu_var") {
1197 Some(Value::Tensor(t)) => {
1198 assert_eq!(t.shape, vec![2, 2]);
1199 assert_eq!(t.data, tensor.data);
1200 }
1201 other => panic!("expected tensor, got {other:?}"),
1202 },
1203 other => panic!("expected struct, got {other:?}"),
1204 }
1205 }
1206}