1use crate::builtins::common::spec::{
4 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5 ReductionNaN, ResidencyPolicy, ShapeRequirements,
6};
7use crate::builtins::structs::type_resolvers::struct_type;
8use runmat_builtins::{
9 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
10 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
11 CellArray, CharArray, StructValue, Value,
12};
13use runmat_macros::runtime_builtin;
14
15use crate::{build_runtime_error, BuiltinResult, RuntimeError};
16
17#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::structs::core::r#struct")]
18pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
19 name: "struct",
20 op_kind: GpuOpKind::Custom("struct"),
21 supported_precisions: &[],
22 broadcast: BroadcastSemantics::None,
23 provider_hooks: &[],
24 constant_strategy: ConstantStrategy::InlineLiteral,
25 residency: ResidencyPolicy::InheritInputs,
26 nan_mode: ReductionNaN::Include,
27 two_pass_threshold: None,
28 workgroup_size: None,
29 accepts_nan_mode: false,
30 notes: "Host-only construction; GPU values are preserved as handles without gathering.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::structs::core::r#struct")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35 name: "struct",
36 shape: ShapeRequirements::Any,
37 constant_strategy: ConstantStrategy::InlineLiteral,
38 elementwise: None,
39 reduction: None,
40 emits_nan: false,
41 notes: "Struct creation breaks fusion planning but retains GPU residency for field values.",
42};
43
44struct FieldEntry {
45 name: String,
46 value: FieldValue,
47}
48
49enum FieldValue {
50 Single(Value),
51 Cell(CellArray),
52}
53
54const BUILTIN_NAME: &str = "struct";
55
56const STRUCT_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
57 name: "S",
58 ty: BuiltinParamType::Any,
59 arity: BuiltinParamArity::Required,
60 default: None,
61 description: "Scalar struct or struct array.",
62}];
63
64const STRUCT_INPUTS_EMPTY: [BuiltinParamDescriptor; 0] = [];
65const STRUCT_INPUTS_TEMPLATE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
66 name: "template",
67 ty: BuiltinParamType::Any,
68 arity: BuiltinParamArity::Required,
69 default: None,
70 description: "Existing struct/struct-array template or empty array for struct([]).",
71}];
72const STRUCT_INPUTS_PAIRS: [BuiltinParamDescriptor; 3] = [
73 BuiltinParamDescriptor {
74 name: "field",
75 ty: BuiltinParamType::PropertyName,
76 arity: BuiltinParamArity::Required,
77 default: None,
78 description: "Field name.",
79 },
80 BuiltinParamDescriptor {
81 name: "value",
82 ty: BuiltinParamType::Any,
83 arity: BuiltinParamArity::Required,
84 default: None,
85 description: "Field value or cell array of field values.",
86 },
87 BuiltinParamDescriptor {
88 name: "name_value_pairs",
89 ty: BuiltinParamType::Any,
90 arity: BuiltinParamArity::Variadic,
91 default: None,
92 description: "Additional field/value pairs.",
93 },
94];
95
96const STRUCT_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
97 BuiltinSignatureDescriptor {
98 label: "S = struct()",
99 inputs: &STRUCT_INPUTS_EMPTY,
100 outputs: &STRUCT_OUTPUT,
101 },
102 BuiltinSignatureDescriptor {
103 label: "S = struct(template)",
104 inputs: &STRUCT_INPUTS_TEMPLATE,
105 outputs: &STRUCT_OUTPUT,
106 },
107 BuiltinSignatureDescriptor {
108 label: "S = struct(field, value, ...)",
109 inputs: &STRUCT_INPUTS_PAIRS,
110 outputs: &STRUCT_OUTPUT,
111 },
112];
113
114const STRUCT_ERROR_INVALID_SINGLE_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
115 code: "RM.STRUCT.INVALID_SINGLE_INPUT",
116 identifier: Some("RunMat:struct:InvalidSingleInput"),
117 when: "Single input is neither struct, struct-array cell, nor empty numeric/logical array.",
118 message:
119 "struct: expected name/value pairs, an existing struct or struct array, or [] to create an empty struct array",
120};
121
122const STRUCT_ERROR_NAME_VALUE_PAIRS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
123 code: "RM.STRUCT.NAME_VALUE_PAIRS",
124 identifier: Some("RunMat:struct:NameValuePairs"),
125 when: "Name/value arguments are not supplied in complete pairs.",
126 message: "struct: expected name/value pairs",
127};
128
129const STRUCT_ERROR_CELL_SIZE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
130 code: "RM.STRUCT.CELL_SIZE_MISMATCH",
131 identifier: Some("RunMat:struct:CellSizeMismatch"),
132 when: "Cell value inputs for struct-array construction do not share the same shape.",
133 message: "struct: cell inputs must have matching sizes",
134};
135
136const STRUCT_ERROR_SIZE_OVERFLOW: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
137 code: "RM.STRUCT.SIZE_OVERFLOW",
138 identifier: Some("RunMat:struct:SizeOverflow"),
139 when: "Requested struct-array size exceeds platform limits.",
140 message: "struct: struct array size exceeds platform limits",
141};
142
143const STRUCT_ERROR_ASSEMBLE_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
144 code: "RM.STRUCT.ASSEMBLE_FAILED",
145 identifier: Some("RunMat:struct:AssembleFailed"),
146 when: "Internal struct-array assembly failed.",
147 message: "struct: failed to assemble struct array",
148};
149
150const STRUCT_ERROR_EMPTY_ARRAY_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
151 code: "RM.STRUCT.EMPTY_ARRAY_FAILED",
152 identifier: Some("RunMat:struct:EmptyArrayFailed"),
153 when: "Internal empty struct-array creation failed.",
154 message: "struct: failed to create empty struct array",
155};
156
157const STRUCT_ERROR_STRUCT_ARRAY_CONTENTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
158 code: "RM.STRUCT.STRUCT_ARRAY_CONTENTS",
159 identifier: Some("RunMat:struct:StructArrayContents"),
160 when: "Single-argument struct-array cell input contains non-struct values.",
161 message: "struct: single argument cell input must contain structs",
162};
163
164const STRUCT_ERROR_STRUCT_ARRAY_COPY_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
165 code: "RM.STRUCT.STRUCT_ARRAY_COPY_FAILED",
166 identifier: Some("RunMat:struct:StructArrayCopyFailed"),
167 when: "Copying a single-argument struct-array cell input failed.",
168 message: "struct: failed to copy struct array",
169};
170
171const STRUCT_ERROR_FIELD_NAME_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
172 code: "RM.STRUCT.FIELD_NAME_TYPE",
173 identifier: Some("RunMat:struct:FieldNameType"),
174 when: "Field name is not a string scalar or 1xN character vector.",
175 message: "struct: field names must be strings or character vectors",
176};
177
178const STRUCT_ERROR_FIELD_NAME_SCALAR: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
179 code: "RM.STRUCT.FIELD_NAME_SCALAR",
180 identifier: Some("RunMat:struct:FieldNameScalar"),
181 when: "Field name char/string-array input is not scalar.",
182 message: "struct: field names must be scalar string arrays or character vectors",
183};
184
185const STRUCT_ERROR_FIELD_NAME_CHAR_VECTOR: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
186 code: "RM.STRUCT.FIELD_NAME_CHAR_VECTOR",
187 identifier: Some("RunMat:struct:FieldNameCharVector"),
188 when: "Character-array field name input is not a 1-by-N character vector.",
189 message: "struct: field names must be 1-by-N character vectors",
190};
191
192const STRUCT_ERROR_FIELD_NAME_EMPTY: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
193 code: "RM.STRUCT.FIELD_NAME_EMPTY",
194 identifier: Some("RunMat:struct:FieldNameEmpty"),
195 when: "Field name is empty.",
196 message: "struct: field names must be nonempty",
197};
198
199const STRUCT_ERROR_FIELD_NAME_START_CHAR: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
200 code: "RM.STRUCT.FIELD_NAME_START_CHAR",
201 identifier: Some("RunMat:struct:FieldNameStartChar"),
202 when: "Field name does not start with a letter or underscore.",
203 message: "struct: field names must begin with a letter or underscore",
204};
205
206const STRUCT_ERROR_FIELD_NAME_INVALID_CHAR: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
207 code: "RM.STRUCT.FIELD_NAME_INVALID_CHAR",
208 identifier: Some("RunMat:struct:FieldNameInvalidChar"),
209 when: "Field name includes unsupported characters.",
210 message: "struct: invalid character in field name",
211};
212
213const STRUCT_ERRORS: [BuiltinErrorDescriptor; 14] = [
214 STRUCT_ERROR_INVALID_SINGLE_INPUT,
215 STRUCT_ERROR_NAME_VALUE_PAIRS,
216 STRUCT_ERROR_CELL_SIZE_MISMATCH,
217 STRUCT_ERROR_SIZE_OVERFLOW,
218 STRUCT_ERROR_ASSEMBLE_FAILED,
219 STRUCT_ERROR_EMPTY_ARRAY_FAILED,
220 STRUCT_ERROR_STRUCT_ARRAY_CONTENTS,
221 STRUCT_ERROR_STRUCT_ARRAY_COPY_FAILED,
222 STRUCT_ERROR_FIELD_NAME_TYPE,
223 STRUCT_ERROR_FIELD_NAME_SCALAR,
224 STRUCT_ERROR_FIELD_NAME_CHAR_VECTOR,
225 STRUCT_ERROR_FIELD_NAME_EMPTY,
226 STRUCT_ERROR_FIELD_NAME_START_CHAR,
227 STRUCT_ERROR_FIELD_NAME_INVALID_CHAR,
228];
229
230pub const STRUCT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
231 signatures: &STRUCT_SIGNATURES,
232 output_mode: BuiltinOutputMode::Fixed,
233 completion_policy: BuiltinCompletionPolicy::Public,
234 errors: &STRUCT_ERRORS,
235};
236
237fn struct_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
238 struct_error_with_message(error.message, error)
239}
240
241fn struct_error_with_message(
242 message: impl Into<String>,
243 error: &'static BuiltinErrorDescriptor,
244) -> RuntimeError {
245 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
246 if let Some(identifier) = error.identifier {
247 builder = builder.with_identifier(identifier);
248 }
249 builder.build()
250}
251
252#[runtime_builtin(
253 name = "struct",
254 category = "structs/core",
255 summary = "Create scalar structs or struct arrays from field/value inputs.",
256 keywords = "struct,structure,name-value,record",
257 type_resolver(struct_type),
258 descriptor(crate::builtins::structs::core::r#struct::STRUCT_DESCRIPTOR),
259 builtin_path = "crate::builtins::structs::core::r#struct"
260)]
261async fn struct_builtin(rest: Vec<Value>) -> BuiltinResult<Value> {
262 match rest.len() {
263 0 => Ok(Value::Struct(StructValue::new())),
264 1 => match rest.into_iter().next().unwrap() {
265 Value::Struct(existing) => Ok(Value::Struct(existing.clone())),
266 Value::Cell(cell) => clone_struct_array(&cell),
267 Value::Tensor(tensor) if tensor.data.is_empty() => empty_struct_array(),
268 Value::LogicalArray(logical) if logical.data.is_empty() => empty_struct_array(),
269 other => Err(struct_error_with_message(
270 format!(
271 "{} (got {other:?})",
272 STRUCT_ERROR_INVALID_SINGLE_INPUT.message
273 ),
274 &STRUCT_ERROR_INVALID_SINGLE_INPUT,
275 )),
276 },
277 len if len % 2 == 0 => build_from_pairs(rest),
278 _ => Err(struct_error(&STRUCT_ERROR_NAME_VALUE_PAIRS)),
279 }
280}
281
282fn build_from_pairs(args: Vec<Value>) -> BuiltinResult<Value> {
283 let mut entries: Vec<FieldEntry> = Vec::new();
284 let mut target_shape: Option<Vec<usize>> = None;
285
286 let mut iter = args.into_iter();
287 while let (Some(name_value), Some(field_value)) = (iter.next(), iter.next()) {
288 let field_name = parse_field_name(&name_value)?;
289 match field_value {
290 Value::Cell(cell) => {
291 let shape = cell.shape.clone();
292 if let Some(existing) = &target_shape {
293 if *existing != shape {
294 return Err(struct_error(&STRUCT_ERROR_CELL_SIZE_MISMATCH));
295 }
296 } else {
297 target_shape = Some(shape);
298 }
299 entries.push(FieldEntry {
300 name: field_name,
301 value: FieldValue::Cell(cell),
302 });
303 }
304 other => entries.push(FieldEntry {
305 name: field_name,
306 value: FieldValue::Single(other),
307 }),
308 }
309 }
310
311 if let Some(shape) = target_shape {
312 build_struct_array(entries, shape)
313 } else {
314 build_scalar_struct(entries)
315 }
316}
317
318fn build_scalar_struct(entries: Vec<FieldEntry>) -> BuiltinResult<Value> {
319 let mut fields = StructValue::new();
320 for entry in entries {
321 match entry.value {
322 FieldValue::Single(value) => {
323 fields.fields.insert(entry.name, value);
324 }
325 FieldValue::Cell(cell) => {
326 let shape = cell.shape.clone();
327 return build_struct_array(
328 vec![FieldEntry {
329 name: entry.name,
330 value: FieldValue::Cell(cell),
331 }],
332 shape,
333 );
334 }
335 }
336 }
337 Ok(Value::Struct(fields))
338}
339
340fn build_struct_array(entries: Vec<FieldEntry>, shape: Vec<usize>) -> BuiltinResult<Value> {
341 let total_len = shape
342 .iter()
343 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
344 .ok_or_else(|| struct_error(&STRUCT_ERROR_SIZE_OVERFLOW))?;
345
346 for entry in &entries {
347 if let FieldValue::Cell(cell) = &entry.value {
348 if cell.data.len() != total_len {
349 return Err(struct_error(&STRUCT_ERROR_CELL_SIZE_MISMATCH));
350 }
351 }
352 }
353
354 let mut structs: Vec<Value> = Vec::with_capacity(total_len);
355 for idx in 0..total_len {
356 let mut fields = StructValue::new();
357 for entry in &entries {
358 let value = match &entry.value {
359 FieldValue::Single(val) => val.clone(),
360 FieldValue::Cell(cell) => clone_cell_element(cell, idx)?,
361 };
362 fields.fields.insert(entry.name.clone(), value);
363 }
364 structs.push(Value::Struct(fields));
365 }
366
367 CellArray::new_with_shape(structs, shape)
368 .map(Value::Cell)
369 .map_err(|e| {
370 struct_error_with_message(
371 format!("{}: {e}", STRUCT_ERROR_ASSEMBLE_FAILED.message),
372 &STRUCT_ERROR_ASSEMBLE_FAILED,
373 )
374 })
375}
376
377fn clone_cell_element(cell: &CellArray, index: usize) -> BuiltinResult<Value> {
378 cell.data
379 .get(index)
380 .map(|ptr| unsafe { &*ptr.as_raw() }.clone())
381 .ok_or_else(|| struct_error(&STRUCT_ERROR_CELL_SIZE_MISMATCH))
382}
383
384fn empty_struct_array() -> BuiltinResult<Value> {
385 CellArray::new(Vec::new(), 0, 0)
386 .map(Value::Cell)
387 .map_err(|e| {
388 struct_error_with_message(
389 format!("{}: {e}", STRUCT_ERROR_EMPTY_ARRAY_FAILED.message),
390 &STRUCT_ERROR_EMPTY_ARRAY_FAILED,
391 )
392 })
393}
394
395fn clone_struct_array(array: &CellArray) -> BuiltinResult<Value> {
396 let mut values: Vec<Value> = Vec::with_capacity(array.data.len());
397 for (index, handle) in array.data.iter().enumerate() {
398 let value = unsafe { &*handle.as_raw() }.clone();
399 if !matches!(value, Value::Struct(_)) {
400 return Err(struct_error_with_message(
401 format!(
402 "{} (element {} is not a struct)",
403 STRUCT_ERROR_STRUCT_ARRAY_CONTENTS.message,
404 index + 1
405 ),
406 &STRUCT_ERROR_STRUCT_ARRAY_CONTENTS,
407 ));
408 }
409 values.push(value);
410 }
411 CellArray::new_with_shape(values, array.shape.clone())
412 .map(Value::Cell)
413 .map_err(|e| {
414 struct_error_with_message(
415 format!("{}: {e}", STRUCT_ERROR_STRUCT_ARRAY_COPY_FAILED.message),
416 &STRUCT_ERROR_STRUCT_ARRAY_COPY_FAILED,
417 )
418 })
419}
420
421fn parse_field_name(value: &Value) -> BuiltinResult<String> {
422 let text = match value {
423 Value::String(s) => s.clone(),
424 Value::StringArray(sa) => {
425 if sa.data.len() == 1 {
426 sa.data[0].clone()
427 } else {
428 return Err(struct_error(&STRUCT_ERROR_FIELD_NAME_SCALAR));
429 }
430 }
431 Value::CharArray(ca) => char_array_to_string(ca)?,
432 _ => return Err(struct_error(&STRUCT_ERROR_FIELD_NAME_TYPE)),
433 };
434
435 validate_field_name(&text)?;
436 Ok(text)
437}
438
439fn char_array_to_string(ca: &CharArray) -> BuiltinResult<String> {
440 if ca.rows > 1 {
441 return Err(struct_error(&STRUCT_ERROR_FIELD_NAME_CHAR_VECTOR));
442 }
443 let mut out = String::with_capacity(ca.data.len());
444 for ch in &ca.data {
445 out.push(*ch);
446 }
447 Ok(out)
448}
449
450fn validate_field_name(name: &str) -> BuiltinResult<()> {
451 if name.is_empty() {
452 return Err(struct_error(&STRUCT_ERROR_FIELD_NAME_EMPTY));
453 }
454 let mut chars = name.chars();
455 let Some(first) = chars.next() else {
456 return Err(struct_error(&STRUCT_ERROR_FIELD_NAME_EMPTY));
457 };
458 if !is_first_char_valid(first) {
459 return Err(struct_error_with_message(
460 format!(
461 "{} (got '{name}')",
462 STRUCT_ERROR_FIELD_NAME_START_CHAR.message
463 ),
464 &STRUCT_ERROR_FIELD_NAME_START_CHAR,
465 ));
466 }
467 if let Some(bad) = chars.find(|c| !is_subsequent_char_valid(*c)) {
468 return Err(struct_error_with_message(
469 format!(
470 "{} ('{bad}' in '{name}')",
471 STRUCT_ERROR_FIELD_NAME_INVALID_CHAR.message
472 ),
473 &STRUCT_ERROR_FIELD_NAME_INVALID_CHAR,
474 ));
475 }
476 Ok(())
477}
478
479fn is_first_char_valid(c: char) -> bool {
480 c == '_' || c.is_ascii_alphabetic()
481}
482
483fn is_subsequent_char_valid(c: char) -> bool {
484 c == '_' || c.is_ascii_alphanumeric()
485}
486
487#[cfg(test)]
488pub(crate) mod tests {
489 use super::*;
490 use runmat_accelerate_api::GpuTensorHandle;
491 use runmat_builtins::{CellArray, IntValue, StringArray, StructValue, Tensor};
492
493 #[cfg(feature = "wgpu")]
494 use runmat_accelerate_api::HostTensorView;
495
496 fn error_message(err: crate::RuntimeError) -> String {
497 err.message().to_string()
498 }
499
500 fn run_struct(args: Vec<Value>) -> BuiltinResult<Value> {
501 futures::executor::block_on(struct_builtin(args))
502 }
503
504 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
505 #[test]
506 fn struct_empty() {
507 let Value::Struct(s) = run_struct(Vec::new()).expect("struct") else {
508 panic!("expected struct value");
509 };
510 assert!(s.fields.is_empty());
511 }
512
513 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
514 #[test]
515 fn struct_empty_from_empty_matrix() {
516 let tensor = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
517 let value = run_struct(vec![Value::Tensor(tensor)]).expect("struct([])");
518 match value {
519 Value::Cell(cell) => {
520 assert_eq!(cell.rows, 0);
521 assert_eq!(cell.cols, 0);
522 assert!(cell.data.is_empty());
523 }
524 other => panic!("expected empty struct array, got {other:?}"),
525 }
526 }
527
528 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
529 #[test]
530 fn struct_name_value_pairs() {
531 let args = vec![
532 Value::from("name"),
533 Value::from("Ada"),
534 Value::from("score"),
535 Value::Int(IntValue::I32(42)),
536 ];
537 let Value::Struct(s) = run_struct(args).expect("struct") else {
538 panic!("expected struct value");
539 };
540 assert_eq!(s.fields.len(), 2);
541 assert!(matches!(s.fields.get("name"), Some(Value::String(v)) if v == "Ada"));
542 assert!(matches!(
543 s.fields.get("score"),
544 Some(Value::Int(IntValue::I32(42)))
545 ));
546 }
547
548 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
549 #[test]
550 fn struct_struct_array_from_cells() {
551 let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
552 let ages = CellArray::new(
553 vec![Value::Int(IntValue::I32(36)), Value::Int(IntValue::I32(45))],
554 1,
555 2,
556 )
557 .unwrap();
558 let result = run_struct(vec![
559 Value::from("name"),
560 Value::Cell(names),
561 Value::from("age"),
562 Value::Cell(ages),
563 ])
564 .expect("struct array");
565 let structs = expect_struct_array(result);
566 assert_eq!(structs.len(), 2);
567 assert!(matches!(
568 structs[0].fields.get("name"),
569 Some(Value::String(v)) if v == "Ada"
570 ));
571 assert!(matches!(
572 structs[1].fields.get("age"),
573 Some(Value::Int(IntValue::I32(45)))
574 ));
575 }
576
577 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
578 #[test]
579 fn struct_struct_array_replicates_scalars() {
580 let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
581 let result = run_struct(vec![
582 Value::from("name"),
583 Value::Cell(names),
584 Value::from("department"),
585 Value::from("Research"),
586 ])
587 .expect("struct array");
588 let structs = expect_struct_array(result);
589 assert_eq!(structs.len(), 2);
590 for entry in structs {
591 assert!(matches!(
592 entry.fields.get("department"),
593 Some(Value::String(v)) if v == "Research"
594 ));
595 }
596 }
597
598 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
599 #[test]
600 fn struct_struct_array_cell_size_mismatch_errors() {
601 let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
602 let scores = CellArray::new(vec![Value::Int(IntValue::I32(1))], 1, 1).unwrap();
603 let err = error_message(
604 run_struct(vec![
605 Value::from("name"),
606 Value::Cell(names),
607 Value::from("score"),
608 Value::Cell(scores),
609 ])
610 .unwrap_err(),
611 );
612 assert!(err.contains("matching sizes"));
613 }
614
615 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
616 #[test]
617 fn struct_overwrites_duplicates() {
618 let args = vec![
619 Value::from("version"),
620 Value::Int(IntValue::I32(1)),
621 Value::from("version"),
622 Value::Int(IntValue::I32(2)),
623 ];
624 let Value::Struct(s) = run_struct(args).expect("struct") else {
625 panic!("expected struct value");
626 };
627 assert_eq!(s.fields.len(), 1);
628 assert!(matches!(
629 s.fields.get("version"),
630 Some(Value::Int(IntValue::I32(2)))
631 ));
632 }
633
634 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
635 #[test]
636 fn struct_rejects_odd_arguments() {
637 let err = error_message(run_struct(vec![Value::from("name")]).unwrap_err());
638 assert!(err.contains("name/value pairs"));
639 }
640
641 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
642 #[test]
643 fn struct_rejects_invalid_field_name() {
644 let err = error_message(
645 run_struct(vec![Value::from("1bad"), Value::Int(IntValue::I32(1))]).unwrap_err(),
646 );
647 assert!(err.contains("begin with a letter or underscore"));
648 }
649
650 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
651 #[test]
652 fn struct_rejects_non_text_field_name() {
653 let err = error_message(
654 run_struct(vec![Value::Num(1.0), Value::Int(IntValue::I32(1))]).unwrap_err(),
655 );
656 assert!(err.contains("strings or character vectors"));
657 }
658
659 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
660 #[test]
661 fn struct_accepts_char_vector_name() {
662 let chars = CharArray::new("field".chars().collect(), 1, 5).unwrap();
663 let args = vec![Value::CharArray(chars), Value::Num(1.0)];
664 let Value::Struct(s) = run_struct(args).expect("struct") else {
665 panic!("expected struct value");
666 };
667 assert!(s.fields.contains_key("field"));
668 }
669
670 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
671 #[test]
672 fn struct_accepts_string_scalar_name() {
673 let sa = StringArray::new(vec!["field".to_string()], vec![1]).unwrap();
674 let args = vec![Value::StringArray(sa), Value::Num(1.0)];
675 let Value::Struct(s) = run_struct(args).expect("struct") else {
676 panic!("expected struct value");
677 };
678 assert!(s.fields.contains_key("field"));
679 }
680
681 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
682 #[test]
683 fn struct_allows_existing_struct_copy() {
684 let mut base = StructValue::new();
685 base.fields
686 .insert("id".to_string(), Value::Int(IntValue::I32(7)));
687 let copy = run_struct(vec![Value::Struct(base.clone())]).expect("struct");
688 assert_eq!(copy, Value::Struct(base));
689 }
690
691 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
692 #[test]
693 fn struct_copies_struct_array_argument() {
694 let mut proto = StructValue::new();
695 proto
696 .fields
697 .insert("id".into(), Value::Int(IntValue::I32(7)));
698 let struct_array = CellArray::new(
699 vec![
700 Value::Struct(proto.clone()),
701 Value::Struct(proto.clone()),
702 Value::Struct(proto.clone()),
703 ],
704 1,
705 3,
706 )
707 .unwrap();
708 let original = struct_array.clone();
709 let result = run_struct(vec![Value::Cell(struct_array)]).expect("struct array clone");
710 let cloned = expect_struct_array(result);
711 let baseline = expect_struct_array(Value::Cell(original));
712 assert_eq!(cloned, baseline);
713 }
714
715 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
716 #[test]
717 fn struct_rejects_cell_argument_without_structs() {
718 let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
719 let err = error_message(run_struct(vec![Value::Cell(cell)]).unwrap_err());
720 assert!(err.contains("must contain structs"));
721 }
722
723 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
724 #[test]
725 fn struct_preserves_gpu_tensor_handles() {
726 let handle = GpuTensorHandle {
727 shape: vec![2, 2],
728 device_id: 1,
729 buffer_id: 99,
730 };
731 let args = vec![Value::from("data"), Value::GpuTensor(handle.clone())];
732 let Value::Struct(s) = run_struct(args).expect("struct") else {
733 panic!("expected struct value");
734 };
735 assert!(matches!(s.fields.get("data"), Some(Value::GpuTensor(h)) if h == &handle));
736 }
737
738 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
739 #[test]
740 fn struct_struct_array_preserves_gpu_handles() {
741 let first = GpuTensorHandle {
742 shape: vec![1, 1],
743 device_id: 2,
744 buffer_id: 11,
745 };
746 let second = GpuTensorHandle {
747 shape: vec![1, 1],
748 device_id: 2,
749 buffer_id: 12,
750 };
751 let cell = CellArray::new(
752 vec![
753 Value::GpuTensor(first.clone()),
754 Value::GpuTensor(second.clone()),
755 ],
756 1,
757 2,
758 )
759 .unwrap();
760 let result = run_struct(vec![Value::from("payload"), Value::Cell(cell)])
761 .expect("struct array gpu handles");
762 let structs = expect_struct_array(result);
763 assert!(matches!(
764 structs[0].fields.get("payload"),
765 Some(Value::GpuTensor(h)) if h == &first
766 ));
767 assert!(matches!(
768 structs[1].fields.get("payload"),
769 Some(Value::GpuTensor(h)) if h == &second
770 ));
771 }
772
773 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
774 #[test]
775 #[cfg(feature = "wgpu")]
776 fn struct_preserves_gpu_handles_with_registered_provider() {
777 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
778 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
779 );
780 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
781 let host = HostTensorView {
782 data: &[1.0, 2.0],
783 shape: &[2, 1],
784 };
785 let handle = provider.upload(&host).expect("upload");
786 let args = vec![Value::from("gpu"), Value::GpuTensor(handle.clone())];
787 let Value::Struct(s) = run_struct(args).expect("struct") else {
788 panic!("expected struct value");
789 };
790 assert!(matches!(s.fields.get("gpu"), Some(Value::GpuTensor(h)) if h == &handle));
791 }
792
793 fn expect_struct_array(value: Value) -> Vec<StructValue> {
794 match value {
795 Value::Cell(cell) => cell
796 .data
797 .iter()
798 .map(|ptr| unsafe { &*ptr.as_raw() }.clone())
799 .map(|value| match value {
800 Value::Struct(st) => st,
801 other => panic!("expected struct element, got {other:?}"),
802 })
803 .collect(),
804 Value::Struct(st) => vec![st],
805 other => panic!("expected struct or struct array, got {other:?}"),
806 }
807 }
808}