Skip to main content

runmat_runtime/builtins/strings/transform/
lower.rs

1//! MATLAB-compatible `lower` builtin with GPU-aware semantics for RunMat.
2
3use runmat_builtins::{CellArray, CharArray, StringArray, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::spec::{
8    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9    ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::strings::common::{char_row_to_string_slice, lowercase_preserving_missing};
12use crate::builtins::strings::type_resolvers::text_preserve_type;
13use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::lower")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17    name: "lower",
18    op_kind: GpuOpKind::Custom("string-transform"),
19    supported_precisions: &[],
20    broadcast: BroadcastSemantics::None,
21    provider_hooks: &[],
22    constant_strategy: ConstantStrategy::InlineLiteral,
23    residency: ResidencyPolicy::GatherImmediately,
24    nan_mode: ReductionNaN::Include,
25    two_pass_threshold: None,
26    workgroup_size: None,
27    accepts_nan_mode: false,
28    notes:
29        "Executes on the CPU; GPU-resident inputs are gathered to host memory before conversion.",
30};
31
32#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::lower")]
33pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
34    name: "lower",
35    shape: ShapeRequirements::Any,
36    constant_strategy: ConstantStrategy::InlineLiteral,
37    elementwise: None,
38    reduction: None,
39    emits_nan: false,
40    notes: "String transformation builtin; not eligible for fusion and always gathers GPU inputs.",
41};
42
43const BUILTIN_NAME: &str = "lower";
44const ARG_TYPE_ERROR: &str =
45    "lower: first argument must be a string array, character array, or cell array of character vectors";
46const CELL_ELEMENT_ERROR: &str =
47    "lower: cell array elements must be string scalars or character vectors";
48
49fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
50    build_runtime_error(message)
51        .with_builtin(BUILTIN_NAME)
52        .build()
53}
54
55fn map_flow(err: RuntimeError) -> RuntimeError {
56    map_control_flow_with_builtin(err, BUILTIN_NAME)
57}
58
59#[runtime_builtin(
60    name = "lower",
61    category = "strings/transform",
62    summary = "Convert strings, character arrays, and cell arrays of character vectors to lowercase.",
63    keywords = "lower,lowercase,strings,character array,text",
64    accel = "sink",
65    type_resolver(text_preserve_type),
66    builtin_path = "crate::builtins::strings::transform::lower"
67)]
68async fn lower_builtin(value: Value) -> BuiltinResult<Value> {
69    let gathered = gather_if_needed_async(&value).await.map_err(map_flow)?;
70    match gathered {
71        Value::String(text) => Ok(Value::String(lowercase_preserving_missing(text))),
72        Value::StringArray(array) => lower_string_array(array),
73        Value::CharArray(array) => lower_char_array(array),
74        Value::Cell(cell) => lower_cell_array(cell),
75        _ => Err(runtime_error_for(ARG_TYPE_ERROR)),
76    }
77}
78
79fn lower_string_array(array: StringArray) -> BuiltinResult<Value> {
80    let StringArray { data, shape, .. } = array;
81    let lowered = data
82        .into_iter()
83        .map(lowercase_preserving_missing)
84        .collect::<Vec<_>>();
85    let lowered_array = StringArray::new(lowered, shape)
86        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
87    Ok(Value::StringArray(lowered_array))
88}
89
90fn lower_char_array(array: CharArray) -> BuiltinResult<Value> {
91    let CharArray { data, rows, cols } = array;
92    if rows == 0 || cols == 0 {
93        return Ok(Value::CharArray(CharArray { data, rows, cols }));
94    }
95
96    let mut lowered_rows = Vec::with_capacity(rows);
97    let mut target_cols = cols;
98    for row in 0..rows {
99        let text = char_row_to_string_slice(&data, cols, row).to_lowercase();
100        let len = text.chars().count();
101        target_cols = target_cols.max(len);
102        lowered_rows.push(text);
103    }
104
105    let mut lowered_data = Vec::with_capacity(rows * target_cols);
106    for row_text in lowered_rows {
107        let mut chars: Vec<char> = row_text.chars().collect();
108        if chars.len() < target_cols {
109            chars.resize(target_cols, ' ');
110        }
111        lowered_data.extend(chars.into_iter());
112    }
113
114    CharArray::new(lowered_data, rows, target_cols)
115        .map(Value::CharArray)
116        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
117}
118
119fn lower_cell_array(cell: CellArray) -> BuiltinResult<Value> {
120    let CellArray {
121        data, rows, cols, ..
122    } = cell;
123    let mut lowered_values = Vec::with_capacity(rows * cols);
124    for row in 0..rows {
125        for col in 0..cols {
126            let idx = row * cols + col;
127            let lowered = lower_cell_element(&data[idx])?;
128            lowered_values.push(lowered);
129        }
130    }
131    make_cell(lowered_values, rows, cols)
132        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
133}
134
135fn lower_cell_element(value: &Value) -> BuiltinResult<Value> {
136    match value {
137        Value::String(text) => Ok(Value::String(lowercase_preserving_missing(text.clone()))),
138        Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(
139            lowercase_preserving_missing(sa.data[0].clone()),
140        )),
141        Value::CharArray(ca) if ca.rows <= 1 => lower_char_array(ca.clone()),
142        Value::CharArray(_) => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
143        _ => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
144    }
145}
146
147#[cfg(test)]
148pub(crate) mod tests {
149    use super::*;
150    use runmat_builtins::{ResolveContext, Type};
151
152    fn run_lower(value: Value) -> BuiltinResult<Value> {
153        futures::executor::block_on(lower_builtin(value))
154    }
155
156    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
157    #[test]
158    fn lower_string_scalar_value() {
159        let result = run_lower(Value::String("RunMat".into())).expect("lower");
160        assert_eq!(result, Value::String("runmat".into()));
161    }
162
163    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
164    #[test]
165    fn lower_string_array_preserves_shape() {
166        let array = StringArray::new(
167            vec![
168                "GPU".into(),
169                "ACCEL".into(),
170                "<missing>".into(),
171                "MiXeD".into(),
172            ],
173            vec![2, 2],
174        )
175        .unwrap();
176        let result = run_lower(Value::StringArray(array)).expect("lower");
177        match result {
178            Value::StringArray(sa) => {
179                assert_eq!(sa.shape, vec![2, 2]);
180                assert_eq!(
181                    sa.data,
182                    vec![
183                        String::from("gpu"),
184                        String::from("accel"),
185                        String::from("<missing>"),
186                        String::from("mixed")
187                    ]
188                );
189            }
190            other => panic!("expected string array, got {other:?}"),
191        }
192    }
193
194    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
195    #[test]
196    fn lower_char_array_multiple_rows() {
197        let data: Vec<char> = vec!['C', 'A', 'T', 'D', 'O', 'G'];
198        let array = CharArray::new(data, 2, 3).unwrap();
199        let result = run_lower(Value::CharArray(array)).expect("lower");
200        match result {
201            Value::CharArray(ca) => {
202                assert_eq!(ca.rows, 2);
203                assert_eq!(ca.cols, 3);
204                assert_eq!(ca.data, vec!['c', 'a', 't', 'd', 'o', 'g']);
205            }
206            other => panic!("expected char array, got {other:?}"),
207        }
208    }
209
210    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
211    #[test]
212    fn lower_char_vector_handles_padding() {
213        let array = CharArray::new_row("HELLO ");
214        let result = run_lower(Value::CharArray(array)).expect("lower");
215        match result {
216            Value::CharArray(ca) => {
217                assert_eq!(ca.rows, 1);
218                assert_eq!(ca.cols, 6);
219                let expected: Vec<char> = "hello ".chars().collect();
220                assert_eq!(ca.data, expected);
221            }
222            other => panic!("expected char array, got {other:?}"),
223        }
224    }
225
226    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
227    #[test]
228    fn lower_char_array_unicode_expansion_extends_width() {
229        let data: Vec<char> = vec!['İ', 'A'];
230        let array = CharArray::new(data, 1, 2).unwrap();
231        let result = run_lower(Value::CharArray(array)).expect("lower");
232        match result {
233            Value::CharArray(ca) => {
234                assert_eq!(ca.rows, 1);
235                assert_eq!(ca.cols, 3);
236                let expected: Vec<char> = vec!['i', '\u{307}', 'a'];
237                assert_eq!(ca.data, expected);
238            }
239            other => panic!("expected char array, got {other:?}"),
240        }
241    }
242
243    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
244    #[test]
245    fn lower_cell_array_mixed_content() {
246        let cell = CellArray::new(
247            vec![
248                Value::CharArray(CharArray::new_row("RUN")),
249                Value::String("Mat".into()),
250            ],
251            1,
252            2,
253        )
254        .unwrap();
255        let result = run_lower(Value::Cell(cell)).expect("lower");
256        match result {
257            Value::Cell(out) => {
258                let first = out.get(0, 0).unwrap();
259                let second = out.get(0, 1).unwrap();
260                assert_eq!(first, Value::CharArray(CharArray::new_row("run")));
261                assert_eq!(second, Value::String("mat".into()));
262            }
263            other => panic!("expected cell array, got {other:?}"),
264        }
265    }
266
267    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
268    #[test]
269    fn lower_errors_on_invalid_input() {
270        let err = run_lower(Value::Num(1.0)).unwrap_err();
271        assert_eq!(err.to_string(), ARG_TYPE_ERROR);
272    }
273
274    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
275    #[test]
276    fn lower_cell_errors_on_invalid_element() {
277        let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
278        let err = run_lower(Value::Cell(cell)).unwrap_err();
279        assert_eq!(err.to_string(), CELL_ELEMENT_ERROR);
280    }
281
282    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
283    #[test]
284    fn lower_preserves_missing_string() {
285        let result = run_lower(Value::String("<missing>".into())).expect("lower");
286        assert_eq!(result, Value::String("<missing>".into()));
287    }
288
289    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
290    #[test]
291    fn lower_cell_allows_empty_char_vector() {
292        let empty_char = CharArray::new(Vec::new(), 1, 0).unwrap();
293        let cell = CellArray::new(vec![Value::CharArray(empty_char.clone())], 1, 1).unwrap();
294        let result = run_lower(Value::Cell(cell)).expect("lower");
295        match result {
296            Value::Cell(out) => {
297                let element = out.get(0, 0).unwrap();
298                assert_eq!(element, Value::CharArray(empty_char));
299            }
300            other => panic!("expected cell array, got {other:?}"),
301        }
302    }
303
304    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
305    #[test]
306    #[cfg(feature = "wgpu")]
307    fn lower_gpu_tensor_input_gathers_then_errors() {
308        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
309            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
310        );
311        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
312        let data = [1.0f64, 2.0];
313        let shape = [2usize, 1usize];
314        let handle = provider
315            .upload(&runmat_accelerate_api::HostTensorView {
316                data: &data,
317                shape: &shape,
318            })
319            .expect("upload");
320        let err = run_lower(Value::GpuTensor(handle.clone())).unwrap_err();
321        assert_eq!(err.to_string(), ARG_TYPE_ERROR);
322        provider.free(&handle).ok();
323    }
324
325    #[test]
326    fn lower_type_preserves_text() {
327        assert_eq!(
328            text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
329            Type::String
330        );
331    }
332}