1use runmat_builtins::{CellArray, CharArray, StringArray, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::spec::{
8 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9 ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::common::tensor;
12use crate::builtins::strings::common::is_missing_string;
13use crate::builtins::strings::type_resolvers::numeric_text_scalar_or_tensor_type;
14use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
15
16#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strlength")]
17pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
18 name: "strlength",
19 op_kind: GpuOpKind::Custom("string-metadata"),
20 supported_precisions: &[],
21 broadcast: BroadcastSemantics::None,
22 provider_hooks: &[],
23 constant_strategy: ConstantStrategy::InlineLiteral,
24 residency: ResidencyPolicy::GatherImmediately,
25 nan_mode: ReductionNaN::Include,
26 two_pass_threshold: None,
27 workgroup_size: None,
28 accepts_nan_mode: false,
29 notes: "Measures string lengths on the CPU; any GPU-resident inputs are gathered before evaluation.",
30};
31
32#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strlength")]
33pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
34 name: "strlength",
35 shape: ShapeRequirements::Any,
36 constant_strategy: ConstantStrategy::InlineLiteral,
37 elementwise: None,
38 reduction: None,
39 emits_nan: true,
40 notes: "Metadata-only builtin; not eligible for fusion and never emits GPU kernels.",
41};
42
43const ARG_TYPE_ERROR: &str =
44 "strlength: first argument must be a string array, character array, or cell array of character vectors";
45const CELL_ELEMENT_ERROR: &str =
46 "strlength: cell array elements must be character vectors or string scalars";
47
48fn strlength_flow(message: impl Into<String>) -> RuntimeError {
49 build_runtime_error(message)
50 .with_builtin("strlength")
51 .build()
52}
53
54fn remap_strlength_flow(err: RuntimeError) -> RuntimeError {
55 map_control_flow_with_builtin(err, "strlength")
56}
57
58#[runtime_builtin(
59 name = "strlength",
60 category = "strings/core",
61 summary = "Count characters in string arrays, character arrays, or cell arrays of character vectors.",
62 keywords = "strlength,string length,text,count,characters",
63 accel = "sink",
64 type_resolver(numeric_text_scalar_or_tensor_type),
65 builtin_path = "crate::builtins::strings::core::strlength"
66)]
67async fn strlength_builtin(value: Value) -> crate::BuiltinResult<Value> {
68 let gathered = gather_if_needed_async(&value)
69 .await
70 .map_err(remap_strlength_flow)?;
71 match gathered {
72 Value::StringArray(array) => strlength_string_array(array),
73 Value::String(text) => Ok(Value::Num(string_scalar_length(&text))),
74 Value::CharArray(array) => strlength_char_array(array),
75 Value::Cell(cell) => strlength_cell_array(cell),
76 _ => Err(strlength_flow(ARG_TYPE_ERROR)),
77 }
78}
79
80fn strlength_string_array(array: StringArray) -> BuiltinResult<Value> {
81 let StringArray { data, shape, .. } = array;
82 let mut lengths = Vec::with_capacity(data.len());
83 for text in &data {
84 lengths.push(string_scalar_length(text));
85 }
86 let tensor =
87 Tensor::new(lengths, shape).map_err(|e| strlength_flow(format!("strlength: {e}")))?;
88 Ok(tensor::tensor_into_value(tensor))
89}
90
91fn strlength_char_array(array: CharArray) -> BuiltinResult<Value> {
92 let rows = array.rows;
93 let mut lengths = Vec::with_capacity(rows);
94 for row in 0..rows {
95 let length = if array.rows <= 1 {
96 array.cols
97 } else {
98 trimmed_row_length(&array, row)
99 } as f64;
100 lengths.push(length);
101 }
102 let tensor = Tensor::new(lengths, vec![rows, 1])
103 .map_err(|e| strlength_flow(format!("strlength: {e}")))?;
104 Ok(tensor::tensor_into_value(tensor))
105}
106
107fn strlength_cell_array(cell: CellArray) -> BuiltinResult<Value> {
108 let CellArray {
109 data, rows, cols, ..
110 } = cell;
111 let mut lengths = Vec::with_capacity(rows * cols);
112 for col in 0..cols {
113 for row in 0..rows {
114 let idx = row * cols + col;
115 let value: &Value = &data[idx];
116 let length = match value {
117 Value::String(text) => string_scalar_length(text),
118 Value::StringArray(sa) if sa.data.len() == 1 => string_scalar_length(&sa.data[0]),
119 Value::CharArray(char_vec) if char_vec.rows == 1 => char_vec.cols as f64,
120 Value::CharArray(_) => return Err(strlength_flow(CELL_ELEMENT_ERROR)),
121 _ => return Err(strlength_flow(CELL_ELEMENT_ERROR)),
122 };
123 lengths.push(length);
124 }
125 }
126 let tensor = Tensor::new(lengths, vec![rows, cols])
127 .map_err(|e| strlength_flow(format!("strlength: {e}")))?;
128 Ok(tensor::tensor_into_value(tensor))
129}
130
131fn string_scalar_length(text: &str) -> f64 {
132 if is_missing_string(text) {
133 f64::NAN
134 } else {
135 text.chars().count() as f64
136 }
137}
138
139fn trimmed_row_length(array: &CharArray, row: usize) -> usize {
140 let cols = array.cols;
141 let mut end = cols;
142 while end > 0 {
143 let ch = array.data[row * cols + end - 1];
144 if ch == ' ' {
145 end -= 1;
146 } else {
147 break;
148 }
149 }
150 end
151}
152
153#[cfg(test)]
154pub(crate) mod tests {
155 use super::*;
156 use runmat_builtins::{ResolveContext, Type};
157
158 fn strlength_builtin(value: Value) -> BuiltinResult<Value> {
159 futures::executor::block_on(super::strlength_builtin(value))
160 }
161
162 fn error_message(err: crate::RuntimeError) -> String {
163 err.message().to_string()
164 }
165
166 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
167 #[test]
168 fn strlength_string_scalar() {
169 let result = strlength_builtin(Value::String("RunMat".into())).expect("strlength");
170 assert_eq!(result, Value::Num(6.0));
171 }
172
173 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
174 #[test]
175 fn strlength_string_array_with_missing() {
176 let array = StringArray::new(vec!["alpha".into(), "<missing>".into()], vec![2, 1]).unwrap();
177 let result = strlength_builtin(Value::StringArray(array)).expect("strlength");
178 match result {
179 Value::Tensor(tensor) => {
180 assert_eq!(tensor.shape, vec![2, 1]);
181 assert_eq!(tensor.data.len(), 2);
182 assert_eq!(tensor.data[0], 5.0);
183 assert!(tensor.data[1].is_nan());
184 }
185 other => panic!("expected tensor result, got {other:?}"),
186 }
187 }
188
189 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
190 #[test]
191 fn strlength_char_array_multiple_rows() {
192 let data: Vec<char> = vec!['c', 'a', 't', ' ', ' ', 'h', 'o', 'r', 's', 'e'];
193 let array = CharArray::new(data, 2, 5).unwrap();
194 let result = strlength_builtin(Value::CharArray(array)).expect("strlength");
195 match result {
196 Value::Tensor(tensor) => {
197 assert_eq!(tensor.shape, vec![2, 1]);
198 assert_eq!(tensor.data, vec![3.0, 5.0]);
199 }
200 other => panic!("expected tensor result, got {other:?}"),
201 }
202 }
203
204 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
205 #[test]
206 fn strlength_char_vector_retains_explicit_spaces() {
207 let data: Vec<char> = "hi ".chars().collect();
208 let array = CharArray::new(data, 1, 5).unwrap();
209 let result = strlength_builtin(Value::CharArray(array)).expect("strlength");
210 assert_eq!(result, Value::Num(5.0));
211 }
212
213 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
214 #[test]
215 fn strlength_cell_array_of_char_vectors() {
216 let cell = CellArray::new(
217 vec![
218 Value::CharArray(CharArray::new_row("red")),
219 Value::CharArray(CharArray::new_row("green")),
220 ],
221 1,
222 2,
223 )
224 .unwrap();
225 let result = strlength_builtin(Value::Cell(cell)).expect("strlength");
226 match result {
227 Value::Tensor(tensor) => {
228 assert_eq!(tensor.shape, vec![1, 2]);
229 assert_eq!(tensor.data, vec![3.0, 5.0]);
230 }
231 other => panic!("expected tensor result, got {other:?}"),
232 }
233 }
234
235 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
236 #[test]
237 fn strlength_cell_array_with_string_scalars() {
238 let cell = CellArray::new(
239 vec![
240 Value::String("alpha".into()),
241 Value::String("beta".into()),
242 Value::String("<missing>".into()),
243 ],
244 1,
245 3,
246 )
247 .unwrap();
248 let result = strlength_builtin(Value::Cell(cell)).expect("strlength");
249 match result {
250 Value::Tensor(tensor) => {
251 assert_eq!(tensor.shape, vec![1, 3]);
252 assert_eq!(tensor.data.len(), 3);
253 assert_eq!(tensor.data[0], 5.0);
254 assert_eq!(tensor.data[1], 4.0);
255 assert!(tensor.data[2].is_nan());
256 }
257 other => panic!("expected tensor result, got {other:?}"),
258 }
259 }
260
261 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
262 #[test]
263 fn strlength_string_array_preserves_shape() {
264 let array = StringArray::new(
265 vec!["ab".into(), "c".into(), "def".into(), "".into()],
266 vec![2, 2],
267 )
268 .unwrap();
269 let result = strlength_builtin(Value::StringArray(array)).expect("strlength");
270 match result {
271 Value::Tensor(tensor) => {
272 assert_eq!(tensor.shape, vec![2, 2]);
273 assert_eq!(tensor.data, vec![2.0, 1.0, 3.0, 0.0]);
274 }
275 other => panic!("expected tensor result, got {other:?}"),
276 }
277 }
278
279 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
280 #[test]
281 fn strlength_char_array_trims_padding() {
282 let data: Vec<char> = vec!['d', 'o', 'g', ' ', ' ', 'h', 'o', 'r', 's', 'e'];
283 let array = CharArray::new(data, 2, 5).unwrap();
284 let result = strlength_builtin(Value::CharArray(array)).expect("strlength");
285 match result {
286 Value::Tensor(tensor) => {
287 assert_eq!(tensor.shape, vec![2, 1]);
288 assert_eq!(tensor.data, vec![3.0, 5.0]);
289 }
290 other => panic!("expected tensor result, got {other:?}"),
291 }
292 }
293
294 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
295 #[test]
296 fn strlength_errors_on_invalid_input() {
297 let err = error_message(strlength_builtin(Value::Num(1.0)).unwrap_err());
298 assert_eq!(err, ARG_TYPE_ERROR);
299 }
300
301 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
302 #[test]
303 fn strlength_rejects_cell_with_invalid_element() {
304 let cell = CellArray::new(
305 vec![Value::CharArray(CharArray::new_row("ok")), Value::Num(5.0)],
306 1,
307 2,
308 )
309 .unwrap();
310 let err = error_message(strlength_builtin(Value::Cell(cell)).unwrap_err());
311 assert_eq!(err, CELL_ELEMENT_ERROR);
312 }
313
314 #[test]
315 fn strlength_type_is_numeric_text_scalar_or_tensor() {
316 assert_eq!(
317 numeric_text_scalar_or_tensor_type(&[Type::String], &ResolveContext::new(Vec::new())),
318 Type::Num
319 );
320 }
321}