1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6 CellArray, CharArray, StringArray, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::strings::common::{char_row_to_string_slice, lowercase_preserving_missing};
16use crate::builtins::strings::type_resolvers::text_preserve_type;
17use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::lower")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21 name: "lower",
22 op_kind: GpuOpKind::Custom("string-transform"),
23 supported_precisions: &[],
24 broadcast: BroadcastSemantics::None,
25 provider_hooks: &[],
26 constant_strategy: ConstantStrategy::InlineLiteral,
27 residency: ResidencyPolicy::GatherImmediately,
28 nan_mode: ReductionNaN::Include,
29 two_pass_threshold: None,
30 workgroup_size: None,
31 accepts_nan_mode: false,
32 notes:
33 "Executes on the CPU; GPU-resident inputs are gathered to host memory before conversion.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::lower")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38 name: "lower",
39 shape: ShapeRequirements::Any,
40 constant_strategy: ConstantStrategy::InlineLiteral,
41 elementwise: None,
42 reduction: None,
43 emits_nan: false,
44 notes: "String transformation builtin; not eligible for fusion and always gathers GPU inputs.",
45};
46
47const BUILTIN_NAME: &str = "lower";
48
49const LOWER_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
50 name: "out",
51 ty: BuiltinParamType::Any,
52 arity: BuiltinParamArity::Required,
53 default: None,
54 description: "Lowercased text preserving input container kind and shape.",
55}];
56
57const LOWER_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
58 name: "str",
59 ty: BuiltinParamType::Any,
60 arity: BuiltinParamArity::Required,
61 default: None,
62 description: "String/char/cell text input to transform.",
63}];
64
65const LOWER_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
66 label: "out = lower(str)",
67 inputs: &LOWER_INPUTS,
68 outputs: &LOWER_OUTPUT,
69}];
70
71const LOWER_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
72 code: "RM.LOWER.INVALID_INPUT",
73 identifier: Some("RunMat:lower:InvalidInput"),
74 when: "Input is not a string array, character array, or cell array of text scalars.",
75 message:
76 "lower: first argument must be a string array, character array, or cell array of character vectors",
77};
78
79const LOWER_ERROR_CELL_ELEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
80 code: "RM.LOWER.CELL_ELEMENT",
81 identifier: Some("RunMat:lower:CellElement"),
82 when: "Cell array contains a non-text element or non-row char array element.",
83 message: "lower: cell array elements must be string scalars or character vectors",
84};
85
86const LOWER_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
87 code: "RM.LOWER.INTERNAL",
88 identifier: Some("RunMat:lower:InternalError"),
89 when: "Internal output container construction failed.",
90 message: "lower: internal error",
91};
92
93const LOWER_ERRORS: [BuiltinErrorDescriptor; 3] = [
94 LOWER_ERROR_INVALID_INPUT,
95 LOWER_ERROR_CELL_ELEMENT,
96 LOWER_ERROR_INTERNAL,
97];
98
99pub const LOWER_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
100 signatures: &LOWER_SIGNATURES,
101 output_mode: BuiltinOutputMode::Fixed,
102 completion_policy: BuiltinCompletionPolicy::Public,
103 errors: &LOWER_ERRORS,
104};
105
106fn map_flow(err: RuntimeError) -> RuntimeError {
107 map_control_flow_with_builtin(err, BUILTIN_NAME)
108}
109
110fn lower_error_with_message(
111 message: impl Into<String>,
112 error: &'static BuiltinErrorDescriptor,
113) -> RuntimeError {
114 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
115 if let Some(identifier) = error.identifier {
116 builder = builder.with_identifier(identifier);
117 }
118 builder.build()
119}
120
121fn lower_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
122 lower_error_with_message(error.message, error)
123}
124
125#[runtime_builtin(
126 name = "lower",
127 category = "strings/transform",
128 summary = "Convert strings, character arrays, and cell arrays of character vectors to lowercase.",
129 keywords = "lower,lowercase,strings,character array,text",
130 accel = "sink",
131 type_resolver(text_preserve_type),
132 descriptor(crate::builtins::strings::transform::lower::LOWER_DESCRIPTOR),
133 builtin_path = "crate::builtins::strings::transform::lower"
134)]
135async fn lower_builtin(value: Value) -> BuiltinResult<Value> {
136 let gathered = gather_if_needed_async(&value).await.map_err(map_flow)?;
137 match gathered {
138 Value::String(text) => Ok(Value::String(lowercase_preserving_missing(text))),
139 Value::StringArray(array) => lower_string_array(array),
140 Value::CharArray(array) => lower_char_array(array),
141 Value::Cell(cell) => lower_cell_array(cell),
142 _ => Err(lower_error(&LOWER_ERROR_INVALID_INPUT)),
143 }
144}
145
146fn lower_string_array(array: StringArray) -> BuiltinResult<Value> {
147 let StringArray { data, shape, .. } = array;
148 let lowered = data
149 .into_iter()
150 .map(lowercase_preserving_missing)
151 .collect::<Vec<_>>();
152 let lowered_array = StringArray::new(lowered, shape).map_err(|e| {
153 lower_error_with_message(format!("{BUILTIN_NAME}: {e}"), &LOWER_ERROR_INTERNAL)
154 })?;
155 Ok(Value::StringArray(lowered_array))
156}
157
158fn lower_char_array(array: CharArray) -> BuiltinResult<Value> {
159 let CharArray { data, rows, cols } = array;
160 if rows == 0 || cols == 0 {
161 return Ok(Value::CharArray(CharArray { data, rows, cols }));
162 }
163
164 let mut lowered_rows = Vec::with_capacity(rows);
165 let mut target_cols = cols;
166 for row in 0..rows {
167 let text = char_row_to_string_slice(&data, cols, row).to_lowercase();
168 let len = text.chars().count();
169 target_cols = target_cols.max(len);
170 lowered_rows.push(text);
171 }
172
173 let mut lowered_data = Vec::with_capacity(rows * target_cols);
174 for row_text in lowered_rows {
175 let mut chars: Vec<char> = row_text.chars().collect();
176 if chars.len() < target_cols {
177 chars.resize(target_cols, ' ');
178 }
179 lowered_data.extend(chars.into_iter());
180 }
181
182 CharArray::new(lowered_data, rows, target_cols)
183 .map(Value::CharArray)
184 .map_err(|e| {
185 lower_error_with_message(format!("{BUILTIN_NAME}: {e}"), &LOWER_ERROR_INTERNAL)
186 })
187}
188
189fn lower_cell_array(cell: CellArray) -> BuiltinResult<Value> {
190 let CellArray {
191 data, rows, cols, ..
192 } = cell;
193 let mut lowered_values = Vec::with_capacity(rows * cols);
194 for row in 0..rows {
195 for col in 0..cols {
196 let idx = row * cols + col;
197 let lowered = lower_cell_element(&data[idx])?;
198 lowered_values.push(lowered);
199 }
200 }
201 make_cell(lowered_values, rows, cols).map_err(|e| {
202 lower_error_with_message(format!("{BUILTIN_NAME}: {e}"), &LOWER_ERROR_INTERNAL)
203 })
204}
205
206fn lower_cell_element(value: &Value) -> BuiltinResult<Value> {
207 match value {
208 Value::String(text) => Ok(Value::String(lowercase_preserving_missing(text.clone()))),
209 Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(
210 lowercase_preserving_missing(sa.data[0].clone()),
211 )),
212 Value::CharArray(ca) if ca.rows <= 1 => lower_char_array(ca.clone()),
213 Value::CharArray(_) => Err(lower_error(&LOWER_ERROR_CELL_ELEMENT)),
214 _ => Err(lower_error(&LOWER_ERROR_CELL_ELEMENT)),
215 }
216}
217
218#[cfg(test)]
219pub(crate) mod tests {
220 use super::*;
221 use runmat_builtins::{ResolveContext, Type};
222
223 fn run_lower(value: Value) -> BuiltinResult<Value> {
224 futures::executor::block_on(lower_builtin(value))
225 }
226
227 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
228 #[test]
229 fn lower_string_scalar_value() {
230 let result = run_lower(Value::String("RunMat".into())).expect("lower");
231 assert_eq!(result, Value::String("runmat".into()));
232 }
233
234 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
235 #[test]
236 fn lower_string_array_preserves_shape() {
237 let array = StringArray::new(
238 vec![
239 "GPU".into(),
240 "ACCEL".into(),
241 "<missing>".into(),
242 "MiXeD".into(),
243 ],
244 vec![2, 2],
245 )
246 .unwrap();
247 let result = run_lower(Value::StringArray(array)).expect("lower");
248 match result {
249 Value::StringArray(sa) => {
250 assert_eq!(sa.shape, vec![2, 2]);
251 assert_eq!(
252 sa.data,
253 vec![
254 String::from("gpu"),
255 String::from("accel"),
256 String::from("<missing>"),
257 String::from("mixed")
258 ]
259 );
260 }
261 other => panic!("expected string array, got {other:?}"),
262 }
263 }
264
265 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
266 #[test]
267 fn lower_char_array_multiple_rows() {
268 let data: Vec<char> = vec!['C', 'A', 'T', 'D', 'O', 'G'];
269 let array = CharArray::new(data, 2, 3).unwrap();
270 let result = run_lower(Value::CharArray(array)).expect("lower");
271 match result {
272 Value::CharArray(ca) => {
273 assert_eq!(ca.rows, 2);
274 assert_eq!(ca.cols, 3);
275 assert_eq!(ca.data, vec!['c', 'a', 't', 'd', 'o', 'g']);
276 }
277 other => panic!("expected char array, got {other:?}"),
278 }
279 }
280
281 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
282 #[test]
283 fn lower_char_vector_handles_padding() {
284 let array = CharArray::new_row("HELLO ");
285 let result = run_lower(Value::CharArray(array)).expect("lower");
286 match result {
287 Value::CharArray(ca) => {
288 assert_eq!(ca.rows, 1);
289 assert_eq!(ca.cols, 6);
290 let expected: Vec<char> = "hello ".chars().collect();
291 assert_eq!(ca.data, expected);
292 }
293 other => panic!("expected char array, got {other:?}"),
294 }
295 }
296
297 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
298 #[test]
299 fn lower_char_array_unicode_expansion_extends_width() {
300 let data: Vec<char> = vec!['İ', 'A'];
301 let array = CharArray::new(data, 1, 2).unwrap();
302 let result = run_lower(Value::CharArray(array)).expect("lower");
303 match result {
304 Value::CharArray(ca) => {
305 assert_eq!(ca.rows, 1);
306 assert_eq!(ca.cols, 3);
307 let expected: Vec<char> = vec!['i', '\u{307}', 'a'];
308 assert_eq!(ca.data, expected);
309 }
310 other => panic!("expected char array, got {other:?}"),
311 }
312 }
313
314 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
315 #[test]
316 fn lower_cell_array_mixed_content() {
317 let cell = CellArray::new(
318 vec![
319 Value::CharArray(CharArray::new_row("RUN")),
320 Value::String("Mat".into()),
321 ],
322 1,
323 2,
324 )
325 .unwrap();
326 let result = run_lower(Value::Cell(cell)).expect("lower");
327 match result {
328 Value::Cell(out) => {
329 let first = out.get(0, 0).unwrap();
330 let second = out.get(0, 1).unwrap();
331 assert_eq!(first, Value::CharArray(CharArray::new_row("run")));
332 assert_eq!(second, Value::String("mat".into()));
333 }
334 other => panic!("expected cell array, got {other:?}"),
335 }
336 }
337
338 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
339 #[test]
340 fn lower_errors_on_invalid_input() {
341 let err = run_lower(Value::Num(1.0)).unwrap_err();
342 assert_eq!(err.to_string(), LOWER_ERROR_INVALID_INPUT.message);
343 }
344
345 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
346 #[test]
347 fn lower_cell_errors_on_invalid_element() {
348 let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
349 let err = run_lower(Value::Cell(cell)).unwrap_err();
350 assert_eq!(err.to_string(), LOWER_ERROR_CELL_ELEMENT.message);
351 }
352
353 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
354 #[test]
355 fn lower_preserves_missing_string() {
356 let result = run_lower(Value::String("<missing>".into())).expect("lower");
357 assert_eq!(result, Value::String("<missing>".into()));
358 }
359
360 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
361 #[test]
362 fn lower_cell_allows_empty_char_vector() {
363 let empty_char = CharArray::new(Vec::new(), 1, 0).unwrap();
364 let cell = CellArray::new(vec![Value::CharArray(empty_char.clone())], 1, 1).unwrap();
365 let result = run_lower(Value::Cell(cell)).expect("lower");
366 match result {
367 Value::Cell(out) => {
368 let element = out.get(0, 0).unwrap();
369 assert_eq!(element, Value::CharArray(empty_char));
370 }
371 other => panic!("expected cell array, got {other:?}"),
372 }
373 }
374
375 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
376 #[test]
377 #[cfg(feature = "wgpu")]
378 fn lower_gpu_tensor_input_gathers_then_errors() {
379 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
380 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
381 );
382 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
383 let data = [1.0f64, 2.0];
384 let shape = [2usize, 1usize];
385 let handle = provider
386 .upload(&runmat_accelerate_api::HostTensorView {
387 data: &data,
388 shape: &shape,
389 })
390 .expect("upload");
391 let err = run_lower(Value::GpuTensor(handle.clone())).unwrap_err();
392 assert_eq!(err.to_string(), LOWER_ERROR_INVALID_INPUT.message);
393 provider.free(&handle).ok();
394 }
395
396 #[test]
397 fn lower_type_preserves_text() {
398 assert_eq!(
399 text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
400 Type::String
401 );
402 }
403}