Skip to main content

runmat_runtime/builtins/strings/transform/
strtrim.rs

1//! MATLAB-compatible `strtrim` builtin with GPU-aware semantics for RunMat.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6    CellArray, CharArray, StringArray, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
16use crate::builtins::strings::type_resolvers::text_preserve_type;
17use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::strtrim")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21    name: "strtrim",
22    op_kind: GpuOpKind::Custom("string-transform"),
23    supported_precisions: &[],
24    broadcast: BroadcastSemantics::None,
25    provider_hooks: &[],
26    constant_strategy: ConstantStrategy::InlineLiteral,
27    residency: ResidencyPolicy::GatherImmediately,
28    nan_mode: ReductionNaN::Include,
29    two_pass_threshold: None,
30    workgroup_size: None,
31    accepts_nan_mode: false,
32    notes:
33        "Executes on the CPU; GPU-resident inputs are gathered to host memory before trimming whitespace.",
34};
35
36#[runmat_macros::register_fusion_spec(
37    builtin_path = "crate::builtins::strings::transform::strtrim"
38)]
39pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
40    name: "strtrim",
41    shape: ShapeRequirements::Any,
42    constant_strategy: ConstantStrategy::InlineLiteral,
43    elementwise: None,
44    reduction: None,
45    emits_nan: false,
46    notes: "String transformation builtin; not eligible for fusion and always gathers GPU inputs.",
47};
48
49const BUILTIN_NAME: &str = "strtrim";
50
51const STRTRIM_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
52    name: "out",
53    ty: BuiltinParamType::Any,
54    arity: BuiltinParamArity::Required,
55    default: None,
56    description: "Trimmed text preserving input container kind and shape.",
57}];
58
59const STRTRIM_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
60    name: "str",
61    ty: BuiltinParamType::Any,
62    arity: BuiltinParamArity::Required,
63    default: None,
64    description: "String/char/cell text input to trim.",
65}];
66
67const STRTRIM_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
68    label: "out = strtrim(str)",
69    inputs: &STRTRIM_INPUTS,
70    outputs: &STRTRIM_OUTPUT,
71}];
72
73const STRTRIM_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
74    code: "RM.STRTRIM.INVALID_INPUT",
75    identifier: Some("RunMat:strtrim:InvalidInput"),
76    when: "Input is not a string array, character array, or cell array of text scalars.",
77    message:
78        "strtrim: first argument must be a string array, character array, or cell array of character vectors",
79};
80
81const STRTRIM_ERROR_CELL_ELEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
82    code: "RM.STRTRIM.CELL_ELEMENT",
83    identifier: Some("RunMat:strtrim:CellElement"),
84    when: "Cell array contains a non-text element or non-row char array element.",
85    message: "strtrim: cell array elements must be string scalars or character vectors",
86};
87
88const STRTRIM_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
89    code: "RM.STRTRIM.INTERNAL",
90    identifier: Some("RunMat:strtrim:InternalError"),
91    when: "Internal output container construction failed.",
92    message: "strtrim: internal error",
93};
94
95const STRTRIM_ERRORS: [BuiltinErrorDescriptor; 3] = [
96    STRTRIM_ERROR_INVALID_INPUT,
97    STRTRIM_ERROR_CELL_ELEMENT,
98    STRTRIM_ERROR_INTERNAL,
99];
100
101pub const STRTRIM_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
102    signatures: &STRTRIM_SIGNATURES,
103    output_mode: BuiltinOutputMode::Fixed,
104    completion_policy: BuiltinCompletionPolicy::Public,
105    errors: &STRTRIM_ERRORS,
106};
107
108fn map_flow(err: RuntimeError) -> RuntimeError {
109    map_control_flow_with_builtin(err, BUILTIN_NAME)
110}
111
112fn strtrim_error_with_message(
113    message: impl Into<String>,
114    error: &'static BuiltinErrorDescriptor,
115) -> RuntimeError {
116    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
117    if let Some(identifier) = error.identifier {
118        builder = builder.with_identifier(identifier);
119    }
120    builder.build()
121}
122
123fn strtrim_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
124    strtrim_error_with_message(error.message, error)
125}
126
127#[runtime_builtin(
128    name = "strtrim",
129    category = "strings/transform",
130    summary = "Remove leading and trailing whitespace from text inputs.",
131    keywords = "strtrim,trim,whitespace,strings,character array,text",
132    accel = "sink",
133    type_resolver(text_preserve_type),
134    descriptor(crate::builtins::strings::transform::strtrim::STRTRIM_DESCRIPTOR),
135    builtin_path = "crate::builtins::strings::transform::strtrim"
136)]
137async fn strtrim_builtin(value: Value) -> BuiltinResult<Value> {
138    let gathered = gather_if_needed_async(&value).await.map_err(map_flow)?;
139    match gathered {
140        Value::String(text) => Ok(Value::String(trim_string(text))),
141        Value::StringArray(array) => strtrim_string_array(array),
142        Value::CharArray(array) => strtrim_char_array(array),
143        Value::Cell(cell) => strtrim_cell_array(cell).await,
144        _ => Err(strtrim_error(&STRTRIM_ERROR_INVALID_INPUT)),
145    }
146}
147
148fn strtrim_string_array(array: StringArray) -> BuiltinResult<Value> {
149    let StringArray { data, shape, .. } = array;
150    let trimmed = data.into_iter().map(trim_string).collect::<Vec<_>>();
151    let out = StringArray::new(trimmed, shape).map_err(|e| {
152        strtrim_error_with_message(format!("{BUILTIN_NAME}: {e}"), &STRTRIM_ERROR_INTERNAL)
153    })?;
154    Ok(Value::StringArray(out))
155}
156
157fn strtrim_char_array(array: CharArray) -> BuiltinResult<Value> {
158    let CharArray { data, rows, cols } = array;
159    if rows == 0 {
160        return Ok(Value::CharArray(CharArray { data, rows, cols }));
161    }
162
163    let mut trimmed_rows: Vec<Vec<char>> = Vec::with_capacity(rows);
164    let mut target_cols: usize = 0;
165    for row in 0..rows {
166        let text = char_row_to_string_slice(&data, cols, row);
167        let trimmed = trim_whitespace(&text);
168        let chars: Vec<char> = trimmed.chars().collect();
169        target_cols = target_cols.max(chars.len());
170        trimmed_rows.push(chars);
171    }
172
173    let mut new_data: Vec<char> = Vec::with_capacity(rows * target_cols);
174    for mut chars in trimmed_rows {
175        if chars.len() < target_cols {
176            chars.resize(target_cols, ' ');
177        }
178        new_data.extend(chars);
179    }
180
181    CharArray::new(new_data, rows, target_cols)
182        .map(Value::CharArray)
183        .map_err(|e| {
184            strtrim_error_with_message(format!("{BUILTIN_NAME}: {e}"), &STRTRIM_ERROR_INTERNAL)
185        })
186}
187
188async fn strtrim_cell_array(cell: CellArray) -> BuiltinResult<Value> {
189    let CellArray {
190        data, rows, cols, ..
191    } = cell;
192    let mut trimmed_values = Vec::with_capacity(rows * cols);
193    for value in &data {
194        let trimmed = strtrim_cell_element(value).await?;
195        trimmed_values.push(trimmed);
196    }
197    make_cell(trimmed_values, rows, cols).map_err(|e| {
198        strtrim_error_with_message(format!("{BUILTIN_NAME}: {e}"), &STRTRIM_ERROR_INTERNAL)
199    })
200}
201
202async fn strtrim_cell_element(value: &Value) -> BuiltinResult<Value> {
203    match gather_if_needed_async(value).await.map_err(map_flow)? {
204        Value::String(text) => Ok(Value::String(trim_string(text))),
205        Value::StringArray(sa) if sa.data.len() == 1 => {
206            let text = sa.data.into_iter().next().unwrap();
207            Ok(Value::String(trim_string(text)))
208        }
209        Value::CharArray(ca) if ca.rows <= 1 => {
210            if ca.rows == 0 {
211                return Ok(Value::CharArray(ca));
212            }
213            let source = char_row_to_string_slice(&ca.data, ca.cols, 0);
214            let trimmed = trim_whitespace(&source);
215            let chars: Vec<char> = trimmed.chars().collect();
216            let cols = chars.len();
217            CharArray::new(chars, ca.rows, cols)
218                .map(Value::CharArray)
219                .map_err(|e| {
220                    strtrim_error_with_message(
221                        format!("{BUILTIN_NAME}: {e}"),
222                        &STRTRIM_ERROR_INTERNAL,
223                    )
224                })
225        }
226        Value::CharArray(_) => Err(strtrim_error(&STRTRIM_ERROR_CELL_ELEMENT)),
227        _ => Err(strtrim_error(&STRTRIM_ERROR_CELL_ELEMENT)),
228    }
229}
230
231fn trim_string(text: String) -> String {
232    if is_missing_string(&text) {
233        text
234    } else {
235        trim_whitespace(&text)
236    }
237}
238
239fn trim_whitespace(text: &str) -> String {
240    let trimmed = text.trim_matches(|c: char| c.is_whitespace());
241    trimmed.to_string()
242}
243
244#[cfg(test)]
245pub(crate) mod tests {
246    use super::*;
247    use runmat_builtins::{ResolveContext, Type};
248
249    fn run_strtrim(value: Value) -> BuiltinResult<Value> {
250        futures::executor::block_on(strtrim_builtin(value))
251    }
252
253    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
254    #[test]
255    fn strtrim_string_scalar_trims_whitespace() {
256        let result =
257            run_strtrim(Value::String("  RunMat  ".into())).expect("strtrim string scalar");
258        assert_eq!(result, Value::String("RunMat".into()));
259    }
260
261    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
262    #[test]
263    fn strtrim_string_array_preserves_shape() {
264        let array = StringArray::new(
265            vec![
266                " one ".into(),
267                "<missing>".into(),
268                "two".into(),
269                " three ".into(),
270            ],
271            vec![2, 2],
272        )
273        .unwrap();
274        let result = run_strtrim(Value::StringArray(array)).expect("strtrim string array");
275        match result {
276            Value::StringArray(sa) => {
277                assert_eq!(sa.shape, vec![2, 2]);
278                assert_eq!(
279                    sa.data,
280                    vec![
281                        String::from("one"),
282                        String::from("<missing>"),
283                        String::from("two"),
284                        String::from("three")
285                    ]
286                );
287            }
288            other => panic!("expected string array, got {other:?}"),
289        }
290    }
291
292    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
293    #[test]
294    fn strtrim_char_array_multiple_rows() {
295        let data: Vec<char> = "  cat  ".chars().chain(" dog   ".chars()).collect();
296        let array = CharArray::new(data, 2, 7).unwrap();
297        let result = run_strtrim(Value::CharArray(array)).expect("strtrim char array");
298        match result {
299            Value::CharArray(ca) => {
300                assert_eq!(ca.rows, 2);
301                assert_eq!(ca.cols, 3);
302                assert_eq!(ca.data, vec!['c', 'a', 't', 'd', 'o', 'g']);
303            }
304            other => panic!("expected char array, got {other:?}"),
305        }
306    }
307
308    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
309    #[test]
310    fn strtrim_char_array_all_whitespace_yields_zero_width() {
311        let array = CharArray::new("   ".chars().collect(), 1, 3).unwrap();
312        let result = run_strtrim(Value::CharArray(array)).expect("strtrim char whitespace");
313        match result {
314            Value::CharArray(ca) => {
315                assert_eq!(ca.rows, 1);
316                assert_eq!(ca.cols, 0);
317                assert!(ca.data.is_empty());
318            }
319            other => panic!("expected empty char array, got {other:?}"),
320        }
321    }
322
323    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
324    #[test]
325    fn strtrim_cell_array_mixed_content() {
326        let cell = CellArray::new(
327            vec![
328                Value::CharArray(CharArray::new_row("  GPU  ")),
329                Value::String(" Accelerate ".into()),
330            ],
331            1,
332            2,
333        )
334        .unwrap();
335        let result = run_strtrim(Value::Cell(cell)).expect("strtrim cell array");
336        match result {
337            Value::Cell(out) => {
338                let first = out.get(0, 0).unwrap();
339                let second = out.get(0, 1).unwrap();
340                assert_eq!(first, Value::CharArray(CharArray::new_row("GPU")));
341                assert_eq!(second, Value::String("Accelerate".into()));
342            }
343            other => panic!("expected cell array, got {other:?}"),
344        }
345    }
346
347    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
348    #[test]
349    fn strtrim_preserves_missing_strings() {
350        let result =
351            run_strtrim(Value::String("<missing>".into())).expect("strtrim missing string");
352        assert_eq!(result, Value::String("<missing>".into()));
353    }
354
355    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
356    #[test]
357    fn strtrim_handles_tabs_and_newlines() {
358        let input = Value::String("\tMetrics \n".into());
359        let result = run_strtrim(input).expect("strtrim tab/newline");
360        assert_eq!(result, Value::String("Metrics".into()));
361    }
362
363    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
364    #[test]
365    fn strtrim_trims_unicode_whitespace() {
366        let input = Value::String("\u{00A0}RunMat\u{2003}".into());
367        let result = run_strtrim(input).expect("strtrim unicode whitespace");
368        assert_eq!(result, Value::String("RunMat".into()));
369    }
370
371    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
372    #[test]
373    fn strtrim_char_array_zero_rows_stable() {
374        let array = CharArray::new(Vec::new(), 0, 0).unwrap();
375        let result = run_strtrim(Value::CharArray(array.clone())).expect("strtrim 0x0 char");
376        assert_eq!(result, Value::CharArray(array));
377    }
378
379    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
380    #[test]
381    fn strtrim_cell_array_accepts_string_scalar() {
382        let scalar = StringArray::new(vec![" padded ".into()], vec![1, 1]).unwrap();
383        let cell = CellArray::new(vec![Value::StringArray(scalar)], 1, 1).unwrap();
384        let trimmed = run_strtrim(Value::Cell(cell)).expect("strtrim cell string scalar");
385        match trimmed {
386            Value::Cell(out) => {
387                let value = out.get(0, 0).expect("cell element");
388                assert_eq!(value, Value::String("padded".into()));
389            }
390            other => panic!("expected cell array, got {other:?}"),
391        }
392    }
393
394    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
395    #[test]
396    fn strtrim_cell_array_rejects_non_text() {
397        let cell = CellArray::new(vec![Value::Num(5.0)], 1, 1).unwrap();
398        let err = run_strtrim(Value::Cell(cell)).expect_err("strtrim cell non-text");
399        assert!(err.to_string().contains("cell array elements"));
400    }
401
402    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
403    #[test]
404    fn strtrim_errors_on_invalid_input() {
405        let err = run_strtrim(Value::Num(1.0)).unwrap_err();
406        assert!(err.to_string().contains("strtrim"));
407    }
408
409    #[test]
410    fn strtrim_type_preserves_text() {
411        assert_eq!(
412            text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
413            Type::String
414        );
415    }
416}