Skip to main content

runmat_runtime/builtins/strings/transform/
strtrim.rs

1//! MATLAB-compatible `strtrim` builtin with GPU-aware semantics for RunMat.
2
3use runmat_builtins::{CellArray, CharArray, StringArray, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::spec::{
8    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9    ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
12use crate::builtins::strings::type_resolvers::text_preserve_type;
13use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::strtrim")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17    name: "strtrim",
18    op_kind: GpuOpKind::Custom("string-transform"),
19    supported_precisions: &[],
20    broadcast: BroadcastSemantics::None,
21    provider_hooks: &[],
22    constant_strategy: ConstantStrategy::InlineLiteral,
23    residency: ResidencyPolicy::GatherImmediately,
24    nan_mode: ReductionNaN::Include,
25    two_pass_threshold: None,
26    workgroup_size: None,
27    accepts_nan_mode: false,
28    notes:
29        "Executes on the CPU; GPU-resident inputs are gathered to host memory before trimming whitespace.",
30};
31
32#[runmat_macros::register_fusion_spec(
33    builtin_path = "crate::builtins::strings::transform::strtrim"
34)]
35pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
36    name: "strtrim",
37    shape: ShapeRequirements::Any,
38    constant_strategy: ConstantStrategy::InlineLiteral,
39    elementwise: None,
40    reduction: None,
41    emits_nan: false,
42    notes: "String transformation builtin; not eligible for fusion and always gathers GPU inputs.",
43};
44
45const BUILTIN_NAME: &str = "strtrim";
46const ARG_TYPE_ERROR: &str =
47    "strtrim: first argument must be a string array, character array, or cell array of character vectors";
48const CELL_ELEMENT_ERROR: &str =
49    "strtrim: cell array elements must be string scalars or character vectors";
50
51fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
52    build_runtime_error(message)
53        .with_builtin(BUILTIN_NAME)
54        .build()
55}
56
57fn map_flow(err: RuntimeError) -> RuntimeError {
58    map_control_flow_with_builtin(err, BUILTIN_NAME)
59}
60
61#[runtime_builtin(
62    name = "strtrim",
63    category = "strings/transform",
64    summary = "Remove leading and trailing whitespace from strings, character arrays, and cell arrays.",
65    keywords = "strtrim,trim,whitespace,strings,character array,text",
66    accel = "sink",
67    type_resolver(text_preserve_type),
68    builtin_path = "crate::builtins::strings::transform::strtrim"
69)]
70async fn strtrim_builtin(value: Value) -> BuiltinResult<Value> {
71    let gathered = gather_if_needed_async(&value).await.map_err(map_flow)?;
72    match gathered {
73        Value::String(text) => Ok(Value::String(trim_string(text))),
74        Value::StringArray(array) => strtrim_string_array(array),
75        Value::CharArray(array) => strtrim_char_array(array),
76        Value::Cell(cell) => strtrim_cell_array(cell).await,
77        _ => Err(runtime_error_for(ARG_TYPE_ERROR)),
78    }
79}
80
81fn strtrim_string_array(array: StringArray) -> BuiltinResult<Value> {
82    let StringArray { data, shape, .. } = array;
83    let trimmed = data.into_iter().map(trim_string).collect::<Vec<_>>();
84    let out = StringArray::new(trimmed, shape)
85        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
86    Ok(Value::StringArray(out))
87}
88
89fn strtrim_char_array(array: CharArray) -> BuiltinResult<Value> {
90    let CharArray { data, rows, cols } = array;
91    if rows == 0 {
92        return Ok(Value::CharArray(CharArray { data, rows, cols }));
93    }
94
95    let mut trimmed_rows: Vec<Vec<char>> = Vec::with_capacity(rows);
96    let mut target_cols: usize = 0;
97    for row in 0..rows {
98        let text = char_row_to_string_slice(&data, cols, row);
99        let trimmed = trim_whitespace(&text);
100        let chars: Vec<char> = trimmed.chars().collect();
101        target_cols = target_cols.max(chars.len());
102        trimmed_rows.push(chars);
103    }
104
105    let mut new_data: Vec<char> = Vec::with_capacity(rows * target_cols);
106    for mut chars in trimmed_rows {
107        if chars.len() < target_cols {
108            chars.resize(target_cols, ' ');
109        }
110        new_data.extend(chars);
111    }
112
113    CharArray::new(new_data, rows, target_cols)
114        .map(Value::CharArray)
115        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
116}
117
118async fn strtrim_cell_array(cell: CellArray) -> BuiltinResult<Value> {
119    let CellArray {
120        data, rows, cols, ..
121    } = cell;
122    let mut trimmed_values = Vec::with_capacity(rows * cols);
123    for value in &data {
124        let trimmed = strtrim_cell_element(value).await?;
125        trimmed_values.push(trimmed);
126    }
127    make_cell(trimmed_values, rows, cols)
128        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
129}
130
131async fn strtrim_cell_element(value: &Value) -> BuiltinResult<Value> {
132    match gather_if_needed_async(value).await.map_err(map_flow)? {
133        Value::String(text) => Ok(Value::String(trim_string(text))),
134        Value::StringArray(sa) if sa.data.len() == 1 => {
135            let text = sa.data.into_iter().next().unwrap();
136            Ok(Value::String(trim_string(text)))
137        }
138        Value::CharArray(ca) if ca.rows <= 1 => {
139            if ca.rows == 0 {
140                return Ok(Value::CharArray(ca));
141            }
142            let source = char_row_to_string_slice(&ca.data, ca.cols, 0);
143            let trimmed = trim_whitespace(&source);
144            let chars: Vec<char> = trimmed.chars().collect();
145            let cols = chars.len();
146            CharArray::new(chars, ca.rows, cols)
147                .map(Value::CharArray)
148                .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
149        }
150        Value::CharArray(_) => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
151        _ => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
152    }
153}
154
155fn trim_string(text: String) -> String {
156    if is_missing_string(&text) {
157        text
158    } else {
159        trim_whitespace(&text)
160    }
161}
162
163fn trim_whitespace(text: &str) -> String {
164    let trimmed = text.trim_matches(|c: char| c.is_whitespace());
165    trimmed.to_string()
166}
167
168#[cfg(test)]
169pub(crate) mod tests {
170    use super::*;
171    use runmat_builtins::{ResolveContext, Type};
172
173    fn run_strtrim(value: Value) -> BuiltinResult<Value> {
174        futures::executor::block_on(strtrim_builtin(value))
175    }
176
177    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
178    #[test]
179    fn strtrim_string_scalar_trims_whitespace() {
180        let result =
181            run_strtrim(Value::String("  RunMat  ".into())).expect("strtrim string scalar");
182        assert_eq!(result, Value::String("RunMat".into()));
183    }
184
185    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
186    #[test]
187    fn strtrim_string_array_preserves_shape() {
188        let array = StringArray::new(
189            vec![
190                " one ".into(),
191                "<missing>".into(),
192                "two".into(),
193                " three ".into(),
194            ],
195            vec![2, 2],
196        )
197        .unwrap();
198        let result = run_strtrim(Value::StringArray(array)).expect("strtrim string array");
199        match result {
200            Value::StringArray(sa) => {
201                assert_eq!(sa.shape, vec![2, 2]);
202                assert_eq!(
203                    sa.data,
204                    vec![
205                        String::from("one"),
206                        String::from("<missing>"),
207                        String::from("two"),
208                        String::from("three")
209                    ]
210                );
211            }
212            other => panic!("expected string array, got {other:?}"),
213        }
214    }
215
216    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
217    #[test]
218    fn strtrim_char_array_multiple_rows() {
219        let data: Vec<char> = "  cat  ".chars().chain(" dog   ".chars()).collect();
220        let array = CharArray::new(data, 2, 7).unwrap();
221        let result = run_strtrim(Value::CharArray(array)).expect("strtrim char array");
222        match result {
223            Value::CharArray(ca) => {
224                assert_eq!(ca.rows, 2);
225                assert_eq!(ca.cols, 3);
226                assert_eq!(ca.data, vec!['c', 'a', 't', 'd', 'o', 'g']);
227            }
228            other => panic!("expected char array, got {other:?}"),
229        }
230    }
231
232    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
233    #[test]
234    fn strtrim_char_array_all_whitespace_yields_zero_width() {
235        let array = CharArray::new("   ".chars().collect(), 1, 3).unwrap();
236        let result = run_strtrim(Value::CharArray(array)).expect("strtrim char whitespace");
237        match result {
238            Value::CharArray(ca) => {
239                assert_eq!(ca.rows, 1);
240                assert_eq!(ca.cols, 0);
241                assert!(ca.data.is_empty());
242            }
243            other => panic!("expected empty char array, got {other:?}"),
244        }
245    }
246
247    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
248    #[test]
249    fn strtrim_cell_array_mixed_content() {
250        let cell = CellArray::new(
251            vec![
252                Value::CharArray(CharArray::new_row("  GPU  ")),
253                Value::String(" Accelerate ".into()),
254            ],
255            1,
256            2,
257        )
258        .unwrap();
259        let result = run_strtrim(Value::Cell(cell)).expect("strtrim cell array");
260        match result {
261            Value::Cell(out) => {
262                let first = out.get(0, 0).unwrap();
263                let second = out.get(0, 1).unwrap();
264                assert_eq!(first, Value::CharArray(CharArray::new_row("GPU")));
265                assert_eq!(second, Value::String("Accelerate".into()));
266            }
267            other => panic!("expected cell array, got {other:?}"),
268        }
269    }
270
271    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
272    #[test]
273    fn strtrim_preserves_missing_strings() {
274        let result =
275            run_strtrim(Value::String("<missing>".into())).expect("strtrim missing string");
276        assert_eq!(result, Value::String("<missing>".into()));
277    }
278
279    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
280    #[test]
281    fn strtrim_handles_tabs_and_newlines() {
282        let input = Value::String("\tMetrics \n".into());
283        let result = run_strtrim(input).expect("strtrim tab/newline");
284        assert_eq!(result, Value::String("Metrics".into()));
285    }
286
287    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
288    #[test]
289    fn strtrim_trims_unicode_whitespace() {
290        let input = Value::String("\u{00A0}RunMat\u{2003}".into());
291        let result = run_strtrim(input).expect("strtrim unicode whitespace");
292        assert_eq!(result, Value::String("RunMat".into()));
293    }
294
295    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
296    #[test]
297    fn strtrim_char_array_zero_rows_stable() {
298        let array = CharArray::new(Vec::new(), 0, 0).unwrap();
299        let result = run_strtrim(Value::CharArray(array.clone())).expect("strtrim 0x0 char");
300        assert_eq!(result, Value::CharArray(array));
301    }
302
303    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
304    #[test]
305    fn strtrim_cell_array_accepts_string_scalar() {
306        let scalar = StringArray::new(vec![" padded ".into()], vec![1, 1]).unwrap();
307        let cell = CellArray::new(vec![Value::StringArray(scalar)], 1, 1).unwrap();
308        let trimmed = run_strtrim(Value::Cell(cell)).expect("strtrim cell string scalar");
309        match trimmed {
310            Value::Cell(out) => {
311                let value = out.get(0, 0).expect("cell element");
312                assert_eq!(value, Value::String("padded".into()));
313            }
314            other => panic!("expected cell array, got {other:?}"),
315        }
316    }
317
318    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
319    #[test]
320    fn strtrim_cell_array_rejects_non_text() {
321        let cell = CellArray::new(vec![Value::Num(5.0)], 1, 1).unwrap();
322        let err = run_strtrim(Value::Cell(cell)).expect_err("strtrim cell non-text");
323        assert!(err.to_string().contains("cell array elements"));
324    }
325
326    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
327    #[test]
328    fn strtrim_errors_on_invalid_input() {
329        let err = run_strtrim(Value::Num(1.0)).unwrap_err();
330        assert!(err.to_string().contains("strtrim"));
331    }
332
333    #[test]
334    fn strtrim_type_preserves_text() {
335        assert_eq!(
336            text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
337            Type::String
338        );
339    }
340}