1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6 CellArray, CharArray, StringArray, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
16use crate::builtins::strings::type_resolvers::text_preserve_type;
17use crate::{
18 build_runtime_error, gather_if_needed_async, make_cell_with_shape, BuiltinResult, RuntimeError,
19};
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::strrep")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23 name: "strrep",
24 op_kind: GpuOpKind::Custom("string-transform"),
25 supported_precisions: &[],
26 broadcast: BroadcastSemantics::None,
27 provider_hooks: &[],
28 constant_strategy: ConstantStrategy::InlineLiteral,
29 residency: ResidencyPolicy::GatherImmediately,
30 nan_mode: ReductionNaN::Include,
31 two_pass_threshold: None,
32 workgroup_size: None,
33 accepts_nan_mode: false,
34 notes: "Executes on the CPU; GPU-resident inputs are gathered before replacements are applied.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::strrep")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39 name: "strrep",
40 shape: ShapeRequirements::Any,
41 constant_strategy: ConstantStrategy::InlineLiteral,
42 elementwise: None,
43 reduction: None,
44 emits_nan: false,
45 notes: "String transformation builtin; marked as a sink so fusion skips GPU residency.",
46};
47
48const BUILTIN_NAME: &str = "strrep";
49
50const STRREP_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
51 name: "newStr",
52 ty: BuiltinParamType::Any,
53 arity: BuiltinParamArity::Required,
54 default: None,
55 description: "Text with pattern occurrences replaced, preserving input container kind.",
56}];
57
58const STRREP_INPUTS: [BuiltinParamDescriptor; 3] = [
59 BuiltinParamDescriptor {
60 name: "str",
61 ty: BuiltinParamType::Any,
62 arity: BuiltinParamArity::Required,
63 default: None,
64 description: "Input text (string/char/cell).",
65 },
66 BuiltinParamDescriptor {
67 name: "old",
68 ty: BuiltinParamType::Any,
69 arity: BuiltinParamArity::Required,
70 default: None,
71 description: "Pattern text scalar (string or char row).",
72 },
73 BuiltinParamDescriptor {
74 name: "new",
75 ty: BuiltinParamType::Any,
76 arity: BuiltinParamArity::Required,
77 default: None,
78 description: "Replacement text scalar matching old's data type family.",
79 },
80];
81
82const STRREP_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
83 label: "newStr = strrep(str, old, new)",
84 inputs: &STRREP_INPUTS,
85 outputs: &STRREP_OUTPUT,
86}];
87
88const STRREP_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
89 code: "RM.STRREP.INVALID_INPUT",
90 identifier: Some("RunMat:strrep:InvalidInput"),
91 when: "First argument is not a string array, char array, or cell array of text scalars.",
92 message:
93 "strrep: first argument must be a string array, character array, or cell array of character vectors",
94};
95
96const STRREP_ERROR_PATTERN_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
97 code: "RM.STRREP.PATTERN_TYPE",
98 identifier: Some("RunMat:strrep:PatternType"),
99 when: "old/new arguments are not string scalars or character vectors.",
100 message: "strrep: old and new must be string scalars or character vectors",
101};
102
103const STRREP_ERROR_PATTERN_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
104 code: "RM.STRREP.PATTERN_MISMATCH",
105 identifier: Some("RunMat:strrep:PatternMismatch"),
106 when: "old and new are different text data families (string vs char).",
107 message: "strrep: old and new must be the same data type",
108};
109
110const STRREP_ERROR_CELL_ELEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
111 code: "RM.STRREP.CELL_ELEMENT",
112 identifier: Some("RunMat:strrep:CellElement"),
113 when: "Cell input contains non-text elements or non-row char arrays.",
114 message: "strrep: cell array elements must be string scalars or character vectors",
115};
116
117const STRREP_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
118 code: "RM.STRREP.INTERNAL",
119 identifier: Some("RunMat:strrep:InternalError"),
120 when: "Internal output container construction failed.",
121 message: "strrep: internal error",
122};
123
124const STRREP_ERRORS: [BuiltinErrorDescriptor; 5] = [
125 STRREP_ERROR_INVALID_INPUT,
126 STRREP_ERROR_PATTERN_TYPE,
127 STRREP_ERROR_PATTERN_MISMATCH,
128 STRREP_ERROR_CELL_ELEMENT,
129 STRREP_ERROR_INTERNAL,
130];
131
132pub const STRREP_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
133 signatures: &STRREP_SIGNATURES,
134 output_mode: BuiltinOutputMode::Fixed,
135 completion_policy: BuiltinCompletionPolicy::Public,
136 errors: &STRREP_ERRORS,
137};
138
139#[derive(Clone, Copy, PartialEq, Eq)]
140enum PatternKind {
141 String,
142 Char,
143}
144
145fn map_flow(err: RuntimeError) -> RuntimeError {
146 map_control_flow_with_builtin(err, BUILTIN_NAME)
147}
148
149fn strrep_error_with_message(
150 message: impl Into<String>,
151 error: &'static BuiltinErrorDescriptor,
152) -> RuntimeError {
153 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
154 if let Some(identifier) = error.identifier {
155 builder = builder.with_identifier(identifier);
156 }
157 builder.build()
158}
159
160fn strrep_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
161 strrep_error_with_message(error.message, error)
162}
163
164#[runtime_builtin(
165 name = "strrep",
166 category = "strings/transform",
167 summary = "Replace non-overlapping substring occurrences in text inputs.",
168 keywords = "strrep,replace,strings,character array,text",
169 accel = "sink",
170 type_resolver(text_preserve_type),
171 descriptor(crate::builtins::strings::transform::strrep::STRREP_DESCRIPTOR),
172 builtin_path = "crate::builtins::strings::transform::strrep"
173)]
174async fn strrep_builtin(
175 str_value: Value,
176 old_value: Value,
177 new_value: Value,
178) -> BuiltinResult<Value> {
179 let gathered_str = gather_if_needed_async(&str_value).await.map_err(map_flow)?;
180 let gathered_old = gather_if_needed_async(&old_value).await.map_err(map_flow)?;
181 let gathered_new = gather_if_needed_async(&new_value).await.map_err(map_flow)?;
182
183 let (old_text, old_kind) = parse_pattern(gathered_old)?;
184 let (new_text, new_kind) = parse_pattern(gathered_new)?;
185 if old_kind != new_kind {
186 return Err(strrep_error(&STRREP_ERROR_PATTERN_MISMATCH));
187 }
188
189 match gathered_str {
190 Value::String(text) => Ok(Value::String(strrep_string_value(
191 text, &old_text, &new_text,
192 ))),
193 Value::StringArray(array) => strrep_string_array(array, &old_text, &new_text),
194 Value::CharArray(array) => strrep_char_array(array, &old_text, &new_text),
195 Value::Cell(cell) => strrep_cell_array(cell, &old_text, &new_text),
196 _ => Err(strrep_error(&STRREP_ERROR_INVALID_INPUT)),
197 }
198}
199
200fn parse_pattern(value: Value) -> BuiltinResult<(String, PatternKind)> {
201 match value {
202 Value::String(text) => Ok((text, PatternKind::String)),
203 Value::StringArray(array) => {
204 if array.data.len() == 1 {
205 Ok((array.data[0].clone(), PatternKind::String))
206 } else {
207 Err(strrep_error(&STRREP_ERROR_PATTERN_TYPE))
208 }
209 }
210 Value::CharArray(array) => {
211 if array.rows <= 1 {
212 let text = if array.rows == 0 {
213 String::new()
214 } else {
215 char_row_to_string_slice(&array.data, array.cols, 0)
216 };
217 Ok((text, PatternKind::Char))
218 } else {
219 Err(strrep_error(&STRREP_ERROR_PATTERN_TYPE))
220 }
221 }
222 _ => Err(strrep_error(&STRREP_ERROR_PATTERN_TYPE)),
223 }
224}
225
226fn strrep_string_value(text: String, old: &str, new: &str) -> String {
227 if is_missing_string(&text) {
228 text
229 } else {
230 text.replace(old, new)
231 }
232}
233
234fn strrep_string_array(array: StringArray, old: &str, new: &str) -> BuiltinResult<Value> {
235 let StringArray { data, shape, .. } = array;
236 let replaced = data
237 .into_iter()
238 .map(|text| strrep_string_value(text, old, new))
239 .collect::<Vec<_>>();
240 let rebuilt = StringArray::new(replaced, shape).map_err(|e| {
241 strrep_error_with_message(format!("{BUILTIN_NAME}: {e}"), &STRREP_ERROR_INTERNAL)
242 })?;
243 Ok(Value::StringArray(rebuilt))
244}
245
246fn strrep_char_array(array: CharArray, old: &str, new: &str) -> BuiltinResult<Value> {
247 let CharArray { data, rows, cols } = array;
248 if rows == 0 || cols == 0 {
249 return Ok(Value::CharArray(CharArray { data, rows, cols }));
250 }
251
252 let mut replaced_rows = Vec::with_capacity(rows);
253 let mut target_cols = 0usize;
254 for row in 0..rows {
255 let text = char_row_to_string_slice(&data, cols, row);
256 let replaced = text.replace(old, new);
257 target_cols = target_cols.max(replaced.chars().count());
258 replaced_rows.push(replaced);
259 }
260
261 let mut new_data = Vec::with_capacity(rows * target_cols);
262 for row_text in replaced_rows {
263 let mut chars: Vec<char> = row_text.chars().collect();
264 if chars.len() < target_cols {
265 chars.resize(target_cols, ' ');
266 }
267 new_data.extend(chars);
268 }
269
270 CharArray::new(new_data, rows, target_cols)
271 .map(Value::CharArray)
272 .map_err(|e| {
273 strrep_error_with_message(format!("{BUILTIN_NAME}: {e}"), &STRREP_ERROR_INTERNAL)
274 })
275}
276
277fn strrep_cell_array(cell: CellArray, old: &str, new: &str) -> BuiltinResult<Value> {
278 let CellArray { data, shape, .. } = cell;
279 let mut replaced = Vec::with_capacity(data.len());
280 for ptr in &data {
281 replaced.push(strrep_cell_element(ptr, old, new)?);
282 }
283 make_cell_with_shape(replaced, shape).map_err(|e| {
284 strrep_error_with_message(format!("{BUILTIN_NAME}: {e}"), &STRREP_ERROR_INTERNAL)
285 })
286}
287
288fn strrep_cell_element(value: &Value, old: &str, new: &str) -> BuiltinResult<Value> {
289 match value {
290 Value::String(text) => Ok(Value::String(strrep_string_value(text.clone(), old, new))),
291 Value::StringArray(array) => strrep_string_array(array.clone(), old, new),
292 Value::CharArray(array) => strrep_char_array(array.clone(), old, new),
293 _ => Err(strrep_error(&STRREP_ERROR_CELL_ELEMENT)),
294 }
295}
296
297#[cfg(test)]
298pub(crate) mod tests {
299 use super::*;
300 use runmat_builtins::{ResolveContext, Type};
301
302 fn run_strrep(str_value: Value, old_value: Value, new_value: Value) -> BuiltinResult<Value> {
303 futures::executor::block_on(strrep_builtin(str_value, old_value, new_value))
304 }
305
306 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
307 #[test]
308 fn strrep_string_scalar_basic() {
309 let result = run_strrep(
310 Value::String("RunMat Ignite".into()),
311 Value::String("Ignite".into()),
312 Value::String("Interpreter".into()),
313 )
314 .expect("strrep");
315 assert_eq!(result, Value::String("RunMat Interpreter".into()));
316 }
317
318 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
319 #[test]
320 fn strrep_string_array_preserves_missing() {
321 let array = StringArray::new(
322 vec![
323 String::from("gpu"),
324 String::from("<missing>"),
325 String::from("planner"),
326 ],
327 vec![3, 1],
328 )
329 .unwrap();
330 let result = run_strrep(
331 Value::StringArray(array),
332 Value::String("gpu".into()),
333 Value::String("GPU".into()),
334 )
335 .expect("strrep");
336 match result {
337 Value::StringArray(sa) => {
338 assert_eq!(sa.shape, vec![3, 1]);
339 assert_eq!(
340 sa.data,
341 vec![
342 String::from("GPU"),
343 String::from("<missing>"),
344 String::from("planner")
345 ]
346 );
347 }
348 other => panic!("expected string array, got {other:?}"),
349 }
350 }
351
352 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
353 #[test]
354 fn strrep_string_array_with_char_pattern() {
355 let array = StringArray::new(
356 vec![String::from("alpha"), String::from("beta")],
357 vec![2, 1],
358 )
359 .unwrap();
360 let result = run_strrep(
361 Value::StringArray(array),
362 Value::CharArray(CharArray::new_row("a")),
363 Value::CharArray(CharArray::new_row("A")),
364 )
365 .expect("strrep");
366 match result {
367 Value::StringArray(sa) => {
368 assert_eq!(sa.shape, vec![2, 1]);
369 assert_eq!(sa.data, vec![String::from("AlphA"), String::from("betA")]);
370 }
371 other => panic!("expected string array, got {other:?}"),
372 }
373 }
374
375 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
376 #[test]
377 fn strrep_char_array_padding() {
378 let chars = CharArray::new(vec!['R', 'u', 'n', ' ', 'M', 'a', 't'], 1, 7).unwrap();
379 let result = run_strrep(
380 Value::CharArray(chars),
381 Value::String(" ".into()),
382 Value::String("_".into()),
383 )
384 .expect("strrep");
385 match result {
386 Value::CharArray(out) => {
387 assert_eq!(out.rows, 1);
388 assert_eq!(out.cols, 7);
389 let expected: Vec<char> = "Run_Mat".chars().collect();
390 assert_eq!(out.data, expected);
391 }
392 other => panic!("expected char array, got {other:?}"),
393 }
394 }
395
396 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
397 #[test]
398 fn strrep_char_array_shrinks_rows_pad_with_spaces() {
399 let mut data: Vec<char> = "alpha".chars().collect();
400 data.extend("beta ".chars());
401 let array = CharArray::new(data, 2, 5).unwrap();
402 let result = run_strrep(
403 Value::CharArray(array),
404 Value::String("a".into()),
405 Value::String("".into()),
406 )
407 .expect("strrep");
408 match result {
409 Value::CharArray(out) => {
410 assert_eq!(out.rows, 2);
411 assert_eq!(out.cols, 4);
412 let expected: Vec<char> = vec!['l', 'p', 'h', ' ', 'b', 'e', 't', ' '];
413 assert_eq!(out.data, expected);
414 }
415 other => panic!("expected char array, got {other:?}"),
416 }
417 }
418
419 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
420 #[test]
421 fn strrep_cell_array_char_vectors() {
422 let cell = CellArray::new(
423 vec![
424 Value::CharArray(CharArray::new_row("Kernel Fusion")),
425 Value::CharArray(CharArray::new_row("GPU Planner")),
426 ],
427 1,
428 2,
429 )
430 .unwrap();
431 let result = run_strrep(
432 Value::Cell(cell),
433 Value::String(" ".into()),
434 Value::String("_".into()),
435 )
436 .expect("strrep");
437 match result {
438 Value::Cell(out) => {
439 assert_eq!(out.rows, 1);
440 assert_eq!(out.cols, 2);
441 assert_eq!(
442 out.get(0, 0).unwrap(),
443 Value::CharArray(CharArray::new_row("Kernel_Fusion"))
444 );
445 assert_eq!(
446 out.get(0, 1).unwrap(),
447 Value::CharArray(CharArray::new_row("GPU_Planner"))
448 );
449 }
450 other => panic!("expected cell array, got {other:?}"),
451 }
452 }
453
454 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
455 #[test]
456 fn strrep_cell_array_string_scalars() {
457 let cell = CellArray::new(
458 vec![
459 Value::String("Planner".into()),
460 Value::String("Profiler".into()),
461 ],
462 1,
463 2,
464 )
465 .unwrap();
466 let result = run_strrep(
467 Value::Cell(cell),
468 Value::String("er".into()),
469 Value::String("ER".into()),
470 )
471 .expect("strrep");
472 match result {
473 Value::Cell(out) => {
474 assert_eq!(out.rows, 1);
475 assert_eq!(out.cols, 2);
476 assert_eq!(out.get(0, 0).unwrap(), Value::String("PlannER".into()));
477 assert_eq!(out.get(0, 1).unwrap(), Value::String("ProfilER".into()));
478 }
479 other => panic!("expected cell array, got {other:?}"),
480 }
481 }
482
483 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
484 #[test]
485 fn strrep_cell_array_invalid_element_error() {
486 let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
487 let err = run_strrep(
488 Value::Cell(cell),
489 Value::String("1".into()),
490 Value::String("one".into()),
491 )
492 .expect_err("expected cell element error");
493 assert!(err.to_string().contains("cell array elements"));
494 }
495
496 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
497 #[test]
498 fn strrep_cell_array_char_matrix_element() {
499 let mut chars: Vec<char> = "alpha".chars().collect();
500 chars.extend("beta ".chars());
501 let element = CharArray::new(chars, 2, 5).unwrap();
502 let cell = CellArray::new(vec![Value::CharArray(element)], 1, 1).unwrap();
503 let result = run_strrep(
504 Value::Cell(cell),
505 Value::String("a".into()),
506 Value::String("A".into()),
507 )
508 .expect("strrep");
509 match result {
510 Value::Cell(out) => {
511 let nested = out.get(0, 0).unwrap();
512 match nested {
513 Value::CharArray(ca) => {
514 assert_eq!(ca.rows, 2);
515 assert_eq!(ca.cols, 5);
516 let expected: Vec<char> =
517 vec!['A', 'l', 'p', 'h', 'A', 'b', 'e', 't', 'A', ' '];
518 assert_eq!(ca.data, expected);
519 }
520 other => panic!("expected char array element, got {other:?}"),
521 }
522 }
523 other => panic!("expected cell array, got {other:?}"),
524 }
525 }
526
527 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
528 #[test]
529 fn strrep_cell_array_string_arrays() {
530 let element = StringArray::new(vec!["alpha".into(), "beta".into()], vec![1, 2]).unwrap();
531 let cell = CellArray::new(vec![Value::StringArray(element)], 1, 1).unwrap();
532 let result = run_strrep(
533 Value::Cell(cell),
534 Value::String("a".into()),
535 Value::String("A".into()),
536 )
537 .expect("strrep");
538 match result {
539 Value::Cell(out) => {
540 let nested = out.get(0, 0).unwrap();
541 match nested {
542 Value::StringArray(sa) => {
543 assert_eq!(sa.shape, vec![1, 2]);
544 assert_eq!(sa.data, vec![String::from("AlphA"), String::from("betA")]);
545 }
546 other => panic!("expected string array element, got {other:?}"),
547 }
548 }
549 other => panic!("expected cell array, got {other:?}"),
550 }
551 }
552
553 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
554 #[test]
555 fn strrep_empty_pattern_inserts_replacement() {
556 let result = run_strrep(
557 Value::String("abc".into()),
558 Value::String("".into()),
559 Value::String("-".into()),
560 )
561 .expect("strrep");
562 assert_eq!(result, Value::String("-a-b-c-".into()));
563 }
564
565 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
566 #[test]
567 fn strrep_type_mismatch_errors() {
568 let err = run_strrep(
569 Value::String("abc".into()),
570 Value::String("a".into()),
571 Value::CharArray(CharArray::new_row("x")),
572 )
573 .expect_err("expected type mismatch");
574 assert!(err.to_string().contains("same data type"));
575 }
576
577 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
578 #[test]
579 fn strrep_invalid_pattern_type_errors() {
580 let err = run_strrep(
581 Value::String("abc".into()),
582 Value::Num(1.0),
583 Value::String("x".into()),
584 )
585 .expect_err("expected pattern error");
586 assert!(err
587 .to_string()
588 .contains("string scalars or character vectors"));
589 }
590
591 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
592 #[test]
593 fn strrep_first_argument_type_error() {
594 let err = run_strrep(
595 Value::Num(42.0),
596 Value::String("a".into()),
597 Value::String("b".into()),
598 )
599 .expect_err("expected argument type error");
600 assert!(err.to_string().contains("first argument"));
601 }
602
603 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
604 #[test]
605 #[cfg(feature = "wgpu")]
606 fn strrep_wgpu_provider_fallback() {
607 if runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
608 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
609 )
610 .is_err()
611 {
612 return;
614 }
615 let result = run_strrep(
616 Value::String("Turbine Engine".into()),
617 Value::String("Engine".into()),
618 Value::String("JIT".into()),
619 )
620 .expect("strrep");
621 assert_eq!(result, Value::String("Turbine JIT".into()));
622 }
623
624 #[test]
625 fn strrep_type_preserves_text() {
626 assert_eq!(
627 text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
628 Type::String
629 );
630 }
631}