1use runmat_builtins::{
3 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
4 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
5 CellArray, CharArray, StringArray, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::map_control_flow_with_builtin;
10use crate::builtins::common::spec::{
11 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12 ReductionNaN, ResidencyPolicy, ShapeRequirements,
13};
14use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
15use crate::builtins::strings::type_resolvers::text_preserve_type;
16use crate::{
17 build_runtime_error, gather_if_needed_async, make_cell_with_shape, BuiltinResult, RuntimeError,
18};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::erase")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22 name: "erase",
23 op_kind: GpuOpKind::Custom("string-transform"),
24 supported_precisions: &[],
25 broadcast: BroadcastSemantics::None,
26 provider_hooks: &[],
27 constant_strategy: ConstantStrategy::InlineLiteral,
28 residency: ResidencyPolicy::GatherImmediately,
29 nan_mode: ReductionNaN::Include,
30 two_pass_threshold: None,
31 workgroup_size: None,
32 accepts_nan_mode: false,
33 notes:
34 "Executes on the CPU; GPU-resident inputs are gathered to host memory before substrings are removed.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::erase")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39 name: "erase",
40 shape: ShapeRequirements::Any,
41 constant_strategy: ConstantStrategy::InlineLiteral,
42 elementwise: None,
43 reduction: None,
44 emits_nan: false,
45 notes:
46 "String manipulation builtin; not eligible for fusion plans and always gathers GPU inputs before execution.",
47};
48
49const BUILTIN_NAME: &str = "erase";
50
51const ERASE_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
52 name: "newStr",
53 ty: BuiltinParamType::Any,
54 arity: BuiltinParamArity::Required,
55 default: None,
56 description: "Text with substring occurrences removed, preserving input container kind.",
57}];
58
59const ERASE_INPUTS: [BuiltinParamDescriptor; 2] = [
60 BuiltinParamDescriptor {
61 name: "str",
62 ty: BuiltinParamType::Any,
63 arity: BuiltinParamArity::Required,
64 default: None,
65 description: "Input text (string/char/cell).",
66 },
67 BuiltinParamDescriptor {
68 name: "pattern",
69 ty: BuiltinParamType::Any,
70 arity: BuiltinParamArity::Required,
71 default: None,
72 description: "Pattern text list (scalar or array/cell).",
73 },
74];
75
76const ERASE_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
77 label: "newStr = erase(str, pattern)",
78 inputs: &ERASE_INPUTS,
79 outputs: &ERASE_OUTPUT,
80}];
81
82const ERASE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
83 code: "RM.ERASE.INVALID_INPUT",
84 identifier: Some("RunMat:erase:InvalidInput"),
85 when: "First argument is not a string array, char array, or cell array of text scalars.",
86 message:
87 "erase: first argument must be a string array, character array, or cell array of character vectors",
88};
89
90const ERASE_ERROR_PATTERN_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
91 code: "RM.ERASE.PATTERN_TYPE",
92 identifier: Some("RunMat:erase:PatternType"),
93 when: "Second argument is not a text scalar/array/cell of text scalars.",
94 message:
95 "erase: second argument must be a string array, character array, or cell array of character vectors",
96};
97
98const ERASE_ERROR_CELL_ELEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
99 code: "RM.ERASE.CELL_ELEMENT",
100 identifier: Some("RunMat:erase:CellElement"),
101 when: "Cell arrays contain non-text elements or non-row char arrays.",
102 message: "erase: cell array elements must be string scalars or character vectors",
103};
104
105const ERASE_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
106 code: "RM.ERASE.INTERNAL",
107 identifier: Some("RunMat:erase:InternalError"),
108 when: "Internal output container construction failed.",
109 message: "erase: internal error",
110};
111
112const ERASE_ERRORS: [BuiltinErrorDescriptor; 4] = [
113 ERASE_ERROR_INVALID_INPUT,
114 ERASE_ERROR_PATTERN_TYPE,
115 ERASE_ERROR_CELL_ELEMENT,
116 ERASE_ERROR_INTERNAL,
117];
118
119pub const ERASE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
120 signatures: &ERASE_SIGNATURES,
121 output_mode: BuiltinOutputMode::Fixed,
122 completion_policy: BuiltinCompletionPolicy::Public,
123 errors: &ERASE_ERRORS,
124};
125
126fn map_flow(err: RuntimeError) -> RuntimeError {
127 map_control_flow_with_builtin(err, BUILTIN_NAME)
128}
129
130fn erase_error_with_message(
131 message: impl Into<String>,
132 error: &'static BuiltinErrorDescriptor,
133) -> RuntimeError {
134 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
135 if let Some(identifier) = error.identifier {
136 builder = builder.with_identifier(identifier);
137 }
138 builder.build()
139}
140
141fn erase_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
142 erase_error_with_message(error.message, error)
143}
144
145#[runtime_builtin(
146 name = "erase",
147 category = "strings/transform",
148 summary = "Remove substring occurrences from text inputs.",
149 keywords = "erase,remove substring,strings,character array,text",
150 accel = "sink",
151 type_resolver(text_preserve_type),
152 descriptor(crate::builtins::strings::transform::erase::ERASE_DESCRIPTOR),
153 builtin_path = "crate::builtins::strings::transform::erase"
154)]
155async fn erase_builtin(text: Value, pattern: Value) -> BuiltinResult<Value> {
156 let text = gather_if_needed_async(&text).await.map_err(map_flow)?;
157 let pattern = gather_if_needed_async(&pattern).await.map_err(map_flow)?;
158
159 let patterns = PatternList::from_value(&pattern)?;
160
161 match text {
162 Value::String(s) => Ok(Value::String(erase_string_scalar(s, &patterns))),
163 Value::StringArray(sa) => erase_string_array(sa, &patterns),
164 Value::CharArray(ca) => erase_char_array(ca, &patterns),
165 Value::Cell(cell) => erase_cell_array(cell, &patterns),
166 _ => Err(erase_error(&ERASE_ERROR_INVALID_INPUT)),
167 }
168}
169
170struct PatternList {
171 entries: Vec<String>,
172}
173
174impl PatternList {
175 fn from_value(value: &Value) -> BuiltinResult<Self> {
176 let entries = match value {
177 Value::String(text) => vec![text.clone()],
178 Value::StringArray(array) => array.data.clone(),
179 Value::CharArray(array) => {
180 if array.rows == 0 {
181 Vec::new()
182 } else {
183 let mut list = Vec::with_capacity(array.rows);
184 for row in 0..array.rows {
185 list.push(char_row_to_string_slice(&array.data, array.cols, row));
186 }
187 list
188 }
189 }
190 Value::Cell(cell) => {
191 let mut list = Vec::with_capacity(cell.data.len());
192 for handle in &cell.data {
193 match &**handle {
194 Value::String(text) => list.push(text.clone()),
195 Value::StringArray(sa) if sa.data.len() == 1 => {
196 list.push(sa.data[0].clone());
197 }
198 Value::CharArray(ca) if ca.rows == 0 => list.push(String::new()),
199 Value::CharArray(ca) if ca.rows == 1 => {
200 list.push(char_row_to_string_slice(&ca.data, ca.cols, 0));
201 }
202 Value::CharArray(_) => return Err(erase_error(&ERASE_ERROR_CELL_ELEMENT)),
203 _ => return Err(erase_error(&ERASE_ERROR_CELL_ELEMENT)),
204 }
205 }
206 list
207 }
208 _ => return Err(erase_error(&ERASE_ERROR_PATTERN_TYPE)),
209 };
210 Ok(Self { entries })
211 }
212
213 fn apply(&self, input: &str) -> String {
214 if self.entries.is_empty() {
215 return input.to_string();
216 }
217 let mut current = input.to_string();
218 for pattern in &self.entries {
219 if pattern.is_empty() {
220 continue;
221 }
222 if current.is_empty() {
223 break;
224 }
225 current = current.replace(pattern, "");
226 }
227 current
228 }
229}
230
231fn erase_string_scalar(text: String, patterns: &PatternList) -> String {
232 if is_missing_string(&text) {
233 text
234 } else {
235 patterns.apply(&text)
236 }
237}
238
239fn erase_string_array(array: StringArray, patterns: &PatternList) -> BuiltinResult<Value> {
240 let StringArray { data, shape, .. } = array;
241 let mut erased = Vec::with_capacity(data.len());
242 for entry in data {
243 if is_missing_string(&entry) {
244 erased.push(entry);
245 } else {
246 erased.push(patterns.apply(&entry));
247 }
248 }
249 StringArray::new(erased, shape)
250 .map(Value::StringArray)
251 .map_err(|e| {
252 erase_error_with_message(format!("{BUILTIN_NAME}: {e}"), &ERASE_ERROR_INTERNAL)
253 })
254}
255
256fn erase_char_array(array: CharArray, patterns: &PatternList) -> BuiltinResult<Value> {
257 let CharArray { data, rows, cols } = array;
258 if rows == 0 {
259 return Ok(Value::CharArray(CharArray { data, rows, cols }));
260 }
261
262 let mut processed: Vec<String> = Vec::with_capacity(rows);
263 let mut target_cols = 0usize;
264 for row in 0..rows {
265 let slice = char_row_to_string_slice(&data, cols, row);
266 let erased = patterns.apply(&slice);
267 let len = erased.chars().count();
268 if len > target_cols {
269 target_cols = len;
270 }
271 processed.push(erased);
272 }
273
274 let mut flattened: Vec<char> = Vec::with_capacity(rows * target_cols);
275 for row_text in processed {
276 let mut chars: Vec<char> = row_text.chars().collect();
277 if chars.len() < target_cols {
278 chars.resize(target_cols, ' ');
279 }
280 flattened.extend(chars);
281 }
282
283 CharArray::new(flattened, rows, target_cols)
284 .map(Value::CharArray)
285 .map_err(|e| {
286 erase_error_with_message(format!("{BUILTIN_NAME}: {e}"), &ERASE_ERROR_INTERNAL)
287 })
288}
289
290fn erase_cell_array(cell: CellArray, patterns: &PatternList) -> BuiltinResult<Value> {
291 let shape = cell.shape.clone();
292 let mut values = Vec::with_capacity(cell.data.len());
293 for handle in &cell.data {
294 values.push(erase_cell_element(handle, patterns)?);
295 }
296 make_cell_with_shape(values, shape).map_err(|e| {
297 erase_error_with_message(format!("{BUILTIN_NAME}: {e}"), &ERASE_ERROR_INTERNAL)
298 })
299}
300
301fn erase_cell_element(value: &Value, patterns: &PatternList) -> BuiltinResult<Value> {
302 match value {
303 Value::String(text) => Ok(Value::String(erase_string_scalar(text.clone(), patterns))),
304 Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(erase_string_scalar(
305 sa.data[0].clone(),
306 patterns,
307 ))),
308 Value::CharArray(ca) if ca.rows == 0 => Ok(Value::CharArray(ca.clone())),
309 Value::CharArray(ca) if ca.rows == 1 => {
310 let slice = char_row_to_string_slice(&ca.data, ca.cols, 0);
311 let erased = patterns.apply(&slice);
312 Ok(Value::CharArray(CharArray::new_row(&erased)))
313 }
314 Value::CharArray(_) => Err(erase_error(&ERASE_ERROR_CELL_ELEMENT)),
315 _ => Err(erase_error(&ERASE_ERROR_CELL_ELEMENT)),
316 }
317}
318
319#[cfg(test)]
320pub(crate) mod tests {
321 use super::*;
322 use runmat_builtins::{ResolveContext, Type};
323
324 fn erase_builtin(text: Value, pattern: Value) -> BuiltinResult<Value> {
325 futures::executor::block_on(super::erase_builtin(text, pattern))
326 }
327
328 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
329 #[test]
330 fn erase_string_scalar_single_pattern() {
331 let result = erase_builtin(
332 Value::String("RunMat runtime".into()),
333 Value::String(" runtime".into()),
334 )
335 .expect("erase");
336 assert_eq!(result, Value::String("RunMat".into()));
337 }
338
339 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
340 #[test]
341 fn erase_string_array_multiple_patterns() {
342 let strings = StringArray::new(
343 vec!["gpu".into(), "cpu".into(), "<missing>".into()],
344 vec![3, 1],
345 )
346 .unwrap();
347 let result = erase_builtin(
348 Value::StringArray(strings),
349 Value::StringArray(StringArray::new(vec!["g".into(), "c".into()], vec![2, 1]).unwrap()),
350 )
351 .expect("erase");
352 match result {
353 Value::StringArray(sa) => {
354 assert_eq!(sa.shape, vec![3, 1]);
355 assert_eq!(
356 sa.data,
357 vec![
358 String::from("pu"),
359 String::from("pu"),
360 String::from("<missing>")
361 ]
362 );
363 }
364 other => panic!("expected string array, got {other:?}"),
365 }
366 }
367
368 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
369 #[test]
370 fn erase_string_array_shape_mismatch_applies_all_patterns() {
371 let strings =
372 StringArray::new(vec!["GPU kernel".into(), "CPU kernel".into()], vec![2, 1]).unwrap();
373 let patterns = StringArray::new(vec!["GPU ".into(), "CPU ".into()], vec![1, 2]).unwrap();
374 let result = erase_builtin(Value::StringArray(strings), Value::StringArray(patterns))
375 .expect("erase");
376 match result {
377 Value::StringArray(sa) => {
378 assert_eq!(sa.shape, vec![2, 1]);
379 assert_eq!(
380 sa.data,
381 vec![String::from("kernel"), String::from("kernel")]
382 );
383 }
384 other => panic!("expected string array, got {other:?}"),
385 }
386 }
387
388 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
389 #[test]
390 fn erase_char_array_adjusts_width() {
391 let chars = CharArray::new("matrix".chars().collect(), 1, 6).unwrap();
392 let result =
393 erase_builtin(Value::CharArray(chars), Value::String("tr".into())).expect("erase");
394 match result {
395 Value::CharArray(out) => {
396 assert_eq!(out.rows, 1);
397 assert_eq!(out.cols, 4);
398 let expected: Vec<char> = "maix".chars().collect();
399 assert_eq!(out.data, expected);
400 }
401 other => panic!("expected char array, got {other:?}"),
402 }
403 }
404
405 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
406 #[test]
407 fn erase_char_array_handles_full_removal() {
408 let chars = CharArray::new_row("abc");
409 let result = erase_builtin(Value::CharArray(chars.clone()), Value::String("abc".into()))
410 .expect("erase");
411 match result {
412 Value::CharArray(out) => {
413 assert_eq!(out.rows, 1);
414 assert_eq!(out.cols, 0);
415 assert!(out.data.is_empty());
416 }
417 other => panic!("expected empty char array, got {other:?}"),
418 }
419 }
420
421 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
422 #[test]
423 fn erase_char_array_multiple_rows_sequential_patterns() {
424 let chars = CharArray::new(
425 vec![
426 'G', 'P', 'U', ' ', 'p', 'i', 'p', 'e', 'l', 'i', 'n', 'e', 'C', 'P', 'U', ' ',
427 'p', 'i', 'p', 'e', 'l', 'i', 'n', 'e',
428 ],
429 2,
430 12,
431 )
432 .unwrap();
433 let patterns = CharArray::new_row("GPU ");
434 let result =
435 erase_builtin(Value::CharArray(chars), Value::CharArray(patterns)).expect("erase");
436 match result {
437 Value::CharArray(out) => {
438 assert_eq!(out.rows, 2);
439 assert_eq!(out.cols, 12);
440 let first = char_row_to_string_slice(&out.data, out.cols, 0);
441 let second = char_row_to_string_slice(&out.data, out.cols, 1);
442 assert_eq!(first.trim_end(), "pipeline");
443 assert_eq!(second.trim_end(), "CPU pipeline");
444 }
445 other => panic!("expected char array, got {other:?}"),
446 }
447 }
448
449 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
450 #[test]
451 fn erase_cell_array_mixed_content() {
452 let cell = CellArray::new(
453 vec![
454 Value::CharArray(CharArray::new_row("Kernel Planner")),
455 Value::String("GPU Fusion".into()),
456 ],
457 1,
458 2,
459 )
460 .unwrap();
461 let result = erase_builtin(
462 Value::Cell(cell),
463 Value::Cell(
464 CellArray::new(
465 vec![
466 Value::String("Kernel ".into()),
467 Value::String("GPU ".into()),
468 ],
469 1,
470 2,
471 )
472 .unwrap(),
473 ),
474 )
475 .expect("erase");
476 match result {
477 Value::Cell(out) => {
478 let first = out.get(0, 0).unwrap();
479 let second = out.get(0, 1).unwrap();
480 assert_eq!(first, Value::CharArray(CharArray::new_row("Planner")));
481 assert_eq!(second, Value::String("Fusion".into()));
482 }
483 other => panic!("expected cell array, got {other:?}"),
484 }
485 }
486
487 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
488 #[test]
489 fn erase_cell_array_preserves_shape() {
490 let cell = CellArray::new(
491 vec![
492 Value::String("alpha".into()),
493 Value::String("beta".into()),
494 Value::String("gamma".into()),
495 Value::String("delta".into()),
496 ],
497 2,
498 2,
499 )
500 .unwrap();
501 let patterns = StringArray::new(vec!["a".into()], vec![1, 1]).unwrap();
502 let result = erase_builtin(Value::Cell(cell), Value::StringArray(patterns)).expect("erase");
503 match result {
504 Value::Cell(out) => {
505 assert_eq!(out.rows, 2);
506 assert_eq!(out.cols, 2);
507 assert_eq!(out.get(0, 0).unwrap(), Value::String("lph".into()));
508 assert_eq!(out.get(1, 1).unwrap(), Value::String("delt".into()));
509 }
510 other => panic!("expected cell array, got {other:?}"),
511 }
512 }
513
514 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
515 #[test]
516 fn erase_preserves_missing_string() {
517 let result = erase_builtin(
518 Value::String("<missing>".into()),
519 Value::String("missing".into()),
520 )
521 .expect("erase");
522 assert_eq!(result, Value::String("<missing>".into()));
523 }
524
525 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
526 #[test]
527 fn erase_allows_empty_pattern_list() {
528 let strings = StringArray::new(vec!["alpha".into(), "beta".into()], vec![2, 1]).unwrap();
529 let pattern = StringArray::new(Vec::<String>::new(), vec![0, 0]).unwrap();
530 let result = erase_builtin(
531 Value::StringArray(strings.clone()),
532 Value::StringArray(pattern),
533 )
534 .expect("erase");
535 assert_eq!(result, Value::StringArray(strings));
536 }
537
538 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
539 #[test]
540 fn erase_errors_on_invalid_first_argument() {
541 let err = erase_builtin(Value::Num(1.0), Value::String("a".into())).unwrap_err();
542 assert_eq!(err.to_string(), ERASE_ERROR_INVALID_INPUT.message);
543 }
544
545 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
546 #[test]
547 fn erase_errors_on_invalid_pattern_type() {
548 let err = erase_builtin(Value::String("abc".into()), Value::Num(1.0)).unwrap_err();
549 assert_eq!(err.to_string(), ERASE_ERROR_PATTERN_TYPE.message);
550 }
551
552 #[test]
553 fn erase_type_preserves_text() {
554 assert_eq!(
555 text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
556 Type::String
557 );
558 }
559}