Skip to main content

runmat_runtime/builtins/strings/transform/
replace.rs

1//! MATLAB-compatible `replace` builtin with GPU-aware semantics for RunMat.
2
3use runmat_builtins::{CellArray, CharArray, StringArray, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::spec::{
8    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9    ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
12use crate::builtins::strings::type_resolvers::text_preserve_type;
13use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::replace")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17    name: "replace",
18    op_kind: GpuOpKind::Custom("string-transform"),
19    supported_precisions: &[],
20    broadcast: BroadcastSemantics::None,
21    provider_hooks: &[],
22    constant_strategy: ConstantStrategy::InlineLiteral,
23    residency: ResidencyPolicy::GatherImmediately,
24    nan_mode: ReductionNaN::Include,
25    two_pass_threshold: None,
26    workgroup_size: None,
27    accepts_nan_mode: false,
28    notes:
29        "Executes on the CPU; GPU-resident inputs are gathered to host memory prior to replacement.",
30};
31
32#[runmat_macros::register_fusion_spec(
33    builtin_path = "crate::builtins::strings::transform::replace"
34)]
35pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
36    name: "replace",
37    shape: ShapeRequirements::Any,
38    constant_strategy: ConstantStrategy::InlineLiteral,
39    elementwise: None,
40    reduction: None,
41    emits_nan: false,
42    notes:
43        "String manipulation builtin; not eligible for fusion plans and always gathers GPU inputs.",
44};
45
46const BUILTIN_NAME: &str = "replace";
47const ARG_TYPE_ERROR: &str =
48    "replace: first argument must be a string array, character array, or cell array of character vectors";
49const PATTERN_TYPE_ERROR: &str =
50    "replace: second argument must be a string array, character array, or cell array of character vectors";
51const REPLACEMENT_TYPE_ERROR: &str =
52    "replace: third argument must be a string array, character array, or cell array of character vectors";
53const EMPTY_PATTERN_ERROR: &str =
54    "replace: second argument must contain at least one search string";
55const EMPTY_REPLACEMENT_ERROR: &str =
56    "replace: third argument must contain at least one replacement string";
57const SIZE_MISMATCH_ERROR: &str =
58    "replace: replacement array must be a scalar or match the number of search strings";
59const CELL_ELEMENT_ERROR: &str =
60    "replace: cell array elements must be string scalars or character vectors";
61
62fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
63    build_runtime_error(message)
64        .with_builtin(BUILTIN_NAME)
65        .build()
66}
67
68fn map_flow(err: RuntimeError) -> RuntimeError {
69    map_control_flow_with_builtin(err, BUILTIN_NAME)
70}
71
72#[runtime_builtin(
73    name = "replace",
74    category = "strings/transform",
75    summary = "Replace substring occurrences in strings, character arrays, and cell arrays.",
76    keywords = "replace,strrep,strings,character array,text",
77    accel = "sink",
78    type_resolver(text_preserve_type),
79    builtin_path = "crate::builtins::strings::transform::replace"
80)]
81async fn replace_builtin(text: Value, old: Value, new: Value) -> BuiltinResult<Value> {
82    let text = gather_if_needed_async(&text).await.map_err(map_flow)?;
83    let old = gather_if_needed_async(&old).await.map_err(map_flow)?;
84    let new = gather_if_needed_async(&new).await.map_err(map_flow)?;
85
86    let spec = ReplacementSpec::from_values(&old, &new)?;
87
88    match text {
89        Value::String(s) => Ok(Value::String(replace_string_scalar(s, &spec))),
90        Value::StringArray(sa) => replace_string_array(sa, &spec),
91        Value::CharArray(ca) => replace_char_array(ca, &spec),
92        Value::Cell(cell) => replace_cell_array(cell, &spec),
93        _ => Err(runtime_error_for(ARG_TYPE_ERROR)),
94    }
95}
96
97fn replace_string_scalar(text: String, spec: &ReplacementSpec) -> String {
98    if is_missing_string(&text) {
99        text
100    } else {
101        spec.apply(&text)
102    }
103}
104
105fn replace_string_array(array: StringArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
106    let StringArray { data, shape, .. } = array;
107    let mut replaced = Vec::with_capacity(data.len());
108    for entry in data {
109        if is_missing_string(&entry) {
110            replaced.push(entry);
111        } else {
112            replaced.push(spec.apply(&entry));
113        }
114    }
115    let result = StringArray::new(replaced, shape)
116        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
117    Ok(Value::StringArray(result))
118}
119
120fn replace_char_array(array: CharArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
121    let CharArray { data, rows, cols } = array;
122    if rows == 0 {
123        return Ok(Value::CharArray(CharArray { data, rows, cols }));
124    }
125
126    let mut replaced_rows = Vec::with_capacity(rows);
127    let mut target_cols = 0usize;
128    for row in 0..rows {
129        let slice = char_row_to_string_slice(&data, cols, row);
130        let replaced = spec.apply(&slice);
131        let len = replaced.chars().count();
132        target_cols = target_cols.max(len);
133        replaced_rows.push(replaced);
134    }
135
136    let mut flattened = Vec::with_capacity(rows * target_cols);
137    for row_text in replaced_rows {
138        let mut chars: Vec<char> = row_text.chars().collect();
139        if chars.len() < target_cols {
140            chars.resize(target_cols, ' ');
141        }
142        flattened.extend(chars);
143    }
144
145    CharArray::new(flattened, rows, target_cols)
146        .map(Value::CharArray)
147        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
148}
149
150fn replace_cell_array(cell: CellArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
151    let CellArray {
152        data, rows, cols, ..
153    } = cell;
154    let mut replaced = Vec::with_capacity(rows * cols);
155    for row in 0..rows {
156        for col in 0..cols {
157            let idx = row * cols + col;
158            let value = replace_cell_element(&data[idx], spec)?;
159            replaced.push(value);
160        }
161    }
162    make_cell(replaced, rows, cols).map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
163}
164
165fn replace_cell_element(value: &Value, spec: &ReplacementSpec) -> BuiltinResult<Value> {
166    match value {
167        Value::String(text) => Ok(Value::String(replace_string_scalar(text.clone(), spec))),
168        Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(replace_string_scalar(
169            sa.data[0].clone(),
170            spec,
171        ))),
172        Value::CharArray(ca) if ca.rows <= 1 => replace_char_array(ca.clone(), spec),
173        Value::CharArray(_) => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
174        _ => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
175    }
176}
177
178fn extract_pattern_list(value: &Value) -> BuiltinResult<Vec<String>> {
179    extract_text_list(value, PATTERN_TYPE_ERROR)
180}
181
182fn extract_replacement_list(value: &Value) -> BuiltinResult<Vec<String>> {
183    extract_text_list(value, REPLACEMENT_TYPE_ERROR)
184}
185
186fn extract_text_list(value: &Value, type_error: &str) -> BuiltinResult<Vec<String>> {
187    match value {
188        Value::String(text) => Ok(vec![text.clone()]),
189        Value::StringArray(array) => Ok(array.data.clone()),
190        Value::CharArray(array) => {
191            let CharArray { data, rows, cols } = array.clone();
192            if rows == 0 {
193                Ok(Vec::new())
194            } else {
195                let mut entries = Vec::with_capacity(rows);
196                for row in 0..rows {
197                    entries.push(char_row_to_string_slice(&data, cols, row));
198                }
199                Ok(entries)
200            }
201        }
202        Value::Cell(cell) => {
203            let CellArray { data, .. } = cell.clone();
204            let mut entries = Vec::with_capacity(data.len());
205            for element in data {
206                match &*element {
207                    Value::String(text) => entries.push(text.clone()),
208                    Value::StringArray(sa) if sa.data.len() == 1 => {
209                        entries.push(sa.data[0].clone());
210                    }
211                    Value::CharArray(ca) if ca.rows <= 1 => {
212                        if ca.rows == 0 {
213                            entries.push(String::new());
214                        } else {
215                            entries.push(char_row_to_string_slice(&ca.data, ca.cols, 0));
216                        }
217                    }
218                    Value::CharArray(_) => return Err(runtime_error_for(CELL_ELEMENT_ERROR)),
219                    _ => return Err(runtime_error_for(CELL_ELEMENT_ERROR)),
220                }
221            }
222            Ok(entries)
223        }
224        _ => Err(runtime_error_for(type_error)),
225    }
226}
227
228struct ReplacementSpec {
229    pairs: Vec<(String, String)>,
230}
231
232impl ReplacementSpec {
233    fn from_values(old: &Value, new: &Value) -> BuiltinResult<Self> {
234        let patterns = extract_pattern_list(old)?;
235        if patterns.is_empty() {
236            return Err(runtime_error_for(EMPTY_PATTERN_ERROR));
237        }
238
239        let replacements = extract_replacement_list(new)?;
240        if replacements.is_empty() {
241            return Err(runtime_error_for(EMPTY_REPLACEMENT_ERROR));
242        }
243
244        let pairs = if replacements.len() == patterns.len() {
245            patterns.into_iter().zip(replacements).collect::<Vec<_>>()
246        } else if replacements.len() == 1 {
247            let replacement = replacements[0].clone();
248            patterns
249                .into_iter()
250                .map(|pattern| (pattern, replacement.clone()))
251                .collect::<Vec<_>>()
252        } else {
253            return Err(runtime_error_for(SIZE_MISMATCH_ERROR));
254        };
255
256        Ok(Self { pairs })
257    }
258
259    fn apply(&self, input: &str) -> String {
260        let mut current = input.to_string();
261        for (pattern, replacement) in &self.pairs {
262            if pattern.is_empty() && replacement.is_empty() {
263                continue;
264            }
265            if pattern == replacement {
266                continue;
267            }
268            current = current.replace(pattern, replacement);
269        }
270        current
271    }
272}
273
274#[cfg(test)]
275pub(crate) mod tests {
276    use super::*;
277    use runmat_builtins::{ResolveContext, Type};
278
279    fn replace_builtin(text: Value, old: Value, new: Value) -> BuiltinResult<Value> {
280        futures::executor::block_on(super::replace_builtin(text, old, new))
281    }
282
283    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
284    #[test]
285    fn replace_string_scalar_single_term() {
286        let result = replace_builtin(
287            Value::String("RunMat runtime".into()),
288            Value::String("runtime".into()),
289            Value::String("engine".into()),
290        )
291        .expect("replace");
292        assert_eq!(result, Value::String("RunMat engine".into()));
293    }
294
295    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
296    #[test]
297    fn replace_string_array_multiple_terms() {
298        let strings = StringArray::new(
299            vec!["gpu".into(), "cpu".into(), "<missing>".into()],
300            vec![3, 1],
301        )
302        .unwrap();
303        let result = replace_builtin(
304            Value::StringArray(strings),
305            Value::StringArray(
306                StringArray::new(vec!["gpu".into(), "cpu".into()], vec![2, 1]).unwrap(),
307            ),
308            Value::String("device".into()),
309        )
310        .expect("replace");
311        match result {
312            Value::StringArray(sa) => {
313                assert_eq!(sa.shape, vec![3, 1]);
314                assert_eq!(
315                    sa.data,
316                    vec![
317                        String::from("device"),
318                        String::from("device"),
319                        String::from("<missing>")
320                    ]
321                );
322            }
323            other => panic!("expected string array, got {other:?}"),
324        }
325    }
326
327    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
328    #[test]
329    fn replace_char_array_adjusts_width() {
330        let chars = CharArray::new("matrix".chars().collect(), 1, 6).unwrap();
331        let result = replace_builtin(
332            Value::CharArray(chars),
333            Value::String("matrix".into()),
334            Value::String("tensor".into()),
335        )
336        .expect("replace");
337        match result {
338            Value::CharArray(out) => {
339                assert_eq!(out.rows, 1);
340                assert_eq!(out.cols, 6);
341                let expected: Vec<char> = "tensor".chars().collect();
342                assert_eq!(out.data, expected);
343            }
344            other => panic!("expected char array, got {other:?}"),
345        }
346    }
347
348    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
349    #[test]
350    fn replace_char_array_handles_padding() {
351        let chars = CharArray::new(vec!['a', 'b', 'c', 'd'], 2, 2).unwrap();
352        let result = replace_builtin(
353            Value::CharArray(chars),
354            Value::String("b".into()),
355            Value::String("beta".into()),
356        )
357        .expect("replace");
358        match result {
359            Value::CharArray(out) => {
360                assert_eq!(out.rows, 2);
361                assert_eq!(out.cols, 5);
362                let expected: Vec<char> = vec!['a', 'b', 'e', 't', 'a', 'c', 'd', ' ', ' ', ' '];
363                assert_eq!(out.data, expected);
364            }
365            other => panic!("expected char array, got {other:?}"),
366        }
367    }
368
369    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
370    #[test]
371    fn replace_cell_array_mixed_content() {
372        let cell = CellArray::new(
373            vec![
374                Value::CharArray(CharArray::new_row("Kernel Planner")),
375                Value::String("GPU Fusion".into()),
376            ],
377            1,
378            2,
379        )
380        .unwrap();
381        let result = replace_builtin(
382            Value::Cell(cell),
383            Value::Cell(
384                CellArray::new(
385                    vec![Value::String("Kernel".into()), Value::String("GPU".into())],
386                    1,
387                    2,
388                )
389                .unwrap(),
390            ),
391            Value::StringArray(
392                StringArray::new(vec!["Shader".into(), "Device".into()], vec![1, 2]).unwrap(),
393            ),
394        )
395        .expect("replace");
396        match result {
397            Value::Cell(out) => {
398                let first = out.get(0, 0).unwrap();
399                let second = out.get(0, 1).unwrap();
400                assert_eq!(
401                    first,
402                    Value::CharArray(CharArray::new_row("Shader Planner"))
403                );
404                assert_eq!(second, Value::String("Device Fusion".into()));
405            }
406            other => panic!("expected cell array, got {other:?}"),
407        }
408    }
409
410    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
411    #[test]
412    fn replace_errors_on_invalid_first_argument() {
413        let err = replace_builtin(
414            Value::Num(1.0),
415            Value::String("a".into()),
416            Value::String("b".into()),
417        )
418        .unwrap_err();
419        assert_eq!(err.to_string(), ARG_TYPE_ERROR);
420    }
421
422    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
423    #[test]
424    fn replace_errors_on_invalid_pattern_type() {
425        let err = replace_builtin(
426            Value::String("abc".into()),
427            Value::Num(1.0),
428            Value::String("x".into()),
429        )
430        .unwrap_err();
431        assert_eq!(err.to_string(), PATTERN_TYPE_ERROR);
432    }
433
434    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
435    #[test]
436    fn replace_errors_on_size_mismatch() {
437        let err = replace_builtin(
438            Value::String("abc".into()),
439            Value::StringArray(StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap()),
440            Value::StringArray(
441                StringArray::new(vec!["x".into(), "y".into(), "z".into()], vec![3, 1]).unwrap(),
442            ),
443        )
444        .unwrap_err();
445        assert_eq!(err.to_string(), SIZE_MISMATCH_ERROR);
446    }
447
448    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
449    #[test]
450    fn replace_preserves_missing_string() {
451        let result = replace_builtin(
452            Value::String("<missing>".into()),
453            Value::String("missing".into()),
454            Value::String("value".into()),
455        )
456        .expect("replace");
457        assert_eq!(result, Value::String("<missing>".into()));
458    }
459
460    #[test]
461    fn replace_type_preserves_text() {
462        assert_eq!(
463            text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
464            Type::String
465        );
466    }
467}