Skip to main content

runmat_runtime/builtins/strings/transform/
strrep.rs

1//! MATLAB-compatible `strrep` builtin with GPU-aware semantics for RunMat.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6    CellArray, CharArray, StringArray, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
16use crate::builtins::strings::type_resolvers::text_preserve_type;
17use crate::{
18    build_runtime_error, gather_if_needed_async, make_cell_with_shape, BuiltinResult, RuntimeError,
19};
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::strrep")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23    name: "strrep",
24    op_kind: GpuOpKind::Custom("string-transform"),
25    supported_precisions: &[],
26    broadcast: BroadcastSemantics::None,
27    provider_hooks: &[],
28    constant_strategy: ConstantStrategy::InlineLiteral,
29    residency: ResidencyPolicy::GatherImmediately,
30    nan_mode: ReductionNaN::Include,
31    two_pass_threshold: None,
32    workgroup_size: None,
33    accepts_nan_mode: false,
34    notes: "Executes on the CPU; GPU-resident inputs are gathered before replacements are applied.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::strrep")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39    name: "strrep",
40    shape: ShapeRequirements::Any,
41    constant_strategy: ConstantStrategy::InlineLiteral,
42    elementwise: None,
43    reduction: None,
44    emits_nan: false,
45    notes: "String transformation builtin; marked as a sink so fusion skips GPU residency.",
46};
47
48const BUILTIN_NAME: &str = "strrep";
49
50const STRREP_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
51    name: "newStr",
52    ty: BuiltinParamType::Any,
53    arity: BuiltinParamArity::Required,
54    default: None,
55    description: "Text with pattern occurrences replaced, preserving input container kind.",
56}];
57
58const STRREP_INPUTS: [BuiltinParamDescriptor; 3] = [
59    BuiltinParamDescriptor {
60        name: "str",
61        ty: BuiltinParamType::Any,
62        arity: BuiltinParamArity::Required,
63        default: None,
64        description: "Input text (string/char/cell).",
65    },
66    BuiltinParamDescriptor {
67        name: "old",
68        ty: BuiltinParamType::Any,
69        arity: BuiltinParamArity::Required,
70        default: None,
71        description: "Pattern text scalar (string or char row).",
72    },
73    BuiltinParamDescriptor {
74        name: "new",
75        ty: BuiltinParamType::Any,
76        arity: BuiltinParamArity::Required,
77        default: None,
78        description: "Replacement text scalar matching old's data type family.",
79    },
80];
81
82const STRREP_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
83    label: "newStr = strrep(str, old, new)",
84    inputs: &STRREP_INPUTS,
85    outputs: &STRREP_OUTPUT,
86}];
87
88const STRREP_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
89    code: "RM.STRREP.INVALID_INPUT",
90    identifier: Some("RunMat:strrep:InvalidInput"),
91    when: "First argument is not a string array, char array, or cell array of text scalars.",
92    message:
93        "strrep: first argument must be a string array, character array, or cell array of character vectors",
94};
95
96const STRREP_ERROR_PATTERN_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
97    code: "RM.STRREP.PATTERN_TYPE",
98    identifier: Some("RunMat:strrep:PatternType"),
99    when: "old/new arguments are not string scalars or character vectors.",
100    message: "strrep: old and new must be string scalars or character vectors",
101};
102
103const STRREP_ERROR_PATTERN_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
104    code: "RM.STRREP.PATTERN_MISMATCH",
105    identifier: Some("RunMat:strrep:PatternMismatch"),
106    when: "old and new are different text data families (string vs char).",
107    message: "strrep: old and new must be the same data type",
108};
109
110const STRREP_ERROR_CELL_ELEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
111    code: "RM.STRREP.CELL_ELEMENT",
112    identifier: Some("RunMat:strrep:CellElement"),
113    when: "Cell input contains non-text elements or non-row char arrays.",
114    message: "strrep: cell array elements must be string scalars or character vectors",
115};
116
117const STRREP_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
118    code: "RM.STRREP.INTERNAL",
119    identifier: Some("RunMat:strrep:InternalError"),
120    when: "Internal output container construction failed.",
121    message: "strrep: internal error",
122};
123
124const STRREP_ERRORS: [BuiltinErrorDescriptor; 5] = [
125    STRREP_ERROR_INVALID_INPUT,
126    STRREP_ERROR_PATTERN_TYPE,
127    STRREP_ERROR_PATTERN_MISMATCH,
128    STRREP_ERROR_CELL_ELEMENT,
129    STRREP_ERROR_INTERNAL,
130];
131
132pub const STRREP_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
133    signatures: &STRREP_SIGNATURES,
134    output_mode: BuiltinOutputMode::Fixed,
135    completion_policy: BuiltinCompletionPolicy::Public,
136    errors: &STRREP_ERRORS,
137};
138
139#[derive(Clone, Copy, PartialEq, Eq)]
140enum PatternKind {
141    String,
142    Char,
143}
144
145fn map_flow(err: RuntimeError) -> RuntimeError {
146    map_control_flow_with_builtin(err, BUILTIN_NAME)
147}
148
149fn strrep_error_with_message(
150    message: impl Into<String>,
151    error: &'static BuiltinErrorDescriptor,
152) -> RuntimeError {
153    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
154    if let Some(identifier) = error.identifier {
155        builder = builder.with_identifier(identifier);
156    }
157    builder.build()
158}
159
160fn strrep_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
161    strrep_error_with_message(error.message, error)
162}
163
164#[runtime_builtin(
165    name = "strrep",
166    category = "strings/transform",
167    summary = "Replace non-overlapping substring occurrences in text inputs.",
168    keywords = "strrep,replace,strings,character array,text",
169    accel = "sink",
170    type_resolver(text_preserve_type),
171    descriptor(crate::builtins::strings::transform::strrep::STRREP_DESCRIPTOR),
172    builtin_path = "crate::builtins::strings::transform::strrep"
173)]
174async fn strrep_builtin(
175    str_value: Value,
176    old_value: Value,
177    new_value: Value,
178) -> BuiltinResult<Value> {
179    let gathered_str = gather_if_needed_async(&str_value).await.map_err(map_flow)?;
180    let gathered_old = gather_if_needed_async(&old_value).await.map_err(map_flow)?;
181    let gathered_new = gather_if_needed_async(&new_value).await.map_err(map_flow)?;
182
183    let (old_text, old_kind) = parse_pattern(gathered_old)?;
184    let (new_text, new_kind) = parse_pattern(gathered_new)?;
185    if old_kind != new_kind {
186        return Err(strrep_error(&STRREP_ERROR_PATTERN_MISMATCH));
187    }
188
189    match gathered_str {
190        Value::String(text) => Ok(Value::String(strrep_string_value(
191            text, &old_text, &new_text,
192        ))),
193        Value::StringArray(array) => strrep_string_array(array, &old_text, &new_text),
194        Value::CharArray(array) => strrep_char_array(array, &old_text, &new_text),
195        Value::Cell(cell) => strrep_cell_array(cell, &old_text, &new_text),
196        _ => Err(strrep_error(&STRREP_ERROR_INVALID_INPUT)),
197    }
198}
199
200fn parse_pattern(value: Value) -> BuiltinResult<(String, PatternKind)> {
201    match value {
202        Value::String(text) => Ok((text, PatternKind::String)),
203        Value::StringArray(array) => {
204            if array.data.len() == 1 {
205                Ok((array.data[0].clone(), PatternKind::String))
206            } else {
207                Err(strrep_error(&STRREP_ERROR_PATTERN_TYPE))
208            }
209        }
210        Value::CharArray(array) => {
211            if array.rows <= 1 {
212                let text = if array.rows == 0 {
213                    String::new()
214                } else {
215                    char_row_to_string_slice(&array.data, array.cols, 0)
216                };
217                Ok((text, PatternKind::Char))
218            } else {
219                Err(strrep_error(&STRREP_ERROR_PATTERN_TYPE))
220            }
221        }
222        _ => Err(strrep_error(&STRREP_ERROR_PATTERN_TYPE)),
223    }
224}
225
226fn strrep_string_value(text: String, old: &str, new: &str) -> String {
227    if is_missing_string(&text) {
228        text
229    } else {
230        text.replace(old, new)
231    }
232}
233
234fn strrep_string_array(array: StringArray, old: &str, new: &str) -> BuiltinResult<Value> {
235    let StringArray { data, shape, .. } = array;
236    let replaced = data
237        .into_iter()
238        .map(|text| strrep_string_value(text, old, new))
239        .collect::<Vec<_>>();
240    let rebuilt = StringArray::new(replaced, shape).map_err(|e| {
241        strrep_error_with_message(format!("{BUILTIN_NAME}: {e}"), &STRREP_ERROR_INTERNAL)
242    })?;
243    Ok(Value::StringArray(rebuilt))
244}
245
246fn strrep_char_array(array: CharArray, old: &str, new: &str) -> BuiltinResult<Value> {
247    let CharArray { data, rows, cols } = array;
248    if rows == 0 || cols == 0 {
249        return Ok(Value::CharArray(CharArray { data, rows, cols }));
250    }
251
252    let mut replaced_rows = Vec::with_capacity(rows);
253    let mut target_cols = 0usize;
254    for row in 0..rows {
255        let text = char_row_to_string_slice(&data, cols, row);
256        let replaced = text.replace(old, new);
257        target_cols = target_cols.max(replaced.chars().count());
258        replaced_rows.push(replaced);
259    }
260
261    let mut new_data = Vec::with_capacity(rows * target_cols);
262    for row_text in replaced_rows {
263        let mut chars: Vec<char> = row_text.chars().collect();
264        if chars.len() < target_cols {
265            chars.resize(target_cols, ' ');
266        }
267        new_data.extend(chars);
268    }
269
270    CharArray::new(new_data, rows, target_cols)
271        .map(Value::CharArray)
272        .map_err(|e| {
273            strrep_error_with_message(format!("{BUILTIN_NAME}: {e}"), &STRREP_ERROR_INTERNAL)
274        })
275}
276
277fn strrep_cell_array(cell: CellArray, old: &str, new: &str) -> BuiltinResult<Value> {
278    let CellArray { data, shape, .. } = cell;
279    let mut replaced = Vec::with_capacity(data.len());
280    for ptr in &data {
281        replaced.push(strrep_cell_element(ptr, old, new)?);
282    }
283    make_cell_with_shape(replaced, shape).map_err(|e| {
284        strrep_error_with_message(format!("{BUILTIN_NAME}: {e}"), &STRREP_ERROR_INTERNAL)
285    })
286}
287
288fn strrep_cell_element(value: &Value, old: &str, new: &str) -> BuiltinResult<Value> {
289    match value {
290        Value::String(text) => Ok(Value::String(strrep_string_value(text.clone(), old, new))),
291        Value::StringArray(array) => strrep_string_array(array.clone(), old, new),
292        Value::CharArray(array) => strrep_char_array(array.clone(), old, new),
293        _ => Err(strrep_error(&STRREP_ERROR_CELL_ELEMENT)),
294    }
295}
296
297#[cfg(test)]
298pub(crate) mod tests {
299    use super::*;
300    use runmat_builtins::{ResolveContext, Type};
301
302    fn run_strrep(str_value: Value, old_value: Value, new_value: Value) -> BuiltinResult<Value> {
303        futures::executor::block_on(strrep_builtin(str_value, old_value, new_value))
304    }
305
306    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
307    #[test]
308    fn strrep_string_scalar_basic() {
309        let result = run_strrep(
310            Value::String("RunMat Ignite".into()),
311            Value::String("Ignite".into()),
312            Value::String("Interpreter".into()),
313        )
314        .expect("strrep");
315        assert_eq!(result, Value::String("RunMat Interpreter".into()));
316    }
317
318    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
319    #[test]
320    fn strrep_string_array_preserves_missing() {
321        let array = StringArray::new(
322            vec![
323                String::from("gpu"),
324                String::from("<missing>"),
325                String::from("planner"),
326            ],
327            vec![3, 1],
328        )
329        .unwrap();
330        let result = run_strrep(
331            Value::StringArray(array),
332            Value::String("gpu".into()),
333            Value::String("GPU".into()),
334        )
335        .expect("strrep");
336        match result {
337            Value::StringArray(sa) => {
338                assert_eq!(sa.shape, vec![3, 1]);
339                assert_eq!(
340                    sa.data,
341                    vec![
342                        String::from("GPU"),
343                        String::from("<missing>"),
344                        String::from("planner")
345                    ]
346                );
347            }
348            other => panic!("expected string array, got {other:?}"),
349        }
350    }
351
352    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
353    #[test]
354    fn strrep_string_array_with_char_pattern() {
355        let array = StringArray::new(
356            vec![String::from("alpha"), String::from("beta")],
357            vec![2, 1],
358        )
359        .unwrap();
360        let result = run_strrep(
361            Value::StringArray(array),
362            Value::CharArray(CharArray::new_row("a")),
363            Value::CharArray(CharArray::new_row("A")),
364        )
365        .expect("strrep");
366        match result {
367            Value::StringArray(sa) => {
368                assert_eq!(sa.shape, vec![2, 1]);
369                assert_eq!(sa.data, vec![String::from("AlphA"), String::from("betA")]);
370            }
371            other => panic!("expected string array, got {other:?}"),
372        }
373    }
374
375    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
376    #[test]
377    fn strrep_char_array_padding() {
378        let chars = CharArray::new(vec!['R', 'u', 'n', ' ', 'M', 'a', 't'], 1, 7).unwrap();
379        let result = run_strrep(
380            Value::CharArray(chars),
381            Value::String(" ".into()),
382            Value::String("_".into()),
383        )
384        .expect("strrep");
385        match result {
386            Value::CharArray(out) => {
387                assert_eq!(out.rows, 1);
388                assert_eq!(out.cols, 7);
389                let expected: Vec<char> = "Run_Mat".chars().collect();
390                assert_eq!(out.data, expected);
391            }
392            other => panic!("expected char array, got {other:?}"),
393        }
394    }
395
396    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
397    #[test]
398    fn strrep_char_array_shrinks_rows_pad_with_spaces() {
399        let mut data: Vec<char> = "alpha".chars().collect();
400        data.extend("beta ".chars());
401        let array = CharArray::new(data, 2, 5).unwrap();
402        let result = run_strrep(
403            Value::CharArray(array),
404            Value::String("a".into()),
405            Value::String("".into()),
406        )
407        .expect("strrep");
408        match result {
409            Value::CharArray(out) => {
410                assert_eq!(out.rows, 2);
411                assert_eq!(out.cols, 4);
412                let expected: Vec<char> = vec!['l', 'p', 'h', ' ', 'b', 'e', 't', ' '];
413                assert_eq!(out.data, expected);
414            }
415            other => panic!("expected char array, got {other:?}"),
416        }
417    }
418
419    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
420    #[test]
421    fn strrep_cell_array_char_vectors() {
422        let cell = CellArray::new(
423            vec![
424                Value::CharArray(CharArray::new_row("Kernel Fusion")),
425                Value::CharArray(CharArray::new_row("GPU Planner")),
426            ],
427            1,
428            2,
429        )
430        .unwrap();
431        let result = run_strrep(
432            Value::Cell(cell),
433            Value::String(" ".into()),
434            Value::String("_".into()),
435        )
436        .expect("strrep");
437        match result {
438            Value::Cell(out) => {
439                assert_eq!(out.rows, 1);
440                assert_eq!(out.cols, 2);
441                assert_eq!(
442                    out.get(0, 0).unwrap(),
443                    Value::CharArray(CharArray::new_row("Kernel_Fusion"))
444                );
445                assert_eq!(
446                    out.get(0, 1).unwrap(),
447                    Value::CharArray(CharArray::new_row("GPU_Planner"))
448                );
449            }
450            other => panic!("expected cell array, got {other:?}"),
451        }
452    }
453
454    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
455    #[test]
456    fn strrep_cell_array_string_scalars() {
457        let cell = CellArray::new(
458            vec![
459                Value::String("Planner".into()),
460                Value::String("Profiler".into()),
461            ],
462            1,
463            2,
464        )
465        .unwrap();
466        let result = run_strrep(
467            Value::Cell(cell),
468            Value::String("er".into()),
469            Value::String("ER".into()),
470        )
471        .expect("strrep");
472        match result {
473            Value::Cell(out) => {
474                assert_eq!(out.rows, 1);
475                assert_eq!(out.cols, 2);
476                assert_eq!(out.get(0, 0).unwrap(), Value::String("PlannER".into()));
477                assert_eq!(out.get(0, 1).unwrap(), Value::String("ProfilER".into()));
478            }
479            other => panic!("expected cell array, got {other:?}"),
480        }
481    }
482
483    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
484    #[test]
485    fn strrep_cell_array_invalid_element_error() {
486        let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
487        let err = run_strrep(
488            Value::Cell(cell),
489            Value::String("1".into()),
490            Value::String("one".into()),
491        )
492        .expect_err("expected cell element error");
493        assert!(err.to_string().contains("cell array elements"));
494    }
495
496    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
497    #[test]
498    fn strrep_cell_array_char_matrix_element() {
499        let mut chars: Vec<char> = "alpha".chars().collect();
500        chars.extend("beta ".chars());
501        let element = CharArray::new(chars, 2, 5).unwrap();
502        let cell = CellArray::new(vec![Value::CharArray(element)], 1, 1).unwrap();
503        let result = run_strrep(
504            Value::Cell(cell),
505            Value::String("a".into()),
506            Value::String("A".into()),
507        )
508        .expect("strrep");
509        match result {
510            Value::Cell(out) => {
511                let nested = out.get(0, 0).unwrap();
512                match nested {
513                    Value::CharArray(ca) => {
514                        assert_eq!(ca.rows, 2);
515                        assert_eq!(ca.cols, 5);
516                        let expected: Vec<char> =
517                            vec!['A', 'l', 'p', 'h', 'A', 'b', 'e', 't', 'A', ' '];
518                        assert_eq!(ca.data, expected);
519                    }
520                    other => panic!("expected char array element, got {other:?}"),
521                }
522            }
523            other => panic!("expected cell array, got {other:?}"),
524        }
525    }
526
527    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
528    #[test]
529    fn strrep_cell_array_string_arrays() {
530        let element = StringArray::new(vec!["alpha".into(), "beta".into()], vec![1, 2]).unwrap();
531        let cell = CellArray::new(vec![Value::StringArray(element)], 1, 1).unwrap();
532        let result = run_strrep(
533            Value::Cell(cell),
534            Value::String("a".into()),
535            Value::String("A".into()),
536        )
537        .expect("strrep");
538        match result {
539            Value::Cell(out) => {
540                let nested = out.get(0, 0).unwrap();
541                match nested {
542                    Value::StringArray(sa) => {
543                        assert_eq!(sa.shape, vec![1, 2]);
544                        assert_eq!(sa.data, vec![String::from("AlphA"), String::from("betA")]);
545                    }
546                    other => panic!("expected string array element, got {other:?}"),
547                }
548            }
549            other => panic!("expected cell array, got {other:?}"),
550        }
551    }
552
553    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
554    #[test]
555    fn strrep_empty_pattern_inserts_replacement() {
556        let result = run_strrep(
557            Value::String("abc".into()),
558            Value::String("".into()),
559            Value::String("-".into()),
560        )
561        .expect("strrep");
562        assert_eq!(result, Value::String("-a-b-c-".into()));
563    }
564
565    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
566    #[test]
567    fn strrep_type_mismatch_errors() {
568        let err = run_strrep(
569            Value::String("abc".into()),
570            Value::String("a".into()),
571            Value::CharArray(CharArray::new_row("x")),
572        )
573        .expect_err("expected type mismatch");
574        assert!(err.to_string().contains("same data type"));
575    }
576
577    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
578    #[test]
579    fn strrep_invalid_pattern_type_errors() {
580        let err = run_strrep(
581            Value::String("abc".into()),
582            Value::Num(1.0),
583            Value::String("x".into()),
584        )
585        .expect_err("expected pattern error");
586        assert!(err
587            .to_string()
588            .contains("string scalars or character vectors"));
589    }
590
591    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
592    #[test]
593    fn strrep_first_argument_type_error() {
594        let err = run_strrep(
595            Value::Num(42.0),
596            Value::String("a".into()),
597            Value::String("b".into()),
598        )
599        .expect_err("expected argument type error");
600        assert!(err.to_string().contains("first argument"));
601    }
602
603    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
604    #[test]
605    #[cfg(feature = "wgpu")]
606    fn strrep_wgpu_provider_fallback() {
607        if runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
608            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
609        )
610        .is_err()
611        {
612            // Unable to initialize the provider in this environment; skip.
613            return;
614        }
615        let result = run_strrep(
616            Value::String("Turbine Engine".into()),
617            Value::String("Engine".into()),
618            Value::String("JIT".into()),
619        )
620        .expect("strrep");
621        assert_eq!(result, Value::String("Turbine JIT".into()));
622    }
623
624    #[test]
625    fn strrep_type_preserves_text() {
626        assert_eq!(
627            text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
628            Type::String
629        );
630    }
631}