1use runmat_builtins::{CellArray, CharArray, StringArray, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::spec::{
8 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9 ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::strings::common::{char_row_to_string_slice, lowercase_preserving_missing};
12use crate::builtins::strings::type_resolvers::text_preserve_type;
13use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::lower")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17 name: "lower",
18 op_kind: GpuOpKind::Custom("string-transform"),
19 supported_precisions: &[],
20 broadcast: BroadcastSemantics::None,
21 provider_hooks: &[],
22 constant_strategy: ConstantStrategy::InlineLiteral,
23 residency: ResidencyPolicy::GatherImmediately,
24 nan_mode: ReductionNaN::Include,
25 two_pass_threshold: None,
26 workgroup_size: None,
27 accepts_nan_mode: false,
28 notes:
29 "Executes on the CPU; GPU-resident inputs are gathered to host memory before conversion.",
30};
31
32#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::lower")]
33pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
34 name: "lower",
35 shape: ShapeRequirements::Any,
36 constant_strategy: ConstantStrategy::InlineLiteral,
37 elementwise: None,
38 reduction: None,
39 emits_nan: false,
40 notes: "String transformation builtin; not eligible for fusion and always gathers GPU inputs.",
41};
42
43const BUILTIN_NAME: &str = "lower";
44const ARG_TYPE_ERROR: &str =
45 "lower: first argument must be a string array, character array, or cell array of character vectors";
46const CELL_ELEMENT_ERROR: &str =
47 "lower: cell array elements must be string scalars or character vectors";
48
49fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
50 build_runtime_error(message)
51 .with_builtin(BUILTIN_NAME)
52 .build()
53}
54
55fn map_flow(err: RuntimeError) -> RuntimeError {
56 map_control_flow_with_builtin(err, BUILTIN_NAME)
57}
58
59#[runtime_builtin(
60 name = "lower",
61 category = "strings/transform",
62 summary = "Convert strings, character arrays, and cell arrays of character vectors to lowercase.",
63 keywords = "lower,lowercase,strings,character array,text",
64 accel = "sink",
65 type_resolver(text_preserve_type),
66 builtin_path = "crate::builtins::strings::transform::lower"
67)]
68async fn lower_builtin(value: Value) -> BuiltinResult<Value> {
69 let gathered = gather_if_needed_async(&value).await.map_err(map_flow)?;
70 match gathered {
71 Value::String(text) => Ok(Value::String(lowercase_preserving_missing(text))),
72 Value::StringArray(array) => lower_string_array(array),
73 Value::CharArray(array) => lower_char_array(array),
74 Value::Cell(cell) => lower_cell_array(cell),
75 _ => Err(runtime_error_for(ARG_TYPE_ERROR)),
76 }
77}
78
79fn lower_string_array(array: StringArray) -> BuiltinResult<Value> {
80 let StringArray { data, shape, .. } = array;
81 let lowered = data
82 .into_iter()
83 .map(lowercase_preserving_missing)
84 .collect::<Vec<_>>();
85 let lowered_array = StringArray::new(lowered, shape)
86 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
87 Ok(Value::StringArray(lowered_array))
88}
89
90fn lower_char_array(array: CharArray) -> BuiltinResult<Value> {
91 let CharArray { data, rows, cols } = array;
92 if rows == 0 || cols == 0 {
93 return Ok(Value::CharArray(CharArray { data, rows, cols }));
94 }
95
96 let mut lowered_rows = Vec::with_capacity(rows);
97 let mut target_cols = cols;
98 for row in 0..rows {
99 let text = char_row_to_string_slice(&data, cols, row).to_lowercase();
100 let len = text.chars().count();
101 target_cols = target_cols.max(len);
102 lowered_rows.push(text);
103 }
104
105 let mut lowered_data = Vec::with_capacity(rows * target_cols);
106 for row_text in lowered_rows {
107 let mut chars: Vec<char> = row_text.chars().collect();
108 if chars.len() < target_cols {
109 chars.resize(target_cols, ' ');
110 }
111 lowered_data.extend(chars.into_iter());
112 }
113
114 CharArray::new(lowered_data, rows, target_cols)
115 .map(Value::CharArray)
116 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
117}
118
119fn lower_cell_array(cell: CellArray) -> BuiltinResult<Value> {
120 let CellArray {
121 data, rows, cols, ..
122 } = cell;
123 let mut lowered_values = Vec::with_capacity(rows * cols);
124 for row in 0..rows {
125 for col in 0..cols {
126 let idx = row * cols + col;
127 let lowered = lower_cell_element(&data[idx])?;
128 lowered_values.push(lowered);
129 }
130 }
131 make_cell(lowered_values, rows, cols)
132 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
133}
134
135fn lower_cell_element(value: &Value) -> BuiltinResult<Value> {
136 match value {
137 Value::String(text) => Ok(Value::String(lowercase_preserving_missing(text.clone()))),
138 Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(
139 lowercase_preserving_missing(sa.data[0].clone()),
140 )),
141 Value::CharArray(ca) if ca.rows <= 1 => lower_char_array(ca.clone()),
142 Value::CharArray(_) => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
143 _ => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
144 }
145}
146
147#[cfg(test)]
148pub(crate) mod tests {
149 use super::*;
150 use runmat_builtins::{ResolveContext, Type};
151
152 fn run_lower(value: Value) -> BuiltinResult<Value> {
153 futures::executor::block_on(lower_builtin(value))
154 }
155
156 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
157 #[test]
158 fn lower_string_scalar_value() {
159 let result = run_lower(Value::String("RunMat".into())).expect("lower");
160 assert_eq!(result, Value::String("runmat".into()));
161 }
162
163 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
164 #[test]
165 fn lower_string_array_preserves_shape() {
166 let array = StringArray::new(
167 vec![
168 "GPU".into(),
169 "ACCEL".into(),
170 "<missing>".into(),
171 "MiXeD".into(),
172 ],
173 vec![2, 2],
174 )
175 .unwrap();
176 let result = run_lower(Value::StringArray(array)).expect("lower");
177 match result {
178 Value::StringArray(sa) => {
179 assert_eq!(sa.shape, vec![2, 2]);
180 assert_eq!(
181 sa.data,
182 vec![
183 String::from("gpu"),
184 String::from("accel"),
185 String::from("<missing>"),
186 String::from("mixed")
187 ]
188 );
189 }
190 other => panic!("expected string array, got {other:?}"),
191 }
192 }
193
194 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
195 #[test]
196 fn lower_char_array_multiple_rows() {
197 let data: Vec<char> = vec!['C', 'A', 'T', 'D', 'O', 'G'];
198 let array = CharArray::new(data, 2, 3).unwrap();
199 let result = run_lower(Value::CharArray(array)).expect("lower");
200 match result {
201 Value::CharArray(ca) => {
202 assert_eq!(ca.rows, 2);
203 assert_eq!(ca.cols, 3);
204 assert_eq!(ca.data, vec!['c', 'a', 't', 'd', 'o', 'g']);
205 }
206 other => panic!("expected char array, got {other:?}"),
207 }
208 }
209
210 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
211 #[test]
212 fn lower_char_vector_handles_padding() {
213 let array = CharArray::new_row("HELLO ");
214 let result = run_lower(Value::CharArray(array)).expect("lower");
215 match result {
216 Value::CharArray(ca) => {
217 assert_eq!(ca.rows, 1);
218 assert_eq!(ca.cols, 6);
219 let expected: Vec<char> = "hello ".chars().collect();
220 assert_eq!(ca.data, expected);
221 }
222 other => panic!("expected char array, got {other:?}"),
223 }
224 }
225
226 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
227 #[test]
228 fn lower_char_array_unicode_expansion_extends_width() {
229 let data: Vec<char> = vec!['İ', 'A'];
230 let array = CharArray::new(data, 1, 2).unwrap();
231 let result = run_lower(Value::CharArray(array)).expect("lower");
232 match result {
233 Value::CharArray(ca) => {
234 assert_eq!(ca.rows, 1);
235 assert_eq!(ca.cols, 3);
236 let expected: Vec<char> = vec!['i', '\u{307}', 'a'];
237 assert_eq!(ca.data, expected);
238 }
239 other => panic!("expected char array, got {other:?}"),
240 }
241 }
242
243 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
244 #[test]
245 fn lower_cell_array_mixed_content() {
246 let cell = CellArray::new(
247 vec![
248 Value::CharArray(CharArray::new_row("RUN")),
249 Value::String("Mat".into()),
250 ],
251 1,
252 2,
253 )
254 .unwrap();
255 let result = run_lower(Value::Cell(cell)).expect("lower");
256 match result {
257 Value::Cell(out) => {
258 let first = out.get(0, 0).unwrap();
259 let second = out.get(0, 1).unwrap();
260 assert_eq!(first, Value::CharArray(CharArray::new_row("run")));
261 assert_eq!(second, Value::String("mat".into()));
262 }
263 other => panic!("expected cell array, got {other:?}"),
264 }
265 }
266
267 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
268 #[test]
269 fn lower_errors_on_invalid_input() {
270 let err = run_lower(Value::Num(1.0)).unwrap_err();
271 assert_eq!(err.to_string(), ARG_TYPE_ERROR);
272 }
273
274 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
275 #[test]
276 fn lower_cell_errors_on_invalid_element() {
277 let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
278 let err = run_lower(Value::Cell(cell)).unwrap_err();
279 assert_eq!(err.to_string(), CELL_ELEMENT_ERROR);
280 }
281
282 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
283 #[test]
284 fn lower_preserves_missing_string() {
285 let result = run_lower(Value::String("<missing>".into())).expect("lower");
286 assert_eq!(result, Value::String("<missing>".into()));
287 }
288
289 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
290 #[test]
291 fn lower_cell_allows_empty_char_vector() {
292 let empty_char = CharArray::new(Vec::new(), 1, 0).unwrap();
293 let cell = CellArray::new(vec![Value::CharArray(empty_char.clone())], 1, 1).unwrap();
294 let result = run_lower(Value::Cell(cell)).expect("lower");
295 match result {
296 Value::Cell(out) => {
297 let element = out.get(0, 0).unwrap();
298 assert_eq!(element, Value::CharArray(empty_char));
299 }
300 other => panic!("expected cell array, got {other:?}"),
301 }
302 }
303
304 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
305 #[test]
306 #[cfg(feature = "wgpu")]
307 fn lower_gpu_tensor_input_gathers_then_errors() {
308 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
309 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
310 );
311 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
312 let data = [1.0f64, 2.0];
313 let shape = [2usize, 1usize];
314 let handle = provider
315 .upload(&runmat_accelerate_api::HostTensorView {
316 data: &data,
317 shape: &shape,
318 })
319 .expect("upload");
320 let err = run_lower(Value::GpuTensor(handle.clone())).unwrap_err();
321 assert_eq!(err.to_string(), ARG_TYPE_ERROR);
322 provider.free(&handle).ok();
323 }
324
325 #[test]
326 fn lower_type_preserves_text() {
327 assert_eq!(
328 text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
329 Type::String
330 );
331 }
332}