Skip to main content

runmat_runtime/builtins/strings/transform/
lower.rs

1//! MATLAB-compatible `lower` builtin with GPU-aware semantics for RunMat.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6    CellArray, CharArray, StringArray, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::strings::common::{char_row_to_string_slice, lowercase_preserving_missing};
16use crate::builtins::strings::type_resolvers::text_preserve_type;
17use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::lower")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21    name: "lower",
22    op_kind: GpuOpKind::Custom("string-transform"),
23    supported_precisions: &[],
24    broadcast: BroadcastSemantics::None,
25    provider_hooks: &[],
26    constant_strategy: ConstantStrategy::InlineLiteral,
27    residency: ResidencyPolicy::GatherImmediately,
28    nan_mode: ReductionNaN::Include,
29    two_pass_threshold: None,
30    workgroup_size: None,
31    accepts_nan_mode: false,
32    notes:
33        "Executes on the CPU; GPU-resident inputs are gathered to host memory before conversion.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::lower")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38    name: "lower",
39    shape: ShapeRequirements::Any,
40    constant_strategy: ConstantStrategy::InlineLiteral,
41    elementwise: None,
42    reduction: None,
43    emits_nan: false,
44    notes: "String transformation builtin; not eligible for fusion and always gathers GPU inputs.",
45};
46
47const BUILTIN_NAME: &str = "lower";
48
49const LOWER_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
50    name: "out",
51    ty: BuiltinParamType::Any,
52    arity: BuiltinParamArity::Required,
53    default: None,
54    description: "Lowercased text preserving input container kind and shape.",
55}];
56
57const LOWER_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
58    name: "str",
59    ty: BuiltinParamType::Any,
60    arity: BuiltinParamArity::Required,
61    default: None,
62    description: "String/char/cell text input to transform.",
63}];
64
65const LOWER_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
66    label: "out = lower(str)",
67    inputs: &LOWER_INPUTS,
68    outputs: &LOWER_OUTPUT,
69}];
70
71const LOWER_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
72    code: "RM.LOWER.INVALID_INPUT",
73    identifier: Some("RunMat:lower:InvalidInput"),
74    when: "Input is not a string array, character array, or cell array of text scalars.",
75    message:
76        "lower: first argument must be a string array, character array, or cell array of character vectors",
77};
78
79const LOWER_ERROR_CELL_ELEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
80    code: "RM.LOWER.CELL_ELEMENT",
81    identifier: Some("RunMat:lower:CellElement"),
82    when: "Cell array contains a non-text element or non-row char array element.",
83    message: "lower: cell array elements must be string scalars or character vectors",
84};
85
86const LOWER_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
87    code: "RM.LOWER.INTERNAL",
88    identifier: Some("RunMat:lower:InternalError"),
89    when: "Internal output container construction failed.",
90    message: "lower: internal error",
91};
92
93const LOWER_ERRORS: [BuiltinErrorDescriptor; 3] = [
94    LOWER_ERROR_INVALID_INPUT,
95    LOWER_ERROR_CELL_ELEMENT,
96    LOWER_ERROR_INTERNAL,
97];
98
99pub const LOWER_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
100    signatures: &LOWER_SIGNATURES,
101    output_mode: BuiltinOutputMode::Fixed,
102    completion_policy: BuiltinCompletionPolicy::Public,
103    errors: &LOWER_ERRORS,
104};
105
106fn map_flow(err: RuntimeError) -> RuntimeError {
107    map_control_flow_with_builtin(err, BUILTIN_NAME)
108}
109
110fn lower_error_with_message(
111    message: impl Into<String>,
112    error: &'static BuiltinErrorDescriptor,
113) -> RuntimeError {
114    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
115    if let Some(identifier) = error.identifier {
116        builder = builder.with_identifier(identifier);
117    }
118    builder.build()
119}
120
121fn lower_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
122    lower_error_with_message(error.message, error)
123}
124
125#[runtime_builtin(
126    name = "lower",
127    category = "strings/transform",
128    summary = "Convert strings, character arrays, and cell arrays of character vectors to lowercase.",
129    keywords = "lower,lowercase,strings,character array,text",
130    accel = "sink",
131    type_resolver(text_preserve_type),
132    descriptor(crate::builtins::strings::transform::lower::LOWER_DESCRIPTOR),
133    builtin_path = "crate::builtins::strings::transform::lower"
134)]
135async fn lower_builtin(value: Value) -> BuiltinResult<Value> {
136    let gathered = gather_if_needed_async(&value).await.map_err(map_flow)?;
137    match gathered {
138        Value::String(text) => Ok(Value::String(lowercase_preserving_missing(text))),
139        Value::StringArray(array) => lower_string_array(array),
140        Value::CharArray(array) => lower_char_array(array),
141        Value::Cell(cell) => lower_cell_array(cell),
142        _ => Err(lower_error(&LOWER_ERROR_INVALID_INPUT)),
143    }
144}
145
146fn lower_string_array(array: StringArray) -> BuiltinResult<Value> {
147    let StringArray { data, shape, .. } = array;
148    let lowered = data
149        .into_iter()
150        .map(lowercase_preserving_missing)
151        .collect::<Vec<_>>();
152    let lowered_array = StringArray::new(lowered, shape).map_err(|e| {
153        lower_error_with_message(format!("{BUILTIN_NAME}: {e}"), &LOWER_ERROR_INTERNAL)
154    })?;
155    Ok(Value::StringArray(lowered_array))
156}
157
158fn lower_char_array(array: CharArray) -> BuiltinResult<Value> {
159    let CharArray { data, rows, cols } = array;
160    if rows == 0 || cols == 0 {
161        return Ok(Value::CharArray(CharArray { data, rows, cols }));
162    }
163
164    let mut lowered_rows = Vec::with_capacity(rows);
165    let mut target_cols = cols;
166    for row in 0..rows {
167        let text = char_row_to_string_slice(&data, cols, row).to_lowercase();
168        let len = text.chars().count();
169        target_cols = target_cols.max(len);
170        lowered_rows.push(text);
171    }
172
173    let mut lowered_data = Vec::with_capacity(rows * target_cols);
174    for row_text in lowered_rows {
175        let mut chars: Vec<char> = row_text.chars().collect();
176        if chars.len() < target_cols {
177            chars.resize(target_cols, ' ');
178        }
179        lowered_data.extend(chars.into_iter());
180    }
181
182    CharArray::new(lowered_data, rows, target_cols)
183        .map(Value::CharArray)
184        .map_err(|e| {
185            lower_error_with_message(format!("{BUILTIN_NAME}: {e}"), &LOWER_ERROR_INTERNAL)
186        })
187}
188
189fn lower_cell_array(cell: CellArray) -> BuiltinResult<Value> {
190    let CellArray {
191        data, rows, cols, ..
192    } = cell;
193    let mut lowered_values = Vec::with_capacity(rows * cols);
194    for row in 0..rows {
195        for col in 0..cols {
196            let idx = row * cols + col;
197            let lowered = lower_cell_element(&data[idx])?;
198            lowered_values.push(lowered);
199        }
200    }
201    make_cell(lowered_values, rows, cols).map_err(|e| {
202        lower_error_with_message(format!("{BUILTIN_NAME}: {e}"), &LOWER_ERROR_INTERNAL)
203    })
204}
205
206fn lower_cell_element(value: &Value) -> BuiltinResult<Value> {
207    match value {
208        Value::String(text) => Ok(Value::String(lowercase_preserving_missing(text.clone()))),
209        Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(
210            lowercase_preserving_missing(sa.data[0].clone()),
211        )),
212        Value::CharArray(ca) if ca.rows <= 1 => lower_char_array(ca.clone()),
213        Value::CharArray(_) => Err(lower_error(&LOWER_ERROR_CELL_ELEMENT)),
214        _ => Err(lower_error(&LOWER_ERROR_CELL_ELEMENT)),
215    }
216}
217
218#[cfg(test)]
219pub(crate) mod tests {
220    use super::*;
221    use runmat_builtins::{ResolveContext, Type};
222
223    fn run_lower(value: Value) -> BuiltinResult<Value> {
224        futures::executor::block_on(lower_builtin(value))
225    }
226
227    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
228    #[test]
229    fn lower_string_scalar_value() {
230        let result = run_lower(Value::String("RunMat".into())).expect("lower");
231        assert_eq!(result, Value::String("runmat".into()));
232    }
233
234    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
235    #[test]
236    fn lower_string_array_preserves_shape() {
237        let array = StringArray::new(
238            vec![
239                "GPU".into(),
240                "ACCEL".into(),
241                "<missing>".into(),
242                "MiXeD".into(),
243            ],
244            vec![2, 2],
245        )
246        .unwrap();
247        let result = run_lower(Value::StringArray(array)).expect("lower");
248        match result {
249            Value::StringArray(sa) => {
250                assert_eq!(sa.shape, vec![2, 2]);
251                assert_eq!(
252                    sa.data,
253                    vec![
254                        String::from("gpu"),
255                        String::from("accel"),
256                        String::from("<missing>"),
257                        String::from("mixed")
258                    ]
259                );
260            }
261            other => panic!("expected string array, got {other:?}"),
262        }
263    }
264
265    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
266    #[test]
267    fn lower_char_array_multiple_rows() {
268        let data: Vec<char> = vec!['C', 'A', 'T', 'D', 'O', 'G'];
269        let array = CharArray::new(data, 2, 3).unwrap();
270        let result = run_lower(Value::CharArray(array)).expect("lower");
271        match result {
272            Value::CharArray(ca) => {
273                assert_eq!(ca.rows, 2);
274                assert_eq!(ca.cols, 3);
275                assert_eq!(ca.data, vec!['c', 'a', 't', 'd', 'o', 'g']);
276            }
277            other => panic!("expected char array, got {other:?}"),
278        }
279    }
280
281    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
282    #[test]
283    fn lower_char_vector_handles_padding() {
284        let array = CharArray::new_row("HELLO ");
285        let result = run_lower(Value::CharArray(array)).expect("lower");
286        match result {
287            Value::CharArray(ca) => {
288                assert_eq!(ca.rows, 1);
289                assert_eq!(ca.cols, 6);
290                let expected: Vec<char> = "hello ".chars().collect();
291                assert_eq!(ca.data, expected);
292            }
293            other => panic!("expected char array, got {other:?}"),
294        }
295    }
296
297    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
298    #[test]
299    fn lower_char_array_unicode_expansion_extends_width() {
300        let data: Vec<char> = vec!['İ', 'A'];
301        let array = CharArray::new(data, 1, 2).unwrap();
302        let result = run_lower(Value::CharArray(array)).expect("lower");
303        match result {
304            Value::CharArray(ca) => {
305                assert_eq!(ca.rows, 1);
306                assert_eq!(ca.cols, 3);
307                let expected: Vec<char> = vec!['i', '\u{307}', 'a'];
308                assert_eq!(ca.data, expected);
309            }
310            other => panic!("expected char array, got {other:?}"),
311        }
312    }
313
314    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
315    #[test]
316    fn lower_cell_array_mixed_content() {
317        let cell = CellArray::new(
318            vec![
319                Value::CharArray(CharArray::new_row("RUN")),
320                Value::String("Mat".into()),
321            ],
322            1,
323            2,
324        )
325        .unwrap();
326        let result = run_lower(Value::Cell(cell)).expect("lower");
327        match result {
328            Value::Cell(out) => {
329                let first = out.get(0, 0).unwrap();
330                let second = out.get(0, 1).unwrap();
331                assert_eq!(first, Value::CharArray(CharArray::new_row("run")));
332                assert_eq!(second, Value::String("mat".into()));
333            }
334            other => panic!("expected cell array, got {other:?}"),
335        }
336    }
337
338    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
339    #[test]
340    fn lower_errors_on_invalid_input() {
341        let err = run_lower(Value::Num(1.0)).unwrap_err();
342        assert_eq!(err.to_string(), LOWER_ERROR_INVALID_INPUT.message);
343    }
344
345    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
346    #[test]
347    fn lower_cell_errors_on_invalid_element() {
348        let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
349        let err = run_lower(Value::Cell(cell)).unwrap_err();
350        assert_eq!(err.to_string(), LOWER_ERROR_CELL_ELEMENT.message);
351    }
352
353    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
354    #[test]
355    fn lower_preserves_missing_string() {
356        let result = run_lower(Value::String("<missing>".into())).expect("lower");
357        assert_eq!(result, Value::String("<missing>".into()));
358    }
359
360    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
361    #[test]
362    fn lower_cell_allows_empty_char_vector() {
363        let empty_char = CharArray::new(Vec::new(), 1, 0).unwrap();
364        let cell = CellArray::new(vec![Value::CharArray(empty_char.clone())], 1, 1).unwrap();
365        let result = run_lower(Value::Cell(cell)).expect("lower");
366        match result {
367            Value::Cell(out) => {
368                let element = out.get(0, 0).unwrap();
369                assert_eq!(element, Value::CharArray(empty_char));
370            }
371            other => panic!("expected cell array, got {other:?}"),
372        }
373    }
374
375    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
376    #[test]
377    #[cfg(feature = "wgpu")]
378    fn lower_gpu_tensor_input_gathers_then_errors() {
379        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
380            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
381        );
382        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
383        let data = [1.0f64, 2.0];
384        let shape = [2usize, 1usize];
385        let handle = provider
386            .upload(&runmat_accelerate_api::HostTensorView {
387                data: &data,
388                shape: &shape,
389            })
390            .expect("upload");
391        let err = run_lower(Value::GpuTensor(handle.clone())).unwrap_err();
392        assert_eq!(err.to_string(), LOWER_ERROR_INVALID_INPUT.message);
393        provider.free(&handle).ok();
394    }
395
396    #[test]
397    fn lower_type_preserves_text() {
398        assert_eq!(
399            text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
400            Type::String
401        );
402    }
403}