use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
ReductionNaN, ResidencyPolicy, ShapeRequirements,
};
use crate::builtins::structs::type_resolvers::rmfield_type;
use crate::{build_runtime_error, BuiltinResult, RuntimeError};
use runmat_builtins::{
BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
CellArray, StringArray, StructValue, Value,
};
use runmat_macros::runtime_builtin;
use std::collections::HashSet;
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::structs::core::rmfield")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: "rmfield",
op_kind: GpuOpKind::Custom("rmfield"),
supported_precisions: &[],
broadcast: BroadcastSemantics::None,
provider_hooks: &[],
constant_strategy: ConstantStrategy::InlineLiteral,
residency: ResidencyPolicy::InheritInputs,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes: "Host-only struct metadata update; acceleration providers are not consulted.",
};
#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::structs::core::rmfield")]
pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
name: "rmfield",
shape: ShapeRequirements::Any,
constant_strategy: ConstantStrategy::InlineLiteral,
elementwise: None,
reduction: None,
emits_nan: false,
notes: "Metadata mutation forces fusion planners to flush pending groups on the host.",
};
const BUILTIN_NAME: &str = "rmfield";
const RMFIELD_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
name: "S",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Updated struct or struct array.",
}];
const RMFIELD_INPUTS_SCALAR: [BuiltinParamDescriptor; 2] = [
BuiltinParamDescriptor {
name: "S",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Input struct or struct array.",
},
BuiltinParamDescriptor {
name: "field",
ty: BuiltinParamType::StringScalar,
arity: BuiltinParamArity::Required,
default: None,
description: "Field name to remove.",
},
];
const RMFIELD_INPUTS_COLLECTION: [BuiltinParamDescriptor; 2] = [
BuiltinParamDescriptor {
name: "S",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Input struct or struct array.",
},
BuiltinParamDescriptor {
name: "fields",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "String array or cell array of field names.",
},
];
const RMFIELD_INPUTS_VARIADIC: [BuiltinParamDescriptor; 3] = [
BuiltinParamDescriptor {
name: "S",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Required,
default: None,
description: "Input struct or struct array.",
},
BuiltinParamDescriptor {
name: "field",
ty: BuiltinParamType::StringScalar,
arity: BuiltinParamArity::Required,
default: None,
description: "First field name to remove.",
},
BuiltinParamDescriptor {
name: "more_fields",
ty: BuiltinParamType::Any,
arity: BuiltinParamArity::Variadic,
default: None,
description: "Additional field names.",
},
];
const RMFIELD_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
BuiltinSignatureDescriptor {
label: "S = rmfield(S, field)",
inputs: &RMFIELD_INPUTS_SCALAR,
outputs: &RMFIELD_OUTPUT,
},
BuiltinSignatureDescriptor {
label: "S = rmfield(S, fields)",
inputs: &RMFIELD_INPUTS_COLLECTION,
outputs: &RMFIELD_OUTPUT,
},
BuiltinSignatureDescriptor {
label: "S = rmfield(S, field, ...)",
inputs: &RMFIELD_INPUTS_VARIADIC,
outputs: &RMFIELD_OUTPUT,
},
];
const RMFIELD_ERROR_NOT_ENOUGH_INPUTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.RMFIELD.NOT_ENOUGH_INPUTS",
identifier: Some("RunMat:rmfield:NotEnoughInputs"),
when: "No field-name arguments are supplied.",
message: "rmfield: not enough input arguments",
};
const RMFIELD_ERROR_INVALID_TARGET: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.RMFIELD.INVALID_TARGET",
identifier: Some("RunMat:rmfield:InvalidTarget"),
when: "First input is not a struct or struct array.",
message: "rmfield: expected struct or struct array",
};
const RMFIELD_ERROR_FIELD_NAME_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.RMFIELD.FIELD_NAME_TYPE",
identifier: Some("RunMat:rmfield:FieldNameType"),
when: "Field-name argument has unsupported type or non-scalar shape.",
message: "rmfield: field names must be string scalars, character vectors, or single-element string arrays",
};
const RMFIELD_ERROR_FIELD_NAME_EMPTY: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.RMFIELD.FIELD_NAME_EMPTY",
identifier: Some("RunMat:rmfield:FieldNameEmpty"),
when: "A field name is empty.",
message: "rmfield: field names must be nonempty character vectors or strings",
};
const RMFIELD_ERROR_MISSING_FIELD: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.RMFIELD.MISSING_FIELD",
identifier: Some("RunMat:rmfield:MissingField"),
when: "At least one requested field does not exist on the input struct.",
message: "Reference to non-existent field",
};
const RMFIELD_ERROR_STRUCT_ARRAY_CONTENTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.RMFIELD.STRUCT_ARRAY_CONTENTS",
identifier: Some("RunMat:rmfield:StructArrayContents"),
when: "Struct-array input contains non-struct elements.",
message: "rmfield: expected struct array contents to be structs",
};
const RMFIELD_ERROR_REBUILD_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.RMFIELD.REBUILD_FAILED",
identifier: Some("RunMat:rmfield:RebuildFailed"),
when: "Rebuilding the updated struct array failed.",
message: "rmfield: failed to rebuild struct array",
};
const RMFIELD_ERRORS: [BuiltinErrorDescriptor; 7] = [
RMFIELD_ERROR_NOT_ENOUGH_INPUTS,
RMFIELD_ERROR_INVALID_TARGET,
RMFIELD_ERROR_FIELD_NAME_TYPE,
RMFIELD_ERROR_FIELD_NAME_EMPTY,
RMFIELD_ERROR_MISSING_FIELD,
RMFIELD_ERROR_STRUCT_ARRAY_CONTENTS,
RMFIELD_ERROR_REBUILD_FAILED,
];
pub const RMFIELD_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
signatures: &RMFIELD_SIGNATURES,
output_mode: BuiltinOutputMode::Fixed,
completion_policy: BuiltinCompletionPolicy::Public,
errors: &RMFIELD_ERRORS,
};
fn rmfield_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
rmfield_error_with_message(error.message, error)
}
fn rmfield_error_with_message(
message: impl Into<String>,
error: &'static BuiltinErrorDescriptor,
) -> RuntimeError {
let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
if let Some(identifier) = error.identifier {
builder = builder.with_identifier(identifier);
}
builder.build()
}
#[runtime_builtin(
name = "rmfield",
category = "structs/core",
summary = "Remove one or more named fields from structs or struct arrays.",
keywords = "rmfield,struct,remove field,struct array",
type_resolver(rmfield_type),
descriptor(crate::builtins::structs::core::rmfield::RMFIELD_DESCRIPTOR),
builtin_path = "crate::builtins::structs::core::rmfield"
)]
async fn rmfield_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
let names = parse_field_names(&rest)?;
if names.is_empty() {
return Ok(value);
}
match value {
Value::Struct(st) => {
let updated = remove_fields_from_struct_owned(st, &names)?;
Ok(Value::Struct(updated))
}
Value::Cell(cell) if is_struct_array(&cell) => {
let updated = remove_fields_from_struct_array(&cell, &names)?;
Ok(Value::Cell(updated))
}
other => Err(rmfield_error_with_message(
format!("{} (got {other:?})", RMFIELD_ERROR_INVALID_TARGET.message),
&RMFIELD_ERROR_INVALID_TARGET,
)),
}
}
fn parse_field_names(args: &[Value]) -> BuiltinResult<Vec<String>> {
if args.is_empty() {
return Err(rmfield_error(&RMFIELD_ERROR_NOT_ENOUGH_INPUTS));
}
let mut names: Vec<String> = Vec::new();
for value in args {
names.extend(collect_field_names(value)?);
}
Ok(names)
}
fn collect_field_names(value: &Value) -> BuiltinResult<Vec<String>> {
match value {
Value::String(_) | Value::CharArray(_) => expect_scalar_name(value)
.map(|name| vec![name])
.map_err(|err| field_name_error(err, None)),
Value::StringArray(sa) => {
if sa.data.len() == 1 {
expect_scalar_name(value)
.map(|name| vec![name])
.map_err(|err| field_name_error(err, None))
} else {
string_array_to_names(sa)
}
}
Value::Cell(cell) => cell_to_names(cell),
other => Err(rmfield_error_with_message(
format!("{} (got {other:?})", RMFIELD_ERROR_FIELD_NAME_TYPE.message),
&RMFIELD_ERROR_FIELD_NAME_TYPE,
)),
}
}
fn string_array_to_names(array: &StringArray) -> BuiltinResult<Vec<String>> {
let mut names = Vec::with_capacity(array.data.len());
for (index, name) in array.data.iter().enumerate() {
if name.is_empty() {
return Err(rmfield_error_with_message(
format!(
"{} (string array element {})",
RMFIELD_ERROR_FIELD_NAME_EMPTY.message,
index + 1
),
&RMFIELD_ERROR_FIELD_NAME_EMPTY,
));
}
names.push(name.clone());
}
Ok(names)
}
fn cell_to_names(cell: &CellArray) -> BuiltinResult<Vec<String>> {
let mut output = Vec::with_capacity(cell.data.len());
for (index, handle) in cell.data.iter().enumerate() {
let value = unsafe { &*handle.as_raw() };
let name =
expect_scalar_name(value).map_err(|err| field_name_error(err, Some(index + 1)))?;
output.push(name);
}
Ok(output)
}
#[derive(Clone, Copy)]
enum FieldNameError {
Type,
Empty,
}
fn describe_field_name_error(kind: FieldNameError) -> &'static str {
match kind {
FieldNameError::Type => RMFIELD_ERROR_FIELD_NAME_TYPE.message,
FieldNameError::Empty => RMFIELD_ERROR_FIELD_NAME_EMPTY.message,
}
}
fn field_name_error(kind: FieldNameError, cell_index: Option<usize>) -> RuntimeError {
let descriptor = match kind {
FieldNameError::Type => &RMFIELD_ERROR_FIELD_NAME_TYPE,
FieldNameError::Empty => &RMFIELD_ERROR_FIELD_NAME_EMPTY,
};
let mut message = String::from(describe_field_name_error(kind));
if let Some(index) = cell_index {
message.push_str(&format!(" (cell element {index})"));
}
rmfield_error_with_message(message, descriptor)
}
fn expect_scalar_name(value: &Value) -> Result<String, FieldNameError> {
match value {
Value::String(s) => {
if s.is_empty() {
Err(FieldNameError::Empty)
} else {
Ok(s.clone())
}
}
Value::CharArray(ca) => {
if ca.rows != 1 {
return Err(FieldNameError::Type);
}
let text: String = ca.data.iter().collect();
if text.is_empty() {
Err(FieldNameError::Empty)
} else {
Ok(text)
}
}
Value::StringArray(sa) => {
if sa.data.len() != 1 {
return Err(FieldNameError::Type);
}
let text = sa.data[0].clone();
if text.is_empty() {
Err(FieldNameError::Empty)
} else {
Ok(text)
}
}
_ => Err(FieldNameError::Type),
}
}
fn remove_fields_from_struct_owned(
mut st: StructValue,
names: &[String],
) -> BuiltinResult<StructValue> {
let mut seen: HashSet<&str> = HashSet::new();
for name in names {
if !seen.insert(name.as_str()) {
continue;
}
if st.remove(name).is_none() {
return Err(missing_field_error(name));
}
}
Ok(st)
}
fn remove_fields_from_struct_array(
array: &CellArray,
names: &[String],
) -> BuiltinResult<CellArray> {
if array.data.is_empty() {
return Ok(array.clone());
}
let mut updated: Vec<Value> = Vec::with_capacity(array.data.len());
for handle in &array.data {
let value = unsafe { &*handle.as_raw() };
let Value::Struct(st) = value else {
return Err(rmfield_error(&RMFIELD_ERROR_STRUCT_ARRAY_CONTENTS));
};
let revised = remove_fields_from_struct_owned(st.clone(), names)?;
updated.push(Value::Struct(revised));
}
CellArray::new_with_shape(updated, array.shape.clone()).map_err(|e| {
rmfield_error_with_message(
format!("{}: {e}", RMFIELD_ERROR_REBUILD_FAILED.message),
&RMFIELD_ERROR_REBUILD_FAILED,
)
})
}
fn missing_field_error(name: &str) -> RuntimeError {
rmfield_error_with_message(
format!("{} '{name}'.", RMFIELD_ERROR_MISSING_FIELD.message),
&RMFIELD_ERROR_MISSING_FIELD,
)
}
fn is_struct_array(cell: &CellArray) -> bool {
cell.data
.iter()
.all(|handle| matches!(unsafe { &*handle.as_raw() }, Value::Struct(_)))
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use runmat_builtins::{CellArray, CharArray, StringArray, StructValue, Value};
fn error_message(err: crate::RuntimeError) -> String {
err.message().to_string()
}
fn run_rmfield(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
futures::executor::block_on(rmfield_builtin(value, rest))
}
#[cfg(feature = "wgpu")]
use runmat_accelerate_api::HostTensorView;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_removes_single_field_from_scalar_struct() {
let mut st = StructValue::new();
st.fields.insert("name".to_string(), Value::from("Ada"));
st.fields.insert("score".to_string(), Value::Num(42.0));
let result = run_rmfield(Value::Struct(st), vec![Value::from("score")]).expect("rmfield");
let Value::Struct(updated) = result else {
panic!("expected struct result");
};
assert!(!updated.fields.contains_key("score"));
assert!(updated.fields.contains_key("name"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_accepts_cell_array_of_field_names() {
let mut st = StructValue::new();
st.fields.insert("left".to_string(), Value::Num(1.0));
st.fields.insert("right".to_string(), Value::Num(2.0));
st.fields.insert("top".to_string(), Value::Num(3.0));
let cell =
CellArray::new(vec![Value::from("left"), Value::from("top")], 1, 2).expect("cell");
let result = run_rmfield(Value::Struct(st), vec![Value::Cell(cell)]).expect("rmfield");
let Value::Struct(updated) = result else {
panic!("expected struct result");
};
assert!(!updated.fields.contains_key("left"));
assert!(!updated.fields.contains_key("top"));
assert!(updated.fields.contains_key("right"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_supports_string_array_names() {
let mut st = StructValue::new();
st.fields.insert("alpha".to_string(), Value::Num(1.0));
st.fields.insert("beta".to_string(), Value::Num(2.0));
st.fields.insert("gamma".to_string(), Value::Num(3.0));
let strings = StringArray::new(vec!["alpha".into(), "gamma".into()], vec![1, 2]).unwrap();
let result =
run_rmfield(Value::Struct(st), vec![Value::StringArray(strings)]).expect("rmfield");
let Value::Struct(updated) = result else {
panic!("expected struct result");
};
assert!(!updated.fields.contains_key("alpha"));
assert!(!updated.fields.contains_key("gamma"));
assert!(updated.fields.contains_key("beta"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_errors_when_field_missing() {
let mut st = StructValue::new();
st.fields.insert("name".to_string(), Value::from("Ada"));
let err =
error_message(run_rmfield(Value::Struct(st), vec![Value::from("id")]).unwrap_err());
assert!(
err.contains("Reference to non-existent field 'id'."),
"unexpected error: {err}"
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_struct_array_roundtrip() {
let mut first = StructValue::new();
first.fields.insert("name".to_string(), Value::from("Ada"));
first.fields.insert("score".to_string(), Value::Num(90.0));
let mut second = StructValue::new();
second
.fields
.insert("name".to_string(), Value::from("Grace"));
second.fields.insert("score".to_string(), Value::Num(95.0));
let array = CellArray::new_with_shape(
vec![Value::Struct(first), Value::Struct(second)],
vec![1, 2],
)
.expect("struct array");
let result = run_rmfield(Value::Cell(array), vec![Value::from("score")]).expect("rmfield");
let Value::Cell(updated) = result else {
panic!("expected struct array");
};
for handle in &updated.data {
let value = unsafe { &*handle.as_raw() };
let Value::Struct(st) = value else {
panic!("expected struct element");
};
assert!(!st.fields.contains_key("score"));
assert!(st.fields.contains_key("name"));
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_struct_array_missing_field_errors() {
let mut first = StructValue::new();
first.fields.insert("id".to_string(), Value::Num(1.0));
let mut second = StructValue::new();
second.fields.insert("id".to_string(), Value::Num(2.0));
second.fields.insert("extra".to_string(), Value::Num(3.0));
let array = CellArray::new_with_shape(
vec![Value::Struct(first), Value::Struct(second)],
vec![1, 2],
)
.expect("struct array");
let err = error_message(
run_rmfield(Value::Cell(array), vec![Value::from("missing")]).unwrap_err(),
);
assert!(
err.contains("Reference to non-existent field 'missing'."),
"unexpected error: {err}"
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_rejects_non_struct_inputs() {
let err =
error_message(run_rmfield(Value::Num(1.0), vec![Value::from("field")]).unwrap_err());
assert!(
err.contains("expected struct or struct array"),
"unexpected error: {err}"
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_produces_error_for_empty_field_name() {
let mut st = StructValue::new();
st.fields.insert("data".to_string(), Value::Num(1.0));
let err = error_message(run_rmfield(Value::Struct(st), vec![Value::from("")]).unwrap_err());
assert!(
err.contains("field names must be nonempty"),
"unexpected error: {err}"
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_accepts_multiple_argument_forms() {
let mut st = StructValue::new();
st.fields.insert("alpha".to_string(), Value::Num(1.0));
st.fields.insert("beta".to_string(), Value::Num(2.0));
st.fields.insert("gamma".to_string(), Value::Num(3.0));
st.fields.insert("delta".to_string(), Value::Num(4.0));
let char_name = CharArray::new_row("beta");
let string_array =
StringArray::new(vec!["gamma".into()], vec![1, 1]).expect("string scalar array");
let cell = CellArray::new(vec![Value::from("delta")], 1, 1).expect("cell array of strings");
let result = run_rmfield(
Value::Struct(st),
vec![
Value::from("alpha"),
Value::CharArray(char_name),
Value::StringArray(string_array),
Value::Cell(cell),
],
)
.expect("rmfield");
let Value::Struct(updated) = result else {
panic!("expected struct result");
};
assert!(updated.fields.is_empty());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_ignores_duplicate_field_names() {
let mut st = StructValue::new();
st.fields.insert("keep".to_string(), Value::Num(1.0));
st.fields.insert("drop".to_string(), Value::Num(2.0));
let result = run_rmfield(
Value::Struct(st),
vec![Value::from("drop"), Value::from("drop")],
)
.expect("rmfield");
let Value::Struct(updated) = result else {
panic!("expected struct result");
};
assert!(!updated.fields.contains_key("drop"));
assert!(updated.fields.contains_key("keep"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_returns_original_when_no_names_supplied() {
let mut st = StructValue::new();
st.fields.insert("value".to_string(), Value::Num(10.0));
let empty = CellArray::new(Vec::new(), 0, 0).expect("empty cell array");
let original = st.clone();
let result =
run_rmfield(Value::Struct(st), vec![Value::Cell(empty)]).expect("rmfield empty");
assert_eq!(result, Value::Struct(original));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn rmfield_requires_field_names() {
let mut st = StructValue::new();
st.fields.insert("value".to_string(), Value::Num(10.0));
let err = error_message(run_rmfield(Value::Struct(st), Vec::new()).unwrap_err());
assert!(
err.contains("rmfield: not enough input arguments"),
"unexpected error: {err}"
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "wgpu")]
fn rmfield_preserves_gpu_handles() {
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
let view = HostTensorView {
data: &[1.0, 2.0],
shape: &[2, 1],
};
let handle = provider.upload(&view).expect("upload");
let mut st = StructValue::new();
st.fields
.insert("gpu".to_string(), Value::GpuTensor(handle.clone()));
st.fields.insert("remove".to_string(), Value::Num(5.0));
let result = run_rmfield(Value::Struct(st), vec![Value::from("remove")]).expect("rmfield");
let Value::Struct(updated) = result else {
panic!("expected struct result");
};
assert!(matches!(
updated.fields.get("gpu"),
Some(Value::GpuTensor(h)) if h == &handle
));
assert!(!updated.fields.contains_key("remove"));
}
}