Skip to main content

runmat_runtime/builtins/structs/core/
struct.rs

1//! MATLAB-compatible `struct` builtin.
2
3use crate::builtins::common::spec::{
4    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5    ReductionNaN, ResidencyPolicy, ShapeRequirements,
6};
7use crate::builtins::structs::type_resolvers::struct_type;
8use runmat_builtins::{CellArray, CharArray, StructValue, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::{build_runtime_error, BuiltinResult, RuntimeError};
12
13#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::structs::core::r#struct")]
14pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
15    name: "struct",
16    op_kind: GpuOpKind::Custom("struct"),
17    supported_precisions: &[],
18    broadcast: BroadcastSemantics::None,
19    provider_hooks: &[],
20    constant_strategy: ConstantStrategy::InlineLiteral,
21    residency: ResidencyPolicy::InheritInputs,
22    nan_mode: ReductionNaN::Include,
23    two_pass_threshold: None,
24    workgroup_size: None,
25    accepts_nan_mode: false,
26    notes: "Host-only construction; GPU values are preserved as handles without gathering.",
27};
28
29#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::structs::core::r#struct")]
30pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
31    name: "struct",
32    shape: ShapeRequirements::Any,
33    constant_strategy: ConstantStrategy::InlineLiteral,
34    elementwise: None,
35    reduction: None,
36    emits_nan: false,
37    notes: "Struct creation breaks fusion planning but retains GPU residency for field values.",
38};
39
40struct FieldEntry {
41    name: String,
42    value: FieldValue,
43}
44
45enum FieldValue {
46    Single(Value),
47    Cell(CellArray),
48}
49
50fn struct_flow(message: impl Into<String>) -> RuntimeError {
51    build_runtime_error(message).with_builtin("struct").build()
52}
53
54#[runtime_builtin(
55    name = "struct",
56    category = "structs/core",
57    summary = "Create scalar structs or struct arrays from name/value pairs.",
58    keywords = "struct,structure,name-value,record",
59    type_resolver(struct_type),
60    builtin_path = "crate::builtins::structs::core::r#struct"
61)]
62async fn struct_builtin(rest: Vec<Value>) -> BuiltinResult<Value> {
63    match rest.len() {
64        0 => Ok(Value::Struct(StructValue::new())),
65        1 => match rest.into_iter().next().unwrap() {
66            Value::Struct(existing) => Ok(Value::Struct(existing.clone())),
67            Value::Cell(cell) => clone_struct_array(&cell),
68            Value::Tensor(tensor) if tensor.data.is_empty() => empty_struct_array(),
69            Value::LogicalArray(logical) if logical.data.is_empty() => empty_struct_array(),
70            other => Err(struct_flow(format!(
71                "struct: expected name/value pairs, an existing struct or struct array, or [] to create an empty struct array (got {other:?})"
72            ))),
73        },
74        len if len % 2 == 0 => build_from_pairs(rest),
75        _ => Err(struct_flow("struct: expected name/value pairs")),
76    }
77}
78
79fn build_from_pairs(args: Vec<Value>) -> BuiltinResult<Value> {
80    let mut entries: Vec<FieldEntry> = Vec::new();
81    let mut target_shape: Option<Vec<usize>> = None;
82
83    let mut iter = args.into_iter();
84    while let (Some(name_value), Some(field_value)) = (iter.next(), iter.next()) {
85        let field_name = parse_field_name(&name_value)?;
86        match field_value {
87            Value::Cell(cell) => {
88                let shape = cell.shape.clone();
89                if let Some(existing) = &target_shape {
90                    if *existing != shape {
91                        return Err(struct_flow("struct: cell inputs must have matching sizes"));
92                    }
93                } else {
94                    target_shape = Some(shape);
95                }
96                entries.push(FieldEntry {
97                    name: field_name,
98                    value: FieldValue::Cell(cell),
99                });
100            }
101            other => entries.push(FieldEntry {
102                name: field_name,
103                value: FieldValue::Single(other),
104            }),
105        }
106    }
107
108    if let Some(shape) = target_shape {
109        build_struct_array(entries, shape)
110    } else {
111        build_scalar_struct(entries)
112    }
113}
114
115fn build_scalar_struct(entries: Vec<FieldEntry>) -> BuiltinResult<Value> {
116    let mut fields = StructValue::new();
117    for entry in entries {
118        match entry.value {
119            FieldValue::Single(value) => {
120                fields.fields.insert(entry.name, value);
121            }
122            FieldValue::Cell(cell) => {
123                let shape = cell.shape.clone();
124                return build_struct_array(
125                    vec![FieldEntry {
126                        name: entry.name,
127                        value: FieldValue::Cell(cell),
128                    }],
129                    shape,
130                );
131            }
132        }
133    }
134    Ok(Value::Struct(fields))
135}
136
137fn build_struct_array(entries: Vec<FieldEntry>, shape: Vec<usize>) -> BuiltinResult<Value> {
138    let total_len = shape
139        .iter()
140        .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
141        .ok_or_else(|| struct_flow("struct: struct array size exceeds platform limits"))?;
142
143    for entry in &entries {
144        if let FieldValue::Cell(cell) = &entry.value {
145            if cell.data.len() != total_len {
146                return Err(struct_flow("struct: cell inputs must have matching sizes"));
147            }
148        }
149    }
150
151    let mut structs: Vec<Value> = Vec::with_capacity(total_len);
152    for idx in 0..total_len {
153        let mut fields = StructValue::new();
154        for entry in &entries {
155            let value = match &entry.value {
156                FieldValue::Single(val) => val.clone(),
157                FieldValue::Cell(cell) => clone_cell_element(cell, idx)?,
158            };
159            fields.fields.insert(entry.name.clone(), value);
160        }
161        structs.push(Value::Struct(fields));
162    }
163
164    CellArray::new_with_shape(structs, shape)
165        .map(Value::Cell)
166        .map_err(|e| struct_flow(format!("struct: failed to assemble struct array: {e}")))
167}
168
169fn clone_cell_element(cell: &CellArray, index: usize) -> BuiltinResult<Value> {
170    cell.data
171        .get(index)
172        .map(|ptr| unsafe { &*ptr.as_raw() }.clone())
173        .ok_or_else(|| struct_flow("struct: cell inputs must have matching sizes"))
174}
175
176fn empty_struct_array() -> BuiltinResult<Value> {
177    CellArray::new(Vec::new(), 0, 0)
178        .map(Value::Cell)
179        .map_err(|e| struct_flow(format!("struct: failed to create empty struct array: {e}")))
180}
181
182fn clone_struct_array(array: &CellArray) -> BuiltinResult<Value> {
183    let mut values: Vec<Value> = Vec::with_capacity(array.data.len());
184    for (index, handle) in array.data.iter().enumerate() {
185        let value = unsafe { &*handle.as_raw() }.clone();
186        if !matches!(value, Value::Struct(_)) {
187            return Err(struct_flow(format!(
188                "struct: single argument cell input must contain structs (element {} is not a struct)",
189                index + 1
190            )));
191        }
192        values.push(value);
193    }
194    CellArray::new_with_shape(values, array.shape.clone())
195        .map(Value::Cell)
196        .map_err(|e| struct_flow(format!("struct: failed to copy struct array: {e}")))
197}
198
199fn parse_field_name(value: &Value) -> BuiltinResult<String> {
200    let text = match value {
201        Value::String(s) => s.clone(),
202        Value::StringArray(sa) => {
203            if sa.data.len() == 1 {
204                sa.data[0].clone()
205            } else {
206                return Err(struct_flow(
207                    "struct: field names must be scalar string arrays or character vectors",
208                ));
209            }
210        }
211        Value::CharArray(ca) => char_array_to_string(ca)?,
212        _ => {
213            return Err(struct_flow(
214                "struct: field names must be strings or character vectors",
215            ))
216        }
217    };
218
219    validate_field_name(&text)?;
220    Ok(text)
221}
222
223fn char_array_to_string(ca: &CharArray) -> BuiltinResult<String> {
224    if ca.rows > 1 {
225        return Err(struct_flow(
226            "struct: field names must be 1-by-N character vectors",
227        ));
228    }
229    let mut out = String::with_capacity(ca.data.len());
230    for ch in &ca.data {
231        out.push(*ch);
232    }
233    Ok(out)
234}
235
236fn validate_field_name(name: &str) -> BuiltinResult<()> {
237    if name.is_empty() {
238        return Err(struct_flow("struct: field names must be nonempty"));
239    }
240    let mut chars = name.chars();
241    let Some(first) = chars.next() else {
242        return Err(struct_flow("struct: field names must be nonempty"));
243    };
244    if !is_first_char_valid(first) {
245        return Err(struct_flow(format!(
246            "struct: field names must begin with a letter or underscore (got '{name}')"
247        )));
248    }
249    if let Some(bad) = chars.find(|c| !is_subsequent_char_valid(*c)) {
250        return Err(struct_flow(format!(
251            "struct: invalid character '{bad}' in field name '{name}'"
252        )));
253    }
254    Ok(())
255}
256
257fn is_first_char_valid(c: char) -> bool {
258    c == '_' || c.is_ascii_alphabetic()
259}
260
261fn is_subsequent_char_valid(c: char) -> bool {
262    c == '_' || c.is_ascii_alphanumeric()
263}
264
265#[cfg(test)]
266pub(crate) mod tests {
267    use super::*;
268    use runmat_accelerate_api::GpuTensorHandle;
269    use runmat_builtins::{CellArray, IntValue, StringArray, StructValue, Tensor};
270
271    #[cfg(feature = "wgpu")]
272    use runmat_accelerate_api::HostTensorView;
273
274    fn error_message(err: crate::RuntimeError) -> String {
275        err.message().to_string()
276    }
277
278    fn run_struct(args: Vec<Value>) -> BuiltinResult<Value> {
279        futures::executor::block_on(struct_builtin(args))
280    }
281
282    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
283    #[test]
284    fn struct_empty() {
285        let Value::Struct(s) = run_struct(Vec::new()).expect("struct") else {
286            panic!("expected struct value");
287        };
288        assert!(s.fields.is_empty());
289    }
290
291    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
292    #[test]
293    fn struct_empty_from_empty_matrix() {
294        let tensor = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
295        let value = run_struct(vec![Value::Tensor(tensor)]).expect("struct([])");
296        match value {
297            Value::Cell(cell) => {
298                assert_eq!(cell.rows, 0);
299                assert_eq!(cell.cols, 0);
300                assert!(cell.data.is_empty());
301            }
302            other => panic!("expected empty struct array, got {other:?}"),
303        }
304    }
305
306    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
307    #[test]
308    fn struct_name_value_pairs() {
309        let args = vec![
310            Value::from("name"),
311            Value::from("Ada"),
312            Value::from("score"),
313            Value::Int(IntValue::I32(42)),
314        ];
315        let Value::Struct(s) = run_struct(args).expect("struct") else {
316            panic!("expected struct value");
317        };
318        assert_eq!(s.fields.len(), 2);
319        assert!(matches!(s.fields.get("name"), Some(Value::String(v)) if v == "Ada"));
320        assert!(matches!(
321            s.fields.get("score"),
322            Some(Value::Int(IntValue::I32(42)))
323        ));
324    }
325
326    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
327    #[test]
328    fn struct_struct_array_from_cells() {
329        let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
330        let ages = CellArray::new(
331            vec![Value::Int(IntValue::I32(36)), Value::Int(IntValue::I32(45))],
332            1,
333            2,
334        )
335        .unwrap();
336        let result = run_struct(vec![
337            Value::from("name"),
338            Value::Cell(names),
339            Value::from("age"),
340            Value::Cell(ages),
341        ])
342        .expect("struct array");
343        let structs = expect_struct_array(result);
344        assert_eq!(structs.len(), 2);
345        assert!(matches!(
346            structs[0].fields.get("name"),
347            Some(Value::String(v)) if v == "Ada"
348        ));
349        assert!(matches!(
350            structs[1].fields.get("age"),
351            Some(Value::Int(IntValue::I32(45)))
352        ));
353    }
354
355    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
356    #[test]
357    fn struct_struct_array_replicates_scalars() {
358        let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
359        let result = run_struct(vec![
360            Value::from("name"),
361            Value::Cell(names),
362            Value::from("department"),
363            Value::from("Research"),
364        ])
365        .expect("struct array");
366        let structs = expect_struct_array(result);
367        assert_eq!(structs.len(), 2);
368        for entry in structs {
369            assert!(matches!(
370                entry.fields.get("department"),
371                Some(Value::String(v)) if v == "Research"
372            ));
373        }
374    }
375
376    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
377    #[test]
378    fn struct_struct_array_cell_size_mismatch_errors() {
379        let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
380        let scores = CellArray::new(vec![Value::Int(IntValue::I32(1))], 1, 1).unwrap();
381        let err = error_message(
382            run_struct(vec![
383                Value::from("name"),
384                Value::Cell(names),
385                Value::from("score"),
386                Value::Cell(scores),
387            ])
388            .unwrap_err(),
389        );
390        assert!(err.contains("matching sizes"));
391    }
392
393    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
394    #[test]
395    fn struct_overwrites_duplicates() {
396        let args = vec![
397            Value::from("version"),
398            Value::Int(IntValue::I32(1)),
399            Value::from("version"),
400            Value::Int(IntValue::I32(2)),
401        ];
402        let Value::Struct(s) = run_struct(args).expect("struct") else {
403            panic!("expected struct value");
404        };
405        assert_eq!(s.fields.len(), 1);
406        assert!(matches!(
407            s.fields.get("version"),
408            Some(Value::Int(IntValue::I32(2)))
409        ));
410    }
411
412    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
413    #[test]
414    fn struct_rejects_odd_arguments() {
415        let err = error_message(run_struct(vec![Value::from("name")]).unwrap_err());
416        assert!(err.contains("name/value pairs"));
417    }
418
419    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
420    #[test]
421    fn struct_rejects_invalid_field_name() {
422        let err = error_message(
423            run_struct(vec![Value::from("1bad"), Value::Int(IntValue::I32(1))]).unwrap_err(),
424        );
425        assert!(err.contains("begin with a letter or underscore"));
426    }
427
428    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
429    #[test]
430    fn struct_rejects_non_text_field_name() {
431        let err = error_message(
432            run_struct(vec![Value::Num(1.0), Value::Int(IntValue::I32(1))]).unwrap_err(),
433        );
434        assert!(err.contains("strings or character vectors"));
435    }
436
437    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
438    #[test]
439    fn struct_accepts_char_vector_name() {
440        let chars = CharArray::new("field".chars().collect(), 1, 5).unwrap();
441        let args = vec![Value::CharArray(chars), Value::Num(1.0)];
442        let Value::Struct(s) = run_struct(args).expect("struct") else {
443            panic!("expected struct value");
444        };
445        assert!(s.fields.contains_key("field"));
446    }
447
448    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
449    #[test]
450    fn struct_accepts_string_scalar_name() {
451        let sa = StringArray::new(vec!["field".to_string()], vec![1]).unwrap();
452        let args = vec![Value::StringArray(sa), Value::Num(1.0)];
453        let Value::Struct(s) = run_struct(args).expect("struct") else {
454            panic!("expected struct value");
455        };
456        assert!(s.fields.contains_key("field"));
457    }
458
459    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
460    #[test]
461    fn struct_allows_existing_struct_copy() {
462        let mut base = StructValue::new();
463        base.fields
464            .insert("id".to_string(), Value::Int(IntValue::I32(7)));
465        let copy = run_struct(vec![Value::Struct(base.clone())]).expect("struct");
466        assert_eq!(copy, Value::Struct(base));
467    }
468
469    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
470    #[test]
471    fn struct_copies_struct_array_argument() {
472        let mut proto = StructValue::new();
473        proto
474            .fields
475            .insert("id".into(), Value::Int(IntValue::I32(7)));
476        let struct_array = CellArray::new(
477            vec![
478                Value::Struct(proto.clone()),
479                Value::Struct(proto.clone()),
480                Value::Struct(proto.clone()),
481            ],
482            1,
483            3,
484        )
485        .unwrap();
486        let original = struct_array.clone();
487        let result = run_struct(vec![Value::Cell(struct_array)]).expect("struct array clone");
488        let cloned = expect_struct_array(result);
489        let baseline = expect_struct_array(Value::Cell(original));
490        assert_eq!(cloned, baseline);
491    }
492
493    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
494    #[test]
495    fn struct_rejects_cell_argument_without_structs() {
496        let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
497        let err = error_message(run_struct(vec![Value::Cell(cell)]).unwrap_err());
498        assert!(err.contains("must contain structs"));
499    }
500
501    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
502    #[test]
503    fn struct_preserves_gpu_tensor_handles() {
504        let handle = GpuTensorHandle {
505            shape: vec![2, 2],
506            device_id: 1,
507            buffer_id: 99,
508        };
509        let args = vec![Value::from("data"), Value::GpuTensor(handle.clone())];
510        let Value::Struct(s) = run_struct(args).expect("struct") else {
511            panic!("expected struct value");
512        };
513        assert!(matches!(s.fields.get("data"), Some(Value::GpuTensor(h)) if h == &handle));
514    }
515
516    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
517    #[test]
518    fn struct_struct_array_preserves_gpu_handles() {
519        let first = GpuTensorHandle {
520            shape: vec![1, 1],
521            device_id: 2,
522            buffer_id: 11,
523        };
524        let second = GpuTensorHandle {
525            shape: vec![1, 1],
526            device_id: 2,
527            buffer_id: 12,
528        };
529        let cell = CellArray::new(
530            vec![
531                Value::GpuTensor(first.clone()),
532                Value::GpuTensor(second.clone()),
533            ],
534            1,
535            2,
536        )
537        .unwrap();
538        let result = run_struct(vec![Value::from("payload"), Value::Cell(cell)])
539            .expect("struct array gpu handles");
540        let structs = expect_struct_array(result);
541        assert!(matches!(
542            structs[0].fields.get("payload"),
543            Some(Value::GpuTensor(h)) if h == &first
544        ));
545        assert!(matches!(
546            structs[1].fields.get("payload"),
547            Some(Value::GpuTensor(h)) if h == &second
548        ));
549    }
550
551    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
552    #[test]
553    #[cfg(feature = "wgpu")]
554    fn struct_preserves_gpu_handles_with_registered_provider() {
555        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
556            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
557        );
558        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
559        let host = HostTensorView {
560            data: &[1.0, 2.0],
561            shape: &[2, 1],
562        };
563        let handle = provider.upload(&host).expect("upload");
564        let args = vec![Value::from("gpu"), Value::GpuTensor(handle.clone())];
565        let Value::Struct(s) = run_struct(args).expect("struct") else {
566            panic!("expected struct value");
567        };
568        assert!(matches!(s.fields.get("gpu"), Some(Value::GpuTensor(h)) if h == &handle));
569    }
570
571    fn expect_struct_array(value: Value) -> Vec<StructValue> {
572        match value {
573            Value::Cell(cell) => cell
574                .data
575                .iter()
576                .map(|ptr| unsafe { &*ptr.as_raw() }.clone())
577                .map(|value| match value {
578                    Value::Struct(st) => st,
579                    other => panic!("expected struct element, got {other:?}"),
580                })
581                .collect(),
582            Value::Struct(st) => vec![st],
583            other => panic!("expected struct or struct array, got {other:?}"),
584        }
585    }
586}