1use runmat_builtins::{CellArray, CharArray, StringArray, Value};
3use runmat_macros::runtime_builtin;
4
5use crate::builtins::common::map_control_flow_with_builtin;
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::strings::common::{char_row_to_string_slice, uppercase_preserving_missing};
11use crate::builtins::strings::type_resolvers::text_preserve_type;
12use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
13
14#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::upper")]
15pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
16 name: "upper",
17 op_kind: GpuOpKind::Custom("string-transform"),
18 supported_precisions: &[],
19 broadcast: BroadcastSemantics::None,
20 provider_hooks: &[],
21 constant_strategy: ConstantStrategy::InlineLiteral,
22 residency: ResidencyPolicy::GatherImmediately,
23 nan_mode: ReductionNaN::Include,
24 two_pass_threshold: None,
25 workgroup_size: None,
26 accepts_nan_mode: false,
27 notes:
28 "Executes on the CPU; GPU-resident inputs are gathered to host memory before conversion.",
29};
30
31#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::upper")]
32pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
33 name: "upper",
34 shape: ShapeRequirements::Any,
35 constant_strategy: ConstantStrategy::InlineLiteral,
36 elementwise: None,
37 reduction: None,
38 emits_nan: false,
39 notes: "String transformation builtin; not eligible for fusion and always gathers GPU inputs.",
40};
41
42const BUILTIN_NAME: &str = "upper";
43const ARG_TYPE_ERROR: &str =
44 "upper: first argument must be a string array, character array, or cell array of character vectors";
45const CELL_ELEMENT_ERROR: &str =
46 "upper: cell array elements must be string scalars or character vectors";
47
48fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
49 build_runtime_error(message)
50 .with_builtin(BUILTIN_NAME)
51 .build()
52}
53
54fn map_flow(err: RuntimeError) -> RuntimeError {
55 map_control_flow_with_builtin(err, BUILTIN_NAME)
56}
57
58#[runtime_builtin(
59 name = "upper",
60 category = "strings/transform",
61 summary = "Convert strings, character arrays, and cell arrays of character vectors to uppercase.",
62 keywords = "upper,uppercase,strings,character array,text",
63 accel = "sink",
64 type_resolver(text_preserve_type),
65 builtin_path = "crate::builtins::strings::transform::upper"
66)]
67async fn upper_builtin(value: Value) -> BuiltinResult<Value> {
68 let gathered = gather_if_needed_async(&value).await.map_err(map_flow)?;
69 match gathered {
70 Value::String(text) => Ok(Value::String(uppercase_preserving_missing(text))),
71 Value::StringArray(array) => upper_string_array(array),
72 Value::CharArray(array) => upper_char_array(array),
73 Value::Cell(cell) => upper_cell_array(cell),
74 _ => Err(runtime_error_for(ARG_TYPE_ERROR)),
75 }
76}
77
78fn upper_string_array(array: StringArray) -> BuiltinResult<Value> {
79 let StringArray { data, shape, .. } = array;
80 let uppered = data
81 .into_iter()
82 .map(uppercase_preserving_missing)
83 .collect::<Vec<_>>();
84 let upper_array = StringArray::new(uppered, shape)
85 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
86 Ok(Value::StringArray(upper_array))
87}
88
89fn upper_char_array(array: CharArray) -> BuiltinResult<Value> {
90 let CharArray { data, rows, cols } = array;
91 if rows == 0 || cols == 0 {
92 return Ok(Value::CharArray(CharArray { data, rows, cols }));
93 }
94
95 let mut upper_rows = Vec::with_capacity(rows);
96 let mut target_cols = cols;
97 for row in 0..rows {
98 let text = char_row_to_string_slice(&data, cols, row).to_uppercase();
99 let len = text.chars().count();
100 target_cols = target_cols.max(len);
101 upper_rows.push(text);
102 }
103
104 let mut upper_data = Vec::with_capacity(rows * target_cols);
105 for row_text in upper_rows {
106 let mut chars: Vec<char> = row_text.chars().collect();
107 if chars.len() < target_cols {
108 chars.resize(target_cols, ' ');
109 }
110 upper_data.extend(chars.into_iter());
111 }
112
113 CharArray::new(upper_data, rows, target_cols)
114 .map(Value::CharArray)
115 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
116}
117
118fn upper_cell_array(cell: CellArray) -> BuiltinResult<Value> {
119 let CellArray {
120 data, rows, cols, ..
121 } = cell;
122 let mut upper_values = Vec::with_capacity(rows * cols);
123 for row in 0..rows {
124 for col in 0..cols {
125 let idx = row * cols + col;
126 let upper = upper_cell_element(&data[idx])?;
127 upper_values.push(upper);
128 }
129 }
130 make_cell(upper_values, rows, cols)
131 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
132}
133
134fn upper_cell_element(value: &Value) -> BuiltinResult<Value> {
135 match value {
136 Value::String(text) => Ok(Value::String(uppercase_preserving_missing(text.clone()))),
137 Value::StringArray(sa) if sa.data.len() == 1 => Ok(Value::String(
138 uppercase_preserving_missing(sa.data[0].clone()),
139 )),
140 Value::CharArray(ca) if ca.rows <= 1 => upper_char_array(ca.clone()),
141 Value::CharArray(_) => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
142 _ => Err(runtime_error_for(CELL_ELEMENT_ERROR)),
143 }
144}
145
146#[cfg(test)]
147pub(crate) mod tests {
148 use super::*;
149 use runmat_builtins::{ResolveContext, Type};
150
151 fn run_upper(value: Value) -> BuiltinResult<Value> {
152 futures::executor::block_on(upper_builtin(value))
153 }
154
155 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
156 #[test]
157 fn upper_string_scalar_value() {
158 let result = run_upper(Value::String("RunMat".into())).expect("upper");
159 assert_eq!(result, Value::String("RUNMAT".into()));
160 }
161
162 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
163 #[test]
164 fn upper_string_array_preserves_shape() {
165 let array = StringArray::new(
166 vec![
167 "gpu".into(),
168 "accel".into(),
169 "<missing>".into(),
170 "MiXeD".into(),
171 ],
172 vec![2, 2],
173 )
174 .unwrap();
175 let result = run_upper(Value::StringArray(array)).expect("upper");
176 match result {
177 Value::StringArray(sa) => {
178 assert_eq!(sa.shape, vec![2, 2]);
179 assert_eq!(
180 sa.data,
181 vec![
182 String::from("GPU"),
183 String::from("ACCEL"),
184 String::from("<missing>"),
185 String::from("MIXED")
186 ]
187 );
188 }
189 other => panic!("expected string array, got {other:?}"),
190 }
191 }
192
193 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
194 #[test]
195 fn upper_char_array_multiple_rows() {
196 let data: Vec<char> = vec!['c', 'a', 't', 'd', 'o', 'g'];
197 let array = CharArray::new(data, 2, 3).unwrap();
198 let result = run_upper(Value::CharArray(array)).expect("upper");
199 match result {
200 Value::CharArray(ca) => {
201 assert_eq!(ca.rows, 2);
202 assert_eq!(ca.cols, 3);
203 assert_eq!(ca.data, vec!['C', 'A', 'T', 'D', 'O', 'G']);
204 }
205 other => panic!("expected char array, got {other:?}"),
206 }
207 }
208
209 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
210 #[test]
211 fn upper_char_vector_handles_padding() {
212 let array = CharArray::new_row("hello ");
213 let result = run_upper(Value::CharArray(array)).expect("upper");
214 match result {
215 Value::CharArray(ca) => {
216 assert_eq!(ca.rows, 1);
217 assert_eq!(ca.cols, 6);
218 let expected: Vec<char> = "HELLO ".chars().collect();
219 assert_eq!(ca.data, expected);
220 }
221 other => panic!("expected char array, got {other:?}"),
222 }
223 }
224
225 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
226 #[test]
227 fn upper_char_array_unicode_expansion_extends_width() {
228 let data: Vec<char> = vec!['ß', 'a'];
229 let array = CharArray::new(data, 1, 2).unwrap();
230 let result = run_upper(Value::CharArray(array)).expect("upper");
231 match result {
232 Value::CharArray(ca) => {
233 assert_eq!(ca.rows, 1);
234 assert_eq!(ca.cols, 3);
235 let expected: Vec<char> = vec!['S', 'S', 'A'];
236 assert_eq!(ca.data, expected);
237 }
238 other => panic!("expected char array, got {other:?}"),
239 }
240 }
241
242 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
243 #[test]
244 fn upper_cell_array_mixed_content() {
245 let cell = CellArray::new(
246 vec![
247 Value::CharArray(CharArray::new_row("run")),
248 Value::String("Mat".into()),
249 ],
250 1,
251 2,
252 )
253 .unwrap();
254 let result = run_upper(Value::Cell(cell)).expect("upper");
255 match result {
256 Value::Cell(out) => {
257 let first = out.get(0, 0).unwrap();
258 let second = out.get(0, 1).unwrap();
259 assert_eq!(first, Value::CharArray(CharArray::new_row("RUN")));
260 assert_eq!(second, Value::String("MAT".into()));
261 }
262 other => panic!("expected cell array, got {other:?}"),
263 }
264 }
265
266 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
267 #[test]
268 fn upper_errors_on_invalid_input() {
269 let err = run_upper(Value::Num(1.0)).unwrap_err();
270 assert_eq!(err.to_string(), ARG_TYPE_ERROR);
271 }
272
273 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
274 #[test]
275 fn upper_cell_errors_on_invalid_element() {
276 let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
277 let err = run_upper(Value::Cell(cell)).unwrap_err();
278 assert_eq!(err.to_string(), CELL_ELEMENT_ERROR);
279 }
280
281 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
282 #[test]
283 fn upper_preserves_missing_string() {
284 let result = run_upper(Value::String("<missing>".into())).expect("upper");
285 assert_eq!(result, Value::String("<missing>".into()));
286 }
287
288 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
289 #[test]
290 fn upper_cell_allows_empty_char_vector() {
291 let empty_char = CharArray::new(Vec::new(), 1, 0).unwrap();
292 let cell = CellArray::new(vec![Value::CharArray(empty_char.clone())], 1, 1).unwrap();
293 let result = run_upper(Value::Cell(cell)).expect("upper");
294 match result {
295 Value::Cell(out) => {
296 let element = out.get(0, 0).unwrap();
297 assert_eq!(element, Value::CharArray(empty_char));
298 }
299 other => panic!("expected cell array, got {other:?}"),
300 }
301 }
302
303 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
304 #[test]
305 #[cfg(feature = "wgpu")]
306 fn upper_gpu_tensor_input_gathers_then_errors() {
307 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
308 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
309 );
310 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
311 let data = [1.0f64, 2.0];
312 let shape = [2usize, 1usize];
313 let handle = provider
314 .upload(&runmat_accelerate_api::HostTensorView {
315 data: &data,
316 shape: &shape,
317 })
318 .expect("upload");
319 let err = run_upper(Value::GpuTensor(handle.clone())).unwrap_err();
320 assert_eq!(err.to_string(), ARG_TYPE_ERROR);
321 provider.free(&handle).ok();
322 }
323
324 #[test]
325 fn upper_type_preserves_text() {
326 assert_eq!(
327 text_preserve_type(&[Type::String], &ResolveContext::new(Vec::new())),
328 Type::String
329 );
330 }
331}