Skip to main content

runmat_runtime/builtins/strings/core/
strlength.rs

1//! MATLAB-compatible `strlength` builtin for RunMat.
2
3use runmat_builtins::{CellArray, CharArray, StringArray, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::spec::{
8    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9    ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::common::tensor;
12use crate::builtins::strings::common::is_missing_string;
13use crate::builtins::strings::type_resolvers::numeric_text_scalar_or_tensor_type;
14use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
15
16#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strlength")]
17pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
18    name: "strlength",
19    op_kind: GpuOpKind::Custom("string-metadata"),
20    supported_precisions: &[],
21    broadcast: BroadcastSemantics::None,
22    provider_hooks: &[],
23    constant_strategy: ConstantStrategy::InlineLiteral,
24    residency: ResidencyPolicy::GatherImmediately,
25    nan_mode: ReductionNaN::Include,
26    two_pass_threshold: None,
27    workgroup_size: None,
28    accepts_nan_mode: false,
29    notes: "Measures string lengths on the CPU; any GPU-resident inputs are gathered before evaluation.",
30};
31
32#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strlength")]
33pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
34    name: "strlength",
35    shape: ShapeRequirements::Any,
36    constant_strategy: ConstantStrategy::InlineLiteral,
37    elementwise: None,
38    reduction: None,
39    emits_nan: true,
40    notes: "Metadata-only builtin; not eligible for fusion and never emits GPU kernels.",
41};
42
43const ARG_TYPE_ERROR: &str =
44    "strlength: first argument must be a string array, character array, or cell array of character vectors";
45const CELL_ELEMENT_ERROR: &str =
46    "strlength: cell array elements must be character vectors or string scalars";
47
48fn strlength_flow(message: impl Into<String>) -> RuntimeError {
49    build_runtime_error(message)
50        .with_builtin("strlength")
51        .build()
52}
53
54fn remap_strlength_flow(err: RuntimeError) -> RuntimeError {
55    map_control_flow_with_builtin(err, "strlength")
56}
57
58#[runtime_builtin(
59    name = "strlength",
60    category = "strings/core",
61    summary = "Count characters in string arrays, character arrays, or cell arrays of character vectors.",
62    keywords = "strlength,string length,text,count,characters",
63    accel = "sink",
64    type_resolver(numeric_text_scalar_or_tensor_type),
65    builtin_path = "crate::builtins::strings::core::strlength"
66)]
67async fn strlength_builtin(value: Value) -> crate::BuiltinResult<Value> {
68    let gathered = gather_if_needed_async(&value)
69        .await
70        .map_err(remap_strlength_flow)?;
71    match gathered {
72        Value::StringArray(array) => strlength_string_array(array),
73        Value::String(text) => Ok(Value::Num(string_scalar_length(&text))),
74        Value::CharArray(array) => strlength_char_array(array),
75        Value::Cell(cell) => strlength_cell_array(cell),
76        _ => Err(strlength_flow(ARG_TYPE_ERROR)),
77    }
78}
79
80fn strlength_string_array(array: StringArray) -> BuiltinResult<Value> {
81    let StringArray { data, shape, .. } = array;
82    let mut lengths = Vec::with_capacity(data.len());
83    for text in &data {
84        lengths.push(string_scalar_length(text));
85    }
86    let tensor =
87        Tensor::new(lengths, shape).map_err(|e| strlength_flow(format!("strlength: {e}")))?;
88    Ok(tensor::tensor_into_value(tensor))
89}
90
91fn strlength_char_array(array: CharArray) -> BuiltinResult<Value> {
92    let rows = array.rows;
93    let mut lengths = Vec::with_capacity(rows);
94    for row in 0..rows {
95        let length = if array.rows <= 1 {
96            array.cols
97        } else {
98            trimmed_row_length(&array, row)
99        } as f64;
100        lengths.push(length);
101    }
102    let tensor = Tensor::new(lengths, vec![rows, 1])
103        .map_err(|e| strlength_flow(format!("strlength: {e}")))?;
104    Ok(tensor::tensor_into_value(tensor))
105}
106
107fn strlength_cell_array(cell: CellArray) -> BuiltinResult<Value> {
108    let CellArray {
109        data, rows, cols, ..
110    } = cell;
111    let mut lengths = Vec::with_capacity(rows * cols);
112    for col in 0..cols {
113        for row in 0..rows {
114            let idx = row * cols + col;
115            let value: &Value = &data[idx];
116            let length = match value {
117                Value::String(text) => string_scalar_length(text),
118                Value::StringArray(sa) if sa.data.len() == 1 => string_scalar_length(&sa.data[0]),
119                Value::CharArray(char_vec) if char_vec.rows == 1 => char_vec.cols as f64,
120                Value::CharArray(_) => return Err(strlength_flow(CELL_ELEMENT_ERROR)),
121                _ => return Err(strlength_flow(CELL_ELEMENT_ERROR)),
122            };
123            lengths.push(length);
124        }
125    }
126    let tensor = Tensor::new(lengths, vec![rows, cols])
127        .map_err(|e| strlength_flow(format!("strlength: {e}")))?;
128    Ok(tensor::tensor_into_value(tensor))
129}
130
131fn string_scalar_length(text: &str) -> f64 {
132    if is_missing_string(text) {
133        f64::NAN
134    } else {
135        text.chars().count() as f64
136    }
137}
138
139fn trimmed_row_length(array: &CharArray, row: usize) -> usize {
140    let cols = array.cols;
141    let mut end = cols;
142    while end > 0 {
143        let ch = array.data[row * cols + end - 1];
144        if ch == ' ' {
145            end -= 1;
146        } else {
147            break;
148        }
149    }
150    end
151}
152
153#[cfg(test)]
154pub(crate) mod tests {
155    use super::*;
156    use runmat_builtins::{ResolveContext, Type};
157
158    fn strlength_builtin(value: Value) -> BuiltinResult<Value> {
159        futures::executor::block_on(super::strlength_builtin(value))
160    }
161
162    fn error_message(err: crate::RuntimeError) -> String {
163        err.message().to_string()
164    }
165
166    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
167    #[test]
168    fn strlength_string_scalar() {
169        let result = strlength_builtin(Value::String("RunMat".into())).expect("strlength");
170        assert_eq!(result, Value::Num(6.0));
171    }
172
173    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
174    #[test]
175    fn strlength_string_array_with_missing() {
176        let array = StringArray::new(vec!["alpha".into(), "<missing>".into()], vec![2, 1]).unwrap();
177        let result = strlength_builtin(Value::StringArray(array)).expect("strlength");
178        match result {
179            Value::Tensor(tensor) => {
180                assert_eq!(tensor.shape, vec![2, 1]);
181                assert_eq!(tensor.data.len(), 2);
182                assert_eq!(tensor.data[0], 5.0);
183                assert!(tensor.data[1].is_nan());
184            }
185            other => panic!("expected tensor result, got {other:?}"),
186        }
187    }
188
189    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
190    #[test]
191    fn strlength_char_array_multiple_rows() {
192        let data: Vec<char> = vec!['c', 'a', 't', ' ', ' ', 'h', 'o', 'r', 's', 'e'];
193        let array = CharArray::new(data, 2, 5).unwrap();
194        let result = strlength_builtin(Value::CharArray(array)).expect("strlength");
195        match result {
196            Value::Tensor(tensor) => {
197                assert_eq!(tensor.shape, vec![2, 1]);
198                assert_eq!(tensor.data, vec![3.0, 5.0]);
199            }
200            other => panic!("expected tensor result, got {other:?}"),
201        }
202    }
203
204    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
205    #[test]
206    fn strlength_char_vector_retains_explicit_spaces() {
207        let data: Vec<char> = "hi   ".chars().collect();
208        let array = CharArray::new(data, 1, 5).unwrap();
209        let result = strlength_builtin(Value::CharArray(array)).expect("strlength");
210        assert_eq!(result, Value::Num(5.0));
211    }
212
213    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
214    #[test]
215    fn strlength_cell_array_of_char_vectors() {
216        let cell = CellArray::new(
217            vec![
218                Value::CharArray(CharArray::new_row("red")),
219                Value::CharArray(CharArray::new_row("green")),
220            ],
221            1,
222            2,
223        )
224        .unwrap();
225        let result = strlength_builtin(Value::Cell(cell)).expect("strlength");
226        match result {
227            Value::Tensor(tensor) => {
228                assert_eq!(tensor.shape, vec![1, 2]);
229                assert_eq!(tensor.data, vec![3.0, 5.0]);
230            }
231            other => panic!("expected tensor result, got {other:?}"),
232        }
233    }
234
235    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
236    #[test]
237    fn strlength_cell_array_with_string_scalars() {
238        let cell = CellArray::new(
239            vec![
240                Value::String("alpha".into()),
241                Value::String("beta".into()),
242                Value::String("<missing>".into()),
243            ],
244            1,
245            3,
246        )
247        .unwrap();
248        let result = strlength_builtin(Value::Cell(cell)).expect("strlength");
249        match result {
250            Value::Tensor(tensor) => {
251                assert_eq!(tensor.shape, vec![1, 3]);
252                assert_eq!(tensor.data.len(), 3);
253                assert_eq!(tensor.data[0], 5.0);
254                assert_eq!(tensor.data[1], 4.0);
255                assert!(tensor.data[2].is_nan());
256            }
257            other => panic!("expected tensor result, got {other:?}"),
258        }
259    }
260
261    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
262    #[test]
263    fn strlength_string_array_preserves_shape() {
264        let array = StringArray::new(
265            vec!["ab".into(), "c".into(), "def".into(), "".into()],
266            vec![2, 2],
267        )
268        .unwrap();
269        let result = strlength_builtin(Value::StringArray(array)).expect("strlength");
270        match result {
271            Value::Tensor(tensor) => {
272                assert_eq!(tensor.shape, vec![2, 2]);
273                assert_eq!(tensor.data, vec![2.0, 1.0, 3.0, 0.0]);
274            }
275            other => panic!("expected tensor result, got {other:?}"),
276        }
277    }
278
279    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
280    #[test]
281    fn strlength_char_array_trims_padding() {
282        let data: Vec<char> = vec!['d', 'o', 'g', ' ', ' ', 'h', 'o', 'r', 's', 'e'];
283        let array = CharArray::new(data, 2, 5).unwrap();
284        let result = strlength_builtin(Value::CharArray(array)).expect("strlength");
285        match result {
286            Value::Tensor(tensor) => {
287                assert_eq!(tensor.shape, vec![2, 1]);
288                assert_eq!(tensor.data, vec![3.0, 5.0]);
289            }
290            other => panic!("expected tensor result, got {other:?}"),
291        }
292    }
293
294    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
295    #[test]
296    fn strlength_errors_on_invalid_input() {
297        let err = error_message(strlength_builtin(Value::Num(1.0)).unwrap_err());
298        assert_eq!(err, ARG_TYPE_ERROR);
299    }
300
301    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
302    #[test]
303    fn strlength_rejects_cell_with_invalid_element() {
304        let cell = CellArray::new(
305            vec![Value::CharArray(CharArray::new_row("ok")), Value::Num(5.0)],
306            1,
307            2,
308        )
309        .unwrap();
310        let err = error_message(strlength_builtin(Value::Cell(cell)).unwrap_err());
311        assert_eq!(err, CELL_ELEMENT_ERROR);
312    }
313
314    #[test]
315    fn strlength_type_is_numeric_text_scalar_or_tensor() {
316        assert_eq!(
317            numeric_text_scalar_or_tensor_type(&[Type::String], &ResolveContext::new(Vec::new())),
318            Type::Num
319        );
320    }
321}