1use crate::builtins::common::spec::{
4 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5 ReductionNaN, ResidencyPolicy, ShapeRequirements,
6};
7use crate::builtins::structs::type_resolvers::isfield_type;
8use runmat_builtins::{CellArray, LogicalArray, StructValue, Value};
9use runmat_macros::runtime_builtin;
10use std::collections::HashSet;
11
12use crate::{build_runtime_error, BuiltinResult, RuntimeError};
13
14#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::structs::core::isfield")]
15pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
16 name: "isfield",
17 op_kind: GpuOpKind::Custom("isfield"),
18 supported_precisions: &[],
19 broadcast: BroadcastSemantics::None,
20 provider_hooks: &[],
21 constant_strategy: ConstantStrategy::InlineLiteral,
22 residency: ResidencyPolicy::InheritInputs,
23 nan_mode: ReductionNaN::Include,
24 two_pass_threshold: None,
25 workgroup_size: None,
26 accepts_nan_mode: false,
27 notes: "Host-only metadata check; acceleration providers do not participate.",
28};
29
30#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::structs::core::isfield")]
31pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
32 name: "isfield",
33 shape: ShapeRequirements::Any,
34 constant_strategy: ConstantStrategy::InlineLiteral,
35 elementwise: None,
36 reduction: None,
37 emits_nan: false,
38 notes: "Acts as a fusion barrier because it inspects struct metadata on the host.",
39};
40
41fn isfield_flow(message: impl Into<String>) -> RuntimeError {
42 build_runtime_error(message).with_builtin("isfield").build()
43}
44
45#[runtime_builtin(
46 name = "isfield",
47 category = "structs/core",
48 summary = "Test whether a struct or struct array defines specific field names.",
49 keywords = "isfield,struct,field existence",
50 type_resolver(isfield_type),
51 builtin_path = "crate::builtins::structs::core::isfield"
52)]
53async fn isfield_builtin(target: Value, names: Value) -> BuiltinResult<Value> {
54 let context = classify_struct(&target)?;
55 let parsed = parse_field_names(names)?;
56 match context {
57 StructContext::Struct(struct_value) => evaluate_scalar(struct_value, parsed),
58 StructContext::StructArray(cell) => evaluate_struct_array(cell, parsed),
59 StructContext::NonStruct => evaluate_non_struct(parsed),
60 }
61}
62
63#[derive(Clone, Copy)]
64enum StructContext<'a> {
65 Struct(&'a StructValue),
66 StructArray(&'a CellArray),
67 NonStruct,
68}
69
70fn classify_struct<'a>(value: &'a Value) -> BuiltinResult<StructContext<'a>> {
71 match value {
72 Value::Struct(st) => Ok(StructContext::Struct(st)),
73 Value::Cell(cell) => {
74 if cell.data.is_empty() {
75 return Ok(StructContext::StructArray(cell));
76 }
77 if cell
78 .data
79 .iter()
80 .all(|handle| matches!(unsafe { &*handle.as_raw() }, Value::Struct(_)))
81 {
82 Ok(StructContext::StructArray(cell))
83 } else {
84 Ok(StructContext::NonStruct)
85 }
86 }
87 _ => Ok(StructContext::NonStruct),
88 }
89}
90
91enum ParsedNames {
92 Scalar(String),
93 Array {
94 names: Vec<String>,
95 shape: Vec<usize>,
96 },
97}
98
99fn parse_field_names(names: Value) -> BuiltinResult<ParsedNames> {
100 match names {
101 Value::String(s) => Ok(ParsedNames::Scalar(s)),
102 Value::CharArray(ca) => {
103 if ca.rows == 1 {
104 Ok(ParsedNames::Scalar(ca.data.iter().collect()))
105 } else {
106 Err(field_name_type_error())
107 }
108 }
109 Value::StringArray(sa) => Ok(ParsedNames::Array {
110 names: sa.data.clone(),
111 shape: sa.shape.clone(),
112 }),
113 Value::Cell(cell) => Ok(ParsedNames::Array {
114 names: collect_cell_names(&cell)?,
115 shape: if cell.shape.is_empty() {
116 vec![cell.rows, cell.cols]
117 } else {
118 cell.shape.clone()
119 },
120 }),
121 other => match try_single_field_name(&other)? {
122 Some(name) => Ok(ParsedNames::Scalar(name)),
123 None => Err(field_name_type_error()),
124 },
125 }
126}
127
128fn try_single_field_name(value: &Value) -> BuiltinResult<Option<String>> {
129 match value {
130 Value::String(s) => Ok(Some(s.clone())),
131 Value::CharArray(ca) => {
132 if ca.rows == 1 {
133 Ok(Some(ca.data.iter().collect()))
134 } else {
135 Err(field_name_type_error())
136 }
137 }
138 Value::StringArray(sa) => {
139 if sa.data.len() == 1 {
140 Ok(Some(sa.data[0].clone()))
141 } else {
142 Err(field_name_type_error())
143 }
144 }
145 _ => Ok(None),
146 }
147}
148
149fn evaluate_scalar(struct_value: &StructValue, names: ParsedNames) -> BuiltinResult<Value> {
150 match names {
151 ParsedNames::Scalar(name) => Ok(Value::Bool(struct_value.fields.contains_key(&name))),
152 ParsedNames::Array { names, shape } => {
153 let mut bits = Vec::with_capacity(names.len());
154 for name in names {
155 bits.push(if struct_value.fields.contains_key(&name) {
156 1
157 } else {
158 0
159 });
160 }
161 let logical = LogicalArray::new(bits, shape)
162 .map_err(|e| isfield_flow(format!("isfield: {e}")))?;
163 Ok(Value::LogicalArray(logical))
164 }
165 }
166}
167
168fn evaluate_struct_array(cell: &CellArray, names: ParsedNames) -> BuiltinResult<Value> {
169 let fields = struct_array_field_intersection(cell)?;
170 match names {
171 ParsedNames::Scalar(name) => Ok(Value::Bool(fields.contains(&name))),
172 ParsedNames::Array { names, shape } => {
173 let mut bits = Vec::with_capacity(names.len());
174 for name in names {
175 bits.push(if fields.contains(&name) { 1 } else { 0 });
176 }
177 let logical = LogicalArray::new(bits, shape)
178 .map_err(|e| isfield_flow(format!("isfield: {e}")))?;
179 Ok(Value::LogicalArray(logical))
180 }
181 }
182}
183
184fn evaluate_non_struct(names: ParsedNames) -> BuiltinResult<Value> {
185 match names {
186 ParsedNames::Scalar(_) => Ok(Value::Bool(false)),
187 ParsedNames::Array { names, shape } => {
188 let logical = LogicalArray::new(vec![0; names.len()], shape)
189 .map_err(|e| isfield_flow(format!("isfield: {e}")))?;
190 Ok(Value::LogicalArray(logical))
191 }
192 }
193}
194
195fn struct_array_field_intersection(cell: &CellArray) -> BuiltinResult<HashSet<String>> {
196 if cell.data.is_empty() {
197 return Ok(HashSet::new());
198 }
199
200 let mut iter = cell.data.iter();
201 let first = unsafe { &*iter.next().unwrap().as_raw() };
202 let Value::Struct(first_struct) = first else {
203 return Err(isfield_flow(
204 "isfield: struct array elements must be structs",
205 ));
206 };
207 let mut fields: HashSet<String> = first_struct.fields.keys().cloned().collect();
208
209 for handle in iter {
210 let value = unsafe { &*handle.as_raw() };
211 let Value::Struct(struct_value) = value else {
212 return Err(isfield_flow(
213 "isfield: struct array elements must be structs",
214 ));
215 };
216 fields.retain(|name| struct_value.fields.contains_key(name));
217 if fields.is_empty() {
218 break;
219 }
220 }
221
222 Ok(fields)
223}
224
225fn collect_cell_names(cell: &CellArray) -> BuiltinResult<Vec<String>> {
226 let total = cell.data.len();
227 if total == 0 {
228 return Ok(Vec::new());
229 }
230
231 let shape = if cell.shape.is_empty() {
232 vec![cell.rows, cell.cols]
233 } else {
234 cell.shape.clone()
235 };
236
237 let mut names = Vec::with_capacity(total);
238 let row_strides = row_major_strides(&shape);
239 for idx in 0..total {
240 let coords = column_major_coordinates(idx, &shape);
241 let mut row_index = 0usize;
242 for (coord, stride) in coords.iter().zip(row_strides.iter()) {
243 row_index += coord * stride;
244 }
245 let value = unsafe { &*cell.data[row_index].as_raw() };
246 names.push(value_to_field_name(value)?);
247 }
248 Ok(names)
249}
250
251fn row_major_strides(shape: &[usize]) -> Vec<usize> {
252 if shape.is_empty() {
253 return Vec::new();
254 }
255 let mut strides = vec![0; shape.len()];
256 let mut stride = 1usize;
257 for (i, dim) in shape.iter().enumerate().rev() {
258 strides[i] = stride;
259 stride = stride.saturating_mul(*dim.max(&1));
260 }
261 strides
262}
263
264fn column_major_coordinates(mut index: usize, shape: &[usize]) -> Vec<usize> {
265 if shape.is_empty() {
266 return Vec::new();
267 }
268 let mut coords = vec![0usize; shape.len()];
269 for (i, dim) in shape.iter().enumerate() {
270 if *dim == 0 {
271 coords[i] = 0;
272 continue;
273 }
274 coords[i] = index % dim;
275 index /= dim;
276 }
277 coords
278}
279
280fn value_to_field_name(value: &Value) -> BuiltinResult<String> {
281 match value {
282 Value::String(s) => Ok(s.clone()),
283 Value::CharArray(ca) => {
284 if ca.rows == 1 {
285 Ok(ca.data.iter().collect())
286 } else {
287 Err(field_name_type_error())
288 }
289 }
290 Value::StringArray(sa) => {
291 if sa.data.len() == 1 {
292 Ok(sa.data[0].clone())
293 } else {
294 Err(field_name_type_error())
295 }
296 }
297 other => Err(isfield_flow(format!(
298 "isfield: cell array elements must be character vectors or strings (got {other:?})"
299 ))),
300 }
301}
302
303fn field_name_type_error() -> RuntimeError {
304 isfield_flow(
305 "isfield: field names must be strings, string arrays, or cell arrays of character vectors",
306 )
307}
308
309#[cfg(test)]
310pub(crate) mod tests {
311 use super::*;
312 use runmat_builtins::{CellArray, CharArray, StringArray, StructValue};
313
314 fn error_message(err: crate::RuntimeError) -> String {
315 err.message().to_string()
316 }
317
318 fn run_isfield(target: Value, names: Value) -> BuiltinResult<Value> {
319 futures::executor::block_on(isfield_builtin(target, names))
320 }
321
322 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
323 #[test]
324 fn isfield_scalar_struct_single_name() {
325 let mut st = StructValue::new();
326 st.fields.insert("name".to_string(), Value::from("Ada"));
327 assert_eq!(
328 run_isfield(Value::Struct(st.clone()), Value::from("name")).unwrap(),
329 Value::Bool(true)
330 );
331 assert_eq!(
332 run_isfield(Value::Struct(st), Value::from("score")).unwrap(),
333 Value::Bool(false)
334 );
335 }
336
337 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
338 #[test]
339 fn isfield_char_array_single_row() {
340 let mut st = StructValue::new();
341 st.fields.insert("alpha".into(), Value::Num(1.0));
342 let chars = CharArray::new("alpha".chars().collect(), 1, 5).unwrap();
343 let result = run_isfield(Value::Struct(st), Value::CharArray(chars)).unwrap();
344 assert_eq!(result, Value::Bool(true));
345 }
346
347 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
348 #[test]
349 fn isfield_struct_cell_names_produces_logical_array() {
350 let mut st = StructValue::new();
351 st.fields.insert("name".to_string(), Value::from("Ada"));
352 st.fields.insert("score".to_string(), Value::from(42.0));
353 let names = CellArray::new(
354 vec![
355 Value::from("name"),
356 Value::from("department"),
357 Value::from("score"),
358 Value::from("email"),
359 ],
360 2,
361 2,
362 )
363 .unwrap();
364 let result = run_isfield(Value::Struct(st), Value::Cell(names)).expect("isfield");
365 let expected = LogicalArray::new(vec![1, 1, 0, 0], vec![2, 2]).expect("logical array");
366 assert_eq!(result, Value::LogicalArray(expected));
367 }
368
369 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
370 #[test]
371 fn isfield_cell_mixed_string_types() {
372 let mut st = StructValue::new();
373 st.fields.insert("name".into(), Value::from("Ada"));
374 st.fields.insert("id".into(), Value::from(7.0));
375 let id_chars = CharArray::new("id".chars().collect(), 1, 2).unwrap();
376 let cell = CellArray::new(
377 vec![
378 Value::from("name"),
379 Value::CharArray(id_chars),
380 Value::from("department"),
381 ],
382 1,
383 3,
384 )
385 .unwrap();
386 let result = run_isfield(Value::Struct(st), Value::Cell(cell)).unwrap();
387 let expected = LogicalArray::new(vec![1, 1, 0], vec![1, 3]).unwrap();
388 assert_eq!(result, Value::LogicalArray(expected));
389 }
390
391 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
392 #[test]
393 fn isfield_struct_array_intersection() {
394 let mut first = StructValue::new();
395 first.fields.insert("name".to_string(), Value::from("Ada"));
396 first.fields.insert("id".to_string(), Value::from(101.0));
397
398 let mut second = StructValue::new();
399 second
400 .fields
401 .insert("name".to_string(), Value::from("Grace"));
402
403 let struct_array = CellArray::new_with_shape(
404 vec![Value::Struct(first), Value::Struct(second)],
405 vec![1, 2],
406 )
407 .unwrap();
408
409 let res_id =
410 run_isfield(Value::Cell(struct_array.clone()), Value::from("id")).expect("isfield");
411 assert_eq!(res_id, Value::Bool(false));
412
413 let res_name =
414 run_isfield(Value::Cell(struct_array), Value::from("name")).expect("isfield");
415 assert_eq!(res_name, Value::Bool(true));
416 }
417
418 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
419 #[test]
420 fn isfield_non_struct_returns_false() {
421 let result = run_isfield(Value::Num(5.0), Value::from("field")).unwrap();
422 assert_eq!(result, Value::Bool(false));
423 }
424
425 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
426 #[test]
427 fn isfield_string_array_names() {
428 let mut st = StructValue::new();
429 st.fields.insert("alpha".into(), Value::Num(1.0));
430 st.fields.insert("beta".into(), Value::Num(2.0));
431 let names = StringArray::new(vec!["alpha".into(), "gamma".into()], vec![2, 1]).unwrap();
432 let result = run_isfield(Value::Struct(st), Value::StringArray(names)).unwrap();
433 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).expect("logical array");
434 assert_eq!(result, Value::LogicalArray(expected));
435 }
436
437 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
438 #[test]
439 fn isfield_invalid_name_type_errors() {
440 let mut st = StructValue::new();
441 st.fields.insert("alpha".into(), Value::Num(1.0));
442 let err = error_message(run_isfield(Value::Struct(st), Value::from(5_i32)).unwrap_err());
443 assert!(err.contains("field names must be strings"));
444 }
445
446 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
447 #[test]
448 fn isfield_char_matrix_errors() {
449 let mut st = StructValue::new();
450 st.fields.insert("alpha".into(), Value::Num(1.0));
451 let matrix = CharArray::new(vec!['a', 'b', 'c', 'd'], 2, 2).unwrap();
452 let err =
453 error_message(run_isfield(Value::Struct(st), Value::CharArray(matrix)).unwrap_err());
454 assert!(err.contains("field names must be strings"));
455 }
456}