1use runmat_builtins::{CellArray, CharArray, StringArray, Value};
3use runmat_macros::runtime_builtin;
4
5use crate::builtins::common::map_control_flow_with_builtin;
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
11use crate::builtins::strings::type_resolvers::text_preserve_type;
12use crate::{
13 build_runtime_error, gather_if_needed_async, make_cell_with_shape, BuiltinResult, RuntimeError,
14};
15
16#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::erase")]
17pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
18 name: "erase",
19 op_kind: GpuOpKind::Custom("string-transform"),
20 supported_precisions: &[],
21 broadcast: BroadcastSemantics::None,
22 provider_hooks: &[],
23 constant_strategy: ConstantStrategy::InlineLiteral,
24 residency: ResidencyPolicy::GatherImmediately,
25 nan_mode: ReductionNaN::Include,
26 two_pass_threshold: None,
27 workgroup_size: None,
28 accepts_nan_mode: false,
29 notes:
30 "Executes on the CPU; GPU-resident inputs are gathered to host memory before substrings are removed.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::erase")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35 name: "erase",
36 shape: ShapeRequirements::Any,
37 constant_strategy: ConstantStrategy::InlineLiteral,
38 elementwise: None,
39 reduction: None,
40 emits_nan: false,
41 notes:
42 "String manipulation builtin; not eligible for fusion plans and always gathers GPU inputs before execution.",
43};
44
45const BUILTIN_NAME: &str = "erase";
46const ARG_TYPE_ERROR: &str =
47 "erase: first argument must be a string array, character array, or cell array of character vectors";
48const PATTERN_TYPE_ERROR: &str =
49 "erase: second argument must be a string array, character array, or cell array of character vectors";
50const CELL_ELEMENT_ERROR: &str =
51 "erase: cell array elements must be string scalars or character vectors";
52
53fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
54 build_runtime_error(message)
55 .with_builtin(BUILTIN_NAME)
56 .build()
57}
58
59fn map_flow(err: RuntimeError) -> RuntimeError {
60 map_control_flow_with_builtin(err, BUILTIN_NAME)
61}
62
63#[runtime_builtin(
64 name = "erase",
65 category = "strings/transform",
66 summary = "Remove substring occurrences from strings, character arrays, and cell arrays.",
67 keywords = "erase,remove substring,strings,character array,text",
68 accel = "sink",
69 type_resolver(text_preserve_type),
70 builtin_path = "crate::builtins::strings::transform::erase"
71)]
72async fn erase_builtin(text: Value, pattern: Value) -> BuiltinResult<Value> {
73 let text = gather_if_needed_async(&text).await.map_err(map_flow)?;
74 let pattern = gather_if_needed_async(&pattern).await.map_err(map_flow)?;
75
76 let patterns = PatternList::from_value(&pattern)?;
77
78 match text {
79 Value::String(s) => Ok(Value::String(erase_string_scalar(s, &patterns))),
80 Value::StringArray(sa) => erase_string_array(sa, &patterns),
81 Value::CharArray(ca) => erase_char_array(ca, &patterns),
82 Value::Cell(cell) => erase_cell_array(cell, &patterns),
83 _ => Err(runtime_error_for(ARG_TYPE_ERROR)),
84 }
85}
86
87struct PatternList {
88 entries: Vec<String>,
89}
90
91impl PatternList {
92 fn from_value(value: &Value) -> BuiltinResult<Self> {
93 let entries = match value {
94 Value::String(text) => vec![text.clone()],
95 Value::StringArray(array) => array.data.clone(),
96 Value::CharArray(array) => {
97 if array.rows == 0 {
98 Vec::new()
99 } else {
100 let mut list = Vec::with_capacity(array.rows);
101 for row in 0..array.rows {
102 list.push(char_row_to_string_slice(&array.data, array.cols, row));
103 }
104 list
105 }
106 }
107 Value::Cell(cell) => {
108 let mut list = Vec::with_capacity(cell.data.len());
109 for handle in &cell.data {
110 match &**handle {
111 Value::String(text) => list.push(text.clone()),
112 Value::StringArray(sa) if sa.data.len() == 1 => {
113 list.push(sa.data[0].clone());
114 }
115 Value::CharArray(ca) if ca.rows == 0 => list.push(String::new()),
116 Value::CharArray(ca) if ca.rows == 1 => {
117 list.push(char_row_to_string_slice(&ca.data, ca.cols, 0));
118 }
119 Value::CharArray(_) => return Err(runtime_error_for(CELL_ELEMENT_ERROR)),
120 _ => return Err(runtime_error_for(CELL_ELEMENT_ERROR)),
121 }
122 }
123 list
124 }
125 _ => return Err(runtime_error_for(PATTERN_TYPE_ERROR)),
126 };
127 Ok(Self { entries })
128 }
129
130 fn apply(&self, input: &str) -> String {
131 if self.entries.is_empty() {
132 return input.to_string();
133 }
134 let mut current = input.to_string();
135 for pattern in &self.entries {
136 if pattern.is_empty() {
137 continue;
138 }
139 if current.is_empty() {
140 break;
141 }
142 current = current.replace(pattern, "");
143 }
144 current
145 }
146}
147
148fn erase_string_scalar(text: String, patterns: &PatternList) -> String {
149 if is_missing_string(&text) {
150 text
151 } else {
152 patterns.apply(&text)
153 }
154}
155
156fn erase_string_array(array: StringArray, patterns: &PatternList) -> BuiltinResult<Value> {
157 let StringArray { data, shape, .. } = array;
158 let mut erased = Vec::with_capacity(data.len());
159 for entry in data {
160 if is_missing_string(&entry) {
161 erased.push(entry);
162 } else {
163 erased.push(patterns.apply(&entry));
164 }
165 }
166 StringArray::new(erased, shape)
167 .map(Value::StringArray)
168 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
169}
170
171fn erase_char_array(array: CharArray, patterns: &PatternList) -> BuiltinResult<Value> {
172 let CharArray { data, rows, cols } = array;
173 if rows == 0 {
174 return Ok(Value::CharArray(CharArray { data, rows, cols }));
175 }
176
177 let mut processed: Vec<String> = Vec::with_capacity(rows);
178 let mut target_cols = 0usize;
179 for row in 0..rows {
180 let slice = char_row_to_string_slice(&data, cols, row);
181 let erased = patterns.apply(&slice);
182 let len = erased.chars().count();
183 if len > target_cols {
184 target_cols = len;
185 }
186 processed.push(erased);
187 }
188
189 let mut flattened: Vec<char> = Vec::with_capacity(rows * target_cols);
190 for row_text in processed {
191 let mut chars: Vec<char> = row_text.chars().collect();
192 if chars.len() < target_cols {
193 chars.resize(target_cols, ' ');
194 }
195 flattened.extend(chars);
196 }
197
198 CharArray::new(flattened, rows, target_cols)
199 .map(Value::CharArray)
200 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
201}
202
203fn erase_cell_array(cell: CellArray, patterns: &PatternList) -> BuiltinResult<Value> {
204 let shape = cell.shape.clone();
205 let mut values = Vec::with_capacity(cell.data.len());
206 for handle in &cell.data {
207 values.push(erase_cell_element(handle, patterns)?);
208 }
209 make_cell_with_shape(values, shape)
210 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
211}
212
213fn erase_cell_element(value: &Value, patterns: &PatternList) -> BuiltinResult<Value> {
214 match value {
215 Value::String(text) => Ok(Value::String(erase_string_scalar(text.clone(), patterns))),
216 Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(erase_string_scalar(
217 sa.data[0].clone(),
218 patterns,
219 ))),
220 Value::CharArray(ca) if ca.rows == 0 => Ok(Value::CharArray(ca.clone())),
221 Value::CharArray(ca) if ca.rows == 1 => {
222 let slice = char_row_to_string_slice(&ca.data, ca.cols, 0);
223 let erased = patterns.apply(&slice);
224 Ok(Value::CharArray(CharArray::new_row(&erased)))
225 }
226 Value::CharArray(_) => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
227 _ => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
228 }
229}
230
231#[cfg(test)]
232pub(crate) mod tests {
233 use super::*;
234 use runmat_builtins::{ResolveContext, Type};
235
236 fn erase_builtin(text: Value, pattern: Value) -> BuiltinResult<Value> {
237 futures::executor::block_on(super::erase_builtin(text, pattern))
238 }
239
240 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
241 #[test]
242 fn erase_string_scalar_single_pattern() {
243 let result = erase_builtin(
244 Value::String("RunMat runtime".into()),
245 Value::String(" runtime".into()),
246 )
247 .expect("erase");
248 assert_eq!(result, Value::String("RunMat".into()));
249 }
250
251 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
252 #[test]
253 fn erase_string_array_multiple_patterns() {
254 let strings = StringArray::new(
255 vec!["gpu".into(), "cpu".into(), "<missing>".into()],
256 vec![3, 1],
257 )
258 .unwrap();
259 let result = erase_builtin(
260 Value::StringArray(strings),
261 Value::StringArray(StringArray::new(vec!["g".into(), "c".into()], vec![2, 1]).unwrap()),
262 )
263 .expect("erase");
264 match result {
265 Value::StringArray(sa) => {
266 assert_eq!(sa.shape, vec![3, 1]);
267 assert_eq!(
268 sa.data,
269 vec![
270 String::from("pu"),
271 String::from("pu"),
272 String::from("<missing>")
273 ]
274 );
275 }
276 other => panic!("expected string array, got {other:?}"),
277 }
278 }
279
280 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
281 #[test]
282 fn erase_string_array_shape_mismatch_applies_all_patterns() {
283 let strings =
284 StringArray::new(vec!["GPU kernel".into(), "CPU kernel".into()], vec![2, 1]).unwrap();
285 let patterns = StringArray::new(vec!["GPU ".into(), "CPU ".into()], vec![1, 2]).unwrap();
286 let result = erase_builtin(Value::StringArray(strings), Value::StringArray(patterns))
287 .expect("erase");
288 match result {
289 Value::StringArray(sa) => {
290 assert_eq!(sa.shape, vec![2, 1]);
291 assert_eq!(
292 sa.data,
293 vec![String::from("kernel"), String::from("kernel")]
294 );
295 }
296 other => panic!("expected string array, got {other:?}"),
297 }
298 }
299
300 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
301 #[test]
302 fn erase_char_array_adjusts_width() {
303 let chars = CharArray::new("matrix".chars().collect(), 1, 6).unwrap();
304 let result =
305 erase_builtin(Value::CharArray(chars), Value::String("tr".into())).expect("erase");
306 match result {
307 Value::CharArray(out) => {
308 assert_eq!(out.rows, 1);
309 assert_eq!(out.cols, 4);
310 let expected: Vec<char> = "maix".chars().collect();
311 assert_eq!(out.data, expected);
312 }
313 other => panic!("expected char array, got {other:?}"),
314 }
315 }
316
317 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
318 #[test]
319 fn erase_char_array_handles_full_removal() {
320 let chars = CharArray::new_row("abc");
321 let result = erase_builtin(Value::CharArray(chars.clone()), Value::String("abc".into()))
322 .expect("erase");
323 match result {
324 Value::CharArray(out) => {
325 assert_eq!(out.rows, 1);
326 assert_eq!(out.cols, 0);
327 assert!(out.data.is_empty());
328 }
329 other => panic!("expected empty char array, got {other:?}"),
330 }
331 }
332
333 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
334 #[test]
335 fn erase_char_array_multiple_rows_sequential_patterns() {
336 let chars = CharArray::new(
337 vec![
338 'G', 'P', 'U', ' ', 'p', 'i', 'p', 'e', 'l', 'i', 'n', 'e', 'C', 'P', 'U', ' ',
339 'p', 'i', 'p', 'e', 'l', 'i', 'n', 'e',
340 ],
341 2,
342 12,
343 )
344 .unwrap();
345 let patterns = CharArray::new_row("GPU ");
346 let result =
347 erase_builtin(Value::CharArray(chars), Value::CharArray(patterns)).expect("erase");
348 match result {
349 Value::CharArray(out) => {
350 assert_eq!(out.rows, 2);
351 assert_eq!(out.cols, 12);
352 let first = char_row_to_string_slice(&out.data, out.cols, 0);
353 let second = char_row_to_string_slice(&out.data, out.cols, 1);
354 assert_eq!(first.trim_end(), "pipeline");
355 assert_eq!(second.trim_end(), "CPU pipeline");
356 }
357 other => panic!("expected char array, got {other:?}"),
358 }
359 }
360
361 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
362 #[test]
363 fn erase_cell_array_mixed_content() {
364 let cell = CellArray::new(
365 vec![
366 Value::CharArray(CharArray::new_row("Kernel Planner")),
367 Value::String("GPU Fusion".into()),
368 ],
369 1,
370 2,
371 )
372 .unwrap();
373 let result = erase_builtin(
374 Value::Cell(cell),
375 Value::Cell(
376 CellArray::new(
377 vec![
378 Value::String("Kernel ".into()),
379 Value::String("GPU ".into()),
380 ],
381 1,
382 2,
383 )
384 .unwrap(),
385 ),
386 )
387 .expect("erase");
388 match result {
389 Value::Cell(out) => {
390 let first = out.get(0, 0).unwrap();
391 let second = out.get(0, 1).unwrap();
392 assert_eq!(first, Value::CharArray(CharArray::new_row("Planner")));
393 assert_eq!(second, Value::String("Fusion".into()));
394 }
395 other => panic!("expected cell array, got {other:?}"),
396 }
397 }
398
399 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
400 #[test]
401 fn erase_cell_array_preserves_shape() {
402 let cell = CellArray::new(
403 vec![
404 Value::String("alpha".into()),
405 Value::String("beta".into()),
406 Value::String("gamma".into()),
407 Value::String("delta".into()),
408 ],
409 2,
410 2,
411 )
412 .unwrap();
413 let patterns = StringArray::new(vec!["a".into()], vec![1, 1]).unwrap();
414 let result = erase_builtin(Value::Cell(cell), Value::StringArray(patterns)).expect("erase");
415 match result {
416 Value::Cell(out) => {
417 assert_eq!(out.rows, 2);
418 assert_eq!(out.cols, 2);
419 assert_eq!(out.get(0, 0).unwrap(), Value::String("lph".into()));
420 assert_eq!(out.get(1, 1).unwrap(), Value::String("delt".into()));
421 }
422 other => panic!("expected cell array, got {other:?}"),
423 }
424 }
425
426 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
427 #[test]
428 fn erase_preserves_missing_string() {
429 let result = erase_builtin(
430 Value::String("<missing>".into()),
431 Value::String("missing".into()),
432 )
433 .expect("erase");
434 assert_eq!(result, Value::String("<missing>".into()));
435 }
436
437 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
438 #[test]
439 fn erase_allows_empty_pattern_list() {
440 let strings = StringArray::new(vec!["alpha".into(), "beta".into()], vec![2, 1]).unwrap();
441 let pattern = StringArray::new(Vec::<String>::new(), vec![0, 0]).unwrap();
442 let result = erase_builtin(
443 Value::StringArray(strings.clone()),
444 Value::StringArray(pattern),
445 )
446 .expect("erase");
447 assert_eq!(result, Value::StringArray(strings));
448 }
449
450 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
451 #[test]
452 fn erase_errors_on_invalid_first_argument() {
453 let err = erase_builtin(Value::Num(1.0), Value::String("a".into())).unwrap_err();
454 assert_eq!(err.to_string(), ARG_TYPE_ERROR);
455 }
456
457 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
458 #[test]
459 fn erase_errors_on_invalid_pattern_type() {
460 let err = erase_builtin(Value::String("abc".into()), Value::Num(1.0)).unwrap_err();
461 assert_eq!(err.to_string(), PATTERN_TYPE_ERROR);
462 }
463
464 #[test]
465 fn erase_type_preserves_text() {
466 assert_eq!(
467 text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
468 Type::String
469 );
470 }
471}