Skip to main content

runmat_runtime/builtins/strings/transform/
erase.rs

1//! MATLAB-compatible `erase` builtin with GPU-aware semantics for RunMat.
2use runmat_builtins::{CellArray, CharArray, StringArray, Value};
3use runmat_macros::runtime_builtin;
4
5use crate::builtins::common::map_control_flow_with_builtin;
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
11use crate::builtins::strings::type_resolvers::text_preserve_type;
12use crate::{
13    build_runtime_error, gather_if_needed_async, make_cell_with_shape, BuiltinResult, RuntimeError,
14};
15
16#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::erase")]
17pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
18    name: "erase",
19    op_kind: GpuOpKind::Custom("string-transform"),
20    supported_precisions: &[],
21    broadcast: BroadcastSemantics::None,
22    provider_hooks: &[],
23    constant_strategy: ConstantStrategy::InlineLiteral,
24    residency: ResidencyPolicy::GatherImmediately,
25    nan_mode: ReductionNaN::Include,
26    two_pass_threshold: None,
27    workgroup_size: None,
28    accepts_nan_mode: false,
29    notes:
30        "Executes on the CPU; GPU-resident inputs are gathered to host memory before substrings are removed.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::erase")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35    name: "erase",
36    shape: ShapeRequirements::Any,
37    constant_strategy: ConstantStrategy::InlineLiteral,
38    elementwise: None,
39    reduction: None,
40    emits_nan: false,
41    notes:
42        "String manipulation builtin; not eligible for fusion plans and always gathers GPU inputs before execution.",
43};
44
45const BUILTIN_NAME: &str = "erase";
46const ARG_TYPE_ERROR: &str =
47    "erase: first argument must be a string array, character array, or cell array of character vectors";
48const PATTERN_TYPE_ERROR: &str =
49    "erase: second argument must be a string array, character array, or cell array of character vectors";
50const CELL_ELEMENT_ERROR: &str =
51    "erase: cell array elements must be string scalars or character vectors";
52
53fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
54    build_runtime_error(message)
55        .with_builtin(BUILTIN_NAME)
56        .build()
57}
58
59fn map_flow(err: RuntimeError) -> RuntimeError {
60    map_control_flow_with_builtin(err, BUILTIN_NAME)
61}
62
63#[runtime_builtin(
64    name = "erase",
65    category = "strings/transform",
66    summary = "Remove substring occurrences from strings, character arrays, and cell arrays.",
67    keywords = "erase,remove substring,strings,character array,text",
68    accel = "sink",
69    type_resolver(text_preserve_type),
70    builtin_path = "crate::builtins::strings::transform::erase"
71)]
72async fn erase_builtin(text: Value, pattern: Value) -> BuiltinResult<Value> {
73    let text = gather_if_needed_async(&text).await.map_err(map_flow)?;
74    let pattern = gather_if_needed_async(&pattern).await.map_err(map_flow)?;
75
76    let patterns = PatternList::from_value(&pattern)?;
77
78    match text {
79        Value::String(s) => Ok(Value::String(erase_string_scalar(s, &patterns))),
80        Value::StringArray(sa) => erase_string_array(sa, &patterns),
81        Value::CharArray(ca) => erase_char_array(ca, &patterns),
82        Value::Cell(cell) => erase_cell_array(cell, &patterns),
83        _ => Err(runtime_error_for(ARG_TYPE_ERROR)),
84    }
85}
86
87struct PatternList {
88    entries: Vec<String>,
89}
90
91impl PatternList {
92    fn from_value(value: &Value) -> BuiltinResult<Self> {
93        let entries = match value {
94            Value::String(text) => vec![text.clone()],
95            Value::StringArray(array) => array.data.clone(),
96            Value::CharArray(array) => {
97                if array.rows == 0 {
98                    Vec::new()
99                } else {
100                    let mut list = Vec::with_capacity(array.rows);
101                    for row in 0..array.rows {
102                        list.push(char_row_to_string_slice(&array.data, array.cols, row));
103                    }
104                    list
105                }
106            }
107            Value::Cell(cell) => {
108                let mut list = Vec::with_capacity(cell.data.len());
109                for handle in &cell.data {
110                    match &**handle {
111                        Value::String(text) => list.push(text.clone()),
112                        Value::StringArray(sa) if sa.data.len() == 1 => {
113                            list.push(sa.data[0].clone());
114                        }
115                        Value::CharArray(ca) if ca.rows == 0 => list.push(String::new()),
116                        Value::CharArray(ca) if ca.rows == 1 => {
117                            list.push(char_row_to_string_slice(&ca.data, ca.cols, 0));
118                        }
119                        Value::CharArray(_) => return Err(runtime_error_for(CELL_ELEMENT_ERROR)),
120                        _ => return Err(runtime_error_for(CELL_ELEMENT_ERROR)),
121                    }
122                }
123                list
124            }
125            _ => return Err(runtime_error_for(PATTERN_TYPE_ERROR)),
126        };
127        Ok(Self { entries })
128    }
129
130    fn apply(&self, input: &str) -> String {
131        if self.entries.is_empty() {
132            return input.to_string();
133        }
134        let mut current = input.to_string();
135        for pattern in &self.entries {
136            if pattern.is_empty() {
137                continue;
138            }
139            if current.is_empty() {
140                break;
141            }
142            current = current.replace(pattern, "");
143        }
144        current
145    }
146}
147
148fn erase_string_scalar(text: String, patterns: &PatternList) -> String {
149    if is_missing_string(&text) {
150        text
151    } else {
152        patterns.apply(&text)
153    }
154}
155
156fn erase_string_array(array: StringArray, patterns: &PatternList) -> BuiltinResult<Value> {
157    let StringArray { data, shape, .. } = array;
158    let mut erased = Vec::with_capacity(data.len());
159    for entry in data {
160        if is_missing_string(&entry) {
161            erased.push(entry);
162        } else {
163            erased.push(patterns.apply(&entry));
164        }
165    }
166    StringArray::new(erased, shape)
167        .map(Value::StringArray)
168        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
169}
170
171fn erase_char_array(array: CharArray, patterns: &PatternList) -> BuiltinResult<Value> {
172    let CharArray { data, rows, cols } = array;
173    if rows == 0 {
174        return Ok(Value::CharArray(CharArray { data, rows, cols }));
175    }
176
177    let mut processed: Vec<String> = Vec::with_capacity(rows);
178    let mut target_cols = 0usize;
179    for row in 0..rows {
180        let slice = char_row_to_string_slice(&data, cols, row);
181        let erased = patterns.apply(&slice);
182        let len = erased.chars().count();
183        if len > target_cols {
184            target_cols = len;
185        }
186        processed.push(erased);
187    }
188
189    let mut flattened: Vec<char> = Vec::with_capacity(rows * target_cols);
190    for row_text in processed {
191        let mut chars: Vec<char> = row_text.chars().collect();
192        if chars.len() < target_cols {
193            chars.resize(target_cols, ' ');
194        }
195        flattened.extend(chars);
196    }
197
198    CharArray::new(flattened, rows, target_cols)
199        .map(Value::CharArray)
200        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
201}
202
203fn erase_cell_array(cell: CellArray, patterns: &PatternList) -> BuiltinResult<Value> {
204    let shape = cell.shape.clone();
205    let mut values = Vec::with_capacity(cell.data.len());
206    for handle in &cell.data {
207        values.push(erase_cell_element(handle, patterns)?);
208    }
209    make_cell_with_shape(values, shape)
210        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
211}
212
213fn erase_cell_element(value: &Value, patterns: &PatternList) -> BuiltinResult<Value> {
214    match value {
215        Value::String(text) => Ok(Value::String(erase_string_scalar(text.clone(), patterns))),
216        Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(erase_string_scalar(
217            sa.data[0].clone(),
218            patterns,
219        ))),
220        Value::CharArray(ca) if ca.rows == 0 => Ok(Value::CharArray(ca.clone())),
221        Value::CharArray(ca) if ca.rows == 1 => {
222            let slice = char_row_to_string_slice(&ca.data, ca.cols, 0);
223            let erased = patterns.apply(&slice);
224            Ok(Value::CharArray(CharArray::new_row(&erased)))
225        }
226        Value::CharArray(_) => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
227        _ => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
228    }
229}
230
231#[cfg(test)]
232pub(crate) mod tests {
233    use super::*;
234    use runmat_builtins::{ResolveContext, Type};
235
236    fn erase_builtin(text: Value, pattern: Value) -> BuiltinResult<Value> {
237        futures::executor::block_on(super::erase_builtin(text, pattern))
238    }
239
240    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
241    #[test]
242    fn erase_string_scalar_single_pattern() {
243        let result = erase_builtin(
244            Value::String("RunMat runtime".into()),
245            Value::String(" runtime".into()),
246        )
247        .expect("erase");
248        assert_eq!(result, Value::String("RunMat".into()));
249    }
250
251    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
252    #[test]
253    fn erase_string_array_multiple_patterns() {
254        let strings = StringArray::new(
255            vec!["gpu".into(), "cpu".into(), "<missing>".into()],
256            vec![3, 1],
257        )
258        .unwrap();
259        let result = erase_builtin(
260            Value::StringArray(strings),
261            Value::StringArray(StringArray::new(vec!["g".into(), "c".into()], vec![2, 1]).unwrap()),
262        )
263        .expect("erase");
264        match result {
265            Value::StringArray(sa) => {
266                assert_eq!(sa.shape, vec![3, 1]);
267                assert_eq!(
268                    sa.data,
269                    vec![
270                        String::from("pu"),
271                        String::from("pu"),
272                        String::from("<missing>")
273                    ]
274                );
275            }
276            other => panic!("expected string array, got {other:?}"),
277        }
278    }
279
280    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
281    #[test]
282    fn erase_string_array_shape_mismatch_applies_all_patterns() {
283        let strings =
284            StringArray::new(vec!["GPU kernel".into(), "CPU kernel".into()], vec![2, 1]).unwrap();
285        let patterns = StringArray::new(vec!["GPU ".into(), "CPU ".into()], vec![1, 2]).unwrap();
286        let result = erase_builtin(Value::StringArray(strings), Value::StringArray(patterns))
287            .expect("erase");
288        match result {
289            Value::StringArray(sa) => {
290                assert_eq!(sa.shape, vec![2, 1]);
291                assert_eq!(
292                    sa.data,
293                    vec![String::from("kernel"), String::from("kernel")]
294                );
295            }
296            other => panic!("expected string array, got {other:?}"),
297        }
298    }
299
300    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
301    #[test]
302    fn erase_char_array_adjusts_width() {
303        let chars = CharArray::new("matrix".chars().collect(), 1, 6).unwrap();
304        let result =
305            erase_builtin(Value::CharArray(chars), Value::String("tr".into())).expect("erase");
306        match result {
307            Value::CharArray(out) => {
308                assert_eq!(out.rows, 1);
309                assert_eq!(out.cols, 4);
310                let expected: Vec<char> = "maix".chars().collect();
311                assert_eq!(out.data, expected);
312            }
313            other => panic!("expected char array, got {other:?}"),
314        }
315    }
316
317    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
318    #[test]
319    fn erase_char_array_handles_full_removal() {
320        let chars = CharArray::new_row("abc");
321        let result = erase_builtin(Value::CharArray(chars.clone()), Value::String("abc".into()))
322            .expect("erase");
323        match result {
324            Value::CharArray(out) => {
325                assert_eq!(out.rows, 1);
326                assert_eq!(out.cols, 0);
327                assert!(out.data.is_empty());
328            }
329            other => panic!("expected empty char array, got {other:?}"),
330        }
331    }
332
333    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
334    #[test]
335    fn erase_char_array_multiple_rows_sequential_patterns() {
336        let chars = CharArray::new(
337            vec![
338                'G', 'P', 'U', ' ', 'p', 'i', 'p', 'e', 'l', 'i', 'n', 'e', 'C', 'P', 'U', ' ',
339                'p', 'i', 'p', 'e', 'l', 'i', 'n', 'e',
340            ],
341            2,
342            12,
343        )
344        .unwrap();
345        let patterns = CharArray::new_row("GPU ");
346        let result =
347            erase_builtin(Value::CharArray(chars), Value::CharArray(patterns)).expect("erase");
348        match result {
349            Value::CharArray(out) => {
350                assert_eq!(out.rows, 2);
351                assert_eq!(out.cols, 12);
352                let first = char_row_to_string_slice(&out.data, out.cols, 0);
353                let second = char_row_to_string_slice(&out.data, out.cols, 1);
354                assert_eq!(first.trim_end(), "pipeline");
355                assert_eq!(second.trim_end(), "CPU pipeline");
356            }
357            other => panic!("expected char array, got {other:?}"),
358        }
359    }
360
361    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
362    #[test]
363    fn erase_cell_array_mixed_content() {
364        let cell = CellArray::new(
365            vec![
366                Value::CharArray(CharArray::new_row("Kernel Planner")),
367                Value::String("GPU Fusion".into()),
368            ],
369            1,
370            2,
371        )
372        .unwrap();
373        let result = erase_builtin(
374            Value::Cell(cell),
375            Value::Cell(
376                CellArray::new(
377                    vec![
378                        Value::String("Kernel ".into()),
379                        Value::String("GPU ".into()),
380                    ],
381                    1,
382                    2,
383                )
384                .unwrap(),
385            ),
386        )
387        .expect("erase");
388        match result {
389            Value::Cell(out) => {
390                let first = out.get(0, 0).unwrap();
391                let second = out.get(0, 1).unwrap();
392                assert_eq!(first, Value::CharArray(CharArray::new_row("Planner")));
393                assert_eq!(second, Value::String("Fusion".into()));
394            }
395            other => panic!("expected cell array, got {other:?}"),
396        }
397    }
398
399    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
400    #[test]
401    fn erase_cell_array_preserves_shape() {
402        let cell = CellArray::new(
403            vec![
404                Value::String("alpha".into()),
405                Value::String("beta".into()),
406                Value::String("gamma".into()),
407                Value::String("delta".into()),
408            ],
409            2,
410            2,
411        )
412        .unwrap();
413        let patterns = StringArray::new(vec!["a".into()], vec![1, 1]).unwrap();
414        let result = erase_builtin(Value::Cell(cell), Value::StringArray(patterns)).expect("erase");
415        match result {
416            Value::Cell(out) => {
417                assert_eq!(out.rows, 2);
418                assert_eq!(out.cols, 2);
419                assert_eq!(out.get(0, 0).unwrap(), Value::String("lph".into()));
420                assert_eq!(out.get(1, 1).unwrap(), Value::String("delt".into()));
421            }
422            other => panic!("expected cell array, got {other:?}"),
423        }
424    }
425
426    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
427    #[test]
428    fn erase_preserves_missing_string() {
429        let result = erase_builtin(
430            Value::String("<missing>".into()),
431            Value::String("missing".into()),
432        )
433        .expect("erase");
434        assert_eq!(result, Value::String("<missing>".into()));
435    }
436
437    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
438    #[test]
439    fn erase_allows_empty_pattern_list() {
440        let strings = StringArray::new(vec!["alpha".into(), "beta".into()], vec![2, 1]).unwrap();
441        let pattern = StringArray::new(Vec::<String>::new(), vec![0, 0]).unwrap();
442        let result = erase_builtin(
443            Value::StringArray(strings.clone()),
444            Value::StringArray(pattern),
445        )
446        .expect("erase");
447        assert_eq!(result, Value::StringArray(strings));
448    }
449
450    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
451    #[test]
452    fn erase_errors_on_invalid_first_argument() {
453        let err = erase_builtin(Value::Num(1.0), Value::String("a".into())).unwrap_err();
454        assert_eq!(err.to_string(), ARG_TYPE_ERROR);
455    }
456
457    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
458    #[test]
459    fn erase_errors_on_invalid_pattern_type() {
460        let err = erase_builtin(Value::String("abc".into()), Value::Num(1.0)).unwrap_err();
461        assert_eq!(err.to_string(), PATTERN_TYPE_ERROR);
462    }
463
464    #[test]
465    fn erase_type_preserves_text() {
466        assert_eq!(
467            text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
468            Type::String
469        );
470    }
471}