Skip to main content

runmat_runtime/builtins/structs/core/
struct.rs

1//! MATLAB-compatible `struct` builtin.
2
3use crate::builtins::common::spec::{
4    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5    ReductionNaN, ResidencyPolicy, ShapeRequirements,
6};
7use crate::builtins::structs::type_resolvers::struct_type;
8use runmat_builtins::{
9    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
10    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
11    CellArray, CharArray, StructValue, Value,
12};
13use runmat_macros::runtime_builtin;
14
15use crate::{build_runtime_error, BuiltinResult, RuntimeError};
16
17#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::structs::core::r#struct")]
18pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
19    name: "struct",
20    op_kind: GpuOpKind::Custom("struct"),
21    supported_precisions: &[],
22    broadcast: BroadcastSemantics::None,
23    provider_hooks: &[],
24    constant_strategy: ConstantStrategy::InlineLiteral,
25    residency: ResidencyPolicy::InheritInputs,
26    nan_mode: ReductionNaN::Include,
27    two_pass_threshold: None,
28    workgroup_size: None,
29    accepts_nan_mode: false,
30    notes: "Host-only construction; GPU values are preserved as handles without gathering.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::structs::core::r#struct")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35    name: "struct",
36    shape: ShapeRequirements::Any,
37    constant_strategy: ConstantStrategy::InlineLiteral,
38    elementwise: None,
39    reduction: None,
40    emits_nan: false,
41    notes: "Struct creation breaks fusion planning but retains GPU residency for field values.",
42};
43
44struct FieldEntry {
45    name: String,
46    value: FieldValue,
47}
48
49enum FieldValue {
50    Single(Value),
51    Cell(CellArray),
52}
53
54const BUILTIN_NAME: &str = "struct";
55
56const STRUCT_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
57    name: "S",
58    ty: BuiltinParamType::Any,
59    arity: BuiltinParamArity::Required,
60    default: None,
61    description: "Scalar struct or struct array.",
62}];
63
64const STRUCT_INPUTS_EMPTY: [BuiltinParamDescriptor; 0] = [];
65const STRUCT_INPUTS_TEMPLATE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
66    name: "template",
67    ty: BuiltinParamType::Any,
68    arity: BuiltinParamArity::Required,
69    default: None,
70    description: "Existing struct/struct-array template or empty array for struct([]).",
71}];
72const STRUCT_INPUTS_PAIRS: [BuiltinParamDescriptor; 3] = [
73    BuiltinParamDescriptor {
74        name: "field",
75        ty: BuiltinParamType::PropertyName,
76        arity: BuiltinParamArity::Required,
77        default: None,
78        description: "Field name.",
79    },
80    BuiltinParamDescriptor {
81        name: "value",
82        ty: BuiltinParamType::Any,
83        arity: BuiltinParamArity::Required,
84        default: None,
85        description: "Field value or cell array of field values.",
86    },
87    BuiltinParamDescriptor {
88        name: "name_value_pairs",
89        ty: BuiltinParamType::Any,
90        arity: BuiltinParamArity::Variadic,
91        default: None,
92        description: "Additional field/value pairs.",
93    },
94];
95
96const STRUCT_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
97    BuiltinSignatureDescriptor {
98        label: "S = struct()",
99        inputs: &STRUCT_INPUTS_EMPTY,
100        outputs: &STRUCT_OUTPUT,
101    },
102    BuiltinSignatureDescriptor {
103        label: "S = struct(template)",
104        inputs: &STRUCT_INPUTS_TEMPLATE,
105        outputs: &STRUCT_OUTPUT,
106    },
107    BuiltinSignatureDescriptor {
108        label: "S = struct(field, value, ...)",
109        inputs: &STRUCT_INPUTS_PAIRS,
110        outputs: &STRUCT_OUTPUT,
111    },
112];
113
114const STRUCT_ERROR_INVALID_SINGLE_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
115    code: "RM.STRUCT.INVALID_SINGLE_INPUT",
116    identifier: Some("RunMat:struct:InvalidSingleInput"),
117    when: "Single input is neither struct, struct-array cell, nor empty numeric/logical array.",
118    message:
119        "struct: expected name/value pairs, an existing struct or struct array, or [] to create an empty struct array",
120};
121
122const STRUCT_ERROR_NAME_VALUE_PAIRS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
123    code: "RM.STRUCT.NAME_VALUE_PAIRS",
124    identifier: Some("RunMat:struct:NameValuePairs"),
125    when: "Name/value arguments are not supplied in complete pairs.",
126    message: "struct: expected name/value pairs",
127};
128
129const STRUCT_ERROR_CELL_SIZE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
130    code: "RM.STRUCT.CELL_SIZE_MISMATCH",
131    identifier: Some("RunMat:struct:CellSizeMismatch"),
132    when: "Cell value inputs for struct-array construction do not share the same shape.",
133    message: "struct: cell inputs must have matching sizes",
134};
135
136const STRUCT_ERROR_SIZE_OVERFLOW: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
137    code: "RM.STRUCT.SIZE_OVERFLOW",
138    identifier: Some("RunMat:struct:SizeOverflow"),
139    when: "Requested struct-array size exceeds platform limits.",
140    message: "struct: struct array size exceeds platform limits",
141};
142
143const STRUCT_ERROR_ASSEMBLE_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
144    code: "RM.STRUCT.ASSEMBLE_FAILED",
145    identifier: Some("RunMat:struct:AssembleFailed"),
146    when: "Internal struct-array assembly failed.",
147    message: "struct: failed to assemble struct array",
148};
149
150const STRUCT_ERROR_EMPTY_ARRAY_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
151    code: "RM.STRUCT.EMPTY_ARRAY_FAILED",
152    identifier: Some("RunMat:struct:EmptyArrayFailed"),
153    when: "Internal empty struct-array creation failed.",
154    message: "struct: failed to create empty struct array",
155};
156
157const STRUCT_ERROR_STRUCT_ARRAY_CONTENTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
158    code: "RM.STRUCT.STRUCT_ARRAY_CONTENTS",
159    identifier: Some("RunMat:struct:StructArrayContents"),
160    when: "Single-argument struct-array cell input contains non-struct values.",
161    message: "struct: single argument cell input must contain structs",
162};
163
164const STRUCT_ERROR_STRUCT_ARRAY_COPY_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
165    code: "RM.STRUCT.STRUCT_ARRAY_COPY_FAILED",
166    identifier: Some("RunMat:struct:StructArrayCopyFailed"),
167    when: "Copying a single-argument struct-array cell input failed.",
168    message: "struct: failed to copy struct array",
169};
170
171const STRUCT_ERROR_FIELD_NAME_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
172    code: "RM.STRUCT.FIELD_NAME_TYPE",
173    identifier: Some("RunMat:struct:FieldNameType"),
174    when: "Field name is not a string scalar or 1xN character vector.",
175    message: "struct: field names must be strings or character vectors",
176};
177
178const STRUCT_ERROR_FIELD_NAME_SCALAR: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
179    code: "RM.STRUCT.FIELD_NAME_SCALAR",
180    identifier: Some("RunMat:struct:FieldNameScalar"),
181    when: "Field name char/string-array input is not scalar.",
182    message: "struct: field names must be scalar string arrays or character vectors",
183};
184
185const STRUCT_ERROR_FIELD_NAME_CHAR_VECTOR: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
186    code: "RM.STRUCT.FIELD_NAME_CHAR_VECTOR",
187    identifier: Some("RunMat:struct:FieldNameCharVector"),
188    when: "Character-array field name input is not a 1-by-N character vector.",
189    message: "struct: field names must be 1-by-N character vectors",
190};
191
192const STRUCT_ERROR_FIELD_NAME_EMPTY: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
193    code: "RM.STRUCT.FIELD_NAME_EMPTY",
194    identifier: Some("RunMat:struct:FieldNameEmpty"),
195    when: "Field name is empty.",
196    message: "struct: field names must be nonempty",
197};
198
199const STRUCT_ERROR_FIELD_NAME_START_CHAR: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
200    code: "RM.STRUCT.FIELD_NAME_START_CHAR",
201    identifier: Some("RunMat:struct:FieldNameStartChar"),
202    when: "Field name does not start with a letter or underscore.",
203    message: "struct: field names must begin with a letter or underscore",
204};
205
206const STRUCT_ERROR_FIELD_NAME_INVALID_CHAR: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
207    code: "RM.STRUCT.FIELD_NAME_INVALID_CHAR",
208    identifier: Some("RunMat:struct:FieldNameInvalidChar"),
209    when: "Field name includes unsupported characters.",
210    message: "struct: invalid character in field name",
211};
212
213const STRUCT_ERRORS: [BuiltinErrorDescriptor; 14] = [
214    STRUCT_ERROR_INVALID_SINGLE_INPUT,
215    STRUCT_ERROR_NAME_VALUE_PAIRS,
216    STRUCT_ERROR_CELL_SIZE_MISMATCH,
217    STRUCT_ERROR_SIZE_OVERFLOW,
218    STRUCT_ERROR_ASSEMBLE_FAILED,
219    STRUCT_ERROR_EMPTY_ARRAY_FAILED,
220    STRUCT_ERROR_STRUCT_ARRAY_CONTENTS,
221    STRUCT_ERROR_STRUCT_ARRAY_COPY_FAILED,
222    STRUCT_ERROR_FIELD_NAME_TYPE,
223    STRUCT_ERROR_FIELD_NAME_SCALAR,
224    STRUCT_ERROR_FIELD_NAME_CHAR_VECTOR,
225    STRUCT_ERROR_FIELD_NAME_EMPTY,
226    STRUCT_ERROR_FIELD_NAME_START_CHAR,
227    STRUCT_ERROR_FIELD_NAME_INVALID_CHAR,
228];
229
230pub const STRUCT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
231    signatures: &STRUCT_SIGNATURES,
232    output_mode: BuiltinOutputMode::Fixed,
233    completion_policy: BuiltinCompletionPolicy::Public,
234    errors: &STRUCT_ERRORS,
235};
236
237fn struct_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
238    struct_error_with_message(error.message, error)
239}
240
241fn struct_error_with_message(
242    message: impl Into<String>,
243    error: &'static BuiltinErrorDescriptor,
244) -> RuntimeError {
245    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
246    if let Some(identifier) = error.identifier {
247        builder = builder.with_identifier(identifier);
248    }
249    builder.build()
250}
251
252#[runtime_builtin(
253    name = "struct",
254    category = "structs/core",
255    summary = "Create scalar structs or struct arrays from field/value inputs.",
256    keywords = "struct,structure,name-value,record",
257    type_resolver(struct_type),
258    descriptor(crate::builtins::structs::core::r#struct::STRUCT_DESCRIPTOR),
259    builtin_path = "crate::builtins::structs::core::r#struct"
260)]
261async fn struct_builtin(rest: Vec<Value>) -> BuiltinResult<Value> {
262    match rest.len() {
263        0 => Ok(Value::Struct(StructValue::new())),
264        1 => match rest.into_iter().next().unwrap() {
265            Value::Struct(existing) => Ok(Value::Struct(existing.clone())),
266            Value::Cell(cell) => clone_struct_array(&cell),
267            Value::Tensor(tensor) if tensor.data.is_empty() => empty_struct_array(),
268            Value::LogicalArray(logical) if logical.data.is_empty() => empty_struct_array(),
269            other => Err(struct_error_with_message(
270                format!(
271                    "{} (got {other:?})",
272                    STRUCT_ERROR_INVALID_SINGLE_INPUT.message
273                ),
274                &STRUCT_ERROR_INVALID_SINGLE_INPUT,
275            )),
276        },
277        len if len % 2 == 0 => build_from_pairs(rest),
278        _ => Err(struct_error(&STRUCT_ERROR_NAME_VALUE_PAIRS)),
279    }
280}
281
282fn build_from_pairs(args: Vec<Value>) -> BuiltinResult<Value> {
283    let mut entries: Vec<FieldEntry> = Vec::new();
284    let mut target_shape: Option<Vec<usize>> = None;
285
286    let mut iter = args.into_iter();
287    while let (Some(name_value), Some(field_value)) = (iter.next(), iter.next()) {
288        let field_name = parse_field_name(&name_value)?;
289        match field_value {
290            Value::Cell(cell) => {
291                let shape = cell.shape.clone();
292                if let Some(existing) = &target_shape {
293                    if *existing != shape {
294                        return Err(struct_error(&STRUCT_ERROR_CELL_SIZE_MISMATCH));
295                    }
296                } else {
297                    target_shape = Some(shape);
298                }
299                entries.push(FieldEntry {
300                    name: field_name,
301                    value: FieldValue::Cell(cell),
302                });
303            }
304            other => entries.push(FieldEntry {
305                name: field_name,
306                value: FieldValue::Single(other),
307            }),
308        }
309    }
310
311    if let Some(shape) = target_shape {
312        build_struct_array(entries, shape)
313    } else {
314        build_scalar_struct(entries)
315    }
316}
317
318fn build_scalar_struct(entries: Vec<FieldEntry>) -> BuiltinResult<Value> {
319    let mut fields = StructValue::new();
320    for entry in entries {
321        match entry.value {
322            FieldValue::Single(value) => {
323                fields.fields.insert(entry.name, value);
324            }
325            FieldValue::Cell(cell) => {
326                let shape = cell.shape.clone();
327                return build_struct_array(
328                    vec![FieldEntry {
329                        name: entry.name,
330                        value: FieldValue::Cell(cell),
331                    }],
332                    shape,
333                );
334            }
335        }
336    }
337    Ok(Value::Struct(fields))
338}
339
340fn build_struct_array(entries: Vec<FieldEntry>, shape: Vec<usize>) -> BuiltinResult<Value> {
341    let total_len = shape
342        .iter()
343        .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
344        .ok_or_else(|| struct_error(&STRUCT_ERROR_SIZE_OVERFLOW))?;
345
346    for entry in &entries {
347        if let FieldValue::Cell(cell) = &entry.value {
348            if cell.data.len() != total_len {
349                return Err(struct_error(&STRUCT_ERROR_CELL_SIZE_MISMATCH));
350            }
351        }
352    }
353
354    let mut structs: Vec<Value> = Vec::with_capacity(total_len);
355    for idx in 0..total_len {
356        let mut fields = StructValue::new();
357        for entry in &entries {
358            let value = match &entry.value {
359                FieldValue::Single(val) => val.clone(),
360                FieldValue::Cell(cell) => clone_cell_element(cell, idx)?,
361            };
362            fields.fields.insert(entry.name.clone(), value);
363        }
364        structs.push(Value::Struct(fields));
365    }
366
367    CellArray::new_with_shape(structs, shape)
368        .map(Value::Cell)
369        .map_err(|e| {
370            struct_error_with_message(
371                format!("{}: {e}", STRUCT_ERROR_ASSEMBLE_FAILED.message),
372                &STRUCT_ERROR_ASSEMBLE_FAILED,
373            )
374        })
375}
376
377fn clone_cell_element(cell: &CellArray, index: usize) -> BuiltinResult<Value> {
378    cell.data
379        .get(index)
380        .map(|ptr| unsafe { &*ptr.as_raw() }.clone())
381        .ok_or_else(|| struct_error(&STRUCT_ERROR_CELL_SIZE_MISMATCH))
382}
383
384fn empty_struct_array() -> BuiltinResult<Value> {
385    CellArray::new(Vec::new(), 0, 0)
386        .map(Value::Cell)
387        .map_err(|e| {
388            struct_error_with_message(
389                format!("{}: {e}", STRUCT_ERROR_EMPTY_ARRAY_FAILED.message),
390                &STRUCT_ERROR_EMPTY_ARRAY_FAILED,
391            )
392        })
393}
394
395fn clone_struct_array(array: &CellArray) -> BuiltinResult<Value> {
396    let mut values: Vec<Value> = Vec::with_capacity(array.data.len());
397    for (index, handle) in array.data.iter().enumerate() {
398        let value = unsafe { &*handle.as_raw() }.clone();
399        if !matches!(value, Value::Struct(_)) {
400            return Err(struct_error_with_message(
401                format!(
402                    "{} (element {} is not a struct)",
403                    STRUCT_ERROR_STRUCT_ARRAY_CONTENTS.message,
404                    index + 1
405                ),
406                &STRUCT_ERROR_STRUCT_ARRAY_CONTENTS,
407            ));
408        }
409        values.push(value);
410    }
411    CellArray::new_with_shape(values, array.shape.clone())
412        .map(Value::Cell)
413        .map_err(|e| {
414            struct_error_with_message(
415                format!("{}: {e}", STRUCT_ERROR_STRUCT_ARRAY_COPY_FAILED.message),
416                &STRUCT_ERROR_STRUCT_ARRAY_COPY_FAILED,
417            )
418        })
419}
420
421fn parse_field_name(value: &Value) -> BuiltinResult<String> {
422    let text = match value {
423        Value::String(s) => s.clone(),
424        Value::StringArray(sa) => {
425            if sa.data.len() == 1 {
426                sa.data[0].clone()
427            } else {
428                return Err(struct_error(&STRUCT_ERROR_FIELD_NAME_SCALAR));
429            }
430        }
431        Value::CharArray(ca) => char_array_to_string(ca)?,
432        _ => return Err(struct_error(&STRUCT_ERROR_FIELD_NAME_TYPE)),
433    };
434
435    validate_field_name(&text)?;
436    Ok(text)
437}
438
439fn char_array_to_string(ca: &CharArray) -> BuiltinResult<String> {
440    if ca.rows > 1 {
441        return Err(struct_error(&STRUCT_ERROR_FIELD_NAME_CHAR_VECTOR));
442    }
443    let mut out = String::with_capacity(ca.data.len());
444    for ch in &ca.data {
445        out.push(*ch);
446    }
447    Ok(out)
448}
449
450fn validate_field_name(name: &str) -> BuiltinResult<()> {
451    if name.is_empty() {
452        return Err(struct_error(&STRUCT_ERROR_FIELD_NAME_EMPTY));
453    }
454    let mut chars = name.chars();
455    let Some(first) = chars.next() else {
456        return Err(struct_error(&STRUCT_ERROR_FIELD_NAME_EMPTY));
457    };
458    if !is_first_char_valid(first) {
459        return Err(struct_error_with_message(
460            format!(
461                "{} (got '{name}')",
462                STRUCT_ERROR_FIELD_NAME_START_CHAR.message
463            ),
464            &STRUCT_ERROR_FIELD_NAME_START_CHAR,
465        ));
466    }
467    if let Some(bad) = chars.find(|c| !is_subsequent_char_valid(*c)) {
468        return Err(struct_error_with_message(
469            format!(
470                "{} ('{bad}' in '{name}')",
471                STRUCT_ERROR_FIELD_NAME_INVALID_CHAR.message
472            ),
473            &STRUCT_ERROR_FIELD_NAME_INVALID_CHAR,
474        ));
475    }
476    Ok(())
477}
478
479fn is_first_char_valid(c: char) -> bool {
480    c == '_' || c.is_ascii_alphabetic()
481}
482
483fn is_subsequent_char_valid(c: char) -> bool {
484    c == '_' || c.is_ascii_alphanumeric()
485}
486
487#[cfg(test)]
488pub(crate) mod tests {
489    use super::*;
490    use runmat_accelerate_api::GpuTensorHandle;
491    use runmat_builtins::{CellArray, IntValue, StringArray, StructValue, Tensor};
492
493    #[cfg(feature = "wgpu")]
494    use runmat_accelerate_api::HostTensorView;
495
496    fn error_message(err: crate::RuntimeError) -> String {
497        err.message().to_string()
498    }
499
500    fn run_struct(args: Vec<Value>) -> BuiltinResult<Value> {
501        futures::executor::block_on(struct_builtin(args))
502    }
503
504    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
505    #[test]
506    fn struct_empty() {
507        let Value::Struct(s) = run_struct(Vec::new()).expect("struct") else {
508            panic!("expected struct value");
509        };
510        assert!(s.fields.is_empty());
511    }
512
513    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
514    #[test]
515    fn struct_empty_from_empty_matrix() {
516        let tensor = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
517        let value = run_struct(vec![Value::Tensor(tensor)]).expect("struct([])");
518        match value {
519            Value::Cell(cell) => {
520                assert_eq!(cell.rows, 0);
521                assert_eq!(cell.cols, 0);
522                assert!(cell.data.is_empty());
523            }
524            other => panic!("expected empty struct array, got {other:?}"),
525        }
526    }
527
528    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
529    #[test]
530    fn struct_name_value_pairs() {
531        let args = vec![
532            Value::from("name"),
533            Value::from("Ada"),
534            Value::from("score"),
535            Value::Int(IntValue::I32(42)),
536        ];
537        let Value::Struct(s) = run_struct(args).expect("struct") else {
538            panic!("expected struct value");
539        };
540        assert_eq!(s.fields.len(), 2);
541        assert!(matches!(s.fields.get("name"), Some(Value::String(v)) if v == "Ada"));
542        assert!(matches!(
543            s.fields.get("score"),
544            Some(Value::Int(IntValue::I32(42)))
545        ));
546    }
547
548    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
549    #[test]
550    fn struct_struct_array_from_cells() {
551        let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
552        let ages = CellArray::new(
553            vec![Value::Int(IntValue::I32(36)), Value::Int(IntValue::I32(45))],
554            1,
555            2,
556        )
557        .unwrap();
558        let result = run_struct(vec![
559            Value::from("name"),
560            Value::Cell(names),
561            Value::from("age"),
562            Value::Cell(ages),
563        ])
564        .expect("struct array");
565        let structs = expect_struct_array(result);
566        assert_eq!(structs.len(), 2);
567        assert!(matches!(
568            structs[0].fields.get("name"),
569            Some(Value::String(v)) if v == "Ada"
570        ));
571        assert!(matches!(
572            structs[1].fields.get("age"),
573            Some(Value::Int(IntValue::I32(45)))
574        ));
575    }
576
577    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
578    #[test]
579    fn struct_struct_array_replicates_scalars() {
580        let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
581        let result = run_struct(vec![
582            Value::from("name"),
583            Value::Cell(names),
584            Value::from("department"),
585            Value::from("Research"),
586        ])
587        .expect("struct array");
588        let structs = expect_struct_array(result);
589        assert_eq!(structs.len(), 2);
590        for entry in structs {
591            assert!(matches!(
592                entry.fields.get("department"),
593                Some(Value::String(v)) if v == "Research"
594            ));
595        }
596    }
597
598    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
599    #[test]
600    fn struct_struct_array_cell_size_mismatch_errors() {
601        let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
602        let scores = CellArray::new(vec![Value::Int(IntValue::I32(1))], 1, 1).unwrap();
603        let err = error_message(
604            run_struct(vec![
605                Value::from("name"),
606                Value::Cell(names),
607                Value::from("score"),
608                Value::Cell(scores),
609            ])
610            .unwrap_err(),
611        );
612        assert!(err.contains("matching sizes"));
613    }
614
615    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
616    #[test]
617    fn struct_overwrites_duplicates() {
618        let args = vec![
619            Value::from("version"),
620            Value::Int(IntValue::I32(1)),
621            Value::from("version"),
622            Value::Int(IntValue::I32(2)),
623        ];
624        let Value::Struct(s) = run_struct(args).expect("struct") else {
625            panic!("expected struct value");
626        };
627        assert_eq!(s.fields.len(), 1);
628        assert!(matches!(
629            s.fields.get("version"),
630            Some(Value::Int(IntValue::I32(2)))
631        ));
632    }
633
634    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
635    #[test]
636    fn struct_rejects_odd_arguments() {
637        let err = error_message(run_struct(vec![Value::from("name")]).unwrap_err());
638        assert!(err.contains("name/value pairs"));
639    }
640
641    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
642    #[test]
643    fn struct_rejects_invalid_field_name() {
644        let err = error_message(
645            run_struct(vec![Value::from("1bad"), Value::Int(IntValue::I32(1))]).unwrap_err(),
646        );
647        assert!(err.contains("begin with a letter or underscore"));
648    }
649
650    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
651    #[test]
652    fn struct_rejects_non_text_field_name() {
653        let err = error_message(
654            run_struct(vec![Value::Num(1.0), Value::Int(IntValue::I32(1))]).unwrap_err(),
655        );
656        assert!(err.contains("strings or character vectors"));
657    }
658
659    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
660    #[test]
661    fn struct_accepts_char_vector_name() {
662        let chars = CharArray::new("field".chars().collect(), 1, 5).unwrap();
663        let args = vec![Value::CharArray(chars), Value::Num(1.0)];
664        let Value::Struct(s) = run_struct(args).expect("struct") else {
665            panic!("expected struct value");
666        };
667        assert!(s.fields.contains_key("field"));
668    }
669
670    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
671    #[test]
672    fn struct_accepts_string_scalar_name() {
673        let sa = StringArray::new(vec!["field".to_string()], vec![1]).unwrap();
674        let args = vec![Value::StringArray(sa), Value::Num(1.0)];
675        let Value::Struct(s) = run_struct(args).expect("struct") else {
676            panic!("expected struct value");
677        };
678        assert!(s.fields.contains_key("field"));
679    }
680
681    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
682    #[test]
683    fn struct_allows_existing_struct_copy() {
684        let mut base = StructValue::new();
685        base.fields
686            .insert("id".to_string(), Value::Int(IntValue::I32(7)));
687        let copy = run_struct(vec![Value::Struct(base.clone())]).expect("struct");
688        assert_eq!(copy, Value::Struct(base));
689    }
690
691    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
692    #[test]
693    fn struct_copies_struct_array_argument() {
694        let mut proto = StructValue::new();
695        proto
696            .fields
697            .insert("id".into(), Value::Int(IntValue::I32(7)));
698        let struct_array = CellArray::new(
699            vec![
700                Value::Struct(proto.clone()),
701                Value::Struct(proto.clone()),
702                Value::Struct(proto.clone()),
703            ],
704            1,
705            3,
706        )
707        .unwrap();
708        let original = struct_array.clone();
709        let result = run_struct(vec![Value::Cell(struct_array)]).expect("struct array clone");
710        let cloned = expect_struct_array(result);
711        let baseline = expect_struct_array(Value::Cell(original));
712        assert_eq!(cloned, baseline);
713    }
714
715    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
716    #[test]
717    fn struct_rejects_cell_argument_without_structs() {
718        let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
719        let err = error_message(run_struct(vec![Value::Cell(cell)]).unwrap_err());
720        assert!(err.contains("must contain structs"));
721    }
722
723    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
724    #[test]
725    fn struct_preserves_gpu_tensor_handles() {
726        let handle = GpuTensorHandle {
727            shape: vec![2, 2],
728            device_id: 1,
729            buffer_id: 99,
730        };
731        let args = vec![Value::from("data"), Value::GpuTensor(handle.clone())];
732        let Value::Struct(s) = run_struct(args).expect("struct") else {
733            panic!("expected struct value");
734        };
735        assert!(matches!(s.fields.get("data"), Some(Value::GpuTensor(h)) if h == &handle));
736    }
737
738    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
739    #[test]
740    fn struct_struct_array_preserves_gpu_handles() {
741        let first = GpuTensorHandle {
742            shape: vec![1, 1],
743            device_id: 2,
744            buffer_id: 11,
745        };
746        let second = GpuTensorHandle {
747            shape: vec![1, 1],
748            device_id: 2,
749            buffer_id: 12,
750        };
751        let cell = CellArray::new(
752            vec![
753                Value::GpuTensor(first.clone()),
754                Value::GpuTensor(second.clone()),
755            ],
756            1,
757            2,
758        )
759        .unwrap();
760        let result = run_struct(vec![Value::from("payload"), Value::Cell(cell)])
761            .expect("struct array gpu handles");
762        let structs = expect_struct_array(result);
763        assert!(matches!(
764            structs[0].fields.get("payload"),
765            Some(Value::GpuTensor(h)) if h == &first
766        ));
767        assert!(matches!(
768            structs[1].fields.get("payload"),
769            Some(Value::GpuTensor(h)) if h == &second
770        ));
771    }
772
773    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
774    #[test]
775    #[cfg(feature = "wgpu")]
776    fn struct_preserves_gpu_handles_with_registered_provider() {
777        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
778            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
779        );
780        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
781        let host = HostTensorView {
782            data: &[1.0, 2.0],
783            shape: &[2, 1],
784        };
785        let handle = provider.upload(&host).expect("upload");
786        let args = vec![Value::from("gpu"), Value::GpuTensor(handle.clone())];
787        let Value::Struct(s) = run_struct(args).expect("struct") else {
788            panic!("expected struct value");
789        };
790        assert!(matches!(s.fields.get("gpu"), Some(Value::GpuTensor(h)) if h == &handle));
791    }
792
793    fn expect_struct_array(value: Value) -> Vec<StructValue> {
794        match value {
795            Value::Cell(cell) => cell
796                .data
797                .iter()
798                .map(|ptr| unsafe { &*ptr.as_raw() }.clone())
799                .map(|value| match value {
800                    Value::Struct(st) => st,
801                    other => panic!("expected struct element, got {other:?}"),
802                })
803                .collect(),
804            Value::Struct(st) => vec![st],
805            other => panic!("expected struct or struct array, got {other:?}"),
806        }
807    }
808}