1use crate::builtins::common::spec::{
4 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5 ReductionNaN, ResidencyPolicy, ShapeRequirements,
6};
7use crate::builtins::structs::type_resolvers::struct_type;
8use runmat_builtins::{CellArray, CharArray, StructValue, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::{build_runtime_error, BuiltinResult, RuntimeError};
12
13#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::structs::core::r#struct")]
14pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
15 name: "struct",
16 op_kind: GpuOpKind::Custom("struct"),
17 supported_precisions: &[],
18 broadcast: BroadcastSemantics::None,
19 provider_hooks: &[],
20 constant_strategy: ConstantStrategy::InlineLiteral,
21 residency: ResidencyPolicy::InheritInputs,
22 nan_mode: ReductionNaN::Include,
23 two_pass_threshold: None,
24 workgroup_size: None,
25 accepts_nan_mode: false,
26 notes: "Host-only construction; GPU values are preserved as handles without gathering.",
27};
28
29#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::structs::core::r#struct")]
30pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
31 name: "struct",
32 shape: ShapeRequirements::Any,
33 constant_strategy: ConstantStrategy::InlineLiteral,
34 elementwise: None,
35 reduction: None,
36 emits_nan: false,
37 notes: "Struct creation breaks fusion planning but retains GPU residency for field values.",
38};
39
40struct FieldEntry {
41 name: String,
42 value: FieldValue,
43}
44
45enum FieldValue {
46 Single(Value),
47 Cell(CellArray),
48}
49
50fn struct_flow(message: impl Into<String>) -> RuntimeError {
51 build_runtime_error(message).with_builtin("struct").build()
52}
53
54#[runtime_builtin(
55 name = "struct",
56 category = "structs/core",
57 summary = "Create scalar structs or struct arrays from name/value pairs.",
58 keywords = "struct,structure,name-value,record",
59 type_resolver(struct_type),
60 builtin_path = "crate::builtins::structs::core::r#struct"
61)]
62async fn struct_builtin(rest: Vec<Value>) -> BuiltinResult<Value> {
63 match rest.len() {
64 0 => Ok(Value::Struct(StructValue::new())),
65 1 => match rest.into_iter().next().unwrap() {
66 Value::Struct(existing) => Ok(Value::Struct(existing.clone())),
67 Value::Cell(cell) => clone_struct_array(&cell),
68 Value::Tensor(tensor) if tensor.data.is_empty() => empty_struct_array(),
69 Value::LogicalArray(logical) if logical.data.is_empty() => empty_struct_array(),
70 other => Err(struct_flow(format!(
71 "struct: expected name/value pairs, an existing struct or struct array, or [] to create an empty struct array (got {other:?})"
72 ))),
73 },
74 len if len % 2 == 0 => build_from_pairs(rest),
75 _ => Err(struct_flow("struct: expected name/value pairs")),
76 }
77}
78
79fn build_from_pairs(args: Vec<Value>) -> BuiltinResult<Value> {
80 let mut entries: Vec<FieldEntry> = Vec::new();
81 let mut target_shape: Option<Vec<usize>> = None;
82
83 let mut iter = args.into_iter();
84 while let (Some(name_value), Some(field_value)) = (iter.next(), iter.next()) {
85 let field_name = parse_field_name(&name_value)?;
86 match field_value {
87 Value::Cell(cell) => {
88 let shape = cell.shape.clone();
89 if let Some(existing) = &target_shape {
90 if *existing != shape {
91 return Err(struct_flow("struct: cell inputs must have matching sizes"));
92 }
93 } else {
94 target_shape = Some(shape);
95 }
96 entries.push(FieldEntry {
97 name: field_name,
98 value: FieldValue::Cell(cell),
99 });
100 }
101 other => entries.push(FieldEntry {
102 name: field_name,
103 value: FieldValue::Single(other),
104 }),
105 }
106 }
107
108 if let Some(shape) = target_shape {
109 build_struct_array(entries, shape)
110 } else {
111 build_scalar_struct(entries)
112 }
113}
114
115fn build_scalar_struct(entries: Vec<FieldEntry>) -> BuiltinResult<Value> {
116 let mut fields = StructValue::new();
117 for entry in entries {
118 match entry.value {
119 FieldValue::Single(value) => {
120 fields.fields.insert(entry.name, value);
121 }
122 FieldValue::Cell(cell) => {
123 let shape = cell.shape.clone();
124 return build_struct_array(
125 vec![FieldEntry {
126 name: entry.name,
127 value: FieldValue::Cell(cell),
128 }],
129 shape,
130 );
131 }
132 }
133 }
134 Ok(Value::Struct(fields))
135}
136
137fn build_struct_array(entries: Vec<FieldEntry>, shape: Vec<usize>) -> BuiltinResult<Value> {
138 let total_len = shape
139 .iter()
140 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
141 .ok_or_else(|| struct_flow("struct: struct array size exceeds platform limits"))?;
142
143 for entry in &entries {
144 if let FieldValue::Cell(cell) = &entry.value {
145 if cell.data.len() != total_len {
146 return Err(struct_flow("struct: cell inputs must have matching sizes"));
147 }
148 }
149 }
150
151 let mut structs: Vec<Value> = Vec::with_capacity(total_len);
152 for idx in 0..total_len {
153 let mut fields = StructValue::new();
154 for entry in &entries {
155 let value = match &entry.value {
156 FieldValue::Single(val) => val.clone(),
157 FieldValue::Cell(cell) => clone_cell_element(cell, idx)?,
158 };
159 fields.fields.insert(entry.name.clone(), value);
160 }
161 structs.push(Value::Struct(fields));
162 }
163
164 CellArray::new_with_shape(structs, shape)
165 .map(Value::Cell)
166 .map_err(|e| struct_flow(format!("struct: failed to assemble struct array: {e}")))
167}
168
169fn clone_cell_element(cell: &CellArray, index: usize) -> BuiltinResult<Value> {
170 cell.data
171 .get(index)
172 .map(|ptr| unsafe { &*ptr.as_raw() }.clone())
173 .ok_or_else(|| struct_flow("struct: cell inputs must have matching sizes"))
174}
175
176fn empty_struct_array() -> BuiltinResult<Value> {
177 CellArray::new(Vec::new(), 0, 0)
178 .map(Value::Cell)
179 .map_err(|e| struct_flow(format!("struct: failed to create empty struct array: {e}")))
180}
181
182fn clone_struct_array(array: &CellArray) -> BuiltinResult<Value> {
183 let mut values: Vec<Value> = Vec::with_capacity(array.data.len());
184 for (index, handle) in array.data.iter().enumerate() {
185 let value = unsafe { &*handle.as_raw() }.clone();
186 if !matches!(value, Value::Struct(_)) {
187 return Err(struct_flow(format!(
188 "struct: single argument cell input must contain structs (element {} is not a struct)",
189 index + 1
190 )));
191 }
192 values.push(value);
193 }
194 CellArray::new_with_shape(values, array.shape.clone())
195 .map(Value::Cell)
196 .map_err(|e| struct_flow(format!("struct: failed to copy struct array: {e}")))
197}
198
199fn parse_field_name(value: &Value) -> BuiltinResult<String> {
200 let text = match value {
201 Value::String(s) => s.clone(),
202 Value::StringArray(sa) => {
203 if sa.data.len() == 1 {
204 sa.data[0].clone()
205 } else {
206 return Err(struct_flow(
207 "struct: field names must be scalar string arrays or character vectors",
208 ));
209 }
210 }
211 Value::CharArray(ca) => char_array_to_string(ca)?,
212 _ => {
213 return Err(struct_flow(
214 "struct: field names must be strings or character vectors",
215 ))
216 }
217 };
218
219 validate_field_name(&text)?;
220 Ok(text)
221}
222
223fn char_array_to_string(ca: &CharArray) -> BuiltinResult<String> {
224 if ca.rows > 1 {
225 return Err(struct_flow(
226 "struct: field names must be 1-by-N character vectors",
227 ));
228 }
229 let mut out = String::with_capacity(ca.data.len());
230 for ch in &ca.data {
231 out.push(*ch);
232 }
233 Ok(out)
234}
235
236fn validate_field_name(name: &str) -> BuiltinResult<()> {
237 if name.is_empty() {
238 return Err(struct_flow("struct: field names must be nonempty"));
239 }
240 let mut chars = name.chars();
241 let Some(first) = chars.next() else {
242 return Err(struct_flow("struct: field names must be nonempty"));
243 };
244 if !is_first_char_valid(first) {
245 return Err(struct_flow(format!(
246 "struct: field names must begin with a letter or underscore (got '{name}')"
247 )));
248 }
249 if let Some(bad) = chars.find(|c| !is_subsequent_char_valid(*c)) {
250 return Err(struct_flow(format!(
251 "struct: invalid character '{bad}' in field name '{name}'"
252 )));
253 }
254 Ok(())
255}
256
257fn is_first_char_valid(c: char) -> bool {
258 c == '_' || c.is_ascii_alphabetic()
259}
260
261fn is_subsequent_char_valid(c: char) -> bool {
262 c == '_' || c.is_ascii_alphanumeric()
263}
264
265#[cfg(test)]
266pub(crate) mod tests {
267 use super::*;
268 use runmat_accelerate_api::GpuTensorHandle;
269 use runmat_builtins::{CellArray, IntValue, StringArray, StructValue, Tensor};
270
271 #[cfg(feature = "wgpu")]
272 use runmat_accelerate_api::HostTensorView;
273
274 fn error_message(err: crate::RuntimeError) -> String {
275 err.message().to_string()
276 }
277
278 fn run_struct(args: Vec<Value>) -> BuiltinResult<Value> {
279 futures::executor::block_on(struct_builtin(args))
280 }
281
282 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
283 #[test]
284 fn struct_empty() {
285 let Value::Struct(s) = run_struct(Vec::new()).expect("struct") else {
286 panic!("expected struct value");
287 };
288 assert!(s.fields.is_empty());
289 }
290
291 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
292 #[test]
293 fn struct_empty_from_empty_matrix() {
294 let tensor = Tensor::new(Vec::new(), vec![0, 0]).unwrap();
295 let value = run_struct(vec![Value::Tensor(tensor)]).expect("struct([])");
296 match value {
297 Value::Cell(cell) => {
298 assert_eq!(cell.rows, 0);
299 assert_eq!(cell.cols, 0);
300 assert!(cell.data.is_empty());
301 }
302 other => panic!("expected empty struct array, got {other:?}"),
303 }
304 }
305
306 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
307 #[test]
308 fn struct_name_value_pairs() {
309 let args = vec![
310 Value::from("name"),
311 Value::from("Ada"),
312 Value::from("score"),
313 Value::Int(IntValue::I32(42)),
314 ];
315 let Value::Struct(s) = run_struct(args).expect("struct") else {
316 panic!("expected struct value");
317 };
318 assert_eq!(s.fields.len(), 2);
319 assert!(matches!(s.fields.get("name"), Some(Value::String(v)) if v == "Ada"));
320 assert!(matches!(
321 s.fields.get("score"),
322 Some(Value::Int(IntValue::I32(42)))
323 ));
324 }
325
326 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
327 #[test]
328 fn struct_struct_array_from_cells() {
329 let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
330 let ages = CellArray::new(
331 vec![Value::Int(IntValue::I32(36)), Value::Int(IntValue::I32(45))],
332 1,
333 2,
334 )
335 .unwrap();
336 let result = run_struct(vec![
337 Value::from("name"),
338 Value::Cell(names),
339 Value::from("age"),
340 Value::Cell(ages),
341 ])
342 .expect("struct array");
343 let structs = expect_struct_array(result);
344 assert_eq!(structs.len(), 2);
345 assert!(matches!(
346 structs[0].fields.get("name"),
347 Some(Value::String(v)) if v == "Ada"
348 ));
349 assert!(matches!(
350 structs[1].fields.get("age"),
351 Some(Value::Int(IntValue::I32(45)))
352 ));
353 }
354
355 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
356 #[test]
357 fn struct_struct_array_replicates_scalars() {
358 let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
359 let result = run_struct(vec![
360 Value::from("name"),
361 Value::Cell(names),
362 Value::from("department"),
363 Value::from("Research"),
364 ])
365 .expect("struct array");
366 let structs = expect_struct_array(result);
367 assert_eq!(structs.len(), 2);
368 for entry in structs {
369 assert!(matches!(
370 entry.fields.get("department"),
371 Some(Value::String(v)) if v == "Research"
372 ));
373 }
374 }
375
376 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
377 #[test]
378 fn struct_struct_array_cell_size_mismatch_errors() {
379 let names = CellArray::new(vec![Value::from("Ada"), Value::from("Grace")], 1, 2).unwrap();
380 let scores = CellArray::new(vec![Value::Int(IntValue::I32(1))], 1, 1).unwrap();
381 let err = error_message(
382 run_struct(vec![
383 Value::from("name"),
384 Value::Cell(names),
385 Value::from("score"),
386 Value::Cell(scores),
387 ])
388 .unwrap_err(),
389 );
390 assert!(err.contains("matching sizes"));
391 }
392
393 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
394 #[test]
395 fn struct_overwrites_duplicates() {
396 let args = vec![
397 Value::from("version"),
398 Value::Int(IntValue::I32(1)),
399 Value::from("version"),
400 Value::Int(IntValue::I32(2)),
401 ];
402 let Value::Struct(s) = run_struct(args).expect("struct") else {
403 panic!("expected struct value");
404 };
405 assert_eq!(s.fields.len(), 1);
406 assert!(matches!(
407 s.fields.get("version"),
408 Some(Value::Int(IntValue::I32(2)))
409 ));
410 }
411
412 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
413 #[test]
414 fn struct_rejects_odd_arguments() {
415 let err = error_message(run_struct(vec![Value::from("name")]).unwrap_err());
416 assert!(err.contains("name/value pairs"));
417 }
418
419 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
420 #[test]
421 fn struct_rejects_invalid_field_name() {
422 let err = error_message(
423 run_struct(vec![Value::from("1bad"), Value::Int(IntValue::I32(1))]).unwrap_err(),
424 );
425 assert!(err.contains("begin with a letter or underscore"));
426 }
427
428 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
429 #[test]
430 fn struct_rejects_non_text_field_name() {
431 let err = error_message(
432 run_struct(vec![Value::Num(1.0), Value::Int(IntValue::I32(1))]).unwrap_err(),
433 );
434 assert!(err.contains("strings or character vectors"));
435 }
436
437 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
438 #[test]
439 fn struct_accepts_char_vector_name() {
440 let chars = CharArray::new("field".chars().collect(), 1, 5).unwrap();
441 let args = vec![Value::CharArray(chars), Value::Num(1.0)];
442 let Value::Struct(s) = run_struct(args).expect("struct") else {
443 panic!("expected struct value");
444 };
445 assert!(s.fields.contains_key("field"));
446 }
447
448 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
449 #[test]
450 fn struct_accepts_string_scalar_name() {
451 let sa = StringArray::new(vec!["field".to_string()], vec![1]).unwrap();
452 let args = vec![Value::StringArray(sa), Value::Num(1.0)];
453 let Value::Struct(s) = run_struct(args).expect("struct") else {
454 panic!("expected struct value");
455 };
456 assert!(s.fields.contains_key("field"));
457 }
458
459 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
460 #[test]
461 fn struct_allows_existing_struct_copy() {
462 let mut base = StructValue::new();
463 base.fields
464 .insert("id".to_string(), Value::Int(IntValue::I32(7)));
465 let copy = run_struct(vec![Value::Struct(base.clone())]).expect("struct");
466 assert_eq!(copy, Value::Struct(base));
467 }
468
469 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
470 #[test]
471 fn struct_copies_struct_array_argument() {
472 let mut proto = StructValue::new();
473 proto
474 .fields
475 .insert("id".into(), Value::Int(IntValue::I32(7)));
476 let struct_array = CellArray::new(
477 vec![
478 Value::Struct(proto.clone()),
479 Value::Struct(proto.clone()),
480 Value::Struct(proto.clone()),
481 ],
482 1,
483 3,
484 )
485 .unwrap();
486 let original = struct_array.clone();
487 let result = run_struct(vec![Value::Cell(struct_array)]).expect("struct array clone");
488 let cloned = expect_struct_array(result);
489 let baseline = expect_struct_array(Value::Cell(original));
490 assert_eq!(cloned, baseline);
491 }
492
493 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
494 #[test]
495 fn struct_rejects_cell_argument_without_structs() {
496 let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
497 let err = error_message(run_struct(vec![Value::Cell(cell)]).unwrap_err());
498 assert!(err.contains("must contain structs"));
499 }
500
501 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
502 #[test]
503 fn struct_preserves_gpu_tensor_handles() {
504 let handle = GpuTensorHandle {
505 shape: vec![2, 2],
506 device_id: 1,
507 buffer_id: 99,
508 };
509 let args = vec![Value::from("data"), Value::GpuTensor(handle.clone())];
510 let Value::Struct(s) = run_struct(args).expect("struct") else {
511 panic!("expected struct value");
512 };
513 assert!(matches!(s.fields.get("data"), Some(Value::GpuTensor(h)) if h == &handle));
514 }
515
516 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
517 #[test]
518 fn struct_struct_array_preserves_gpu_handles() {
519 let first = GpuTensorHandle {
520 shape: vec![1, 1],
521 device_id: 2,
522 buffer_id: 11,
523 };
524 let second = GpuTensorHandle {
525 shape: vec![1, 1],
526 device_id: 2,
527 buffer_id: 12,
528 };
529 let cell = CellArray::new(
530 vec![
531 Value::GpuTensor(first.clone()),
532 Value::GpuTensor(second.clone()),
533 ],
534 1,
535 2,
536 )
537 .unwrap();
538 let result = run_struct(vec![Value::from("payload"), Value::Cell(cell)])
539 .expect("struct array gpu handles");
540 let structs = expect_struct_array(result);
541 assert!(matches!(
542 structs[0].fields.get("payload"),
543 Some(Value::GpuTensor(h)) if h == &first
544 ));
545 assert!(matches!(
546 structs[1].fields.get("payload"),
547 Some(Value::GpuTensor(h)) if h == &second
548 ));
549 }
550
551 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
552 #[test]
553 #[cfg(feature = "wgpu")]
554 fn struct_preserves_gpu_handles_with_registered_provider() {
555 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
556 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
557 );
558 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
559 let host = HostTensorView {
560 data: &[1.0, 2.0],
561 shape: &[2, 1],
562 };
563 let handle = provider.upload(&host).expect("upload");
564 let args = vec![Value::from("gpu"), Value::GpuTensor(handle.clone())];
565 let Value::Struct(s) = run_struct(args).expect("struct") else {
566 panic!("expected struct value");
567 };
568 assert!(matches!(s.fields.get("gpu"), Some(Value::GpuTensor(h)) if h == &handle));
569 }
570
571 fn expect_struct_array(value: Value) -> Vec<StructValue> {
572 match value {
573 Value::Cell(cell) => cell
574 .data
575 .iter()
576 .map(|ptr| unsafe { &*ptr.as_raw() }.clone())
577 .map(|value| match value {
578 Value::Struct(st) => st,
579 other => panic!("expected struct element, got {other:?}"),
580 })
581 .collect(),
582 Value::Struct(st) => vec![st],
583 other => panic!("expected struct or struct array, got {other:?}"),
584 }
585 }
586}