1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6 CellArray, CharArray, StringArray, Tensor, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::common::tensor;
16use crate::builtins::strings::common::is_missing_string;
17use crate::builtins::strings::type_resolvers::numeric_text_scalar_or_tensor_type;
18use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strlength")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22 name: "strlength",
23 op_kind: GpuOpKind::Custom("string-metadata"),
24 supported_precisions: &[],
25 broadcast: BroadcastSemantics::None,
26 provider_hooks: &[],
27 constant_strategy: ConstantStrategy::InlineLiteral,
28 residency: ResidencyPolicy::GatherImmediately,
29 nan_mode: ReductionNaN::Include,
30 two_pass_threshold: None,
31 workgroup_size: None,
32 accepts_nan_mode: false,
33 notes: "Measures string lengths on the CPU; any GPU-resident inputs are gathered before evaluation.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strlength")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38 name: "strlength",
39 shape: ShapeRequirements::Any,
40 constant_strategy: ConstantStrategy::InlineLiteral,
41 elementwise: None,
42 reduction: None,
43 emits_nan: true,
44 notes: "Metadata-only builtin; not eligible for fusion and never emits GPU kernels.",
45};
46
47const BUILTIN_NAME: &str = "strlength";
48
49const STRLENGTH_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
50 name: "L",
51 ty: BuiltinParamType::NumericArray,
52 arity: BuiltinParamArity::Required,
53 default: None,
54 description: "Character counts for each text element.",
55}];
56
57const STRLENGTH_INPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
58 name: "str",
59 ty: BuiltinParamType::Any,
60 arity: BuiltinParamArity::Required,
61 default: None,
62 description: "String array, character array, or cell array of text scalars.",
63}];
64
65const STRLENGTH_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
66 label: "L = strlength(str)",
67 inputs: &STRLENGTH_INPUT,
68 outputs: &STRLENGTH_OUTPUT,
69}];
70
71const STRLENGTH_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
72 code: "RM.STRLENGTH.INVALID_INPUT",
73 identifier: Some("RunMat:strlength:InvalidInput"),
74 when: "Input is not a string array, character array, or cell array of text scalars.",
75 message: "strlength: first argument must be a string array, character array, or cell array of character vectors",
76};
77
78const STRLENGTH_ERROR_INVALID_CELL_ELEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
79 code: "RM.STRLENGTH.INVALID_CELL_ELEMENT",
80 identifier: Some("RunMat:strlength:InvalidCellElement"),
81 when: "A cell-array element is not a character row vector or scalar string.",
82 message: "strlength: cell array elements must be character vectors or string scalars",
83};
84
85const STRLENGTH_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
86 code: "RM.STRLENGTH.INTERNAL",
87 identifier: Some("RunMat:strlength:InternalError"),
88 when: "Internal tensor construction failed while building length results.",
89 message: "strlength: internal error",
90};
91
92const STRLENGTH_ERRORS: [BuiltinErrorDescriptor; 3] = [
93 STRLENGTH_ERROR_INVALID_INPUT,
94 STRLENGTH_ERROR_INVALID_CELL_ELEMENT,
95 STRLENGTH_ERROR_INTERNAL,
96];
97
98pub const STRLENGTH_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
99 signatures: &STRLENGTH_SIGNATURES,
100 output_mode: BuiltinOutputMode::Fixed,
101 completion_policy: BuiltinCompletionPolicy::Public,
102 errors: &STRLENGTH_ERRORS,
103};
104
105fn strlength_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
106 strlength_error_with_message(error.message, error)
107}
108
109fn strlength_error_with_message(
110 message: impl Into<String>,
111 error: &'static BuiltinErrorDescriptor,
112) -> RuntimeError {
113 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
114 if let Some(identifier) = error.identifier {
115 builder = builder.with_identifier(identifier);
116 }
117 builder.build()
118}
119
120fn remap_strlength_flow(err: RuntimeError) -> RuntimeError {
121 map_control_flow_with_builtin(err, BUILTIN_NAME)
122}
123
124#[runtime_builtin(
125 name = "strlength",
126 category = "strings/core",
127 summary = "Count characters in each element of text inputs.",
128 keywords = "strlength,string length,text,count,characters",
129 accel = "sink",
130 type_resolver(numeric_text_scalar_or_tensor_type),
131 descriptor(crate::builtins::strings::core::strlength::STRLENGTH_DESCRIPTOR),
132 builtin_path = "crate::builtins::strings::core::strlength"
133)]
134async fn strlength_builtin(value: Value) -> crate::BuiltinResult<Value> {
135 let gathered = gather_if_needed_async(&value)
136 .await
137 .map_err(remap_strlength_flow)?;
138 match gathered {
139 Value::StringArray(array) => strlength_string_array(array),
140 Value::String(text) => Ok(Value::Num(string_scalar_length(&text))),
141 Value::CharArray(array) => strlength_char_array(array),
142 Value::Cell(cell) => strlength_cell_array(cell),
143 _ => Err(strlength_error(&STRLENGTH_ERROR_INVALID_INPUT)),
144 }
145}
146
147fn strlength_string_array(array: StringArray) -> BuiltinResult<Value> {
148 let StringArray { data, shape, .. } = array;
149 let mut lengths = Vec::with_capacity(data.len());
150 for text in &data {
151 lengths.push(string_scalar_length(text));
152 }
153 let tensor =
154 Tensor::new(lengths, shape).map_err(|_| strlength_error(&STRLENGTH_ERROR_INTERNAL))?;
155 Ok(tensor::tensor_into_value(tensor))
156}
157
158fn strlength_char_array(array: CharArray) -> BuiltinResult<Value> {
159 let rows = array.rows;
160 let mut lengths = Vec::with_capacity(rows);
161 for row in 0..rows {
162 let length = if array.rows <= 1 {
163 array.cols
164 } else {
165 trimmed_row_length(&array, row)
166 } as f64;
167 lengths.push(length);
168 }
169 let tensor = Tensor::new(lengths, vec![rows, 1])
170 .map_err(|_| strlength_error(&STRLENGTH_ERROR_INTERNAL))?;
171 Ok(tensor::tensor_into_value(tensor))
172}
173
174fn strlength_cell_array(cell: CellArray) -> BuiltinResult<Value> {
175 let CellArray {
176 data, rows, cols, ..
177 } = cell;
178 let mut lengths = Vec::with_capacity(rows * cols);
179 for col in 0..cols {
180 for row in 0..rows {
181 let idx = row * cols + col;
182 let value: &Value = &data[idx];
183 let length = match value {
184 Value::String(text) => string_scalar_length(text),
185 Value::StringArray(sa) if sa.data.len() == 1 => string_scalar_length(&sa.data[0]),
186 Value::CharArray(char_vec) if char_vec.rows == 1 => char_vec.cols as f64,
187 Value::CharArray(_) => {
188 return Err(strlength_error(&STRLENGTH_ERROR_INVALID_CELL_ELEMENT));
189 }
190 _ => return Err(strlength_error(&STRLENGTH_ERROR_INVALID_CELL_ELEMENT)),
191 };
192 lengths.push(length);
193 }
194 }
195 let tensor = Tensor::new(lengths, vec![rows, cols])
196 .map_err(|_| strlength_error(&STRLENGTH_ERROR_INTERNAL))?;
197 Ok(tensor::tensor_into_value(tensor))
198}
199
200fn string_scalar_length(text: &str) -> f64 {
201 if is_missing_string(text) {
202 f64::NAN
203 } else {
204 text.chars().count() as f64
205 }
206}
207
208fn trimmed_row_length(array: &CharArray, row: usize) -> usize {
209 let cols = array.cols;
210 let mut end = cols;
211 while end > 0 {
212 let ch = array.data[row * cols + end - 1];
213 if ch == ' ' {
214 end -= 1;
215 } else {
216 break;
217 }
218 }
219 end
220}
221
222#[cfg(test)]
223pub(crate) mod tests {
224 use super::*;
225 use runmat_builtins::{ResolveContext, Type};
226
227 fn strlength_builtin(value: Value) -> BuiltinResult<Value> {
228 futures::executor::block_on(super::strlength_builtin(value))
229 }
230
231 fn error_message(err: crate::RuntimeError) -> String {
232 err.message().to_string()
233 }
234
235 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
236 #[test]
237 fn strlength_string_scalar() {
238 let result = strlength_builtin(Value::String("RunMat".into())).expect("strlength");
239 assert_eq!(result, Value::Num(6.0));
240 }
241
242 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
243 #[test]
244 fn strlength_string_array_with_missing() {
245 let array = StringArray::new(vec!["alpha".into(), "<missing>".into()], vec![2, 1]).unwrap();
246 let result = strlength_builtin(Value::StringArray(array)).expect("strlength");
247 match result {
248 Value::Tensor(tensor) => {
249 assert_eq!(tensor.shape, vec![2, 1]);
250 assert_eq!(tensor.data.len(), 2);
251 assert_eq!(tensor.data[0], 5.0);
252 assert!(tensor.data[1].is_nan());
253 }
254 other => panic!("expected tensor result, got {other:?}"),
255 }
256 }
257
258 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
259 #[test]
260 fn strlength_char_array_multiple_rows() {
261 let data: Vec<char> = vec!['c', 'a', 't', ' ', ' ', 'h', 'o', 'r', 's', 'e'];
262 let array = CharArray::new(data, 2, 5).unwrap();
263 let result = strlength_builtin(Value::CharArray(array)).expect("strlength");
264 match result {
265 Value::Tensor(tensor) => {
266 assert_eq!(tensor.shape, vec![2, 1]);
267 assert_eq!(tensor.data, vec![3.0, 5.0]);
268 }
269 other => panic!("expected tensor result, got {other:?}"),
270 }
271 }
272
273 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
274 #[test]
275 fn strlength_char_vector_retains_explicit_spaces() {
276 let data: Vec<char> = "hi ".chars().collect();
277 let array = CharArray::new(data, 1, 5).unwrap();
278 let result = strlength_builtin(Value::CharArray(array)).expect("strlength");
279 assert_eq!(result, Value::Num(5.0));
280 }
281
282 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
283 #[test]
284 fn strlength_cell_array_of_char_vectors() {
285 let cell = CellArray::new(
286 vec![
287 Value::CharArray(CharArray::new_row("red")),
288 Value::CharArray(CharArray::new_row("green")),
289 ],
290 1,
291 2,
292 )
293 .unwrap();
294 let result = strlength_builtin(Value::Cell(cell)).expect("strlength");
295 match result {
296 Value::Tensor(tensor) => {
297 assert_eq!(tensor.shape, vec![1, 2]);
298 assert_eq!(tensor.data, vec![3.0, 5.0]);
299 }
300 other => panic!("expected tensor result, got {other:?}"),
301 }
302 }
303
304 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
305 #[test]
306 fn strlength_cell_array_with_string_scalars() {
307 let cell = CellArray::new(
308 vec![
309 Value::String("alpha".into()),
310 Value::String("beta".into()),
311 Value::String("<missing>".into()),
312 ],
313 1,
314 3,
315 )
316 .unwrap();
317 let result = strlength_builtin(Value::Cell(cell)).expect("strlength");
318 match result {
319 Value::Tensor(tensor) => {
320 assert_eq!(tensor.shape, vec![1, 3]);
321 assert_eq!(tensor.data.len(), 3);
322 assert_eq!(tensor.data[0], 5.0);
323 assert_eq!(tensor.data[1], 4.0);
324 assert!(tensor.data[2].is_nan());
325 }
326 other => panic!("expected tensor result, got {other:?}"),
327 }
328 }
329
330 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
331 #[test]
332 fn strlength_string_array_preserves_shape() {
333 let array = StringArray::new(
334 vec!["ab".into(), "c".into(), "def".into(), "".into()],
335 vec![2, 2],
336 )
337 .unwrap();
338 let result = strlength_builtin(Value::StringArray(array)).expect("strlength");
339 match result {
340 Value::Tensor(tensor) => {
341 assert_eq!(tensor.shape, vec![2, 2]);
342 assert_eq!(tensor.data, vec![2.0, 1.0, 3.0, 0.0]);
343 }
344 other => panic!("expected tensor result, got {other:?}"),
345 }
346 }
347
348 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
349 #[test]
350 fn strlength_char_array_trims_padding() {
351 let data: Vec<char> = vec!['d', 'o', 'g', ' ', ' ', 'h', 'o', 'r', 's', 'e'];
352 let array = CharArray::new(data, 2, 5).unwrap();
353 let result = strlength_builtin(Value::CharArray(array)).expect("strlength");
354 match result {
355 Value::Tensor(tensor) => {
356 assert_eq!(tensor.shape, vec![2, 1]);
357 assert_eq!(tensor.data, vec![3.0, 5.0]);
358 }
359 other => panic!("expected tensor result, got {other:?}"),
360 }
361 }
362
363 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
364 #[test]
365 fn strlength_errors_on_invalid_input() {
366 let err = error_message(strlength_builtin(Value::Num(1.0)).unwrap_err());
367 assert_eq!(err, STRLENGTH_ERROR_INVALID_INPUT.message);
368 }
369
370 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
371 #[test]
372 fn strlength_rejects_cell_with_invalid_element() {
373 let cell = CellArray::new(
374 vec![Value::CharArray(CharArray::new_row("ok")), Value::Num(5.0)],
375 1,
376 2,
377 )
378 .unwrap();
379 let err = error_message(strlength_builtin(Value::Cell(cell)).unwrap_err());
380 assert_eq!(err, STRLENGTH_ERROR_INVALID_CELL_ELEMENT.message);
381 }
382
383 #[test]
384 fn strlength_type_is_numeric_text_scalar_or_tensor() {
385 assert_eq!(
386 numeric_text_scalar_or_tensor_type(&[Type::String], &ResolveContext::new(Vec::new())),
387 Type::Num
388 );
389 }
390}