1use runmat_builtins::{CellArray, CharArray, StringArray, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::spec::{
8 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9 ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
12use crate::builtins::strings::type_resolvers::text_preserve_type;
13use crate::{
14 build_runtime_error, gather_if_needed_async, make_cell_with_shape, BuiltinResult, RuntimeError,
15};
16
17#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::strrep")]
18pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
19 name: "strrep",
20 op_kind: GpuOpKind::Custom("string-transform"),
21 supported_precisions: &[],
22 broadcast: BroadcastSemantics::None,
23 provider_hooks: &[],
24 constant_strategy: ConstantStrategy::InlineLiteral,
25 residency: ResidencyPolicy::GatherImmediately,
26 nan_mode: ReductionNaN::Include,
27 two_pass_threshold: None,
28 workgroup_size: None,
29 accepts_nan_mode: false,
30 notes: "Executes on the CPU; GPU-resident inputs are gathered before replacements are applied.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::strrep")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35 name: "strrep",
36 shape: ShapeRequirements::Any,
37 constant_strategy: ConstantStrategy::InlineLiteral,
38 elementwise: None,
39 reduction: None,
40 emits_nan: false,
41 notes: "String transformation builtin; marked as a sink so fusion skips GPU residency.",
42};
43
44const BUILTIN_NAME: &str = "strrep";
45const ARGUMENT_TYPE_ERROR: &str =
46 "strrep: first argument must be a string array, character array, or cell array of character vectors";
47const PATTERN_TYPE_ERROR: &str = "strrep: old and new must be string scalars or character vectors";
48const PATTERN_MISMATCH_ERROR: &str = "strrep: old and new must be the same data type";
49const CELL_ELEMENT_ERROR: &str =
50 "strrep: cell array elements must be string scalars or character vectors";
51
52#[derive(Clone, Copy, PartialEq, Eq)]
53enum PatternKind {
54 String,
55 Char,
56}
57
58fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
59 build_runtime_error(message)
60 .with_builtin(BUILTIN_NAME)
61 .build()
62}
63
64fn map_flow(err: RuntimeError) -> RuntimeError {
65 map_control_flow_with_builtin(err, BUILTIN_NAME)
66}
67
68#[runtime_builtin(
69 name = "strrep",
70 category = "strings/transform",
71 summary = "Replace substring occurrences with MATLAB-compatible semantics.",
72 keywords = "strrep,replace,strings,character array,text",
73 accel = "sink",
74 type_resolver(text_preserve_type),
75 builtin_path = "crate::builtins::strings::transform::strrep"
76)]
77async fn strrep_builtin(
78 str_value: Value,
79 old_value: Value,
80 new_value: Value,
81) -> BuiltinResult<Value> {
82 let gathered_str = gather_if_needed_async(&str_value).await.map_err(map_flow)?;
83 let gathered_old = gather_if_needed_async(&old_value).await.map_err(map_flow)?;
84 let gathered_new = gather_if_needed_async(&new_value).await.map_err(map_flow)?;
85
86 let (old_text, old_kind) = parse_pattern(gathered_old)?;
87 let (new_text, new_kind) = parse_pattern(gathered_new)?;
88 if old_kind != new_kind {
89 return Err(runtime_error_for(PATTERN_MISMATCH_ERROR));
90 }
91
92 match gathered_str {
93 Value::String(text) => Ok(Value::String(strrep_string_value(
94 text, &old_text, &new_text,
95 ))),
96 Value::StringArray(array) => strrep_string_array(array, &old_text, &new_text),
97 Value::CharArray(array) => strrep_char_array(array, &old_text, &new_text),
98 Value::Cell(cell) => strrep_cell_array(cell, &old_text, &new_text),
99 _ => Err(runtime_error_for(ARGUMENT_TYPE_ERROR)),
100 }
101}
102
103fn parse_pattern(value: Value) -> BuiltinResult<(String, PatternKind)> {
104 match value {
105 Value::String(text) => Ok((text, PatternKind::String)),
106 Value::StringArray(array) => {
107 if array.data.len() == 1 {
108 Ok((array.data[0].clone(), PatternKind::String))
109 } else {
110 Err(runtime_error_for(PATTERN_TYPE_ERROR))
111 }
112 }
113 Value::CharArray(array) => {
114 if array.rows <= 1 {
115 let text = if array.rows == 0 {
116 String::new()
117 } else {
118 char_row_to_string_slice(&array.data, array.cols, 0)
119 };
120 Ok((text, PatternKind::Char))
121 } else {
122 Err(runtime_error_for(PATTERN_TYPE_ERROR))
123 }
124 }
125 _ => Err(runtime_error_for(PATTERN_TYPE_ERROR)),
126 }
127}
128
129fn strrep_string_value(text: String, old: &str, new: &str) -> String {
130 if is_missing_string(&text) {
131 text
132 } else {
133 text.replace(old, new)
134 }
135}
136
137fn strrep_string_array(array: StringArray, old: &str, new: &str) -> BuiltinResult<Value> {
138 let StringArray { data, shape, .. } = array;
139 let replaced = data
140 .into_iter()
141 .map(|text| strrep_string_value(text, old, new))
142 .collect::<Vec<_>>();
143 let rebuilt = StringArray::new(replaced, shape)
144 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
145 Ok(Value::StringArray(rebuilt))
146}
147
148fn strrep_char_array(array: CharArray, old: &str, new: &str) -> BuiltinResult<Value> {
149 let CharArray { data, rows, cols } = array;
150 if rows == 0 || cols == 0 {
151 return Ok(Value::CharArray(CharArray { data, rows, cols }));
152 }
153
154 let mut replaced_rows = Vec::with_capacity(rows);
155 let mut target_cols = 0usize;
156 for row in 0..rows {
157 let text = char_row_to_string_slice(&data, cols, row);
158 let replaced = text.replace(old, new);
159 target_cols = target_cols.max(replaced.chars().count());
160 replaced_rows.push(replaced);
161 }
162
163 let mut new_data = Vec::with_capacity(rows * target_cols);
164 for row_text in replaced_rows {
165 let mut chars: Vec<char> = row_text.chars().collect();
166 if chars.len() < target_cols {
167 chars.resize(target_cols, ' ');
168 }
169 new_data.extend(chars);
170 }
171
172 CharArray::new(new_data, rows, target_cols)
173 .map(Value::CharArray)
174 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
175}
176
177fn strrep_cell_array(cell: CellArray, old: &str, new: &str) -> BuiltinResult<Value> {
178 let CellArray { data, shape, .. } = cell;
179 let mut replaced = Vec::with_capacity(data.len());
180 for ptr in &data {
181 replaced.push(strrep_cell_element(ptr, old, new)?);
182 }
183 make_cell_with_shape(replaced, shape)
184 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
185}
186
187fn strrep_cell_element(value: &Value, old: &str, new: &str) -> BuiltinResult<Value> {
188 match value {
189 Value::String(text) => Ok(Value::String(strrep_string_value(text.clone(), old, new))),
190 Value::StringArray(array) => strrep_string_array(array.clone(), old, new),
191 Value::CharArray(array) => strrep_char_array(array.clone(), old, new),
192 _ => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
193 }
194}
195
196#[cfg(test)]
197pub(crate) mod tests {
198 use super::*;
199 use runmat_builtins::{ResolveContext, Type};
200
201 fn run_strrep(str_value: Value, old_value: Value, new_value: Value) -> BuiltinResult<Value> {
202 futures::executor::block_on(strrep_builtin(str_value, old_value, new_value))
203 }
204
205 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
206 #[test]
207 fn strrep_string_scalar_basic() {
208 let result = run_strrep(
209 Value::String("RunMat Ignite".into()),
210 Value::String("Ignite".into()),
211 Value::String("Interpreter".into()),
212 )
213 .expect("strrep");
214 assert_eq!(result, Value::String("RunMat Interpreter".into()));
215 }
216
217 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
218 #[test]
219 fn strrep_string_array_preserves_missing() {
220 let array = StringArray::new(
221 vec![
222 String::from("gpu"),
223 String::from("<missing>"),
224 String::from("planner"),
225 ],
226 vec![3, 1],
227 )
228 .unwrap();
229 let result = run_strrep(
230 Value::StringArray(array),
231 Value::String("gpu".into()),
232 Value::String("GPU".into()),
233 )
234 .expect("strrep");
235 match result {
236 Value::StringArray(sa) => {
237 assert_eq!(sa.shape, vec![3, 1]);
238 assert_eq!(
239 sa.data,
240 vec![
241 String::from("GPU"),
242 String::from("<missing>"),
243 String::from("planner")
244 ]
245 );
246 }
247 other => panic!("expected string array, got {other:?}"),
248 }
249 }
250
251 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
252 #[test]
253 fn strrep_string_array_with_char_pattern() {
254 let array = StringArray::new(
255 vec![String::from("alpha"), String::from("beta")],
256 vec![2, 1],
257 )
258 .unwrap();
259 let result = run_strrep(
260 Value::StringArray(array),
261 Value::CharArray(CharArray::new_row("a")),
262 Value::CharArray(CharArray::new_row("A")),
263 )
264 .expect("strrep");
265 match result {
266 Value::StringArray(sa) => {
267 assert_eq!(sa.shape, vec![2, 1]);
268 assert_eq!(sa.data, vec![String::from("AlphA"), String::from("betA")]);
269 }
270 other => panic!("expected string array, got {other:?}"),
271 }
272 }
273
274 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
275 #[test]
276 fn strrep_char_array_padding() {
277 let chars = CharArray::new(vec!['R', 'u', 'n', ' ', 'M', 'a', 't'], 1, 7).unwrap();
278 let result = run_strrep(
279 Value::CharArray(chars),
280 Value::String(" ".into()),
281 Value::String("_".into()),
282 )
283 .expect("strrep");
284 match result {
285 Value::CharArray(out) => {
286 assert_eq!(out.rows, 1);
287 assert_eq!(out.cols, 7);
288 let expected: Vec<char> = "Run_Mat".chars().collect();
289 assert_eq!(out.data, expected);
290 }
291 other => panic!("expected char array, got {other:?}"),
292 }
293 }
294
295 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
296 #[test]
297 fn strrep_char_array_shrinks_rows_pad_with_spaces() {
298 let mut data: Vec<char> = "alpha".chars().collect();
299 data.extend("beta ".chars());
300 let array = CharArray::new(data, 2, 5).unwrap();
301 let result = run_strrep(
302 Value::CharArray(array),
303 Value::String("a".into()),
304 Value::String("".into()),
305 )
306 .expect("strrep");
307 match result {
308 Value::CharArray(out) => {
309 assert_eq!(out.rows, 2);
310 assert_eq!(out.cols, 4);
311 let expected: Vec<char> = vec!['l', 'p', 'h', ' ', 'b', 'e', 't', ' '];
312 assert_eq!(out.data, expected);
313 }
314 other => panic!("expected char array, got {other:?}"),
315 }
316 }
317
318 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
319 #[test]
320 fn strrep_cell_array_char_vectors() {
321 let cell = CellArray::new(
322 vec![
323 Value::CharArray(CharArray::new_row("Kernel Fusion")),
324 Value::CharArray(CharArray::new_row("GPU Planner")),
325 ],
326 1,
327 2,
328 )
329 .unwrap();
330 let result = run_strrep(
331 Value::Cell(cell),
332 Value::String(" ".into()),
333 Value::String("_".into()),
334 )
335 .expect("strrep");
336 match result {
337 Value::Cell(out) => {
338 assert_eq!(out.rows, 1);
339 assert_eq!(out.cols, 2);
340 assert_eq!(
341 out.get(0, 0).unwrap(),
342 Value::CharArray(CharArray::new_row("Kernel_Fusion"))
343 );
344 assert_eq!(
345 out.get(0, 1).unwrap(),
346 Value::CharArray(CharArray::new_row("GPU_Planner"))
347 );
348 }
349 other => panic!("expected cell array, got {other:?}"),
350 }
351 }
352
353 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
354 #[test]
355 fn strrep_cell_array_string_scalars() {
356 let cell = CellArray::new(
357 vec![
358 Value::String("Planner".into()),
359 Value::String("Profiler".into()),
360 ],
361 1,
362 2,
363 )
364 .unwrap();
365 let result = run_strrep(
366 Value::Cell(cell),
367 Value::String("er".into()),
368 Value::String("ER".into()),
369 )
370 .expect("strrep");
371 match result {
372 Value::Cell(out) => {
373 assert_eq!(out.rows, 1);
374 assert_eq!(out.cols, 2);
375 assert_eq!(out.get(0, 0).unwrap(), Value::String("PlannER".into()));
376 assert_eq!(out.get(0, 1).unwrap(), Value::String("ProfilER".into()));
377 }
378 other => panic!("expected cell array, got {other:?}"),
379 }
380 }
381
382 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
383 #[test]
384 fn strrep_cell_array_invalid_element_error() {
385 let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
386 let err = run_strrep(
387 Value::Cell(cell),
388 Value::String("1".into()),
389 Value::String("one".into()),
390 )
391 .expect_err("expected cell element error");
392 assert!(err.to_string().contains("cell array elements"));
393 }
394
395 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
396 #[test]
397 fn strrep_cell_array_char_matrix_element() {
398 let mut chars: Vec<char> = "alpha".chars().collect();
399 chars.extend("beta ".chars());
400 let element = CharArray::new(chars, 2, 5).unwrap();
401 let cell = CellArray::new(vec![Value::CharArray(element)], 1, 1).unwrap();
402 let result = run_strrep(
403 Value::Cell(cell),
404 Value::String("a".into()),
405 Value::String("A".into()),
406 )
407 .expect("strrep");
408 match result {
409 Value::Cell(out) => {
410 let nested = out.get(0, 0).unwrap();
411 match nested {
412 Value::CharArray(ca) => {
413 assert_eq!(ca.rows, 2);
414 assert_eq!(ca.cols, 5);
415 let expected: Vec<char> =
416 vec!['A', 'l', 'p', 'h', 'A', 'b', 'e', 't', 'A', ' '];
417 assert_eq!(ca.data, expected);
418 }
419 other => panic!("expected char array element, got {other:?}"),
420 }
421 }
422 other => panic!("expected cell array, got {other:?}"),
423 }
424 }
425
426 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
427 #[test]
428 fn strrep_cell_array_string_arrays() {
429 let element = StringArray::new(vec!["alpha".into(), "beta".into()], vec![1, 2]).unwrap();
430 let cell = CellArray::new(vec![Value::StringArray(element)], 1, 1).unwrap();
431 let result = run_strrep(
432 Value::Cell(cell),
433 Value::String("a".into()),
434 Value::String("A".into()),
435 )
436 .expect("strrep");
437 match result {
438 Value::Cell(out) => {
439 let nested = out.get(0, 0).unwrap();
440 match nested {
441 Value::StringArray(sa) => {
442 assert_eq!(sa.shape, vec![1, 2]);
443 assert_eq!(sa.data, vec![String::from("AlphA"), String::from("betA")]);
444 }
445 other => panic!("expected string array element, got {other:?}"),
446 }
447 }
448 other => panic!("expected cell array, got {other:?}"),
449 }
450 }
451
452 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
453 #[test]
454 fn strrep_empty_pattern_inserts_replacement() {
455 let result = run_strrep(
456 Value::String("abc".into()),
457 Value::String("".into()),
458 Value::String("-".into()),
459 )
460 .expect("strrep");
461 assert_eq!(result, Value::String("-a-b-c-".into()));
462 }
463
464 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
465 #[test]
466 fn strrep_type_mismatch_errors() {
467 let err = run_strrep(
468 Value::String("abc".into()),
469 Value::String("a".into()),
470 Value::CharArray(CharArray::new_row("x")),
471 )
472 .expect_err("expected type mismatch");
473 assert!(err.to_string().contains("same data type"));
474 }
475
476 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
477 #[test]
478 fn strrep_invalid_pattern_type_errors() {
479 let err = run_strrep(
480 Value::String("abc".into()),
481 Value::Num(1.0),
482 Value::String("x".into()),
483 )
484 .expect_err("expected pattern error");
485 assert!(err
486 .to_string()
487 .contains("string scalars or character vectors"));
488 }
489
490 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
491 #[test]
492 fn strrep_first_argument_type_error() {
493 let err = run_strrep(
494 Value::Num(42.0),
495 Value::String("a".into()),
496 Value::String("b".into()),
497 )
498 .expect_err("expected argument type error");
499 assert!(err.to_string().contains("first argument"));
500 }
501
502 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
503 #[test]
504 #[cfg(feature = "wgpu")]
505 fn strrep_wgpu_provider_fallback() {
506 if runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
507 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
508 )
509 .is_err()
510 {
511 return;
513 }
514 let result = run_strrep(
515 Value::String("Turbine Engine".into()),
516 Value::String("Engine".into()),
517 Value::String("JIT".into()),
518 )
519 .expect("strrep");
520 assert_eq!(result, Value::String("Turbine JIT".into()));
521 }
522
523 #[test]
524 fn strrep_type_preserves_text() {
525 assert_eq!(
526 text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
527 Type::String
528 );
529 }
530}