1use std::borrow::Cow;
4
5use runmat_builtins::{
6 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
7 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
8 CellArray, CharArray, StringArray, Tensor, Value,
9};
10use runmat_macros::runtime_builtin;
11
12use crate::builtins::common::map_control_flow_with_builtin;
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18use crate::builtins::strings::type_resolvers::numeric_text_scalar_or_tensor_type;
19use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
20
21#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::str2double")]
22pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
23 name: "str2double",
24 op_kind: GpuOpKind::Custom("conversion"),
25 supported_precisions: &[],
26 broadcast: BroadcastSemantics::None,
27 provider_hooks: &[],
28 constant_strategy: ConstantStrategy::InlineLiteral,
29 residency: ResidencyPolicy::GatherImmediately,
30 nan_mode: ReductionNaN::Include,
31 two_pass_threshold: None,
32 workgroup_size: None,
33 accepts_nan_mode: false,
34 notes: "Parses text on the CPU; GPU-resident inputs are gathered before conversion.",
35};
36
37#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::str2double")]
38pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
39 name: "str2double",
40 shape: ShapeRequirements::Any,
41 constant_strategy: ConstantStrategy::InlineLiteral,
42 elementwise: None,
43 reduction: None,
44 emits_nan: true,
45 notes: "Conversion builtin; not eligible for fusion and materialises host-side doubles.",
46};
47
48const BUILTIN_NAME: &str = "str2double";
49
50const STR2DOUBLE_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
51 name: "X",
52 ty: BuiltinParamType::NumericArray,
53 arity: BuiltinParamArity::Required,
54 default: None,
55 description: "Parsed double values; invalid parses become NaN.",
56}];
57
58const STR2DOUBLE_INPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
59 name: "str",
60 ty: BuiltinParamType::Any,
61 arity: BuiltinParamArity::Required,
62 default: None,
63 description: "String, character, or cell-array text input to parse.",
64}];
65
66const STR2DOUBLE_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
67 label: "X = str2double(str)",
68 inputs: &STR2DOUBLE_INPUT,
69 outputs: &STR2DOUBLE_OUTPUT,
70}];
71
72const STR2DOUBLE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
73 code: "RM.STR2DOUBLE.INVALID_INPUT",
74 identifier: Some("RunMat:str2double:InvalidInput"),
75 when: "Input is not a supported text container.",
76 message: "str2double: input must be a string array, character array, or cell array of character vectors",
77};
78
79const STR2DOUBLE_ERROR_INVALID_CELL_ELEMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
80 code: "RM.STR2DOUBLE.INVALID_CELL_ELEMENT",
81 identifier: Some("RunMat:str2double:InvalidCellElement"),
82 when: "Cell array contains non-text or non-scalar text entries.",
83 message: "str2double: cell array elements must be character vectors or string scalars",
84};
85
86const STR2DOUBLE_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
87 code: "RM.STR2DOUBLE.INTERNAL",
88 identifier: Some("RunMat:str2double:InternalError"),
89 when: "Internal tensor assembly failed while building parsed output.",
90 message: "str2double: internal error",
91};
92
93const STR2DOUBLE_ERRORS: [BuiltinErrorDescriptor; 3] = [
94 STR2DOUBLE_ERROR_INVALID_INPUT,
95 STR2DOUBLE_ERROR_INVALID_CELL_ELEMENT,
96 STR2DOUBLE_ERROR_INTERNAL,
97];
98
99pub const STR2DOUBLE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
100 signatures: &STR2DOUBLE_SIGNATURES,
101 output_mode: BuiltinOutputMode::Fixed,
102 completion_policy: BuiltinCompletionPolicy::Public,
103 errors: &STR2DOUBLE_ERRORS,
104};
105
106fn str2double_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
107 str2double_error_with_message(error.message, error)
108}
109
110fn str2double_error_with_message(
111 message: impl Into<String>,
112 error: &'static BuiltinErrorDescriptor,
113) -> RuntimeError {
114 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
115 if let Some(identifier) = error.identifier {
116 builder = builder.with_identifier(identifier);
117 }
118 builder.build()
119}
120
121fn remap_str2double_flow(err: RuntimeError) -> RuntimeError {
122 map_control_flow_with_builtin(err, BUILTIN_NAME)
123}
124
125#[runtime_builtin(
126 name = "str2double",
127 category = "strings/core",
128 summary = "Convert text representations of numbers into double-precision values.",
129 keywords = "str2double,string to double,text conversion,gpu",
130 accel = "sink",
131 type_resolver(numeric_text_scalar_or_tensor_type),
132 descriptor(crate::builtins::strings::core::str2double::STR2DOUBLE_DESCRIPTOR),
133 builtin_path = "crate::builtins::strings::core::str2double"
134)]
135async fn str2double_builtin(value: Value) -> crate::BuiltinResult<Value> {
136 let gathered = gather_if_needed_async(&value)
137 .await
138 .map_err(remap_str2double_flow)?;
139 match gathered {
140 Value::String(text) => Ok(Value::Num(parse_numeric_scalar(&text))),
141 Value::StringArray(array) => str2double_string_array(array),
142 Value::CharArray(array) => str2double_char_array(array),
143 Value::Cell(cell) => str2double_cell_array(cell),
144 _ => Err(str2double_error(&STR2DOUBLE_ERROR_INVALID_INPUT)),
145 }
146}
147
148fn str2double_string_array(array: StringArray) -> BuiltinResult<Value> {
149 let StringArray { data, shape, .. } = array;
150 let mut values = Vec::with_capacity(data.len());
151 for text in &data {
152 values.push(parse_numeric_scalar(text));
153 }
154 let tensor =
155 Tensor::new(values, shape).map_err(|_| str2double_error(&STR2DOUBLE_ERROR_INTERNAL))?;
156 Ok(tensor::tensor_into_value(tensor))
157}
158
159fn str2double_char_array(array: CharArray) -> BuiltinResult<Value> {
160 let rows = array.rows;
161 let cols = array.cols;
162 let mut values = Vec::with_capacity(rows);
163 for row in 0..rows {
164 let start = row * cols;
165 let end = start + cols;
166 let row_text: String = array.data[start..end].iter().collect();
167 values.push(parse_numeric_scalar(&row_text));
168 }
169 let tensor = Tensor::new(values, vec![rows, 1])
170 .map_err(|_| str2double_error(&STR2DOUBLE_ERROR_INTERNAL))?;
171 Ok(tensor::tensor_into_value(tensor))
172}
173
174fn str2double_cell_array(cell: CellArray) -> BuiltinResult<Value> {
175 let CellArray {
176 data, rows, cols, ..
177 } = cell;
178 let mut values = Vec::with_capacity(rows * cols);
179 for col in 0..cols {
180 for row in 0..rows {
181 let idx = row * cols + col;
182 let element: &Value = &data[idx];
183 let numeric = match element {
184 Value::String(text) => parse_numeric_scalar(text),
185 Value::StringArray(sa) if sa.data.len() == 1 => parse_numeric_scalar(&sa.data[0]),
186 Value::CharArray(char_vec) if char_vec.rows == 1 => {
187 let row_text: String = char_vec.data.iter().collect();
188 parse_numeric_scalar(&row_text)
189 }
190 Value::CharArray(_) => {
191 return Err(str2double_error(&STR2DOUBLE_ERROR_INVALID_CELL_ELEMENT));
192 }
193 _ => return Err(str2double_error(&STR2DOUBLE_ERROR_INVALID_CELL_ELEMENT)),
194 };
195 values.push(numeric);
196 }
197 }
198 let tensor = Tensor::new(values, vec![rows, cols])
199 .map_err(|_| str2double_error(&STR2DOUBLE_ERROR_INTERNAL))?;
200 Ok(tensor::tensor_into_value(tensor))
201}
202
203fn parse_numeric_scalar(text: &str) -> f64 {
204 let trimmed = text.trim();
205 if trimmed.is_empty() {
206 return f64::NAN;
207 }
208
209 let lowered = trimmed.to_ascii_lowercase();
210 match lowered.as_str() {
211 "nan" => return f64::NAN,
212 "inf" | "+inf" | "infinity" | "+infinity" => return f64::INFINITY,
213 "-inf" | "-infinity" => return f64::NEG_INFINITY,
214 _ => {}
215 }
216
217 let normalized: Cow<'_, str> = if trimmed.chars().any(|c| c == 'd' || c == 'D') {
218 Cow::Owned(
219 trimmed
220 .chars()
221 .map(|c| if c == 'd' || c == 'D' { 'e' } else { c })
222 .collect(),
223 )
224 } else {
225 Cow::Borrowed(trimmed)
226 };
227
228 normalized.parse::<f64>().unwrap_or(f64::NAN)
229}
230
231#[cfg(test)]
232pub(crate) mod tests {
233 use super::*;
234 use runmat_builtins::{ResolveContext, Type};
235
236 fn str2double_builtin(value: Value) -> BuiltinResult<Value> {
237 futures::executor::block_on(super::str2double_builtin(value))
238 }
239
240 fn error_message(err: crate::RuntimeError) -> String {
241 err.message().to_string()
242 }
243
244 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
245 #[test]
246 fn str2double_string_scalar() {
247 let result = str2double_builtin(Value::String("42.5".into())).expect("str2double");
248 assert_eq!(result, Value::Num(42.5));
249 }
250
251 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
252 #[test]
253 fn str2double_string_scalar_invalid_returns_nan() {
254 let result = str2double_builtin(Value::String("abc".into())).expect("str2double");
255 match result {
256 Value::Num(v) => assert!(v.is_nan()),
257 other => panic!("expected scalar result, got {other:?}"),
258 }
259 }
260
261 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
262 #[test]
263 fn str2double_string_array_preserves_shape() {
264 let array =
265 StringArray::new(vec!["1".into(), " 2.5 ".into(), "foo".into()], vec![3, 1]).unwrap();
266 let result = str2double_builtin(Value::StringArray(array)).expect("str2double");
267 match result {
268 Value::Tensor(tensor) => {
269 assert_eq!(tensor.shape, vec![3, 1]);
270 assert_eq!(tensor.data[0], 1.0);
271 assert_eq!(tensor.data[1], 2.5);
272 assert!(tensor.data[2].is_nan());
273 }
274 Value::Num(_) => panic!("expected tensor"),
275 other => panic!("unexpected result {other:?}"),
276 }
277 }
278
279 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
280 #[test]
281 fn str2double_char_array_multiple_rows() {
282 let data: Vec<char> = vec!['4', '2', ' ', ' ', '1', '0', '0', ' '];
283 let array = CharArray::new(data, 2, 4).unwrap();
284 let result = str2double_builtin(Value::CharArray(array)).expect("str2double");
285 match result {
286 Value::Tensor(tensor) => {
287 assert_eq!(tensor.shape, vec![2, 1]);
288 assert_eq!(tensor.data[0], 42.0);
289 assert_eq!(tensor.data[1], 100.0);
290 }
291 other => panic!("expected tensor result, got {other:?}"),
292 }
293 }
294
295 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
296 #[test]
297 fn str2double_char_array_empty_rows() {
298 let array = CharArray::new(Vec::new(), 0, 0).unwrap();
299 let result = str2double_builtin(Value::CharArray(array)).expect("str2double");
300 match result {
301 Value::Tensor(tensor) => {
302 assert_eq!(tensor.shape, vec![0, 1]);
303 assert_eq!(tensor.data.len(), 0);
304 }
305 other => panic!("expected empty tensor, got {other:?}"),
306 }
307 }
308
309 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
310 #[test]
311 #[allow(
312 clippy::approx_constant,
313 reason = "Test ensures literal 3.14 text stays 3.14, not π"
314 )]
315 fn str2double_cell_array_of_text() {
316 let cell = CellArray::new(
317 vec![
318 Value::String("3.14".into()),
319 Value::CharArray(CharArray::new_row("NaN")),
320 Value::String("-Inf".into()),
321 ],
322 1,
323 3,
324 )
325 .unwrap();
326 let result = str2double_builtin(Value::Cell(cell)).expect("str2double");
327 match result {
328 Value::Tensor(tensor) => {
329 assert_eq!(tensor.shape, vec![1, 3]);
330 assert_eq!(tensor.data[0], 3.14);
331 assert!(tensor.data[1].is_nan());
332 assert_eq!(tensor.data[2], f64::NEG_INFINITY);
333 }
334 other => panic!("expected tensor result, got {other:?}"),
335 }
336 }
337
338 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
339 #[test]
340 fn str2double_cell_array_invalid_element_errors() {
341 let cell = CellArray::new(vec![Value::Num(5.0)], 1, 1).unwrap();
342 let err = error_message(str2double_builtin(Value::Cell(cell)).unwrap_err());
343 assert!(
344 err.contains("str2double"),
345 "unexpected error message: {err}"
346 );
347 }
348
349 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
350 #[test]
351 fn str2double_supports_d_exponent() {
352 let result = str2double_builtin(Value::String("1.5D3".into())).expect("str2double");
353 match result {
354 Value::Num(v) => assert_eq!(v, 1500.0),
355 other => panic!("expected scalar result, got {other:?}"),
356 }
357 }
358
359 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
360 #[test]
361 fn str2double_recognises_infinity_forms() {
362 let array = StringArray::new(
363 vec!["Inf".into(), "-Infinity".into(), "+inf".into()],
364 vec![3, 1],
365 )
366 .unwrap();
367 let result = str2double_builtin(Value::StringArray(array)).expect("str2double");
368 match result {
369 Value::Tensor(tensor) => {
370 assert_eq!(tensor.data[0], f64::INFINITY);
371 assert_eq!(tensor.data[1], f64::NEG_INFINITY);
372 assert_eq!(tensor.data[2], f64::INFINITY);
373 }
374 other => panic!("expected tensor result, got {other:?}"),
375 }
376 }
377
378 #[test]
379 fn str2double_type_is_numeric_text_scalar_or_tensor() {
380 assert_eq!(
381 numeric_text_scalar_or_tensor_type(&[Type::String], &ResolveContext::new(Vec::new())),
382 Type::Num
383 );
384 }
385}