1use crate::builtins::common::spec::{
4 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
5 ReductionNaN, ResidencyPolicy, ShapeRequirements,
6};
7use crate::builtins::structs::type_resolvers::rmfield_type;
8use crate::{build_runtime_error, BuiltinResult, RuntimeError};
9use runmat_builtins::{CellArray, StringArray, StructValue, Value};
10use runmat_macros::runtime_builtin;
11use std::collections::HashSet;
12
13#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::structs::core::rmfield")]
14pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
15 name: "rmfield",
16 op_kind: GpuOpKind::Custom("rmfield"),
17 supported_precisions: &[],
18 broadcast: BroadcastSemantics::None,
19 provider_hooks: &[],
20 constant_strategy: ConstantStrategy::InlineLiteral,
21 residency: ResidencyPolicy::InheritInputs,
22 nan_mode: ReductionNaN::Include,
23 two_pass_threshold: None,
24 workgroup_size: None,
25 accepts_nan_mode: false,
26 notes: "Host-only struct metadata update; acceleration providers are not consulted.",
27};
28
29#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::structs::core::rmfield")]
30pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
31 name: "rmfield",
32 shape: ShapeRequirements::Any,
33 constant_strategy: ConstantStrategy::InlineLiteral,
34 elementwise: None,
35 reduction: None,
36 emits_nan: false,
37 notes: "Metadata mutation forces fusion planners to flush pending groups on the host.",
38};
39
40fn rmfield_flow(message: impl Into<String>) -> RuntimeError {
41 build_runtime_error(message).with_builtin("rmfield").build()
42}
43
44#[runtime_builtin(
45 name = "rmfield",
46 category = "structs/core",
47 summary = "Remove one or more fields from scalar structs or struct arrays.",
48 keywords = "rmfield,struct,remove field,struct array",
49 type_resolver(rmfield_type),
50 builtin_path = "crate::builtins::structs::core::rmfield"
51)]
52async fn rmfield_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
53 let names = parse_field_names(&rest)?;
54 if names.is_empty() {
55 return Ok(value);
56 }
57
58 match value {
59 Value::Struct(st) => {
60 let updated = remove_fields_from_struct_owned(st, &names)?;
61 Ok(Value::Struct(updated))
62 }
63 Value::Cell(cell) if is_struct_array(&cell) => {
64 let updated = remove_fields_from_struct_array(&cell, &names)?;
65 Ok(Value::Cell(updated))
66 }
67 other => Err(rmfield_flow(format!(
68 "rmfield: expected struct or struct array, got {other:?}"
69 ))),
70 }
71}
72
73fn parse_field_names(args: &[Value]) -> BuiltinResult<Vec<String>> {
74 if args.is_empty() {
75 return Err(rmfield_flow("rmfield: not enough input arguments"));
76 }
77 let mut names: Vec<String> = Vec::new();
78 for value in args {
79 names.extend(collect_field_names(value)?);
80 }
81 Ok(names)
82}
83
84fn collect_field_names(value: &Value) -> BuiltinResult<Vec<String>> {
85 match value {
86 Value::String(_) | Value::CharArray(_) => expect_scalar_name(value)
87 .map(|name| vec![name])
88 .map_err(|err| rmfield_flow(format!("rmfield: {}", describe_field_name_error(err)))),
89 Value::StringArray(sa) => {
90 if sa.data.len() == 1 {
91 expect_scalar_name(value)
92 .map(|name| vec![name])
93 .map_err(|err| {
94 rmfield_flow(format!("rmfield: {}", describe_field_name_error(err)))
95 })
96 } else {
97 string_array_to_names(sa)
98 }
99 }
100 Value::Cell(cell) => cell_to_names(cell),
101 other => Err(rmfield_flow(format!(
102 "rmfield: field names must be strings or character vectors (got {other:?})"
103 ))),
104 }
105}
106
107fn string_array_to_names(array: &StringArray) -> BuiltinResult<Vec<String>> {
108 let mut names = Vec::with_capacity(array.data.len());
109 for (index, name) in array.data.iter().enumerate() {
110 if name.is_empty() {
111 return Err(rmfield_flow(format!(
112 "rmfield: field names must be nonempty character vectors or strings (string array element {})",
113 index + 1
114 )));
115 }
116 names.push(name.clone());
117 }
118 Ok(names)
119}
120
121fn cell_to_names(cell: &CellArray) -> BuiltinResult<Vec<String>> {
122 let mut output = Vec::with_capacity(cell.data.len());
123 for (index, handle) in cell.data.iter().enumerate() {
124 let value = unsafe { &*handle.as_raw() };
125 let name = expect_scalar_name(value).map_err(|err| {
126 rmfield_flow(format!(
127 "rmfield: {} (cell element {})",
128 describe_field_name_error(err),
129 index + 1
130 ))
131 })?;
132 output.push(name);
133 }
134 Ok(output)
135}
136
137#[derive(Clone, Copy)]
138enum FieldNameError {
139 Type,
140 Empty,
141}
142
143fn describe_field_name_error(kind: FieldNameError) -> &'static str {
144 match kind {
145 FieldNameError::Type => {
146 "field names must be string scalars, character vectors, or single-element string arrays"
147 }
148 FieldNameError::Empty => "field names must be nonempty character vectors or strings",
149 }
150}
151
152fn expect_scalar_name(value: &Value) -> Result<String, FieldNameError> {
153 match value {
154 Value::String(s) => {
155 if s.is_empty() {
156 Err(FieldNameError::Empty)
157 } else {
158 Ok(s.clone())
159 }
160 }
161 Value::CharArray(ca) => {
162 if ca.rows != 1 {
163 return Err(FieldNameError::Type);
164 }
165 let text: String = ca.data.iter().collect();
166 if text.is_empty() {
167 Err(FieldNameError::Empty)
168 } else {
169 Ok(text)
170 }
171 }
172 Value::StringArray(sa) => {
173 if sa.data.len() != 1 {
174 return Err(FieldNameError::Type);
175 }
176 let text = sa.data[0].clone();
177 if text.is_empty() {
178 Err(FieldNameError::Empty)
179 } else {
180 Ok(text)
181 }
182 }
183 _ => Err(FieldNameError::Type),
184 }
185}
186
187fn remove_fields_from_struct_owned(
188 mut st: StructValue,
189 names: &[String],
190) -> BuiltinResult<StructValue> {
191 let mut seen: HashSet<&str> = HashSet::new();
192 for name in names {
193 if !seen.insert(name.as_str()) {
194 continue;
195 }
196 if st.remove(name).is_none() {
197 return Err(missing_field_error(name));
198 }
199 }
200 Ok(st)
201}
202
203fn remove_fields_from_struct_array(
204 array: &CellArray,
205 names: &[String],
206) -> BuiltinResult<CellArray> {
207 if array.data.is_empty() {
208 return Ok(array.clone());
209 }
210
211 let mut updated: Vec<Value> = Vec::with_capacity(array.data.len());
212 for handle in &array.data {
213 let value = unsafe { &*handle.as_raw() };
214 let Value::Struct(st) = value else {
215 return Err(rmfield_flow(
216 "rmfield: expected struct array contents to be structs",
217 ));
218 };
219 let revised = remove_fields_from_struct_owned(st.clone(), names)?;
220 updated.push(Value::Struct(revised));
221 }
222 CellArray::new_with_shape(updated, array.shape.clone())
223 .map_err(|e| rmfield_flow(format!("rmfield: failed to rebuild struct array: {e}")))
224}
225
226fn missing_field_error(name: &str) -> RuntimeError {
227 rmfield_flow(format!("Reference to non-existent field '{name}'."))
228}
229
230fn is_struct_array(cell: &CellArray) -> bool {
231 cell.data
232 .iter()
233 .all(|handle| matches!(unsafe { &*handle.as_raw() }, Value::Struct(_)))
234}
235
236#[cfg(test)]
237pub(crate) mod tests {
238 use super::*;
239 use runmat_builtins::{CellArray, CharArray, StringArray, StructValue, Value};
240
241 fn error_message(err: crate::RuntimeError) -> String {
242 err.message().to_string()
243 }
244
245 fn run_rmfield(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
246 futures::executor::block_on(rmfield_builtin(value, rest))
247 }
248 #[cfg(feature = "wgpu")]
249 use runmat_accelerate_api::HostTensorView;
250
251 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
252 #[test]
253 fn rmfield_removes_single_field_from_scalar_struct() {
254 let mut st = StructValue::new();
255 st.fields.insert("name".to_string(), Value::from("Ada"));
256 st.fields.insert("score".to_string(), Value::Num(42.0));
257 let result = run_rmfield(Value::Struct(st), vec![Value::from("score")]).expect("rmfield");
258 let Value::Struct(updated) = result else {
259 panic!("expected struct result");
260 };
261 assert!(!updated.fields.contains_key("score"));
262 assert!(updated.fields.contains_key("name"));
263 }
264
265 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
266 #[test]
267 fn rmfield_accepts_cell_array_of_field_names() {
268 let mut st = StructValue::new();
269 st.fields.insert("left".to_string(), Value::Num(1.0));
270 st.fields.insert("right".to_string(), Value::Num(2.0));
271 st.fields.insert("top".to_string(), Value::Num(3.0));
272 let cell =
273 CellArray::new(vec![Value::from("left"), Value::from("top")], 1, 2).expect("cell");
274 let result = run_rmfield(Value::Struct(st), vec![Value::Cell(cell)]).expect("rmfield");
275 let Value::Struct(updated) = result else {
276 panic!("expected struct result");
277 };
278 assert!(!updated.fields.contains_key("left"));
279 assert!(!updated.fields.contains_key("top"));
280 assert!(updated.fields.contains_key("right"));
281 }
282
283 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
284 #[test]
285 fn rmfield_supports_string_array_names() {
286 let mut st = StructValue::new();
287 st.fields.insert("alpha".to_string(), Value::Num(1.0));
288 st.fields.insert("beta".to_string(), Value::Num(2.0));
289 st.fields.insert("gamma".to_string(), Value::Num(3.0));
290 let strings = StringArray::new(vec!["alpha".into(), "gamma".into()], vec![1, 2]).unwrap();
291 let result =
292 run_rmfield(Value::Struct(st), vec![Value::StringArray(strings)]).expect("rmfield");
293 let Value::Struct(updated) = result else {
294 panic!("expected struct result");
295 };
296 assert!(!updated.fields.contains_key("alpha"));
297 assert!(!updated.fields.contains_key("gamma"));
298 assert!(updated.fields.contains_key("beta"));
299 }
300
301 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
302 #[test]
303 fn rmfield_errors_when_field_missing() {
304 let mut st = StructValue::new();
305 st.fields.insert("name".to_string(), Value::from("Ada"));
306 let err =
307 error_message(run_rmfield(Value::Struct(st), vec![Value::from("id")]).unwrap_err());
308 assert!(
309 err.contains("Reference to non-existent field 'id'."),
310 "unexpected error: {err}"
311 );
312 }
313
314 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
315 #[test]
316 fn rmfield_struct_array_roundtrip() {
317 let mut first = StructValue::new();
318 first.fields.insert("name".to_string(), Value::from("Ada"));
319 first.fields.insert("score".to_string(), Value::Num(90.0));
320
321 let mut second = StructValue::new();
322 second
323 .fields
324 .insert("name".to_string(), Value::from("Grace"));
325 second.fields.insert("score".to_string(), Value::Num(95.0));
326
327 let array = CellArray::new_with_shape(
328 vec![Value::Struct(first), Value::Struct(second)],
329 vec![1, 2],
330 )
331 .expect("struct array");
332
333 let result = run_rmfield(Value::Cell(array), vec![Value::from("score")]).expect("rmfield");
334 let Value::Cell(updated) = result else {
335 panic!("expected struct array");
336 };
337 for handle in &updated.data {
338 let value = unsafe { &*handle.as_raw() };
339 let Value::Struct(st) = value else {
340 panic!("expected struct element");
341 };
342 assert!(!st.fields.contains_key("score"));
343 assert!(st.fields.contains_key("name"));
344 }
345 }
346
347 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
348 #[test]
349 fn rmfield_struct_array_missing_field_errors() {
350 let mut first = StructValue::new();
351 first.fields.insert("id".to_string(), Value::Num(1.0));
352 let mut second = StructValue::new();
353 second.fields.insert("id".to_string(), Value::Num(2.0));
354 second.fields.insert("extra".to_string(), Value::Num(3.0));
355
356 let array = CellArray::new_with_shape(
357 vec![Value::Struct(first), Value::Struct(second)],
358 vec![1, 2],
359 )
360 .expect("struct array");
361
362 let err = error_message(
363 run_rmfield(Value::Cell(array), vec![Value::from("missing")]).unwrap_err(),
364 );
365 assert!(
366 err.contains("Reference to non-existent field 'missing'."),
367 "unexpected error: {err}"
368 );
369 }
370
371 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
372 #[test]
373 fn rmfield_rejects_non_struct_inputs() {
374 let err =
375 error_message(run_rmfield(Value::Num(1.0), vec![Value::from("field")]).unwrap_err());
376 assert!(
377 err.contains("expected struct or struct array"),
378 "unexpected error: {err}"
379 );
380 }
381
382 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
383 #[test]
384 fn rmfield_produces_error_for_empty_field_name() {
385 let mut st = StructValue::new();
386 st.fields.insert("data".to_string(), Value::Num(1.0));
387 let err = error_message(run_rmfield(Value::Struct(st), vec![Value::from("")]).unwrap_err());
388 assert!(
389 err.contains("field names must be nonempty"),
390 "unexpected error: {err}"
391 );
392 }
393
394 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
395 #[test]
396 fn rmfield_accepts_multiple_argument_forms() {
397 let mut st = StructValue::new();
398 st.fields.insert("alpha".to_string(), Value::Num(1.0));
399 st.fields.insert("beta".to_string(), Value::Num(2.0));
400 st.fields.insert("gamma".to_string(), Value::Num(3.0));
401 st.fields.insert("delta".to_string(), Value::Num(4.0));
402
403 let char_name = CharArray::new_row("beta");
404 let string_array =
405 StringArray::new(vec!["gamma".into()], vec![1, 1]).expect("string scalar array");
406 let cell = CellArray::new(vec![Value::from("delta")], 1, 1).expect("cell array of strings");
407
408 let result = run_rmfield(
409 Value::Struct(st),
410 vec![
411 Value::from("alpha"),
412 Value::CharArray(char_name),
413 Value::StringArray(string_array),
414 Value::Cell(cell),
415 ],
416 )
417 .expect("rmfield");
418
419 let Value::Struct(updated) = result else {
420 panic!("expected struct result");
421 };
422
423 assert!(updated.fields.is_empty());
424 }
425
426 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
427 #[test]
428 fn rmfield_ignores_duplicate_field_names() {
429 let mut st = StructValue::new();
430 st.fields.insert("keep".to_string(), Value::Num(1.0));
431 st.fields.insert("drop".to_string(), Value::Num(2.0));
432 let result = run_rmfield(
433 Value::Struct(st),
434 vec![Value::from("drop"), Value::from("drop")],
435 )
436 .expect("rmfield");
437 let Value::Struct(updated) = result else {
438 panic!("expected struct result");
439 };
440 assert!(!updated.fields.contains_key("drop"));
441 assert!(updated.fields.contains_key("keep"));
442 }
443
444 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
445 #[test]
446 fn rmfield_returns_original_when_no_names_supplied() {
447 let mut st = StructValue::new();
448 st.fields.insert("value".to_string(), Value::Num(10.0));
449 let empty = CellArray::new(Vec::new(), 0, 0).expect("empty cell array");
450 let original = st.clone();
451 let result =
452 run_rmfield(Value::Struct(st), vec![Value::Cell(empty)]).expect("rmfield empty");
453 assert_eq!(result, Value::Struct(original));
454 }
455
456 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
457 #[test]
458 fn rmfield_requires_field_names() {
459 let mut st = StructValue::new();
460 st.fields.insert("value".to_string(), Value::Num(10.0));
461 let err = error_message(run_rmfield(Value::Struct(st), Vec::new()).unwrap_err());
462 assert!(
463 err.contains("rmfield: not enough input arguments"),
464 "unexpected error: {err}"
465 );
466 }
467
468 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
469 #[test]
470 #[cfg(feature = "wgpu")]
471 fn rmfield_preserves_gpu_handles() {
472 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
473 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
474 );
475 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
476 let view = HostTensorView {
477 data: &[1.0, 2.0],
478 shape: &[2, 1],
479 };
480 let handle = provider.upload(&view).expect("upload");
481
482 let mut st = StructValue::new();
483 st.fields
484 .insert("gpu".to_string(), Value::GpuTensor(handle.clone()));
485 st.fields.insert("remove".to_string(), Value::Num(5.0));
486
487 let result = run_rmfield(Value::Struct(st), vec![Value::from("remove")]).expect("rmfield");
488
489 let Value::Struct(updated) = result else {
490 panic!("expected struct result");
491 };
492
493 assert!(matches!(
494 updated.fields.get("gpu"),
495 Some(Value::GpuTensor(h)) if h == &handle
496 ));
497 assert!(!updated.fields.contains_key("remove"));
498 }
499}