1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6 CellArray, CharArray, StringArray, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
16use crate::builtins::strings::type_resolvers::text_preserve_type;
17use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::replace")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21 name: "replace",
22 op_kind: GpuOpKind::Custom("string-transform"),
23 supported_precisions: &[],
24 broadcast: BroadcastSemantics::None,
25 provider_hooks: &[],
26 constant_strategy: ConstantStrategy::InlineLiteral,
27 residency: ResidencyPolicy::GatherImmediately,
28 nan_mode: ReductionNaN::Include,
29 two_pass_threshold: None,
30 workgroup_size: None,
31 accepts_nan_mode: false,
32 notes:
33 "Executes on the CPU; GPU-resident inputs are gathered to host memory prior to replacement.",
34};
35
36#[runmat_macros::register_fusion_spec(
37 builtin_path = "crate::builtins::strings::transform::replace"
38)]
39pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
40 name: "replace",
41 shape: ShapeRequirements::Any,
42 constant_strategy: ConstantStrategy::InlineLiteral,
43 elementwise: None,
44 reduction: None,
45 emits_nan: false,
46 notes:
47 "String manipulation builtin; not eligible for fusion plans and always gathers GPU inputs.",
48};
49
50const BUILTIN_NAME: &str = "replace";
51
52const REPLACE_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
53 name: "newText",
54 ty: BuiltinParamType::Any,
55 arity: BuiltinParamArity::Required,
56 default: None,
57 description: "Text with replacements applied, preserving input container kind.",
58}];
59
60const REPLACE_INPUTS: [BuiltinParamDescriptor; 3] = [
61 BuiltinParamDescriptor {
62 name: "str",
63 ty: BuiltinParamType::Any,
64 arity: BuiltinParamArity::Required,
65 default: None,
66 description: "Input text (string/char/cell).",
67 },
68 BuiltinParamDescriptor {
69 name: "oldText",
70 ty: BuiltinParamType::Any,
71 arity: BuiltinParamArity::Required,
72 default: None,
73 description: "Search text list (scalar or array/cell).",
74 },
75 BuiltinParamDescriptor {
76 name: "newText",
77 ty: BuiltinParamType::Any,
78 arity: BuiltinParamArity::Required,
79 default: None,
80 description: "Replacement text list (scalar or matching-size list).",
81 },
82];
83
84const REPLACE_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
85 label: "newText = replace(str, oldText, newText)",
86 inputs: &REPLACE_INPUTS,
87 outputs: &REPLACE_OUTPUT,
88}];
89
90const REPLACE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
91 code: "RM.REPLACE.INVALID_INPUT",
92 identifier: Some("RunMat:replace:InvalidInput"),
93 when: "First argument is not a string array, char array, or cell array of text scalars.",
94 message:
95 "replace: first argument must be a string array, character array, or cell array of character vectors",
96};
97
98const REPLACE_ERROR_PATTERN_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
99 code: "RM.REPLACE.PATTERN_TYPE",
100 identifier: Some("RunMat:replace:PatternType"),
101 when: "Second argument is not a text scalar/array/cell of text scalars.",
102 message:
103 "replace: second argument must be a string array, character array, or cell array of character vectors",
104};
105
106const REPLACE_ERROR_REPLACEMENT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
107 code: "RM.REPLACE.REPLACEMENT_TYPE",
108 identifier: Some("RunMat:replace:ReplacementType"),
109 when: "Third argument is not a text scalar/array/cell of text scalars.",
110 message:
111 "replace: third argument must be a string array, character array, or cell array of character vectors",
112};
113
114const REPLACE_ERROR_EMPTY_PATTERN: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
115 code: "RM.REPLACE.EMPTY_PATTERN",
116 identifier: Some("RunMat:replace:EmptyPattern"),
117 when: "Search text list is empty.",
118 message: "replace: second argument must contain at least one search string",
119};
120
121const REPLACE_ERROR_EMPTY_REPLACEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
122 code: "RM.REPLACE.EMPTY_REPLACEMENT",
123 identifier: Some("RunMat:replace:EmptyReplacement"),
124 when: "Replacement text list is empty.",
125 message: "replace: third argument must contain at least one replacement string",
126};
127
128const REPLACE_ERROR_SIZE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
129 code: "RM.REPLACE.SIZE_MISMATCH",
130 identifier: Some("RunMat:replace:SizeMismatch"),
131 when: "Replacement list is neither scalar nor equal in length to search list.",
132 message: "replace: replacement array must be a scalar or match the number of search strings",
133};
134
135const REPLACE_ERROR_CELL_ELEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
136 code: "RM.REPLACE.CELL_ELEMENT",
137 identifier: Some("RunMat:replace:CellElement"),
138 when: "Cell arrays contain non-text elements or non-row char arrays.",
139 message: "replace: cell array elements must be string scalars or character vectors",
140};
141
142const REPLACE_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
143 code: "RM.REPLACE.INTERNAL",
144 identifier: Some("RunMat:replace:InternalError"),
145 when: "Internal output container construction failed.",
146 message: "replace: internal error",
147};
148
149const REPLACE_ERRORS: [BuiltinErrorDescriptor; 8] = [
150 REPLACE_ERROR_INVALID_INPUT,
151 REPLACE_ERROR_PATTERN_TYPE,
152 REPLACE_ERROR_REPLACEMENT_TYPE,
153 REPLACE_ERROR_EMPTY_PATTERN,
154 REPLACE_ERROR_EMPTY_REPLACEMENT,
155 REPLACE_ERROR_SIZE_MISMATCH,
156 REPLACE_ERROR_CELL_ELEMENT,
157 REPLACE_ERROR_INTERNAL,
158];
159
160pub const REPLACE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
161 signatures: &REPLACE_SIGNATURES,
162 output_mode: BuiltinOutputMode::Fixed,
163 completion_policy: BuiltinCompletionPolicy::Public,
164 errors: &REPLACE_ERRORS,
165};
166
167fn map_flow(err: RuntimeError) -> RuntimeError {
168 map_control_flow_with_builtin(err, BUILTIN_NAME)
169}
170
171fn replace_error_with_message(
172 message: impl Into<String>,
173 error: &'static BuiltinErrorDescriptor,
174) -> RuntimeError {
175 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
176 if let Some(identifier) = error.identifier {
177 builder = builder.with_identifier(identifier);
178 }
179 builder.build()
180}
181
182fn replace_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
183 replace_error_with_message(error.message, error)
184}
185
186#[runtime_builtin(
187 name = "replace",
188 category = "strings/transform",
189 summary = "Replace substring occurrences in strings, character arrays, and cell arrays.",
190 keywords = "replace,strrep,strings,character array,text",
191 accel = "sink",
192 type_resolver(text_preserve_type),
193 descriptor(crate::builtins::strings::transform::replace::REPLACE_DESCRIPTOR),
194 builtin_path = "crate::builtins::strings::transform::replace"
195)]
196async fn replace_builtin(text: Value, old: Value, new: Value) -> BuiltinResult<Value> {
197 let text = gather_if_needed_async(&text).await.map_err(map_flow)?;
198 let old = gather_if_needed_async(&old).await.map_err(map_flow)?;
199 let new = gather_if_needed_async(&new).await.map_err(map_flow)?;
200
201 let spec = ReplacementSpec::from_values(&old, &new)?;
202
203 match text {
204 Value::String(s) => Ok(Value::String(replace_string_scalar(s, &spec))),
205 Value::StringArray(sa) => replace_string_array(sa, &spec),
206 Value::CharArray(ca) => replace_char_array(ca, &spec),
207 Value::Cell(cell) => replace_cell_array(cell, &spec),
208 _ => Err(replace_error(&REPLACE_ERROR_INVALID_INPUT)),
209 }
210}
211
212fn replace_string_scalar(text: String, spec: &ReplacementSpec) -> String {
213 if is_missing_string(&text) {
214 text
215 } else {
216 spec.apply(&text)
217 }
218}
219
220fn replace_string_array(array: StringArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
221 let StringArray { data, shape, .. } = array;
222 let mut replaced = Vec::with_capacity(data.len());
223 for entry in data {
224 if is_missing_string(&entry) {
225 replaced.push(entry);
226 } else {
227 replaced.push(spec.apply(&entry));
228 }
229 }
230 let result = StringArray::new(replaced, shape).map_err(|e| {
231 replace_error_with_message(format!("{BUILTIN_NAME}: {e}"), &REPLACE_ERROR_INTERNAL)
232 })?;
233 Ok(Value::StringArray(result))
234}
235
236fn replace_char_array(array: CharArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
237 let CharArray { data, rows, cols } = array;
238 if rows == 0 {
239 return Ok(Value::CharArray(CharArray { data, rows, cols }));
240 }
241
242 let mut replaced_rows = Vec::with_capacity(rows);
243 let mut target_cols = 0usize;
244 for row in 0..rows {
245 let slice = char_row_to_string_slice(&data, cols, row);
246 let replaced = spec.apply(&slice);
247 let len = replaced.chars().count();
248 target_cols = target_cols.max(len);
249 replaced_rows.push(replaced);
250 }
251
252 let mut flattened = Vec::with_capacity(rows * target_cols);
253 for row_text in replaced_rows {
254 let mut chars: Vec<char> = row_text.chars().collect();
255 if chars.len() < target_cols {
256 chars.resize(target_cols, ' ');
257 }
258 flattened.extend(chars);
259 }
260
261 CharArray::new(flattened, rows, target_cols)
262 .map(Value::CharArray)
263 .map_err(|e| {
264 replace_error_with_message(format!("{BUILTIN_NAME}: {e}"), &REPLACE_ERROR_INTERNAL)
265 })
266}
267
268fn replace_cell_array(cell: CellArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
269 let CellArray {
270 data, rows, cols, ..
271 } = cell;
272 let mut replaced = Vec::with_capacity(rows * cols);
273 for row in 0..rows {
274 for col in 0..cols {
275 let idx = row * cols + col;
276 let value = replace_cell_element(&data[idx], spec)?;
277 replaced.push(value);
278 }
279 }
280 make_cell(replaced, rows, cols).map_err(|e| {
281 replace_error_with_message(format!("{BUILTIN_NAME}: {e}"), &REPLACE_ERROR_INTERNAL)
282 })
283}
284
285fn replace_cell_element(value: &Value, spec: &ReplacementSpec) -> BuiltinResult<Value> {
286 match value {
287 Value::String(text) => Ok(Value::String(replace_string_scalar(text.clone(), spec))),
288 Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(replace_string_scalar(
289 sa.data[0].clone(),
290 spec,
291 ))),
292 Value::CharArray(ca) if ca.rows <= 1 => replace_char_array(ca.clone(), spec),
293 Value::CharArray(_) => Err(replace_error(&REPLACE_ERROR_CELL_ELEMENT)),
294 _ => Err(replace_error(&REPLACE_ERROR_CELL_ELEMENT)),
295 }
296}
297
298fn extract_pattern_list(value: &Value) -> BuiltinResult<Vec<String>> {
299 extract_text_list(value, &REPLACE_ERROR_PATTERN_TYPE)
300}
301
302fn extract_replacement_list(value: &Value) -> BuiltinResult<Vec<String>> {
303 extract_text_list(value, &REPLACE_ERROR_REPLACEMENT_TYPE)
304}
305
306fn extract_text_list(
307 value: &Value,
308 type_error: &'static BuiltinErrorDescriptor,
309) -> BuiltinResult<Vec<String>> {
310 match value {
311 Value::String(text) => Ok(vec![text.clone()]),
312 Value::StringArray(array) => Ok(array.data.clone()),
313 Value::CharArray(array) => {
314 let CharArray { data, rows, cols } = array.clone();
315 if rows == 0 {
316 Ok(Vec::new())
317 } else {
318 let mut entries = Vec::with_capacity(rows);
319 for row in 0..rows {
320 entries.push(char_row_to_string_slice(&data, cols, row));
321 }
322 Ok(entries)
323 }
324 }
325 Value::Cell(cell) => {
326 let CellArray { data, .. } = cell.clone();
327 let mut entries = Vec::with_capacity(data.len());
328 for element in data {
329 match &*element {
330 Value::String(text) => entries.push(text.clone()),
331 Value::StringArray(sa) if sa.data.len() == 1 => {
332 entries.push(sa.data[0].clone());
333 }
334 Value::CharArray(ca) if ca.rows <= 1 => {
335 if ca.rows == 0 {
336 entries.push(String::new());
337 } else {
338 entries.push(char_row_to_string_slice(&ca.data, ca.cols, 0));
339 }
340 }
341 Value::CharArray(_) => {
342 return Err(replace_error(&REPLACE_ERROR_CELL_ELEMENT));
343 }
344 _ => {
345 return Err(replace_error(&REPLACE_ERROR_CELL_ELEMENT));
346 }
347 }
348 }
349 Ok(entries)
350 }
351 _ => Err(replace_error(type_error)),
352 }
353}
354
355struct ReplacementSpec {
356 pairs: Vec<(String, String)>,
357}
358
359impl ReplacementSpec {
360 fn from_values(old: &Value, new: &Value) -> BuiltinResult<Self> {
361 let patterns = extract_pattern_list(old)?;
362 if patterns.is_empty() {
363 return Err(replace_error(&REPLACE_ERROR_EMPTY_PATTERN));
364 }
365
366 let replacements = extract_replacement_list(new)?;
367 if replacements.is_empty() {
368 return Err(replace_error(&REPLACE_ERROR_EMPTY_REPLACEMENT));
369 }
370
371 let pairs = if replacements.len() == patterns.len() {
372 patterns.into_iter().zip(replacements).collect::<Vec<_>>()
373 } else if replacements.len() == 1 {
374 let replacement = replacements[0].clone();
375 patterns
376 .into_iter()
377 .map(|pattern| (pattern, replacement.clone()))
378 .collect::<Vec<_>>()
379 } else {
380 return Err(replace_error(&REPLACE_ERROR_SIZE_MISMATCH));
381 };
382
383 Ok(Self { pairs })
384 }
385
386 fn apply(&self, input: &str) -> String {
387 let mut current = input.to_string();
388 for (pattern, replacement) in &self.pairs {
389 if pattern.is_empty() && replacement.is_empty() {
390 continue;
391 }
392 if pattern == replacement {
393 continue;
394 }
395 current = current.replace(pattern, replacement);
396 }
397 current
398 }
399}
400
401#[cfg(test)]
402pub(crate) mod tests {
403 use super::*;
404 use runmat_builtins::{ResolveContext, Type};
405
406 fn replace_builtin(text: Value, old: Value, new: Value) -> BuiltinResult<Value> {
407 futures::executor::block_on(super::replace_builtin(text, old, new))
408 }
409
410 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
411 #[test]
412 fn replace_string_scalar_single_term() {
413 let result = replace_builtin(
414 Value::String("RunMat runtime".into()),
415 Value::String("runtime".into()),
416 Value::String("engine".into()),
417 )
418 .expect("replace");
419 assert_eq!(result, Value::String("RunMat engine".into()));
420 }
421
422 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
423 #[test]
424 fn replace_string_array_multiple_terms() {
425 let strings = StringArray::new(
426 vec!["gpu".into(), "cpu".into(), "<missing>".into()],
427 vec![3, 1],
428 )
429 .unwrap();
430 let result = replace_builtin(
431 Value::StringArray(strings),
432 Value::StringArray(
433 StringArray::new(vec!["gpu".into(), "cpu".into()], vec![2, 1]).unwrap(),
434 ),
435 Value::String("device".into()),
436 )
437 .expect("replace");
438 match result {
439 Value::StringArray(sa) => {
440 assert_eq!(sa.shape, vec![3, 1]);
441 assert_eq!(
442 sa.data,
443 vec![
444 String::from("device"),
445 String::from("device"),
446 String::from("<missing>")
447 ]
448 );
449 }
450 other => panic!("expected string array, got {other:?}"),
451 }
452 }
453
454 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
455 #[test]
456 fn replace_char_array_adjusts_width() {
457 let chars = CharArray::new("matrix".chars().collect(), 1, 6).unwrap();
458 let result = replace_builtin(
459 Value::CharArray(chars),
460 Value::String("matrix".into()),
461 Value::String("tensor".into()),
462 )
463 .expect("replace");
464 match result {
465 Value::CharArray(out) => {
466 assert_eq!(out.rows, 1);
467 assert_eq!(out.cols, 6);
468 let expected: Vec<char> = "tensor".chars().collect();
469 assert_eq!(out.data, expected);
470 }
471 other => panic!("expected char array, got {other:?}"),
472 }
473 }
474
475 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
476 #[test]
477 fn replace_char_array_handles_padding() {
478 let chars = CharArray::new(vec!['a', 'b', 'c', 'd'], 2, 2).unwrap();
479 let result = replace_builtin(
480 Value::CharArray(chars),
481 Value::String("b".into()),
482 Value::String("beta".into()),
483 )
484 .expect("replace");
485 match result {
486 Value::CharArray(out) => {
487 assert_eq!(out.rows, 2);
488 assert_eq!(out.cols, 5);
489 let expected: Vec<char> = vec!['a', 'b', 'e', 't', 'a', 'c', 'd', ' ', ' ', ' '];
490 assert_eq!(out.data, expected);
491 }
492 other => panic!("expected char array, got {other:?}"),
493 }
494 }
495
496 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
497 #[test]
498 fn replace_cell_array_mixed_content() {
499 let cell = CellArray::new(
500 vec![
501 Value::CharArray(CharArray::new_row("Kernel Planner")),
502 Value::String("GPU Fusion".into()),
503 ],
504 1,
505 2,
506 )
507 .unwrap();
508 let result = replace_builtin(
509 Value::Cell(cell),
510 Value::Cell(
511 CellArray::new(
512 vec![Value::String("Kernel".into()), Value::String("GPU".into())],
513 1,
514 2,
515 )
516 .unwrap(),
517 ),
518 Value::StringArray(
519 StringArray::new(vec!["Shader".into(), "Device".into()], vec![1, 2]).unwrap(),
520 ),
521 )
522 .expect("replace");
523 match result {
524 Value::Cell(out) => {
525 let first = out.get(0, 0).unwrap();
526 let second = out.get(0, 1).unwrap();
527 assert_eq!(
528 first,
529 Value::CharArray(CharArray::new_row("Shader Planner"))
530 );
531 assert_eq!(second, Value::String("Device Fusion".into()));
532 }
533 other => panic!("expected cell array, got {other:?}"),
534 }
535 }
536
537 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
538 #[test]
539 fn replace_errors_on_invalid_first_argument() {
540 let err = replace_builtin(
541 Value::Num(1.0),
542 Value::String("a".into()),
543 Value::String("b".into()),
544 )
545 .unwrap_err();
546 assert_eq!(err.to_string(), REPLACE_ERROR_INVALID_INPUT.message);
547 }
548
549 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
550 #[test]
551 fn replace_errors_on_invalid_pattern_type() {
552 let err = replace_builtin(
553 Value::String("abc".into()),
554 Value::Num(1.0),
555 Value::String("x".into()),
556 )
557 .unwrap_err();
558 assert_eq!(err.to_string(), REPLACE_ERROR_PATTERN_TYPE.message);
559 }
560
561 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
562 #[test]
563 fn replace_errors_on_size_mismatch() {
564 let err = replace_builtin(
565 Value::String("abc".into()),
566 Value::StringArray(StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap()),
567 Value::StringArray(
568 StringArray::new(vec!["x".into(), "y".into(), "z".into()], vec![3, 1]).unwrap(),
569 ),
570 )
571 .unwrap_err();
572 assert_eq!(err.to_string(), REPLACE_ERROR_SIZE_MISMATCH.message);
573 }
574
575 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
576 #[test]
577 fn replace_preserves_missing_string() {
578 let result = replace_builtin(
579 Value::String("<missing>".into()),
580 Value::String("missing".into()),
581 Value::String("value".into()),
582 )
583 .expect("replace");
584 assert_eq!(result, Value::String("<missing>".into()));
585 }
586
587 #[test]
588 fn replace_type_preserves_text() {
589 assert_eq!(
590 text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
591 Type::String
592 );
593 }
594}