Skip to main content

runmat_runtime/builtins/strings/transform/
strrep.rs

1//! MATLAB-compatible `strrep` builtin with GPU-aware semantics for RunMat.
2
3use runmat_builtins::{CellArray, CharArray, StringArray, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::spec::{
8    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9    ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
12use crate::builtins::strings::type_resolvers::text_preserve_type;
13use crate::{
14    build_runtime_error, gather_if_needed_async, make_cell_with_shape, BuiltinResult, RuntimeError,
15};
16
17#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::strrep")]
18pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
19    name: "strrep",
20    op_kind: GpuOpKind::Custom("string-transform"),
21    supported_precisions: &[],
22    broadcast: BroadcastSemantics::None,
23    provider_hooks: &[],
24    constant_strategy: ConstantStrategy::InlineLiteral,
25    residency: ResidencyPolicy::GatherImmediately,
26    nan_mode: ReductionNaN::Include,
27    two_pass_threshold: None,
28    workgroup_size: None,
29    accepts_nan_mode: false,
30    notes: "Executes on the CPU; GPU-resident inputs are gathered before replacements are applied.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::strrep")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35    name: "strrep",
36    shape: ShapeRequirements::Any,
37    constant_strategy: ConstantStrategy::InlineLiteral,
38    elementwise: None,
39    reduction: None,
40    emits_nan: false,
41    notes: "String transformation builtin; marked as a sink so fusion skips GPU residency.",
42};
43
44const BUILTIN_NAME: &str = "strrep";
45const ARGUMENT_TYPE_ERROR: &str =
46    "strrep: first argument must be a string array, character array, or cell array of character vectors";
47const PATTERN_TYPE_ERROR: &str = "strrep: old and new must be string scalars or character vectors";
48const PATTERN_MISMATCH_ERROR: &str = "strrep: old and new must be the same data type";
49const CELL_ELEMENT_ERROR: &str =
50    "strrep: cell array elements must be string scalars or character vectors";
51
52#[derive(Clone, Copy, PartialEq, Eq)]
53enum PatternKind {
54    String,
55    Char,
56}
57
58fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
59    build_runtime_error(message)
60        .with_builtin(BUILTIN_NAME)
61        .build()
62}
63
64fn map_flow(err: RuntimeError) -> RuntimeError {
65    map_control_flow_with_builtin(err, BUILTIN_NAME)
66}
67
68#[runtime_builtin(
69    name = "strrep",
70    category = "strings/transform",
71    summary = "Replace substring occurrences with MATLAB-compatible semantics.",
72    keywords = "strrep,replace,strings,character array,text",
73    accel = "sink",
74    type_resolver(text_preserve_type),
75    builtin_path = "crate::builtins::strings::transform::strrep"
76)]
77async fn strrep_builtin(
78    str_value: Value,
79    old_value: Value,
80    new_value: Value,
81) -> BuiltinResult<Value> {
82    let gathered_str = gather_if_needed_async(&str_value).await.map_err(map_flow)?;
83    let gathered_old = gather_if_needed_async(&old_value).await.map_err(map_flow)?;
84    let gathered_new = gather_if_needed_async(&new_value).await.map_err(map_flow)?;
85
86    let (old_text, old_kind) = parse_pattern(gathered_old)?;
87    let (new_text, new_kind) = parse_pattern(gathered_new)?;
88    if old_kind != new_kind {
89        return Err(runtime_error_for(PATTERN_MISMATCH_ERROR));
90    }
91
92    match gathered_str {
93        Value::String(text) => Ok(Value::String(strrep_string_value(
94            text, &old_text, &new_text,
95        ))),
96        Value::StringArray(array) => strrep_string_array(array, &old_text, &new_text),
97        Value::CharArray(array) => strrep_char_array(array, &old_text, &new_text),
98        Value::Cell(cell) => strrep_cell_array(cell, &old_text, &new_text),
99        _ => Err(runtime_error_for(ARGUMENT_TYPE_ERROR)),
100    }
101}
102
103fn parse_pattern(value: Value) -> BuiltinResult<(String, PatternKind)> {
104    match value {
105        Value::String(text) => Ok((text, PatternKind::String)),
106        Value::StringArray(array) => {
107            if array.data.len() == 1 {
108                Ok((array.data[0].clone(), PatternKind::String))
109            } else {
110                Err(runtime_error_for(PATTERN_TYPE_ERROR))
111            }
112        }
113        Value::CharArray(array) => {
114            if array.rows <= 1 {
115                let text = if array.rows == 0 {
116                    String::new()
117                } else {
118                    char_row_to_string_slice(&array.data, array.cols, 0)
119                };
120                Ok((text, PatternKind::Char))
121            } else {
122                Err(runtime_error_for(PATTERN_TYPE_ERROR))
123            }
124        }
125        _ => Err(runtime_error_for(PATTERN_TYPE_ERROR)),
126    }
127}
128
129fn strrep_string_value(text: String, old: &str, new: &str) -> String {
130    if is_missing_string(&text) {
131        text
132    } else {
133        text.replace(old, new)
134    }
135}
136
137fn strrep_string_array(array: StringArray, old: &str, new: &str) -> BuiltinResult<Value> {
138    let StringArray { data, shape, .. } = array;
139    let replaced = data
140        .into_iter()
141        .map(|text| strrep_string_value(text, old, new))
142        .collect::<Vec<_>>();
143    let rebuilt = StringArray::new(replaced, shape)
144        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
145    Ok(Value::StringArray(rebuilt))
146}
147
148fn strrep_char_array(array: CharArray, old: &str, new: &str) -> BuiltinResult<Value> {
149    let CharArray { data, rows, cols } = array;
150    if rows == 0 || cols == 0 {
151        return Ok(Value::CharArray(CharArray { data, rows, cols }));
152    }
153
154    let mut replaced_rows = Vec::with_capacity(rows);
155    let mut target_cols = 0usize;
156    for row in 0..rows {
157        let text = char_row_to_string_slice(&data, cols, row);
158        let replaced = text.replace(old, new);
159        target_cols = target_cols.max(replaced.chars().count());
160        replaced_rows.push(replaced);
161    }
162
163    let mut new_data = Vec::with_capacity(rows * target_cols);
164    for row_text in replaced_rows {
165        let mut chars: Vec<char> = row_text.chars().collect();
166        if chars.len() < target_cols {
167            chars.resize(target_cols, ' ');
168        }
169        new_data.extend(chars);
170    }
171
172    CharArray::new(new_data, rows, target_cols)
173        .map(Value::CharArray)
174        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
175}
176
177fn strrep_cell_array(cell: CellArray, old: &str, new: &str) -> BuiltinResult<Value> {
178    let CellArray { data, shape, .. } = cell;
179    let mut replaced = Vec::with_capacity(data.len());
180    for ptr in &data {
181        replaced.push(strrep_cell_element(ptr, old, new)?);
182    }
183    make_cell_with_shape(replaced, shape)
184        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
185}
186
187fn strrep_cell_element(value: &Value, old: &str, new: &str) -> BuiltinResult<Value> {
188    match value {
189        Value::String(text) => Ok(Value::String(strrep_string_value(text.clone(), old, new))),
190        Value::StringArray(array) => strrep_string_array(array.clone(), old, new),
191        Value::CharArray(array) => strrep_char_array(array.clone(), old, new),
192        _ => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
193    }
194}
195
196#[cfg(test)]
197pub(crate) mod tests {
198    use super::*;
199    use runmat_builtins::{ResolveContext, Type};
200
201    fn run_strrep(str_value: Value, old_value: Value, new_value: Value) -> BuiltinResult<Value> {
202        futures::executor::block_on(strrep_builtin(str_value, old_value, new_value))
203    }
204
205    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
206    #[test]
207    fn strrep_string_scalar_basic() {
208        let result = run_strrep(
209            Value::String("RunMat Ignite".into()),
210            Value::String("Ignite".into()),
211            Value::String("Interpreter".into()),
212        )
213        .expect("strrep");
214        assert_eq!(result, Value::String("RunMat Interpreter".into()));
215    }
216
217    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
218    #[test]
219    fn strrep_string_array_preserves_missing() {
220        let array = StringArray::new(
221            vec![
222                String::from("gpu"),
223                String::from("<missing>"),
224                String::from("planner"),
225            ],
226            vec![3, 1],
227        )
228        .unwrap();
229        let result = run_strrep(
230            Value::StringArray(array),
231            Value::String("gpu".into()),
232            Value::String("GPU".into()),
233        )
234        .expect("strrep");
235        match result {
236            Value::StringArray(sa) => {
237                assert_eq!(sa.shape, vec![3, 1]);
238                assert_eq!(
239                    sa.data,
240                    vec![
241                        String::from("GPU"),
242                        String::from("<missing>"),
243                        String::from("planner")
244                    ]
245                );
246            }
247            other => panic!("expected string array, got {other:?}"),
248        }
249    }
250
251    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
252    #[test]
253    fn strrep_string_array_with_char_pattern() {
254        let array = StringArray::new(
255            vec![String::from("alpha"), String::from("beta")],
256            vec![2, 1],
257        )
258        .unwrap();
259        let result = run_strrep(
260            Value::StringArray(array),
261            Value::CharArray(CharArray::new_row("a")),
262            Value::CharArray(CharArray::new_row("A")),
263        )
264        .expect("strrep");
265        match result {
266            Value::StringArray(sa) => {
267                assert_eq!(sa.shape, vec![2, 1]);
268                assert_eq!(sa.data, vec![String::from("AlphA"), String::from("betA")]);
269            }
270            other => panic!("expected string array, got {other:?}"),
271        }
272    }
273
274    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
275    #[test]
276    fn strrep_char_array_padding() {
277        let chars = CharArray::new(vec!['R', 'u', 'n', ' ', 'M', 'a', 't'], 1, 7).unwrap();
278        let result = run_strrep(
279            Value::CharArray(chars),
280            Value::String(" ".into()),
281            Value::String("_".into()),
282        )
283        .expect("strrep");
284        match result {
285            Value::CharArray(out) => {
286                assert_eq!(out.rows, 1);
287                assert_eq!(out.cols, 7);
288                let expected: Vec<char> = "Run_Mat".chars().collect();
289                assert_eq!(out.data, expected);
290            }
291            other => panic!("expected char array, got {other:?}"),
292        }
293    }
294
295    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
296    #[test]
297    fn strrep_char_array_shrinks_rows_pad_with_spaces() {
298        let mut data: Vec<char> = "alpha".chars().collect();
299        data.extend("beta ".chars());
300        let array = CharArray::new(data, 2, 5).unwrap();
301        let result = run_strrep(
302            Value::CharArray(array),
303            Value::String("a".into()),
304            Value::String("".into()),
305        )
306        .expect("strrep");
307        match result {
308            Value::CharArray(out) => {
309                assert_eq!(out.rows, 2);
310                assert_eq!(out.cols, 4);
311                let expected: Vec<char> = vec!['l', 'p', 'h', ' ', 'b', 'e', 't', ' '];
312                assert_eq!(out.data, expected);
313            }
314            other => panic!("expected char array, got {other:?}"),
315        }
316    }
317
318    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
319    #[test]
320    fn strrep_cell_array_char_vectors() {
321        let cell = CellArray::new(
322            vec![
323                Value::CharArray(CharArray::new_row("Kernel Fusion")),
324                Value::CharArray(CharArray::new_row("GPU Planner")),
325            ],
326            1,
327            2,
328        )
329        .unwrap();
330        let result = run_strrep(
331            Value::Cell(cell),
332            Value::String(" ".into()),
333            Value::String("_".into()),
334        )
335        .expect("strrep");
336        match result {
337            Value::Cell(out) => {
338                assert_eq!(out.rows, 1);
339                assert_eq!(out.cols, 2);
340                assert_eq!(
341                    out.get(0, 0).unwrap(),
342                    Value::CharArray(CharArray::new_row("Kernel_Fusion"))
343                );
344                assert_eq!(
345                    out.get(0, 1).unwrap(),
346                    Value::CharArray(CharArray::new_row("GPU_Planner"))
347                );
348            }
349            other => panic!("expected cell array, got {other:?}"),
350        }
351    }
352
353    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
354    #[test]
355    fn strrep_cell_array_string_scalars() {
356        let cell = CellArray::new(
357            vec![
358                Value::String("Planner".into()),
359                Value::String("Profiler".into()),
360            ],
361            1,
362            2,
363        )
364        .unwrap();
365        let result = run_strrep(
366            Value::Cell(cell),
367            Value::String("er".into()),
368            Value::String("ER".into()),
369        )
370        .expect("strrep");
371        match result {
372            Value::Cell(out) => {
373                assert_eq!(out.rows, 1);
374                assert_eq!(out.cols, 2);
375                assert_eq!(out.get(0, 0).unwrap(), Value::String("PlannER".into()));
376                assert_eq!(out.get(0, 1).unwrap(), Value::String("ProfilER".into()));
377            }
378            other => panic!("expected cell array, got {other:?}"),
379        }
380    }
381
382    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
383    #[test]
384    fn strrep_cell_array_invalid_element_error() {
385        let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
386        let err = run_strrep(
387            Value::Cell(cell),
388            Value::String("1".into()),
389            Value::String("one".into()),
390        )
391        .expect_err("expected cell element error");
392        assert!(err.to_string().contains("cell array elements"));
393    }
394
395    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
396    #[test]
397    fn strrep_cell_array_char_matrix_element() {
398        let mut chars: Vec<char> = "alpha".chars().collect();
399        chars.extend("beta ".chars());
400        let element = CharArray::new(chars, 2, 5).unwrap();
401        let cell = CellArray::new(vec![Value::CharArray(element)], 1, 1).unwrap();
402        let result = run_strrep(
403            Value::Cell(cell),
404            Value::String("a".into()),
405            Value::String("A".into()),
406        )
407        .expect("strrep");
408        match result {
409            Value::Cell(out) => {
410                let nested = out.get(0, 0).unwrap();
411                match nested {
412                    Value::CharArray(ca) => {
413                        assert_eq!(ca.rows, 2);
414                        assert_eq!(ca.cols, 5);
415                        let expected: Vec<char> =
416                            vec!['A', 'l', 'p', 'h', 'A', 'b', 'e', 't', 'A', ' '];
417                        assert_eq!(ca.data, expected);
418                    }
419                    other => panic!("expected char array element, got {other:?}"),
420                }
421            }
422            other => panic!("expected cell array, got {other:?}"),
423        }
424    }
425
426    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
427    #[test]
428    fn strrep_cell_array_string_arrays() {
429        let element = StringArray::new(vec!["alpha".into(), "beta".into()], vec![1, 2]).unwrap();
430        let cell = CellArray::new(vec![Value::StringArray(element)], 1, 1).unwrap();
431        let result = run_strrep(
432            Value::Cell(cell),
433            Value::String("a".into()),
434            Value::String("A".into()),
435        )
436        .expect("strrep");
437        match result {
438            Value::Cell(out) => {
439                let nested = out.get(0, 0).unwrap();
440                match nested {
441                    Value::StringArray(sa) => {
442                        assert_eq!(sa.shape, vec![1, 2]);
443                        assert_eq!(sa.data, vec![String::from("AlphA"), String::from("betA")]);
444                    }
445                    other => panic!("expected string array element, got {other:?}"),
446                }
447            }
448            other => panic!("expected cell array, got {other:?}"),
449        }
450    }
451
452    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
453    #[test]
454    fn strrep_empty_pattern_inserts_replacement() {
455        let result = run_strrep(
456            Value::String("abc".into()),
457            Value::String("".into()),
458            Value::String("-".into()),
459        )
460        .expect("strrep");
461        assert_eq!(result, Value::String("-a-b-c-".into()));
462    }
463
464    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
465    #[test]
466    fn strrep_type_mismatch_errors() {
467        let err = run_strrep(
468            Value::String("abc".into()),
469            Value::String("a".into()),
470            Value::CharArray(CharArray::new_row("x")),
471        )
472        .expect_err("expected type mismatch");
473        assert!(err.to_string().contains("same data type"));
474    }
475
476    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
477    #[test]
478    fn strrep_invalid_pattern_type_errors() {
479        let err = run_strrep(
480            Value::String("abc".into()),
481            Value::Num(1.0),
482            Value::String("x".into()),
483        )
484        .expect_err("expected pattern error");
485        assert!(err
486            .to_string()
487            .contains("string scalars or character vectors"));
488    }
489
490    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
491    #[test]
492    fn strrep_first_argument_type_error() {
493        let err = run_strrep(
494            Value::Num(42.0),
495            Value::String("a".into()),
496            Value::String("b".into()),
497        )
498        .expect_err("expected argument type error");
499        assert!(err.to_string().contains("first argument"));
500    }
501
502    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
503    #[test]
504    #[cfg(feature = "wgpu")]
505    fn strrep_wgpu_provider_fallback() {
506        if runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
507            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
508        )
509        .is_err()
510        {
511            // Unable to initialize the provider in this environment; skip.
512            return;
513        }
514        let result = run_strrep(
515            Value::String("Turbine Engine".into()),
516            Value::String("Engine".into()),
517            Value::String("JIT".into()),
518        )
519        .expect("strrep");
520        assert_eq!(result, Value::String("Turbine JIT".into()));
521    }
522
523    #[test]
524    fn strrep_type_preserves_text() {
525        assert_eq!(
526            text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
527            Type::String
528        );
529    }
530}