Skip to main content

runmat_runtime/builtins/strings/transform/
replace.rs

1//! MATLAB-compatible `replace` builtin with GPU-aware semantics for RunMat.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6    CellArray, CharArray, StringArray, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
16use crate::builtins::strings::type_resolvers::text_preserve_type;
17use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::replace")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21    name: "replace",
22    op_kind: GpuOpKind::Custom("string-transform"),
23    supported_precisions: &[],
24    broadcast: BroadcastSemantics::None,
25    provider_hooks: &[],
26    constant_strategy: ConstantStrategy::InlineLiteral,
27    residency: ResidencyPolicy::GatherImmediately,
28    nan_mode: ReductionNaN::Include,
29    two_pass_threshold: None,
30    workgroup_size: None,
31    accepts_nan_mode: false,
32    notes:
33        "Executes on the CPU; GPU-resident inputs are gathered to host memory prior to replacement.",
34};
35
36#[runmat_macros::register_fusion_spec(
37    builtin_path = "crate::builtins::strings::transform::replace"
38)]
39pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
40    name: "replace",
41    shape: ShapeRequirements::Any,
42    constant_strategy: ConstantStrategy::InlineLiteral,
43    elementwise: None,
44    reduction: None,
45    emits_nan: false,
46    notes:
47        "String manipulation builtin; not eligible for fusion plans and always gathers GPU inputs.",
48};
49
50const BUILTIN_NAME: &str = "replace";
51
52const REPLACE_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
53    name: "newText",
54    ty: BuiltinParamType::Any,
55    arity: BuiltinParamArity::Required,
56    default: None,
57    description: "Text with replacements applied, preserving input container kind.",
58}];
59
60const REPLACE_INPUTS: [BuiltinParamDescriptor; 3] = [
61    BuiltinParamDescriptor {
62        name: "str",
63        ty: BuiltinParamType::Any,
64        arity: BuiltinParamArity::Required,
65        default: None,
66        description: "Input text (string/char/cell).",
67    },
68    BuiltinParamDescriptor {
69        name: "oldText",
70        ty: BuiltinParamType::Any,
71        arity: BuiltinParamArity::Required,
72        default: None,
73        description: "Search text list (scalar or array/cell).",
74    },
75    BuiltinParamDescriptor {
76        name: "newText",
77        ty: BuiltinParamType::Any,
78        arity: BuiltinParamArity::Required,
79        default: None,
80        description: "Replacement text list (scalar or matching-size list).",
81    },
82];
83
84const REPLACE_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
85    label: "newText = replace(str, oldText, newText)",
86    inputs: &REPLACE_INPUTS,
87    outputs: &REPLACE_OUTPUT,
88}];
89
90const REPLACE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
91    code: "RM.REPLACE.INVALID_INPUT",
92    identifier: Some("RunMat:replace:InvalidInput"),
93    when: "First argument is not a string array, char array, or cell array of text scalars.",
94    message:
95        "replace: first argument must be a string array, character array, or cell array of character vectors",
96};
97
98const REPLACE_ERROR_PATTERN_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
99    code: "RM.REPLACE.PATTERN_TYPE",
100    identifier: Some("RunMat:replace:PatternType"),
101    when: "Second argument is not a text scalar/array/cell of text scalars.",
102    message:
103        "replace: second argument must be a string array, character array, or cell array of character vectors",
104};
105
106const REPLACE_ERROR_REPLACEMENT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
107    code: "RM.REPLACE.REPLACEMENT_TYPE",
108    identifier: Some("RunMat:replace:ReplacementType"),
109    when: "Third argument is not a text scalar/array/cell of text scalars.",
110    message:
111        "replace: third argument must be a string array, character array, or cell array of character vectors",
112};
113
114const REPLACE_ERROR_EMPTY_PATTERN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
115    code: "RM.REPLACE.EMPTY_PATTERN",
116    identifier: Some("RunMat:replace:EmptyPattern"),
117    when: "Search text list is empty.",
118    message: "replace: second argument must contain at least one search string",
119};
120
121const REPLACE_ERROR_EMPTY_REPLACEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
122    code: "RM.REPLACE.EMPTY_REPLACEMENT",
123    identifier: Some("RunMat:replace:EmptyReplacement"),
124    when: "Replacement text list is empty.",
125    message: "replace: third argument must contain at least one replacement string",
126};
127
128const REPLACE_ERROR_SIZE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
129    code: "RM.REPLACE.SIZE_MISMATCH",
130    identifier: Some("RunMat:replace:SizeMismatch"),
131    when: "Replacement list is neither scalar nor equal in length to search list.",
132    message: "replace: replacement array must be a scalar or match the number of search strings",
133};
134
135const REPLACE_ERROR_CELL_ELEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
136    code: "RM.REPLACE.CELL_ELEMENT",
137    identifier: Some("RunMat:replace:CellElement"),
138    when: "Cell arrays contain non-text elements or non-row char arrays.",
139    message: "replace: cell array elements must be string scalars or character vectors",
140};
141
142const REPLACE_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
143    code: "RM.REPLACE.INTERNAL",
144    identifier: Some("RunMat:replace:InternalError"),
145    when: "Internal output container construction failed.",
146    message: "replace: internal error",
147};
148
149const REPLACE_ERRORS: [BuiltinErrorDescriptor; 8] = [
150    REPLACE_ERROR_INVALID_INPUT,
151    REPLACE_ERROR_PATTERN_TYPE,
152    REPLACE_ERROR_REPLACEMENT_TYPE,
153    REPLACE_ERROR_EMPTY_PATTERN,
154    REPLACE_ERROR_EMPTY_REPLACEMENT,
155    REPLACE_ERROR_SIZE_MISMATCH,
156    REPLACE_ERROR_CELL_ELEMENT,
157    REPLACE_ERROR_INTERNAL,
158];
159
160pub const REPLACE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
161    signatures: &REPLACE_SIGNATURES,
162    output_mode: BuiltinOutputMode::Fixed,
163    completion_policy: BuiltinCompletionPolicy::Public,
164    errors: &REPLACE_ERRORS,
165};
166
167fn map_flow(err: RuntimeError) -> RuntimeError {
168    map_control_flow_with_builtin(err, BUILTIN_NAME)
169}
170
171fn replace_error_with_message(
172    message: impl Into<String>,
173    error: &'static BuiltinErrorDescriptor,
174) -> RuntimeError {
175    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
176    if let Some(identifier) = error.identifier {
177        builder = builder.with_identifier(identifier);
178    }
179    builder.build()
180}
181
182fn replace_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
183    replace_error_with_message(error.message, error)
184}
185
186#[runtime_builtin(
187    name = "replace",
188    category = "strings/transform",
189    summary = "Replace substring occurrences in strings, character arrays, and cell arrays.",
190    keywords = "replace,strrep,strings,character array,text",
191    accel = "sink",
192    type_resolver(text_preserve_type),
193    descriptor(crate::builtins::strings::transform::replace::REPLACE_DESCRIPTOR),
194    builtin_path = "crate::builtins::strings::transform::replace"
195)]
196async fn replace_builtin(text: Value, old: Value, new: Value) -> BuiltinResult<Value> {
197    let text = gather_if_needed_async(&text).await.map_err(map_flow)?;
198    let old = gather_if_needed_async(&old).await.map_err(map_flow)?;
199    let new = gather_if_needed_async(&new).await.map_err(map_flow)?;
200
201    let spec = ReplacementSpec::from_values(&old, &new)?;
202
203    match text {
204        Value::String(s) => Ok(Value::String(replace_string_scalar(s, &spec))),
205        Value::StringArray(sa) => replace_string_array(sa, &spec),
206        Value::CharArray(ca) => replace_char_array(ca, &spec),
207        Value::Cell(cell) => replace_cell_array(cell, &spec),
208        _ => Err(replace_error(&REPLACE_ERROR_INVALID_INPUT)),
209    }
210}
211
212fn replace_string_scalar(text: String, spec: &ReplacementSpec) -> String {
213    if is_missing_string(&text) {
214        text
215    } else {
216        spec.apply(&text)
217    }
218}
219
220fn replace_string_array(array: StringArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
221    let StringArray { data, shape, .. } = array;
222    let mut replaced = Vec::with_capacity(data.len());
223    for entry in data {
224        if is_missing_string(&entry) {
225            replaced.push(entry);
226        } else {
227            replaced.push(spec.apply(&entry));
228        }
229    }
230    let result = StringArray::new(replaced, shape).map_err(|e| {
231        replace_error_with_message(format!("{BUILTIN_NAME}: {e}"), &REPLACE_ERROR_INTERNAL)
232    })?;
233    Ok(Value::StringArray(result))
234}
235
236fn replace_char_array(array: CharArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
237    let CharArray { data, rows, cols } = array;
238    if rows == 0 {
239        return Ok(Value::CharArray(CharArray { data, rows, cols }));
240    }
241
242    let mut replaced_rows = Vec::with_capacity(rows);
243    let mut target_cols = 0usize;
244    for row in 0..rows {
245        let slice = char_row_to_string_slice(&data, cols, row);
246        let replaced = spec.apply(&slice);
247        let len = replaced.chars().count();
248        target_cols = target_cols.max(len);
249        replaced_rows.push(replaced);
250    }
251
252    let mut flattened = Vec::with_capacity(rows * target_cols);
253    for row_text in replaced_rows {
254        let mut chars: Vec<char> = row_text.chars().collect();
255        if chars.len() < target_cols {
256            chars.resize(target_cols, ' ');
257        }
258        flattened.extend(chars);
259    }
260
261    CharArray::new(flattened, rows, target_cols)
262        .map(Value::CharArray)
263        .map_err(|e| {
264            replace_error_with_message(format!("{BUILTIN_NAME}: {e}"), &REPLACE_ERROR_INTERNAL)
265        })
266}
267
268fn replace_cell_array(cell: CellArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
269    let CellArray {
270        data, rows, cols, ..
271    } = cell;
272    let mut replaced = Vec::with_capacity(rows * cols);
273    for row in 0..rows {
274        for col in 0..cols {
275            let idx = row * cols + col;
276            let value = replace_cell_element(&data[idx], spec)?;
277            replaced.push(value);
278        }
279    }
280    make_cell(replaced, rows, cols).map_err(|e| {
281        replace_error_with_message(format!("{BUILTIN_NAME}: {e}"), &REPLACE_ERROR_INTERNAL)
282    })
283}
284
285fn replace_cell_element(value: &Value, spec: &ReplacementSpec) -> BuiltinResult<Value> {
286    match value {
287        Value::String(text) => Ok(Value::String(replace_string_scalar(text.clone(), spec))),
288        Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(replace_string_scalar(
289            sa.data[0].clone(),
290            spec,
291        ))),
292        Value::CharArray(ca) if ca.rows <= 1 => replace_char_array(ca.clone(), spec),
293        Value::CharArray(_) => Err(replace_error(&REPLACE_ERROR_CELL_ELEMENT)),
294        _ => Err(replace_error(&REPLACE_ERROR_CELL_ELEMENT)),
295    }
296}
297
298fn extract_pattern_list(value: &Value) -> BuiltinResult<Vec<String>> {
299    extract_text_list(value, &REPLACE_ERROR_PATTERN_TYPE)
300}
301
302fn extract_replacement_list(value: &Value) -> BuiltinResult<Vec<String>> {
303    extract_text_list(value, &REPLACE_ERROR_REPLACEMENT_TYPE)
304}
305
306fn extract_text_list(
307    value: &Value,
308    type_error: &'static BuiltinErrorDescriptor,
309) -> BuiltinResult<Vec<String>> {
310    match value {
311        Value::String(text) => Ok(vec![text.clone()]),
312        Value::StringArray(array) => Ok(array.data.clone()),
313        Value::CharArray(array) => {
314            let CharArray { data, rows, cols } = array.clone();
315            if rows == 0 {
316                Ok(Vec::new())
317            } else {
318                let mut entries = Vec::with_capacity(rows);
319                for row in 0..rows {
320                    entries.push(char_row_to_string_slice(&data, cols, row));
321                }
322                Ok(entries)
323            }
324        }
325        Value::Cell(cell) => {
326            let CellArray { data, .. } = cell.clone();
327            let mut entries = Vec::with_capacity(data.len());
328            for element in data {
329                match &*element {
330                    Value::String(text) => entries.push(text.clone()),
331                    Value::StringArray(sa) if sa.data.len() == 1 => {
332                        entries.push(sa.data[0].clone());
333                    }
334                    Value::CharArray(ca) if ca.rows <= 1 => {
335                        if ca.rows == 0 {
336                            entries.push(String::new());
337                        } else {
338                            entries.push(char_row_to_string_slice(&ca.data, ca.cols, 0));
339                        }
340                    }
341                    Value::CharArray(_) => {
342                        return Err(replace_error(&REPLACE_ERROR_CELL_ELEMENT));
343                    }
344                    _ => {
345                        return Err(replace_error(&REPLACE_ERROR_CELL_ELEMENT));
346                    }
347                }
348            }
349            Ok(entries)
350        }
351        _ => Err(replace_error(type_error)),
352    }
353}
354
355struct ReplacementSpec {
356    pairs: Vec<(String, String)>,
357}
358
359impl ReplacementSpec {
360    fn from_values(old: &Value, new: &Value) -> BuiltinResult<Self> {
361        let patterns = extract_pattern_list(old)?;
362        if patterns.is_empty() {
363            return Err(replace_error(&REPLACE_ERROR_EMPTY_PATTERN));
364        }
365
366        let replacements = extract_replacement_list(new)?;
367        if replacements.is_empty() {
368            return Err(replace_error(&REPLACE_ERROR_EMPTY_REPLACEMENT));
369        }
370
371        let pairs = if replacements.len() == patterns.len() {
372            patterns.into_iter().zip(replacements).collect::<Vec<_>>()
373        } else if replacements.len() == 1 {
374            let replacement = replacements[0].clone();
375            patterns
376                .into_iter()
377                .map(|pattern| (pattern, replacement.clone()))
378                .collect::<Vec<_>>()
379        } else {
380            return Err(replace_error(&REPLACE_ERROR_SIZE_MISMATCH));
381        };
382
383        Ok(Self { pairs })
384    }
385
386    fn apply(&self, input: &str) -> String {
387        let mut current = input.to_string();
388        for (pattern, replacement) in &self.pairs {
389            if pattern.is_empty() && replacement.is_empty() {
390                continue;
391            }
392            if pattern == replacement {
393                continue;
394            }
395            current = current.replace(pattern, replacement);
396        }
397        current
398    }
399}
400
401#[cfg(test)]
402pub(crate) mod tests {
403    use super::*;
404    use runmat_builtins::{ResolveContext, Type};
405
406    fn replace_builtin(text: Value, old: Value, new: Value) -> BuiltinResult<Value> {
407        futures::executor::block_on(super::replace_builtin(text, old, new))
408    }
409
410    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
411    #[test]
412    fn replace_string_scalar_single_term() {
413        let result = replace_builtin(
414            Value::String("RunMat runtime".into()),
415            Value::String("runtime".into()),
416            Value::String("engine".into()),
417        )
418        .expect("replace");
419        assert_eq!(result, Value::String("RunMat engine".into()));
420    }
421
422    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
423    #[test]
424    fn replace_string_array_multiple_terms() {
425        let strings = StringArray::new(
426            vec!["gpu".into(), "cpu".into(), "<missing>".into()],
427            vec![3, 1],
428        )
429        .unwrap();
430        let result = replace_builtin(
431            Value::StringArray(strings),
432            Value::StringArray(
433                StringArray::new(vec!["gpu".into(), "cpu".into()], vec![2, 1]).unwrap(),
434            ),
435            Value::String("device".into()),
436        )
437        .expect("replace");
438        match result {
439            Value::StringArray(sa) => {
440                assert_eq!(sa.shape, vec![3, 1]);
441                assert_eq!(
442                    sa.data,
443                    vec![
444                        String::from("device"),
445                        String::from("device"),
446                        String::from("<missing>")
447                    ]
448                );
449            }
450            other => panic!("expected string array, got {other:?}"),
451        }
452    }
453
454    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
455    #[test]
456    fn replace_char_array_adjusts_width() {
457        let chars = CharArray::new("matrix".chars().collect(), 1, 6).unwrap();
458        let result = replace_builtin(
459            Value::CharArray(chars),
460            Value::String("matrix".into()),
461            Value::String("tensor".into()),
462        )
463        .expect("replace");
464        match result {
465            Value::CharArray(out) => {
466                assert_eq!(out.rows, 1);
467                assert_eq!(out.cols, 6);
468                let expected: Vec<char> = "tensor".chars().collect();
469                assert_eq!(out.data, expected);
470            }
471            other => panic!("expected char array, got {other:?}"),
472        }
473    }
474
475    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
476    #[test]
477    fn replace_char_array_handles_padding() {
478        let chars = CharArray::new(vec!['a', 'b', 'c', 'd'], 2, 2).unwrap();
479        let result = replace_builtin(
480            Value::CharArray(chars),
481            Value::String("b".into()),
482            Value::String("beta".into()),
483        )
484        .expect("replace");
485        match result {
486            Value::CharArray(out) => {
487                assert_eq!(out.rows, 2);
488                assert_eq!(out.cols, 5);
489                let expected: Vec<char> = vec!['a', 'b', 'e', 't', 'a', 'c', 'd', ' ', ' ', ' '];
490                assert_eq!(out.data, expected);
491            }
492            other => panic!("expected char array, got {other:?}"),
493        }
494    }
495
496    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
497    #[test]
498    fn replace_cell_array_mixed_content() {
499        let cell = CellArray::new(
500            vec![
501                Value::CharArray(CharArray::new_row("Kernel Planner")),
502                Value::String("GPU Fusion".into()),
503            ],
504            1,
505            2,
506        )
507        .unwrap();
508        let result = replace_builtin(
509            Value::Cell(cell),
510            Value::Cell(
511                CellArray::new(
512                    vec![Value::String("Kernel".into()), Value::String("GPU".into())],
513                    1,
514                    2,
515                )
516                .unwrap(),
517            ),
518            Value::StringArray(
519                StringArray::new(vec!["Shader".into(), "Device".into()], vec![1, 2]).unwrap(),
520            ),
521        )
522        .expect("replace");
523        match result {
524            Value::Cell(out) => {
525                let first = out.get(0, 0).unwrap();
526                let second = out.get(0, 1).unwrap();
527                assert_eq!(
528                    first,
529                    Value::CharArray(CharArray::new_row("Shader Planner"))
530                );
531                assert_eq!(second, Value::String("Device Fusion".into()));
532            }
533            other => panic!("expected cell array, got {other:?}"),
534        }
535    }
536
537    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
538    #[test]
539    fn replace_errors_on_invalid_first_argument() {
540        let err = replace_builtin(
541            Value::Num(1.0),
542            Value::String("a".into()),
543            Value::String("b".into()),
544        )
545        .unwrap_err();
546        assert_eq!(err.to_string(), REPLACE_ERROR_INVALID_INPUT.message);
547    }
548
549    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
550    #[test]
551    fn replace_errors_on_invalid_pattern_type() {
552        let err = replace_builtin(
553            Value::String("abc".into()),
554            Value::Num(1.0),
555            Value::String("x".into()),
556        )
557        .unwrap_err();
558        assert_eq!(err.to_string(), REPLACE_ERROR_PATTERN_TYPE.message);
559    }
560
561    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
562    #[test]
563    fn replace_errors_on_size_mismatch() {
564        let err = replace_builtin(
565            Value::String("abc".into()),
566            Value::StringArray(StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap()),
567            Value::StringArray(
568                StringArray::new(vec!["x".into(), "y".into(), "z".into()], vec![3, 1]).unwrap(),
569            ),
570        )
571        .unwrap_err();
572        assert_eq!(err.to_string(), REPLACE_ERROR_SIZE_MISMATCH.message);
573    }
574
575    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
576    #[test]
577    fn replace_preserves_missing_string() {
578        let result = replace_builtin(
579            Value::String("<missing>".into()),
580            Value::String("missing".into()),
581            Value::String("value".into()),
582        )
583        .expect("replace");
584        assert_eq!(result, Value::String("<missing>".into()));
585    }
586
587    #[test]
588    fn replace_type_preserves_text() {
589        assert_eq!(
590            text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
591            Type::String
592        );
593    }
594}