1use runmat_builtins::{CellArray, CharArray, StringArray, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::spec::{
8 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9 ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
12use crate::builtins::strings::type_resolvers::text_preserve_type;
13use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::replace")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17 name: "replace",
18 op_kind: GpuOpKind::Custom("string-transform"),
19 supported_precisions: &[],
20 broadcast: BroadcastSemantics::None,
21 provider_hooks: &[],
22 constant_strategy: ConstantStrategy::InlineLiteral,
23 residency: ResidencyPolicy::GatherImmediately,
24 nan_mode: ReductionNaN::Include,
25 two_pass_threshold: None,
26 workgroup_size: None,
27 accepts_nan_mode: false,
28 notes:
29 "Executes on the CPU; GPU-resident inputs are gathered to host memory prior to replacement.",
30};
31
32#[runmat_macros::register_fusion_spec(
33 builtin_path = "crate::builtins::strings::transform::replace"
34)]
35pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
36 name: "replace",
37 shape: ShapeRequirements::Any,
38 constant_strategy: ConstantStrategy::InlineLiteral,
39 elementwise: None,
40 reduction: None,
41 emits_nan: false,
42 notes:
43 "String manipulation builtin; not eligible for fusion plans and always gathers GPU inputs.",
44};
45
46const BUILTIN_NAME: &str = "replace";
47const ARG_TYPE_ERROR: &str =
48 "replace: first argument must be a string array, character array, or cell array of character vectors";
49const PATTERN_TYPE_ERROR: &str =
50 "replace: second argument must be a string array, character array, or cell array of character vectors";
51const REPLACEMENT_TYPE_ERROR: &str =
52 "replace: third argument must be a string array, character array, or cell array of character vectors";
53const EMPTY_PATTERN_ERROR: &str =
54 "replace: second argument must contain at least one search string";
55const EMPTY_REPLACEMENT_ERROR: &str =
56 "replace: third argument must contain at least one replacement string";
57const SIZE_MISMATCH_ERROR: &str =
58 "replace: replacement array must be a scalar or match the number of search strings";
59const CELL_ELEMENT_ERROR: &str =
60 "replace: cell array elements must be string scalars or character vectors";
61
62fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
63 build_runtime_error(message)
64 .with_builtin(BUILTIN_NAME)
65 .build()
66}
67
68fn map_flow(err: RuntimeError) -> RuntimeError {
69 map_control_flow_with_builtin(err, BUILTIN_NAME)
70}
71
72#[runtime_builtin(
73 name = "replace",
74 category = "strings/transform",
75 summary = "Replace substring occurrences in strings, character arrays, and cell arrays.",
76 keywords = "replace,strrep,strings,character array,text",
77 accel = "sink",
78 type_resolver(text_preserve_type),
79 builtin_path = "crate::builtins::strings::transform::replace"
80)]
81async fn replace_builtin(text: Value, old: Value, new: Value) -> BuiltinResult<Value> {
82 let text = gather_if_needed_async(&text).await.map_err(map_flow)?;
83 let old = gather_if_needed_async(&old).await.map_err(map_flow)?;
84 let new = gather_if_needed_async(&new).await.map_err(map_flow)?;
85
86 let spec = ReplacementSpec::from_values(&old, &new)?;
87
88 match text {
89 Value::String(s) => Ok(Value::String(replace_string_scalar(s, &spec))),
90 Value::StringArray(sa) => replace_string_array(sa, &spec),
91 Value::CharArray(ca) => replace_char_array(ca, &spec),
92 Value::Cell(cell) => replace_cell_array(cell, &spec),
93 _ => Err(runtime_error_for(ARG_TYPE_ERROR)),
94 }
95}
96
97fn replace_string_scalar(text: String, spec: &ReplacementSpec) -> String {
98 if is_missing_string(&text) {
99 text
100 } else {
101 spec.apply(&text)
102 }
103}
104
105fn replace_string_array(array: StringArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
106 let StringArray { data, shape, .. } = array;
107 let mut replaced = Vec::with_capacity(data.len());
108 for entry in data {
109 if is_missing_string(&entry) {
110 replaced.push(entry);
111 } else {
112 replaced.push(spec.apply(&entry));
113 }
114 }
115 let result = StringArray::new(replaced, shape)
116 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
117 Ok(Value::StringArray(result))
118}
119
120fn replace_char_array(array: CharArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
121 let CharArray { data, rows, cols } = array;
122 if rows == 0 {
123 return Ok(Value::CharArray(CharArray { data, rows, cols }));
124 }
125
126 let mut replaced_rows = Vec::with_capacity(rows);
127 let mut target_cols = 0usize;
128 for row in 0..rows {
129 let slice = char_row_to_string_slice(&data, cols, row);
130 let replaced = spec.apply(&slice);
131 let len = replaced.chars().count();
132 target_cols = target_cols.max(len);
133 replaced_rows.push(replaced);
134 }
135
136 let mut flattened = Vec::with_capacity(rows * target_cols);
137 for row_text in replaced_rows {
138 let mut chars: Vec<char> = row_text.chars().collect();
139 if chars.len() < target_cols {
140 chars.resize(target_cols, ' ');
141 }
142 flattened.extend(chars);
143 }
144
145 CharArray::new(flattened, rows, target_cols)
146 .map(Value::CharArray)
147 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
148}
149
150fn replace_cell_array(cell: CellArray, spec: &ReplacementSpec) -> BuiltinResult<Value> {
151 let CellArray {
152 data, rows, cols, ..
153 } = cell;
154 let mut replaced = Vec::with_capacity(rows * cols);
155 for row in 0..rows {
156 for col in 0..cols {
157 let idx = row * cols + col;
158 let value = replace_cell_element(&data[idx], spec)?;
159 replaced.push(value);
160 }
161 }
162 make_cell(replaced, rows, cols).map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
163}
164
165fn replace_cell_element(value: &Value, spec: &ReplacementSpec) -> BuiltinResult<Value> {
166 match value {
167 Value::String(text) => Ok(Value::String(replace_string_scalar(text.clone(), spec))),
168 Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(replace_string_scalar(
169 sa.data[0].clone(),
170 spec,
171 ))),
172 Value::CharArray(ca) if ca.rows <= 1 => replace_char_array(ca.clone(), spec),
173 Value::CharArray(_) => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
174 _ => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
175 }
176}
177
178fn extract_pattern_list(value: &Value) -> BuiltinResult<Vec<String>> {
179 extract_text_list(value, PATTERN_TYPE_ERROR)
180}
181
182fn extract_replacement_list(value: &Value) -> BuiltinResult<Vec<String>> {
183 extract_text_list(value, REPLACEMENT_TYPE_ERROR)
184}
185
186fn extract_text_list(value: &Value, type_error: &str) -> BuiltinResult<Vec<String>> {
187 match value {
188 Value::String(text) => Ok(vec![text.clone()]),
189 Value::StringArray(array) => Ok(array.data.clone()),
190 Value::CharArray(array) => {
191 let CharArray { data, rows, cols } = array.clone();
192 if rows == 0 {
193 Ok(Vec::new())
194 } else {
195 let mut entries = Vec::with_capacity(rows);
196 for row in 0..rows {
197 entries.push(char_row_to_string_slice(&data, cols, row));
198 }
199 Ok(entries)
200 }
201 }
202 Value::Cell(cell) => {
203 let CellArray { data, .. } = cell.clone();
204 let mut entries = Vec::with_capacity(data.len());
205 for element in data {
206 match &*element {
207 Value::String(text) => entries.push(text.clone()),
208 Value::StringArray(sa) if sa.data.len() == 1 => {
209 entries.push(sa.data[0].clone());
210 }
211 Value::CharArray(ca) if ca.rows <= 1 => {
212 if ca.rows == 0 {
213 entries.push(String::new());
214 } else {
215 entries.push(char_row_to_string_slice(&ca.data, ca.cols, 0));
216 }
217 }
218 Value::CharArray(_) => return Err(runtime_error_for(CELL_ELEMENT_ERROR)),
219 _ => return Err(runtime_error_for(CELL_ELEMENT_ERROR)),
220 }
221 }
222 Ok(entries)
223 }
224 _ => Err(runtime_error_for(type_error)),
225 }
226}
227
228struct ReplacementSpec {
229 pairs: Vec<(String, String)>,
230}
231
232impl ReplacementSpec {
233 fn from_values(old: &Value, new: &Value) -> BuiltinResult<Self> {
234 let patterns = extract_pattern_list(old)?;
235 if patterns.is_empty() {
236 return Err(runtime_error_for(EMPTY_PATTERN_ERROR));
237 }
238
239 let replacements = extract_replacement_list(new)?;
240 if replacements.is_empty() {
241 return Err(runtime_error_for(EMPTY_REPLACEMENT_ERROR));
242 }
243
244 let pairs = if replacements.len() == patterns.len() {
245 patterns.into_iter().zip(replacements).collect::<Vec<_>>()
246 } else if replacements.len() == 1 {
247 let replacement = replacements[0].clone();
248 patterns
249 .into_iter()
250 .map(|pattern| (pattern, replacement.clone()))
251 .collect::<Vec<_>>()
252 } else {
253 return Err(runtime_error_for(SIZE_MISMATCH_ERROR));
254 };
255
256 Ok(Self { pairs })
257 }
258
259 fn apply(&self, input: &str) -> String {
260 let mut current = input.to_string();
261 for (pattern, replacement) in &self.pairs {
262 if pattern.is_empty() && replacement.is_empty() {
263 continue;
264 }
265 if pattern == replacement {
266 continue;
267 }
268 current = current.replace(pattern, replacement);
269 }
270 current
271 }
272}
273
274#[cfg(test)]
275pub(crate) mod tests {
276 use super::*;
277 use runmat_builtins::{ResolveContext, Type};
278
279 fn replace_builtin(text: Value, old: Value, new: Value) -> BuiltinResult<Value> {
280 futures::executor::block_on(super::replace_builtin(text, old, new))
281 }
282
283 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
284 #[test]
285 fn replace_string_scalar_single_term() {
286 let result = replace_builtin(
287 Value::String("RunMat runtime".into()),
288 Value::String("runtime".into()),
289 Value::String("engine".into()),
290 )
291 .expect("replace");
292 assert_eq!(result, Value::String("RunMat engine".into()));
293 }
294
295 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
296 #[test]
297 fn replace_string_array_multiple_terms() {
298 let strings = StringArray::new(
299 vec!["gpu".into(), "cpu".into(), "<missing>".into()],
300 vec![3, 1],
301 )
302 .unwrap();
303 let result = replace_builtin(
304 Value::StringArray(strings),
305 Value::StringArray(
306 StringArray::new(vec!["gpu".into(), "cpu".into()], vec![2, 1]).unwrap(),
307 ),
308 Value::String("device".into()),
309 )
310 .expect("replace");
311 match result {
312 Value::StringArray(sa) => {
313 assert_eq!(sa.shape, vec![3, 1]);
314 assert_eq!(
315 sa.data,
316 vec![
317 String::from("device"),
318 String::from("device"),
319 String::from("<missing>")
320 ]
321 );
322 }
323 other => panic!("expected string array, got {other:?}"),
324 }
325 }
326
327 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
328 #[test]
329 fn replace_char_array_adjusts_width() {
330 let chars = CharArray::new("matrix".chars().collect(), 1, 6).unwrap();
331 let result = replace_builtin(
332 Value::CharArray(chars),
333 Value::String("matrix".into()),
334 Value::String("tensor".into()),
335 )
336 .expect("replace");
337 match result {
338 Value::CharArray(out) => {
339 assert_eq!(out.rows, 1);
340 assert_eq!(out.cols, 6);
341 let expected: Vec<char> = "tensor".chars().collect();
342 assert_eq!(out.data, expected);
343 }
344 other => panic!("expected char array, got {other:?}"),
345 }
346 }
347
348 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
349 #[test]
350 fn replace_char_array_handles_padding() {
351 let chars = CharArray::new(vec!['a', 'b', 'c', 'd'], 2, 2).unwrap();
352 let result = replace_builtin(
353 Value::CharArray(chars),
354 Value::String("b".into()),
355 Value::String("beta".into()),
356 )
357 .expect("replace");
358 match result {
359 Value::CharArray(out) => {
360 assert_eq!(out.rows, 2);
361 assert_eq!(out.cols, 5);
362 let expected: Vec<char> = vec!['a', 'b', 'e', 't', 'a', 'c', 'd', ' ', ' ', ' '];
363 assert_eq!(out.data, expected);
364 }
365 other => panic!("expected char array, got {other:?}"),
366 }
367 }
368
369 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
370 #[test]
371 fn replace_cell_array_mixed_content() {
372 let cell = CellArray::new(
373 vec![
374 Value::CharArray(CharArray::new_row("Kernel Planner")),
375 Value::String("GPU Fusion".into()),
376 ],
377 1,
378 2,
379 )
380 .unwrap();
381 let result = replace_builtin(
382 Value::Cell(cell),
383 Value::Cell(
384 CellArray::new(
385 vec![Value::String("Kernel".into()), Value::String("GPU".into())],
386 1,
387 2,
388 )
389 .unwrap(),
390 ),
391 Value::StringArray(
392 StringArray::new(vec!["Shader".into(), "Device".into()], vec![1, 2]).unwrap(),
393 ),
394 )
395 .expect("replace");
396 match result {
397 Value::Cell(out) => {
398 let first = out.get(0, 0).unwrap();
399 let second = out.get(0, 1).unwrap();
400 assert_eq!(
401 first,
402 Value::CharArray(CharArray::new_row("Shader Planner"))
403 );
404 assert_eq!(second, Value::String("Device Fusion".into()));
405 }
406 other => panic!("expected cell array, got {other:?}"),
407 }
408 }
409
410 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
411 #[test]
412 fn replace_errors_on_invalid_first_argument() {
413 let err = replace_builtin(
414 Value::Num(1.0),
415 Value::String("a".into()),
416 Value::String("b".into()),
417 )
418 .unwrap_err();
419 assert_eq!(err.to_string(), ARG_TYPE_ERROR);
420 }
421
422 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
423 #[test]
424 fn replace_errors_on_invalid_pattern_type() {
425 let err = replace_builtin(
426 Value::String("abc".into()),
427 Value::Num(1.0),
428 Value::String("x".into()),
429 )
430 .unwrap_err();
431 assert_eq!(err.to_string(), PATTERN_TYPE_ERROR);
432 }
433
434 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
435 #[test]
436 fn replace_errors_on_size_mismatch() {
437 let err = replace_builtin(
438 Value::String("abc".into()),
439 Value::StringArray(StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap()),
440 Value::StringArray(
441 StringArray::new(vec!["x".into(), "y".into(), "z".into()], vec![3, 1]).unwrap(),
442 ),
443 )
444 .unwrap_err();
445 assert_eq!(err.to_string(), SIZE_MISMATCH_ERROR);
446 }
447
448 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
449 #[test]
450 fn replace_preserves_missing_string() {
451 let result = replace_builtin(
452 Value::String("<missing>".into()),
453 Value::String("missing".into()),
454 Value::String("value".into()),
455 )
456 .expect("replace");
457 assert_eq!(result, Value::String("<missing>".into()));
458 }
459
460 #[test]
461 fn replace_type_preserves_text() {
462 assert_eq!(
463 text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
464 Type::String
465 );
466 }
467}