Skip to main content

runmat_runtime/builtins/structs/core/
rmfield.rs

1//! MATLAB-compatible `rmfield` builtin that removes fields from structs and struct arrays.
2
3use crate::builtins::common::spec::{
4    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5    ReductionNaN, ResidencyPolicy, ShapeRequirements,
6};
7use crate::builtins::structs::type_resolvers::rmfield_type;
8use crate::{build_runtime_error, BuiltinResult, RuntimeError};
9use runmat_builtins::{CellArray, StringArray, StructValue, Value};
10use runmat_macros::runtime_builtin;
11use std::collections::HashSet;
12
13#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::structs::core::rmfield")]
14pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
15    name: "rmfield",
16    op_kind: GpuOpKind::Custom("rmfield"),
17    supported_precisions: &[],
18    broadcast: BroadcastSemantics::None,
19    provider_hooks: &[],
20    constant_strategy: ConstantStrategy::InlineLiteral,
21    residency: ResidencyPolicy::InheritInputs,
22    nan_mode: ReductionNaN::Include,
23    two_pass_threshold: None,
24    workgroup_size: None,
25    accepts_nan_mode: false,
26    notes: "Host-only struct metadata update; acceleration providers are not consulted.",
27};
28
29#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::structs::core::rmfield")]
30pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
31    name: "rmfield",
32    shape: ShapeRequirements::Any,
33    constant_strategy: ConstantStrategy::InlineLiteral,
34    elementwise: None,
35    reduction: None,
36    emits_nan: false,
37    notes: "Metadata mutation forces fusion planners to flush pending groups on the host.",
38};
39
40fn rmfield_flow(message: impl Into<String>) -> RuntimeError {
41    build_runtime_error(message).with_builtin("rmfield").build()
42}
43
44#[runtime_builtin(
45    name = "rmfield",
46    category = "structs/core",
47    summary = "Remove one or more fields from scalar structs or struct arrays.",
48    keywords = "rmfield,struct,remove field,struct array",
49    type_resolver(rmfield_type),
50    builtin_path = "crate::builtins::structs::core::rmfield"
51)]
52async fn rmfield_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
53    let names = parse_field_names(&rest)?;
54    if names.is_empty() {
55        return Ok(value);
56    }
57
58    match value {
59        Value::Struct(st) => {
60            let updated = remove_fields_from_struct_owned(st, &names)?;
61            Ok(Value::Struct(updated))
62        }
63        Value::Cell(cell) if is_struct_array(&cell) => {
64            let updated = remove_fields_from_struct_array(&cell, &names)?;
65            Ok(Value::Cell(updated))
66        }
67        other => Err(rmfield_flow(format!(
68            "rmfield: expected struct or struct array, got {other:?}"
69        ))),
70    }
71}
72
73fn parse_field_names(args: &[Value]) -> BuiltinResult<Vec<String>> {
74    if args.is_empty() {
75        return Err(rmfield_flow("rmfield: not enough input arguments"));
76    }
77    let mut names: Vec<String> = Vec::new();
78    for value in args {
79        names.extend(collect_field_names(value)?);
80    }
81    Ok(names)
82}
83
84fn collect_field_names(value: &Value) -> BuiltinResult<Vec<String>> {
85    match value {
86        Value::String(_) | Value::CharArray(_) => expect_scalar_name(value)
87            .map(|name| vec![name])
88            .map_err(|err| rmfield_flow(format!("rmfield: {}", describe_field_name_error(err)))),
89        Value::StringArray(sa) => {
90            if sa.data.len() == 1 {
91                expect_scalar_name(value)
92                    .map(|name| vec![name])
93                    .map_err(|err| {
94                        rmfield_flow(format!("rmfield: {}", describe_field_name_error(err)))
95                    })
96            } else {
97                string_array_to_names(sa)
98            }
99        }
100        Value::Cell(cell) => cell_to_names(cell),
101        other => Err(rmfield_flow(format!(
102            "rmfield: field names must be strings or character vectors (got {other:?})"
103        ))),
104    }
105}
106
107fn string_array_to_names(array: &StringArray) -> BuiltinResult<Vec<String>> {
108    let mut names = Vec::with_capacity(array.data.len());
109    for (index, name) in array.data.iter().enumerate() {
110        if name.is_empty() {
111            return Err(rmfield_flow(format!(
112                "rmfield: field names must be nonempty character vectors or strings (string array element {})",
113                index + 1
114            )));
115        }
116        names.push(name.clone());
117    }
118    Ok(names)
119}
120
121fn cell_to_names(cell: &CellArray) -> BuiltinResult<Vec<String>> {
122    let mut output = Vec::with_capacity(cell.data.len());
123    for (index, handle) in cell.data.iter().enumerate() {
124        let value = unsafe { &*handle.as_raw() };
125        let name = expect_scalar_name(value).map_err(|err| {
126            rmfield_flow(format!(
127                "rmfield: {} (cell element {})",
128                describe_field_name_error(err),
129                index + 1
130            ))
131        })?;
132        output.push(name);
133    }
134    Ok(output)
135}
136
137#[derive(Clone, Copy)]
138enum FieldNameError {
139    Type,
140    Empty,
141}
142
143fn describe_field_name_error(kind: FieldNameError) -> &'static str {
144    match kind {
145        FieldNameError::Type => {
146            "field names must be string scalars, character vectors, or single-element string arrays"
147        }
148        FieldNameError::Empty => "field names must be nonempty character vectors or strings",
149    }
150}
151
152fn expect_scalar_name(value: &Value) -> Result<String, FieldNameError> {
153    match value {
154        Value::String(s) => {
155            if s.is_empty() {
156                Err(FieldNameError::Empty)
157            } else {
158                Ok(s.clone())
159            }
160        }
161        Value::CharArray(ca) => {
162            if ca.rows != 1 {
163                return Err(FieldNameError::Type);
164            }
165            let text: String = ca.data.iter().collect();
166            if text.is_empty() {
167                Err(FieldNameError::Empty)
168            } else {
169                Ok(text)
170            }
171        }
172        Value::StringArray(sa) => {
173            if sa.data.len() != 1 {
174                return Err(FieldNameError::Type);
175            }
176            let text = sa.data[0].clone();
177            if text.is_empty() {
178                Err(FieldNameError::Empty)
179            } else {
180                Ok(text)
181            }
182        }
183        _ => Err(FieldNameError::Type),
184    }
185}
186
187fn remove_fields_from_struct_owned(
188    mut st: StructValue,
189    names: &[String],
190) -> BuiltinResult<StructValue> {
191    let mut seen: HashSet<&str> = HashSet::new();
192    for name in names {
193        if !seen.insert(name.as_str()) {
194            continue;
195        }
196        if st.remove(name).is_none() {
197            return Err(missing_field_error(name));
198        }
199    }
200    Ok(st)
201}
202
203fn remove_fields_from_struct_array(
204    array: &CellArray,
205    names: &[String],
206) -> BuiltinResult<CellArray> {
207    if array.data.is_empty() {
208        return Ok(array.clone());
209    }
210
211    let mut updated: Vec<Value> = Vec::with_capacity(array.data.len());
212    for handle in &array.data {
213        let value = unsafe { &*handle.as_raw() };
214        let Value::Struct(st) = value else {
215            return Err(rmfield_flow(
216                "rmfield: expected struct array contents to be structs",
217            ));
218        };
219        let revised = remove_fields_from_struct_owned(st.clone(), names)?;
220        updated.push(Value::Struct(revised));
221    }
222    CellArray::new_with_shape(updated, array.shape.clone())
223        .map_err(|e| rmfield_flow(format!("rmfield: failed to rebuild struct array: {e}")))
224}
225
226fn missing_field_error(name: &str) -> RuntimeError {
227    rmfield_flow(format!("Reference to non-existent field '{name}'."))
228}
229
230fn is_struct_array(cell: &CellArray) -> bool {
231    cell.data
232        .iter()
233        .all(|handle| matches!(unsafe { &*handle.as_raw() }, Value::Struct(_)))
234}
235
236#[cfg(test)]
237pub(crate) mod tests {
238    use super::*;
239    use runmat_builtins::{CellArray, CharArray, StringArray, StructValue, Value};
240
241    fn error_message(err: crate::RuntimeError) -> String {
242        err.message().to_string()
243    }
244
245    fn run_rmfield(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
246        futures::executor::block_on(rmfield_builtin(value, rest))
247    }
248    #[cfg(feature = "wgpu")]
249    use runmat_accelerate_api::HostTensorView;
250
251    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
252    #[test]
253    fn rmfield_removes_single_field_from_scalar_struct() {
254        let mut st = StructValue::new();
255        st.fields.insert("name".to_string(), Value::from("Ada"));
256        st.fields.insert("score".to_string(), Value::Num(42.0));
257        let result = run_rmfield(Value::Struct(st), vec![Value::from("score")]).expect("rmfield");
258        let Value::Struct(updated) = result else {
259            panic!("expected struct result");
260        };
261        assert!(!updated.fields.contains_key("score"));
262        assert!(updated.fields.contains_key("name"));
263    }
264
265    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
266    #[test]
267    fn rmfield_accepts_cell_array_of_field_names() {
268        let mut st = StructValue::new();
269        st.fields.insert("left".to_string(), Value::Num(1.0));
270        st.fields.insert("right".to_string(), Value::Num(2.0));
271        st.fields.insert("top".to_string(), Value::Num(3.0));
272        let cell =
273            CellArray::new(vec![Value::from("left"), Value::from("top")], 1, 2).expect("cell");
274        let result = run_rmfield(Value::Struct(st), vec![Value::Cell(cell)]).expect("rmfield");
275        let Value::Struct(updated) = result else {
276            panic!("expected struct result");
277        };
278        assert!(!updated.fields.contains_key("left"));
279        assert!(!updated.fields.contains_key("top"));
280        assert!(updated.fields.contains_key("right"));
281    }
282
283    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
284    #[test]
285    fn rmfield_supports_string_array_names() {
286        let mut st = StructValue::new();
287        st.fields.insert("alpha".to_string(), Value::Num(1.0));
288        st.fields.insert("beta".to_string(), Value::Num(2.0));
289        st.fields.insert("gamma".to_string(), Value::Num(3.0));
290        let strings = StringArray::new(vec!["alpha".into(), "gamma".into()], vec![1, 2]).unwrap();
291        let result =
292            run_rmfield(Value::Struct(st), vec![Value::StringArray(strings)]).expect("rmfield");
293        let Value::Struct(updated) = result else {
294            panic!("expected struct result");
295        };
296        assert!(!updated.fields.contains_key("alpha"));
297        assert!(!updated.fields.contains_key("gamma"));
298        assert!(updated.fields.contains_key("beta"));
299    }
300
301    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
302    #[test]
303    fn rmfield_errors_when_field_missing() {
304        let mut st = StructValue::new();
305        st.fields.insert("name".to_string(), Value::from("Ada"));
306        let err =
307            error_message(run_rmfield(Value::Struct(st), vec![Value::from("id")]).unwrap_err());
308        assert!(
309            err.contains("Reference to non-existent field 'id'."),
310            "unexpected error: {err}"
311        );
312    }
313
314    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
315    #[test]
316    fn rmfield_struct_array_roundtrip() {
317        let mut first = StructValue::new();
318        first.fields.insert("name".to_string(), Value::from("Ada"));
319        first.fields.insert("score".to_string(), Value::Num(90.0));
320
321        let mut second = StructValue::new();
322        second
323            .fields
324            .insert("name".to_string(), Value::from("Grace"));
325        second.fields.insert("score".to_string(), Value::Num(95.0));
326
327        let array = CellArray::new_with_shape(
328            vec![Value::Struct(first), Value::Struct(second)],
329            vec![1, 2],
330        )
331        .expect("struct array");
332
333        let result = run_rmfield(Value::Cell(array), vec![Value::from("score")]).expect("rmfield");
334        let Value::Cell(updated) = result else {
335            panic!("expected struct array");
336        };
337        for handle in &updated.data {
338            let value = unsafe { &*handle.as_raw() };
339            let Value::Struct(st) = value else {
340                panic!("expected struct element");
341            };
342            assert!(!st.fields.contains_key("score"));
343            assert!(st.fields.contains_key("name"));
344        }
345    }
346
347    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
348    #[test]
349    fn rmfield_struct_array_missing_field_errors() {
350        let mut first = StructValue::new();
351        first.fields.insert("id".to_string(), Value::Num(1.0));
352        let mut second = StructValue::new();
353        second.fields.insert("id".to_string(), Value::Num(2.0));
354        second.fields.insert("extra".to_string(), Value::Num(3.0));
355
356        let array = CellArray::new_with_shape(
357            vec![Value::Struct(first), Value::Struct(second)],
358            vec![1, 2],
359        )
360        .expect("struct array");
361
362        let err = error_message(
363            run_rmfield(Value::Cell(array), vec![Value::from("missing")]).unwrap_err(),
364        );
365        assert!(
366            err.contains("Reference to non-existent field 'missing'."),
367            "unexpected error: {err}"
368        );
369    }
370
371    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
372    #[test]
373    fn rmfield_rejects_non_struct_inputs() {
374        let err =
375            error_message(run_rmfield(Value::Num(1.0), vec![Value::from("field")]).unwrap_err());
376        assert!(
377            err.contains("expected struct or struct array"),
378            "unexpected error: {err}"
379        );
380    }
381
382    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
383    #[test]
384    fn rmfield_produces_error_for_empty_field_name() {
385        let mut st = StructValue::new();
386        st.fields.insert("data".to_string(), Value::Num(1.0));
387        let err = error_message(run_rmfield(Value::Struct(st), vec![Value::from("")]).unwrap_err());
388        assert!(
389            err.contains("field names must be nonempty"),
390            "unexpected error: {err}"
391        );
392    }
393
394    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
395    #[test]
396    fn rmfield_accepts_multiple_argument_forms() {
397        let mut st = StructValue::new();
398        st.fields.insert("alpha".to_string(), Value::Num(1.0));
399        st.fields.insert("beta".to_string(), Value::Num(2.0));
400        st.fields.insert("gamma".to_string(), Value::Num(3.0));
401        st.fields.insert("delta".to_string(), Value::Num(4.0));
402
403        let char_name = CharArray::new_row("beta");
404        let string_array =
405            StringArray::new(vec!["gamma".into()], vec![1, 1]).expect("string scalar array");
406        let cell = CellArray::new(vec![Value::from("delta")], 1, 1).expect("cell array of strings");
407
408        let result = run_rmfield(
409            Value::Struct(st),
410            vec![
411                Value::from("alpha"),
412                Value::CharArray(char_name),
413                Value::StringArray(string_array),
414                Value::Cell(cell),
415            ],
416        )
417        .expect("rmfield");
418
419        let Value::Struct(updated) = result else {
420            panic!("expected struct result");
421        };
422
423        assert!(updated.fields.is_empty());
424    }
425
426    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
427    #[test]
428    fn rmfield_ignores_duplicate_field_names() {
429        let mut st = StructValue::new();
430        st.fields.insert("keep".to_string(), Value::Num(1.0));
431        st.fields.insert("drop".to_string(), Value::Num(2.0));
432        let result = run_rmfield(
433            Value::Struct(st),
434            vec![Value::from("drop"), Value::from("drop")],
435        )
436        .expect("rmfield");
437        let Value::Struct(updated) = result else {
438            panic!("expected struct result");
439        };
440        assert!(!updated.fields.contains_key("drop"));
441        assert!(updated.fields.contains_key("keep"));
442    }
443
444    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
445    #[test]
446    fn rmfield_returns_original_when_no_names_supplied() {
447        let mut st = StructValue::new();
448        st.fields.insert("value".to_string(), Value::Num(10.0));
449        let empty = CellArray::new(Vec::new(), 0, 0).expect("empty cell array");
450        let original = st.clone();
451        let result =
452            run_rmfield(Value::Struct(st), vec![Value::Cell(empty)]).expect("rmfield empty");
453        assert_eq!(result, Value::Struct(original));
454    }
455
456    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
457    #[test]
458    fn rmfield_requires_field_names() {
459        let mut st = StructValue::new();
460        st.fields.insert("value".to_string(), Value::Num(10.0));
461        let err = error_message(run_rmfield(Value::Struct(st), Vec::new()).unwrap_err());
462        assert!(
463            err.contains("rmfield: not enough input arguments"),
464            "unexpected error: {err}"
465        );
466    }
467
468    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
469    #[test]
470    #[cfg(feature = "wgpu")]
471    fn rmfield_preserves_gpu_handles() {
472        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
473            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
474        );
475        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
476        let view = HostTensorView {
477            data: &[1.0, 2.0],
478            shape: &[2, 1],
479        };
480        let handle = provider.upload(&view).expect("upload");
481
482        let mut st = StructValue::new();
483        st.fields
484            .insert("gpu".to_string(), Value::GpuTensor(handle.clone()));
485        st.fields.insert("remove".to_string(), Value::Num(5.0));
486
487        let result = run_rmfield(Value::Struct(st), vec![Value::from("remove")]).expect("rmfield");
488
489        let Value::Struct(updated) = result else {
490            panic!("expected struct result");
491        };
492
493        assert!(matches!(
494            updated.fields.get("gpu"),
495            Some(Value::GpuTensor(h)) if h == &handle
496        ));
497        assert!(!updated.fields.contains_key("remove"));
498    }
499}