Skip to main content

runmat_runtime/builtins/structs/core/
isfield.rs

1//! MATLAB-compatible `isfield` builtin that reports whether structs contain a field.
2
3use crate::builtins::common::spec::{
4    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5    ReductionNaN, ResidencyPolicy, ShapeRequirements,
6};
7use crate::builtins::structs::type_resolvers::isfield_type;
8use runmat_builtins::{CellArray, LogicalArray, StructValue, Value};
9use runmat_macros::runtime_builtin;
10use std::collections::HashSet;
11
12use crate::{build_runtime_error, BuiltinResult, RuntimeError};
13
14#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::structs::core::isfield")]
15pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
16    name: "isfield",
17    op_kind: GpuOpKind::Custom("isfield"),
18    supported_precisions: &[],
19    broadcast: BroadcastSemantics::None,
20    provider_hooks: &[],
21    constant_strategy: ConstantStrategy::InlineLiteral,
22    residency: ResidencyPolicy::InheritInputs,
23    nan_mode: ReductionNaN::Include,
24    two_pass_threshold: None,
25    workgroup_size: None,
26    accepts_nan_mode: false,
27    notes: "Host-only metadata check; acceleration providers do not participate.",
28};
29
30#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::structs::core::isfield")]
31pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
32    name: "isfield",
33    shape: ShapeRequirements::Any,
34    constant_strategy: ConstantStrategy::InlineLiteral,
35    elementwise: None,
36    reduction: None,
37    emits_nan: false,
38    notes: "Acts as a fusion barrier because it inspects struct metadata on the host.",
39};
40
41fn isfield_flow(message: impl Into<String>) -> RuntimeError {
42    build_runtime_error(message).with_builtin("isfield").build()
43}
44
45#[runtime_builtin(
46    name = "isfield",
47    category = "structs/core",
48    summary = "Test whether a struct or struct array defines specific field names.",
49    keywords = "isfield,struct,field existence",
50    type_resolver(isfield_type),
51    builtin_path = "crate::builtins::structs::core::isfield"
52)]
53async fn isfield_builtin(target: Value, names: Value) -> BuiltinResult<Value> {
54    let context = classify_struct(&target)?;
55    let parsed = parse_field_names(names)?;
56    match context {
57        StructContext::Struct(struct_value) => evaluate_scalar(struct_value, parsed),
58        StructContext::StructArray(cell) => evaluate_struct_array(cell, parsed),
59        StructContext::NonStruct => evaluate_non_struct(parsed),
60    }
61}
62
63#[derive(Clone, Copy)]
64enum StructContext<'a> {
65    Struct(&'a StructValue),
66    StructArray(&'a CellArray),
67    NonStruct,
68}
69
70fn classify_struct<'a>(value: &'a Value) -> BuiltinResult<StructContext<'a>> {
71    match value {
72        Value::Struct(st) => Ok(StructContext::Struct(st)),
73        Value::Cell(cell) => {
74            if cell.data.is_empty() {
75                return Ok(StructContext::StructArray(cell));
76            }
77            if cell
78                .data
79                .iter()
80                .all(|handle| matches!(unsafe { &*handle.as_raw() }, Value::Struct(_)))
81            {
82                Ok(StructContext::StructArray(cell))
83            } else {
84                Ok(StructContext::NonStruct)
85            }
86        }
87        _ => Ok(StructContext::NonStruct),
88    }
89}
90
91enum ParsedNames {
92    Scalar(String),
93    Array {
94        names: Vec<String>,
95        shape: Vec<usize>,
96    },
97}
98
99fn parse_field_names(names: Value) -> BuiltinResult<ParsedNames> {
100    match names {
101        Value::String(s) => Ok(ParsedNames::Scalar(s)),
102        Value::CharArray(ca) => {
103            if ca.rows == 1 {
104                Ok(ParsedNames::Scalar(ca.data.iter().collect()))
105            } else {
106                Err(field_name_type_error())
107            }
108        }
109        Value::StringArray(sa) => Ok(ParsedNames::Array {
110            names: sa.data.clone(),
111            shape: sa.shape.clone(),
112        }),
113        Value::Cell(cell) => Ok(ParsedNames::Array {
114            names: collect_cell_names(&cell)?,
115            shape: if cell.shape.is_empty() {
116                vec![cell.rows, cell.cols]
117            } else {
118                cell.shape.clone()
119            },
120        }),
121        other => match try_single_field_name(&other)? {
122            Some(name) => Ok(ParsedNames::Scalar(name)),
123            None => Err(field_name_type_error()),
124        },
125    }
126}
127
128fn try_single_field_name(value: &Value) -> BuiltinResult<Option<String>> {
129    match value {
130        Value::String(s) => Ok(Some(s.clone())),
131        Value::CharArray(ca) => {
132            if ca.rows == 1 {
133                Ok(Some(ca.data.iter().collect()))
134            } else {
135                Err(field_name_type_error())
136            }
137        }
138        Value::StringArray(sa) => {
139            if sa.data.len() == 1 {
140                Ok(Some(sa.data[0].clone()))
141            } else {
142                Err(field_name_type_error())
143            }
144        }
145        _ => Ok(None),
146    }
147}
148
149fn evaluate_scalar(struct_value: &StructValue, names: ParsedNames) -> BuiltinResult<Value> {
150    match names {
151        ParsedNames::Scalar(name) => Ok(Value::Bool(struct_value.fields.contains_key(&name))),
152        ParsedNames::Array { names, shape } => {
153            let mut bits = Vec::with_capacity(names.len());
154            for name in names {
155                bits.push(if struct_value.fields.contains_key(&name) {
156                    1
157                } else {
158                    0
159                });
160            }
161            let logical = LogicalArray::new(bits, shape)
162                .map_err(|e| isfield_flow(format!("isfield: {e}")))?;
163            Ok(Value::LogicalArray(logical))
164        }
165    }
166}
167
168fn evaluate_struct_array(cell: &CellArray, names: ParsedNames) -> BuiltinResult<Value> {
169    let fields = struct_array_field_intersection(cell)?;
170    match names {
171        ParsedNames::Scalar(name) => Ok(Value::Bool(fields.contains(&name))),
172        ParsedNames::Array { names, shape } => {
173            let mut bits = Vec::with_capacity(names.len());
174            for name in names {
175                bits.push(if fields.contains(&name) { 1 } else { 0 });
176            }
177            let logical = LogicalArray::new(bits, shape)
178                .map_err(|e| isfield_flow(format!("isfield: {e}")))?;
179            Ok(Value::LogicalArray(logical))
180        }
181    }
182}
183
184fn evaluate_non_struct(names: ParsedNames) -> BuiltinResult<Value> {
185    match names {
186        ParsedNames::Scalar(_) => Ok(Value::Bool(false)),
187        ParsedNames::Array { names, shape } => {
188            let logical = LogicalArray::new(vec![0; names.len()], shape)
189                .map_err(|e| isfield_flow(format!("isfield: {e}")))?;
190            Ok(Value::LogicalArray(logical))
191        }
192    }
193}
194
195fn struct_array_field_intersection(cell: &CellArray) -> BuiltinResult<HashSet<String>> {
196    if cell.data.is_empty() {
197        return Ok(HashSet::new());
198    }
199
200    let mut iter = cell.data.iter();
201    let first = unsafe { &*iter.next().unwrap().as_raw() };
202    let Value::Struct(first_struct) = first else {
203        return Err(isfield_flow(
204            "isfield: struct array elements must be structs",
205        ));
206    };
207    let mut fields: HashSet<String> = first_struct.fields.keys().cloned().collect();
208
209    for handle in iter {
210        let value = unsafe { &*handle.as_raw() };
211        let Value::Struct(struct_value) = value else {
212            return Err(isfield_flow(
213                "isfield: struct array elements must be structs",
214            ));
215        };
216        fields.retain(|name| struct_value.fields.contains_key(name));
217        if fields.is_empty() {
218            break;
219        }
220    }
221
222    Ok(fields)
223}
224
225fn collect_cell_names(cell: &CellArray) -> BuiltinResult<Vec<String>> {
226    let total = cell.data.len();
227    if total == 0 {
228        return Ok(Vec::new());
229    }
230
231    let shape = if cell.shape.is_empty() {
232        vec![cell.rows, cell.cols]
233    } else {
234        cell.shape.clone()
235    };
236
237    let mut names = Vec::with_capacity(total);
238    let row_strides = row_major_strides(&shape);
239    for idx in 0..total {
240        let coords = column_major_coordinates(idx, &shape);
241        let mut row_index = 0usize;
242        for (coord, stride) in coords.iter().zip(row_strides.iter()) {
243            row_index += coord * stride;
244        }
245        let value = unsafe { &*cell.data[row_index].as_raw() };
246        names.push(value_to_field_name(value)?);
247    }
248    Ok(names)
249}
250
251fn row_major_strides(shape: &[usize]) -> Vec<usize> {
252    if shape.is_empty() {
253        return Vec::new();
254    }
255    let mut strides = vec![0; shape.len()];
256    let mut stride = 1usize;
257    for (i, dim) in shape.iter().enumerate().rev() {
258        strides[i] = stride;
259        stride = stride.saturating_mul(*dim.max(&1));
260    }
261    strides
262}
263
264fn column_major_coordinates(mut index: usize, shape: &[usize]) -> Vec<usize> {
265    if shape.is_empty() {
266        return Vec::new();
267    }
268    let mut coords = vec![0usize; shape.len()];
269    for (i, dim) in shape.iter().enumerate() {
270        if *dim == 0 {
271            coords[i] = 0;
272            continue;
273        }
274        coords[i] = index % dim;
275        index /= dim;
276    }
277    coords
278}
279
280fn value_to_field_name(value: &Value) -> BuiltinResult<String> {
281    match value {
282        Value::String(s) => Ok(s.clone()),
283        Value::CharArray(ca) => {
284            if ca.rows == 1 {
285                Ok(ca.data.iter().collect())
286            } else {
287                Err(field_name_type_error())
288            }
289        }
290        Value::StringArray(sa) => {
291            if sa.data.len() == 1 {
292                Ok(sa.data[0].clone())
293            } else {
294                Err(field_name_type_error())
295            }
296        }
297        other => Err(isfield_flow(format!(
298            "isfield: cell array elements must be character vectors or strings (got {other:?})"
299        ))),
300    }
301}
302
303fn field_name_type_error() -> RuntimeError {
304    isfield_flow(
305        "isfield: field names must be strings, string arrays, or cell arrays of character vectors",
306    )
307}
308
309#[cfg(test)]
310pub(crate) mod tests {
311    use super::*;
312    use runmat_builtins::{CellArray, CharArray, StringArray, StructValue};
313
314    fn error_message(err: crate::RuntimeError) -> String {
315        err.message().to_string()
316    }
317
318    fn run_isfield(target: Value, names: Value) -> BuiltinResult<Value> {
319        futures::executor::block_on(isfield_builtin(target, names))
320    }
321
322    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
323    #[test]
324    fn isfield_scalar_struct_single_name() {
325        let mut st = StructValue::new();
326        st.fields.insert("name".to_string(), Value::from("Ada"));
327        assert_eq!(
328            run_isfield(Value::Struct(st.clone()), Value::from("name")).unwrap(),
329            Value::Bool(true)
330        );
331        assert_eq!(
332            run_isfield(Value::Struct(st), Value::from("score")).unwrap(),
333            Value::Bool(false)
334        );
335    }
336
337    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
338    #[test]
339    fn isfield_char_array_single_row() {
340        let mut st = StructValue::new();
341        st.fields.insert("alpha".into(), Value::Num(1.0));
342        let chars = CharArray::new("alpha".chars().collect(), 1, 5).unwrap();
343        let result = run_isfield(Value::Struct(st), Value::CharArray(chars)).unwrap();
344        assert_eq!(result, Value::Bool(true));
345    }
346
347    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
348    #[test]
349    fn isfield_struct_cell_names_produces_logical_array() {
350        let mut st = StructValue::new();
351        st.fields.insert("name".to_string(), Value::from("Ada"));
352        st.fields.insert("score".to_string(), Value::from(42.0));
353        let names = CellArray::new(
354            vec![
355                Value::from("name"),
356                Value::from("department"),
357                Value::from("score"),
358                Value::from("email"),
359            ],
360            2,
361            2,
362        )
363        .unwrap();
364        let result = run_isfield(Value::Struct(st), Value::Cell(names)).expect("isfield");
365        let expected = LogicalArray::new(vec![1, 1, 0, 0], vec![2, 2]).expect("logical array");
366        assert_eq!(result, Value::LogicalArray(expected));
367    }
368
369    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
370    #[test]
371    fn isfield_cell_mixed_string_types() {
372        let mut st = StructValue::new();
373        st.fields.insert("name".into(), Value::from("Ada"));
374        st.fields.insert("id".into(), Value::from(7.0));
375        let id_chars = CharArray::new("id".chars().collect(), 1, 2).unwrap();
376        let cell = CellArray::new(
377            vec![
378                Value::from("name"),
379                Value::CharArray(id_chars),
380                Value::from("department"),
381            ],
382            1,
383            3,
384        )
385        .unwrap();
386        let result = run_isfield(Value::Struct(st), Value::Cell(cell)).unwrap();
387        let expected = LogicalArray::new(vec![1, 1, 0], vec![1, 3]).unwrap();
388        assert_eq!(result, Value::LogicalArray(expected));
389    }
390
391    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
392    #[test]
393    fn isfield_struct_array_intersection() {
394        let mut first = StructValue::new();
395        first.fields.insert("name".to_string(), Value::from("Ada"));
396        first.fields.insert("id".to_string(), Value::from(101.0));
397
398        let mut second = StructValue::new();
399        second
400            .fields
401            .insert("name".to_string(), Value::from("Grace"));
402
403        let struct_array = CellArray::new_with_shape(
404            vec![Value::Struct(first), Value::Struct(second)],
405            vec![1, 2],
406        )
407        .unwrap();
408
409        let res_id =
410            run_isfield(Value::Cell(struct_array.clone()), Value::from("id")).expect("isfield");
411        assert_eq!(res_id, Value::Bool(false));
412
413        let res_name =
414            run_isfield(Value::Cell(struct_array), Value::from("name")).expect("isfield");
415        assert_eq!(res_name, Value::Bool(true));
416    }
417
418    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
419    #[test]
420    fn isfield_non_struct_returns_false() {
421        let result = run_isfield(Value::Num(5.0), Value::from("field")).unwrap();
422        assert_eq!(result, Value::Bool(false));
423    }
424
425    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
426    #[test]
427    fn isfield_string_array_names() {
428        let mut st = StructValue::new();
429        st.fields.insert("alpha".into(), Value::Num(1.0));
430        st.fields.insert("beta".into(), Value::Num(2.0));
431        let names = StringArray::new(vec!["alpha".into(), "gamma".into()], vec![2, 1]).unwrap();
432        let result = run_isfield(Value::Struct(st), Value::StringArray(names)).unwrap();
433        let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).expect("logical array");
434        assert_eq!(result, Value::LogicalArray(expected));
435    }
436
437    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
438    #[test]
439    fn isfield_invalid_name_type_errors() {
440        let mut st = StructValue::new();
441        st.fields.insert("alpha".into(), Value::Num(1.0));
442        let err = error_message(run_isfield(Value::Struct(st), Value::from(5_i32)).unwrap_err());
443        assert!(err.contains("field names must be strings"));
444    }
445
446    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
447    #[test]
448    fn isfield_char_matrix_errors() {
449        let mut st = StructValue::new();
450        st.fields.insert("alpha".into(), Value::Num(1.0));
451        let matrix = CharArray::new(vec!['a', 'b', 'c', 'd'], 2, 2).unwrap();
452        let err =
453            error_message(run_isfield(Value::Struct(st), Value::CharArray(matrix)).unwrap_err());
454        assert!(err.contains("field names must be strings"));
455    }
456}