Skip to main content

runmat_runtime/builtins/strings/core/
num2str.rs

1//! MATLAB-compatible `num2str` builtin with GPU-aware semantics for RunMat.
2
3use regex::Regex;
4use runmat_builtins::{
5    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7    CharArray, ComplexTensor, Tensor, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::gpu_helpers;
12use crate::builtins::common::map_control_flow_with_builtin;
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18use crate::builtins::strings::type_resolvers::string_scalar_type;
19use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
20
21const DEFAULT_PRECISION: usize = 15;
22const MAX_PRECISION: usize = 52;
23
24const BUILTIN_NAME: &str = "num2str";
25
26const NUM2STR_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27    name: "txt",
28    ty: BuiltinParamType::Any,
29    arity: BuiltinParamArity::Required,
30    default: None,
31    description: "Character array containing formatted numeric values.",
32}];
33
34const NUM2STR_INPUT_A: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
35    name: "A",
36    ty: BuiltinParamType::Any,
37    arity: BuiltinParamArity::Required,
38    default: None,
39    description: "Numeric or logical scalar/vector/matrix input.",
40}];
41
42const NUM2STR_INPUT_PREC: [BuiltinParamDescriptor; 2] = [
43    BuiltinParamDescriptor {
44        name: "A",
45        ty: BuiltinParamType::Any,
46        arity: BuiltinParamArity::Required,
47        default: None,
48        description: "Numeric or logical input.",
49    },
50    BuiltinParamDescriptor {
51        name: "p",
52        ty: BuiltinParamType::IntegerScalar,
53        arity: BuiltinParamArity::Required,
54        default: Some("15"),
55        description: "General-format precision (0..52).",
56    },
57];
58
59const NUM2STR_INPUT_FORMAT: [BuiltinParamDescriptor; 2] = [
60    BuiltinParamDescriptor {
61        name: "A",
62        ty: BuiltinParamType::Any,
63        arity: BuiltinParamArity::Required,
64        default: None,
65        description: "Numeric or logical input.",
66    },
67    BuiltinParamDescriptor {
68        name: "formatSpec",
69        ty: BuiltinParamType::StringScalar,
70        arity: BuiltinParamArity::Required,
71        default: None,
72        description: "Custom format such as \"%0.3f\" or \"%.5g\".",
73    },
74];
75
76const NUM2STR_INPUT_LOCAL: [BuiltinParamDescriptor; 3] = [
77    BuiltinParamDescriptor {
78        name: "A",
79        ty: BuiltinParamType::Any,
80        arity: BuiltinParamArity::Required,
81        default: None,
82        description: "Numeric or logical input.",
83    },
84    BuiltinParamDescriptor {
85        name: "arg2",
86        ty: BuiltinParamType::Any,
87        arity: BuiltinParamArity::Optional,
88        default: None,
89        description: "Precision or format string.",
90    },
91    BuiltinParamDescriptor {
92        name: "local",
93        ty: BuiltinParamType::StringScalar,
94        arity: BuiltinParamArity::Required,
95        default: Some("\"local\""),
96        description: "Locale-aware decimal separator option.",
97    },
98];
99
100const NUM2STR_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
101    BuiltinSignatureDescriptor {
102        label: "txt = num2str(A)",
103        inputs: &NUM2STR_INPUT_A,
104        outputs: &NUM2STR_OUTPUT,
105    },
106    BuiltinSignatureDescriptor {
107        label: "txt = num2str(A, p)",
108        inputs: &NUM2STR_INPUT_PREC,
109        outputs: &NUM2STR_OUTPUT,
110    },
111    BuiltinSignatureDescriptor {
112        label: "txt = num2str(A, formatSpec)",
113        inputs: &NUM2STR_INPUT_FORMAT,
114        outputs: &NUM2STR_OUTPUT,
115    },
116    BuiltinSignatureDescriptor {
117        label: "txt = num2str(A, arg2, \"local\")",
118        inputs: &NUM2STR_INPUT_LOCAL,
119        outputs: &NUM2STR_OUTPUT,
120    },
121];
122
123const NUM2STR_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
124    code: "RM.NUM2STR.INVALID_INPUT",
125    identifier: Some("RunMat:num2str:InvalidInput"),
126    when: "Input value is not a supported numeric/logical scalar, vector, or matrix.",
127    message: "num2str: unsupported input type",
128};
129
130const NUM2STR_ERROR_INVALID_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
131    code: "RM.NUM2STR.INVALID_OPTION",
132    identifier: Some("RunMat:num2str:InvalidOption"),
133    when: "Optional arguments are malformed or too many were supplied.",
134    message: "num2str: invalid option arguments",
135};
136
137const NUM2STR_ERROR_INVALID_PRECISION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
138    code: "RM.NUM2STR.INVALID_PRECISION",
139    identifier: Some("RunMat:num2str:InvalidPrecision"),
140    when: "Precision argument is non-finite, non-integer, or out of range.",
141    message: "num2str: invalid precision",
142};
143
144const NUM2STR_ERROR_INVALID_FORMAT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
145    code: "RM.NUM2STR.INVALID_FORMAT",
146    identifier: Some("RunMat:num2str:InvalidFormat"),
147    when: "Custom format string is unsupported or malformed.",
148    message: "num2str: unsupported format string",
149};
150
151const NUM2STR_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
152    code: "RM.NUM2STR.INTERNAL",
153    identifier: Some("RunMat:num2str:InternalError"),
154    when: "Internal char-array assembly failed.",
155    message: "num2str: internal error",
156};
157
158const NUM2STR_ERRORS: [BuiltinErrorDescriptor; 5] = [
159    NUM2STR_ERROR_INVALID_INPUT,
160    NUM2STR_ERROR_INVALID_OPTION,
161    NUM2STR_ERROR_INVALID_PRECISION,
162    NUM2STR_ERROR_INVALID_FORMAT,
163    NUM2STR_ERROR_INTERNAL,
164];
165
166pub const NUM2STR_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
167    signatures: &NUM2STR_SIGNATURES,
168    output_mode: BuiltinOutputMode::Fixed,
169    completion_policy: BuiltinCompletionPolicy::Public,
170    errors: &NUM2STR_ERRORS,
171};
172
173fn num2str_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
174    num2str_error_with_message(error.message, error)
175}
176
177fn num2str_error_with_message(
178    message: impl Into<String>,
179    error: &'static BuiltinErrorDescriptor,
180) -> RuntimeError {
181    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
182    if let Some(identifier) = error.identifier {
183        builder = builder.with_identifier(identifier);
184    }
185    builder.build()
186}
187
188fn remap_num2str_flow(err: RuntimeError) -> RuntimeError {
189    map_control_flow_with_builtin(err, BUILTIN_NAME)
190}
191
192#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::num2str")]
193pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
194    name: "num2str",
195    op_kind: GpuOpKind::Custom("conversion"),
196    supported_precisions: &[],
197    broadcast: BroadcastSemantics::None,
198    provider_hooks: &[],
199    constant_strategy: ConstantStrategy::InlineLiteral,
200    residency: ResidencyPolicy::GatherImmediately,
201    nan_mode: ReductionNaN::Include,
202    two_pass_threshold: None,
203    workgroup_size: None,
204    accepts_nan_mode: false,
205    notes: "Always gathers GPU data to host memory before formatting numeric text.",
206};
207
208#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::num2str")]
209pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
210    name: "num2str",
211    shape: ShapeRequirements::Any,
212    constant_strategy: ConstantStrategy::InlineLiteral,
213    elementwise: None,
214    reduction: None,
215    emits_nan: false,
216    notes:
217        "Conversion builtin; not eligible for fusion and always materialises host character arrays.",
218};
219
220#[runtime_builtin(
221    name = "num2str",
222    category = "strings/core",
223    summary = "Convert numeric values to character arrays.",
224    keywords = "num2str,number to string,format,precision",
225    examples = "txt = num2str([1 2 3]);",
226    type_resolver(string_scalar_type),
227    descriptor(crate::builtins::strings::core::num2str::NUM2STR_DESCRIPTOR),
228    builtin_path = "crate::builtins::strings::core::num2str"
229)]
230async fn num2str_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
231    let gathered = gather_if_needed_async(&value)
232        .await
233        .map_err(remap_num2str_flow)?;
234    let data = extract_numeric_data(gathered).await?;
235
236    let options = parse_options(rest).await?;
237    let char_array = format_numeric_data(data, &options)?;
238    Ok(Value::CharArray(char_array))
239}
240
241struct FormatOptions {
242    spec: FormatSpec,
243    decimal: char,
244}
245
246#[derive(Clone)]
247enum FormatSpec {
248    General { digits: usize },
249    Custom(CustomFormat),
250}
251
252#[derive(Clone)]
253struct CustomFormat {
254    kind: CustomKind,
255    width: Option<usize>,
256    precision: Option<usize>,
257    sign_always: bool,
258    left_align: bool,
259    zero_pad: bool,
260    uppercase: bool,
261}
262
263#[derive(Clone, Copy, PartialEq, Eq)]
264enum CustomKind {
265    Fixed,
266    Exponent,
267    General,
268}
269
270enum NumericData {
271    Real {
272        data: Vec<f64>,
273        rows: usize,
274        cols: usize,
275    },
276    Complex {
277        data: Vec<(f64, f64)>,
278        rows: usize,
279        cols: usize,
280    },
281}
282
283async fn parse_options(args: Vec<Value>) -> BuiltinResult<FormatOptions> {
284    if args.is_empty() {
285        return Ok(FormatOptions {
286            spec: FormatSpec::General {
287                digits: DEFAULT_PRECISION,
288            },
289            decimal: '.',
290        });
291    }
292
293    let mut gathered = Vec::with_capacity(args.len());
294    for arg in args {
295        gathered.push(
296            gather_if_needed_async(&arg)
297                .await
298                .map_err(remap_num2str_flow)?,
299        );
300    }
301
302    let mut iter = gathered.into_iter();
303    let mut spec = FormatSpec::General {
304        digits: DEFAULT_PRECISION,
305    };
306    let mut decimal = '.';
307
308    if let Some(first) = iter.next() {
309        if is_local_token(&first)? {
310            decimal = detect_decimal_separator(true);
311            if iter.next().is_some() {
312                return Err(num2str_error_with_message(
313                    "num2str: too many input arguments",
314                    &NUM2STR_ERROR_INVALID_OPTION,
315                ));
316            }
317            return Ok(FormatOptions { spec, decimal });
318        }
319
320        spec = if let Some(digits) = try_extract_precision(&first)? {
321            FormatSpec::General { digits }
322        } else if let Some(text) = value_to_text(&first) {
323            FormatSpec::Custom(parse_custom_format(&text)?)
324        } else {
325            return Err(num2str_error_with_message(
326                "num2str: second argument must be a precision or format string",
327                &NUM2STR_ERROR_INVALID_OPTION,
328            ));
329        };
330    }
331
332    if let Some(second) = iter.next() {
333        if !is_local_token(&second)? {
334            return Err(num2str_error_with_message(
335                "num2str: expected 'local' as the third argument",
336                &NUM2STR_ERROR_INVALID_OPTION,
337            ));
338        }
339        decimal = detect_decimal_separator(true);
340    }
341
342    if iter.next().is_some() {
343        return Err(num2str_error_with_message(
344            "num2str: too many input arguments",
345            &NUM2STR_ERROR_INVALID_OPTION,
346        ));
347    }
348
349    Ok(FormatOptions { spec, decimal })
350}
351
352fn is_local_token(value: &Value) -> BuiltinResult<bool> {
353    let Some(text) = value_to_text(value) else {
354        return Ok(false);
355    };
356    Ok(text.trim().eq_ignore_ascii_case("local"))
357}
358
359fn try_extract_precision(value: &Value) -> BuiltinResult<Option<usize>> {
360    match value {
361        Value::Int(i) => {
362            let digits = i.to_i64();
363            validate_precision(digits)?;
364            Ok(Some(digits as usize))
365        }
366        Value::Num(n) => {
367            if !n.is_finite() {
368                return Err(num2str_error_with_message(
369                    "num2str: precision must be finite",
370                    &NUM2STR_ERROR_INVALID_PRECISION,
371                ));
372            }
373            let rounded = n.round();
374            if (rounded - n).abs() > f64::EPSILON {
375                return Err(num2str_error_with_message(
376                    "num2str: precision must be an integer",
377                    &NUM2STR_ERROR_INVALID_PRECISION,
378                ));
379            }
380            validate_precision(rounded as i64)?;
381            Ok(Some(rounded as usize))
382        }
383        Value::Tensor(t) if t.data.len() == 1 => {
384            let value = t.data[0];
385            if !value.is_finite() {
386                return Err(num2str_error_with_message(
387                    "num2str: precision must be finite",
388                    &NUM2STR_ERROR_INVALID_PRECISION,
389                ));
390            }
391            let rounded = value.round();
392            if (rounded - value).abs() > f64::EPSILON {
393                return Err(num2str_error_with_message(
394                    "num2str: precision must be an integer",
395                    &NUM2STR_ERROR_INVALID_PRECISION,
396                ));
397            }
398            validate_precision(rounded as i64)?;
399            Ok(Some(rounded as usize))
400        }
401        Value::LogicalArray(la) if la.data.len() == 1 => {
402            let digits = if la.data[0] != 0 { 1 } else { 0 };
403            validate_precision(digits)?;
404            Ok(Some(digits as usize))
405        }
406        Value::Bool(b) => {
407            let digits = if *b { 1 } else { 0 };
408            Ok(Some(digits))
409        }
410        _ => Ok(None),
411    }
412}
413
414fn validate_precision(value: i64) -> BuiltinResult<()> {
415    if value < 0 || value > MAX_PRECISION as i64 {
416        return Err(num2str_error_with_message(
417            format!("num2str: precision must satisfy 0 <= p <= {MAX_PRECISION}"),
418            &NUM2STR_ERROR_INVALID_PRECISION,
419        ));
420    }
421    Ok(())
422}
423
424fn value_to_text(value: &Value) -> Option<String> {
425    match value {
426        Value::String(s) => Some(s.clone()),
427        Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
428        Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
429        _ => None,
430    }
431}
432
433fn detect_decimal_separator(local: bool) -> char {
434    if !local {
435        return '.';
436    }
437
438    if let Ok(custom) = std::env::var("RUNMAT_DECIMAL_SEPARATOR") {
439        let trimmed = custom.trim();
440        if let Some(ch) = trimmed.chars().next() {
441            return ch;
442        }
443    }
444
445    let locale = std::env::var("LC_NUMERIC")
446        .or_else(|_| std::env::var("RUNMAT_LOCALE"))
447        .or_else(|_| std::env::var("LANG"))
448        .unwrap_or_default()
449        .to_lowercase();
450
451    if locale.is_empty() {
452        return '.';
453    }
454
455    let comma_locales = [
456        "af", "bs", "ca", "cs", "da", "de", "el", "es", "eu", "fi", "fr", "gl", "hr", "hu", "id",
457        "is", "it", "lt", "lv", "nb", "nl", "pl", "pt", "ro", "ru", "sk", "sl", "sr", "sv", "tr",
458        "uk", "vi",
459    ];
460    let locale_prefix = locale.split(['.', '_', '@']).next().unwrap_or(&locale);
461    for prefix in &comma_locales {
462        if locale_prefix.starts_with(prefix) {
463            return ',';
464        }
465    }
466    '.'
467}
468
469fn parse_custom_format(text: &str) -> BuiltinResult<CustomFormat> {
470    if !text.starts_with('%') {
471        return Err(num2str_error_with_message(
472            "num2str: format must start with '%'",
473            &NUM2STR_ERROR_INVALID_FORMAT,
474        ));
475    }
476    if text == "%%" {
477        return Err(num2str_error_with_message(
478            "num2str: '%' escape is not supported for numeric conversion",
479            &NUM2STR_ERROR_INVALID_FORMAT,
480        ));
481    }
482
483    static FORMAT_RE: once_cell::sync::Lazy<Regex> = once_cell::sync::Lazy::new(|| {
484        Regex::new(r"^%([+\-0]*)(\d+)?(?:\.(\d*))?([fFeEgG])$").expect("format regex")
485    });
486
487    let captures = FORMAT_RE.captures(text).ok_or_else(|| {
488        num2str_error_with_message(
489            format!(
490                "{}; expected variants like '%0.3f' or '%.5g'",
491                NUM2STR_ERROR_INVALID_FORMAT.message
492            ),
493            &NUM2STR_ERROR_INVALID_FORMAT,
494        )
495    })?;
496
497    let flags = captures.get(1).map(|m| m.as_str()).unwrap_or("");
498    let width = captures
499        .get(2)
500        .map(|m| m.as_str().parse::<usize>().expect("width parse"));
501    let precision = captures.get(3).map(|m| {
502        if m.as_str().is_empty() {
503            0usize
504        } else {
505            m.as_str().parse::<usize>().expect("precision parse")
506        }
507    });
508    let conversion = captures
509        .get(4)
510        .map(|m| m.as_str().chars().next().unwrap())
511        .unwrap();
512
513    let mut sign_always = false;
514    let mut left_align = false;
515    let mut zero_pad = false;
516
517    for ch in flags.chars() {
518        match ch {
519            '+' => sign_always = true,
520            '-' => left_align = true,
521            '0' => zero_pad = true,
522            _ => {
523                return Err(num2str_error_with_message(
524                    format!(
525                    "num2str: unsupported format flag '{}'; only '+', '-', and '0' are supported",
526                    ch
527                ),
528                    &NUM2STR_ERROR_INVALID_FORMAT,
529                ))
530            }
531        }
532    }
533
534    if let Some(p) = precision {
535        if p > MAX_PRECISION {
536            return Err(num2str_error_with_message(
537                format!("num2str: precision must satisfy 0 <= p <= {MAX_PRECISION}"),
538                &NUM2STR_ERROR_INVALID_PRECISION,
539            ));
540        }
541    }
542
543    let (kind, uppercase) = match conversion {
544        'f' => (CustomKind::Fixed, false),
545        'F' => (CustomKind::Fixed, true),
546        'e' => (CustomKind::Exponent, false),
547        'E' => (CustomKind::Exponent, true),
548        'g' => (CustomKind::General, false),
549        'G' => (CustomKind::General, true),
550        _ => unreachable!(),
551    };
552
553    Ok(CustomFormat {
554        kind,
555        width,
556        precision,
557        sign_always,
558        left_align,
559        zero_pad,
560        uppercase,
561    })
562}
563
564async fn extract_numeric_data(value: Value) -> BuiltinResult<NumericData> {
565    match value {
566        Value::Num(n) => Ok(NumericData::Real {
567            data: vec![n],
568            rows: 1,
569            cols: 1,
570        }),
571        Value::Int(i) => Ok(NumericData::Real {
572            data: vec![i.to_f64()],
573            rows: 1,
574            cols: 1,
575        }),
576        Value::Bool(b) => Ok(NumericData::Real {
577            data: vec![if b { 1.0 } else { 0.0 }],
578            rows: 1,
579            cols: 1,
580        }),
581        Value::Tensor(t) => tensor_to_numeric_data(t),
582        Value::LogicalArray(la) => {
583            let tensor = tensor::logical_to_tensor(&la)
584                .map_err(|_| num2str_error(&NUM2STR_ERROR_INVALID_INPUT))?;
585            tensor_to_numeric_data(tensor)
586        }
587        Value::Complex(re, im) => Ok(NumericData::Complex {
588            data: vec![(re, im)],
589            rows: 1,
590            cols: 1,
591        }),
592        Value::ComplexTensor(t) => complex_tensor_to_data(t),
593        Value::GpuTensor(handle) => {
594            let gathered = gpu_helpers::gather_tensor_async(&handle)
595                .await
596                .map_err(remap_num2str_flow)?;
597            tensor_to_numeric_data(gathered)
598        }
599        other => Err(num2str_error_with_message(
600            format!(
601                "{} {:?}; expected numeric or logical values",
602                NUM2STR_ERROR_INVALID_INPUT.message, other
603            ),
604            &NUM2STR_ERROR_INVALID_INPUT,
605        )),
606    }
607}
608
609fn tensor_to_numeric_data(tensor: Tensor) -> BuiltinResult<NumericData> {
610    if tensor.shape.len() > 2 {
611        return Err(num2str_error_with_message(
612            "num2str: input must be scalar, vector, or 2-D matrix",
613            &NUM2STR_ERROR_INVALID_INPUT,
614        ));
615    }
616    let rows = tensor.rows();
617    let cols = tensor.cols();
618    if rows == 0 || cols == 0 {
619        return Ok(NumericData::Real {
620            data: tensor.data,
621            rows,
622            cols,
623        });
624    }
625    Ok(NumericData::Real {
626        data: tensor.data,
627        rows,
628        cols,
629    })
630}
631
632fn complex_tensor_to_data(tensor: ComplexTensor) -> BuiltinResult<NumericData> {
633    if tensor.shape.len() > 2 {
634        return Err(num2str_error_with_message(
635            "num2str: complex input must be scalar, vector, or 2-D matrix",
636            &NUM2STR_ERROR_INVALID_INPUT,
637        ));
638    }
639    let rows = tensor.rows;
640    let cols = tensor.cols;
641    Ok(NumericData::Complex {
642        data: tensor.data,
643        rows,
644        cols,
645    })
646}
647
648#[derive(Clone)]
649struct CellEntry {
650    text: String,
651    width: usize,
652}
653
654fn format_numeric_data(data: NumericData, options: &FormatOptions) -> BuiltinResult<CharArray> {
655    match data {
656        NumericData::Real { data, rows, cols } => format_real_matrix(&data, rows, cols, options),
657        NumericData::Complex { data, rows, cols } => {
658            format_complex_matrix(&data, rows, cols, options)
659        }
660    }
661}
662
663fn format_real_matrix(
664    data: &[f64],
665    rows: usize,
666    cols: usize,
667    options: &FormatOptions,
668) -> BuiltinResult<CharArray> {
669    if rows == 0 {
670        return CharArray::new(Vec::new(), 0, 0)
671            .map_err(|_| num2str_error(&NUM2STR_ERROR_INTERNAL));
672    }
673    if cols == 0 {
674        return CharArray::new(Vec::new(), rows, 0)
675            .map_err(|_| num2str_error(&NUM2STR_ERROR_INTERNAL));
676    }
677
678    let mut entries = vec![
679        vec![
680            CellEntry {
681                text: String::new(),
682                width: 0
683            };
684            cols
685        ];
686        rows
687    ];
688    let mut col_widths = vec![0usize; cols];
689
690    for (col, width) in col_widths.iter_mut().enumerate() {
691        for (row, row_entries) in entries.iter_mut().enumerate() {
692            let idx = row + col * rows;
693            let value = data.get(idx).copied().unwrap_or(0.0);
694            let text = format_real(value, &options.spec, options.decimal);
695            let entry_width = text.chars().count();
696            row_entries[col] = CellEntry {
697                text,
698                width: entry_width,
699            };
700            if entry_width > *width {
701                *width = entry_width;
702            }
703        }
704    }
705
706    if cols > 1 {
707        for (idx, width) in col_widths.iter_mut().enumerate() {
708            if idx > 0 {
709                *width += 1;
710            }
711        }
712    }
713
714    let rows_str = assemble_rows(entries, col_widths);
715    rows_to_char_array(rows_str)
716}
717
718fn format_complex_matrix(
719    data: &[(f64, f64)],
720    rows: usize,
721    cols: usize,
722    options: &FormatOptions,
723) -> BuiltinResult<CharArray> {
724    if rows == 0 {
725        return CharArray::new(Vec::new(), 0, 0)
726            .map_err(|_| num2str_error(&NUM2STR_ERROR_INTERNAL));
727    }
728    if cols == 0 {
729        return CharArray::new(Vec::new(), rows, 0)
730            .map_err(|_| num2str_error(&NUM2STR_ERROR_INTERNAL));
731    }
732
733    let mut entries = vec![
734        vec![
735            CellEntry {
736                text: String::new(),
737                width: 0
738            };
739            cols
740        ];
741        rows
742    ];
743    let mut col_widths = vec![0usize; cols];
744
745    for (col, width) in col_widths.iter_mut().enumerate() {
746        for (row, row_entries) in entries.iter_mut().enumerate() {
747            let idx = row + col * rows;
748            let (re, im) = data.get(idx).copied().unwrap_or((0.0, 0.0));
749            let text = format_complex(re, im, &options.spec, options.decimal);
750            let entry_width = text.chars().count();
751            row_entries[col] = CellEntry {
752                text,
753                width: entry_width,
754            };
755            if entry_width > *width {
756                *width = entry_width;
757            }
758        }
759    }
760
761    if cols > 1 {
762        for (idx, width) in col_widths.iter_mut().enumerate() {
763            if idx > 0 {
764                *width += 1;
765            }
766        }
767    }
768
769    let rows_str = assemble_rows(entries, col_widths);
770    rows_to_char_array(rows_str)
771}
772
773fn assemble_rows(entries: Vec<Vec<CellEntry>>, col_widths: Vec<usize>) -> Vec<String> {
774    entries
775        .into_iter()
776        .map(|row_entries| {
777            row_entries
778                .into_iter()
779                .enumerate()
780                .fold(String::new(), |mut acc, (col, entry)| {
781                    if col > 0 {
782                        acc.push(' ');
783                    }
784                    let target = col_widths[col];
785                    let pad = target.saturating_sub(entry.width);
786                    acc.extend(std::iter::repeat_n(' ', pad));
787                    acc.push_str(&entry.text);
788                    acc
789                })
790        })
791        .collect()
792}
793
794fn rows_to_char_array(rows: Vec<String>) -> BuiltinResult<CharArray> {
795    if rows.is_empty() {
796        return CharArray::new(Vec::new(), 0, 0)
797            .map_err(|_| num2str_error(&NUM2STR_ERROR_INTERNAL));
798    }
799    let row_count = rows.len();
800    let col_count = rows
801        .iter()
802        .map(|row| row.chars().count())
803        .max()
804        .unwrap_or(0);
805
806    let mut data = Vec::with_capacity(row_count * col_count);
807    for row in rows {
808        let mut chars: Vec<char> = row.chars().collect();
809        if chars.len() < col_count {
810            chars.extend(std::iter::repeat_n(' ', col_count - chars.len()));
811        }
812        data.extend(chars);
813    }
814
815    CharArray::new(data, row_count, col_count).map_err(|_| num2str_error(&NUM2STR_ERROR_INTERNAL))
816}
817
818fn format_real(value: f64, spec: &FormatSpec, decimal: char) -> String {
819    let text = match spec {
820        FormatSpec::General { digits } => format_general(value, *digits, false),
821        FormatSpec::Custom(custom) => format_custom(value, custom),
822    };
823    apply_decimal_locale(text, decimal)
824}
825
826fn format_complex(re: f64, im: f64, spec: &FormatSpec, decimal: char) -> String {
827    let real_str = format_real(re, spec, decimal);
828    let imag_sign = if im.is_sign_negative() { '-' } else { '+' };
829    let abs_im = if im == 0.0 { 0.0 } else { im.abs() };
830    let imag_str = format_real(abs_im, spec, decimal);
831
832    if abs_im == 0.0 && !im.is_nan() {
833        return real_str;
834    }
835
836    if re == 0.0 && !re.is_sign_negative() && !re.is_nan() {
837        if im.is_sign_negative() && !im.is_nan() {
838            return format!(
839                "{}i",
840                if imag_str.starts_with('-') {
841                    imag_str.clone()
842                } else {
843                    format!("-{imag_str}")
844                }
845            );
846        }
847        return format!("{imag_str}i");
848    }
849
850    format!("{real_str} {imag_sign} {imag_str}i")
851}
852
853fn format_general(value: f64, digits: usize, uppercase: bool) -> String {
854    if value.is_nan() {
855        return "NaN".to_string();
856    }
857    if value.is_infinite() {
858        return if value.is_sign_negative() {
859            "-Inf".to_string()
860        } else {
861            "Inf".to_string()
862        };
863    }
864    if value == 0.0 {
865        return "0".to_string();
866    }
867
868    let sig_digits = digits.max(1);
869    let abs_val = value.abs();
870    let exp10 = abs_val.log10().floor() as i32;
871    let use_scientific = exp10 < -4 || exp10 >= sig_digits as i32;
872
873    if use_scientific {
874        let precision = sig_digits.saturating_sub(1);
875        let s = if uppercase {
876            format!("{:.*E}", precision, value)
877        } else {
878            format!("{:.*e}", precision, value)
879        };
880        let marker = if uppercase { 'E' } else { 'e' };
881        if let Some(idx) = s.find(marker) {
882            let (mantissa, exponent) = s.split_at(idx);
883            let mut mant = mantissa.to_string();
884            trim_trailing_zeros(&mut mant);
885            normalize_negative_zero(&mut mant);
886            let mut result = mant;
887            result.push_str(exponent);
888            return result;
889        }
890        s
891    } else {
892        let decimals = if sig_digits as i32 - 1 - exp10 < 0 {
893            0
894        } else {
895            (sig_digits as i32 - 1 - exp10) as usize
896        };
897        let mut s = format!("{:.*}", decimals, value);
898        trim_trailing_zeros(&mut s);
899        normalize_negative_zero(&mut s);
900        s
901    }
902}
903
904fn trim_trailing_zeros(text: &mut String) {
905    if let Some(dot_pos) = text.find('.') {
906        let mut end = text.len();
907        while end > dot_pos + 1 && text.as_bytes()[end - 1] == b'0' {
908            end -= 1;
909        }
910        if end > dot_pos && text.as_bytes()[end - 1] == b'.' {
911            end -= 1;
912        }
913        text.truncate(end);
914    }
915}
916
917fn normalize_negative_zero(text: &mut String) {
918    if text.starts_with('-') && text.chars().skip(1).all(|ch| ch == '0') {
919        *text = "0".to_string();
920    }
921}
922
923fn format_custom(value: f64, fmt: &CustomFormat) -> String {
924    if value.is_nan() {
925        return "NaN".to_string();
926    }
927    if value.is_infinite() {
928        return if value.is_sign_negative() {
929            "-Inf".to_string()
930        } else {
931            "Inf".to_string()
932        };
933    }
934
935    let precision = fmt.precision.unwrap_or(match fmt.kind {
936        CustomKind::Fixed | CustomKind::Exponent => 6,
937        CustomKind::General => DEFAULT_PRECISION,
938    });
939
940    let mut text = match fmt.kind {
941        CustomKind::Fixed => format!("{:.*}", precision, value),
942        CustomKind::Exponent => {
943            let mut s = format!("{:.*e}", precision, value);
944            if fmt.uppercase {
945                s = s.to_uppercase();
946            }
947            s
948        }
949        CustomKind::General => format_general(value, precision.max(1), fmt.uppercase),
950    };
951
952    if fmt.kind != CustomKind::Fixed {
953        trim_trailing_zeros(&mut text);
954        normalize_negative_zero(&mut text);
955    }
956
957    apply_format_flags(text, fmt)
958}
959
960fn apply_decimal_locale(text: String, decimal: char) -> String {
961    if decimal == '.' {
962        return text;
963    }
964    let mut replaced = false;
965    text.chars()
966        .map(|ch| {
967            if ch == '.' && !replaced {
968                replaced = true;
969                decimal
970            } else {
971                ch
972            }
973        })
974        .collect()
975}
976
977fn apply_format_flags(mut text: String, fmt: &CustomFormat) -> String {
978    if fmt.sign_always && !text.starts_with('-') && !text.starts_with('+') && text != "NaN" {
979        text.insert(0, '+');
980    }
981
982    let width = fmt.width.unwrap_or(0);
983    if width == 0 {
984        return text;
985    }
986
987    let len = text.chars().count();
988    if len >= width {
989        return text;
990    }
991
992    let pad_count = width - len;
993    let pad_char = if fmt.zero_pad && !fmt.left_align {
994        '0'
995    } else {
996        ' '
997    };
998
999    if fmt.left_align {
1000        let mut result = text.clone();
1001        result.extend(std::iter::repeat_n(' ', pad_count));
1002        return result;
1003    }
1004
1005    if pad_char == '0' && (text.starts_with('+') || text.starts_with('-')) {
1006        let mut chars = text.chars();
1007        let sign = chars.next().unwrap();
1008        let remainder: String = chars.collect();
1009        let mut result = String::with_capacity(width);
1010        result.push(sign);
1011        result.extend(std::iter::repeat_n('0', pad_count));
1012        result.push_str(&remainder);
1013        return result;
1014    }
1015
1016    let mut result = String::with_capacity(width);
1017    result.extend(std::iter::repeat_n(' ', pad_count));
1018    result.push_str(&text);
1019    result
1020}
1021
1022#[cfg(test)]
1023pub(crate) mod tests {
1024    use super::*;
1025    use crate::builtins::common::test_support;
1026    use runmat_builtins::{ResolveContext, Type};
1027
1028    fn num2str_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
1029        futures::executor::block_on(super::num2str_builtin(value, rest))
1030    }
1031    use runmat_builtins::{IntValue, LogicalArray, Tensor};
1032
1033    fn error_message(err: crate::RuntimeError) -> String {
1034        err.message().to_string()
1035    }
1036
1037    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1038    #[test]
1039    fn num2str_scalar_default_precision() {
1040        let value = Value::Num(std::f64::consts::PI);
1041        let out = num2str_builtin(value, Vec::new()).expect("num2str");
1042        match out {
1043            Value::CharArray(ca) => {
1044                let text: String = ca.data.iter().collect();
1045                assert_eq!(ca.rows, 1);
1046                assert!(text.starts_with("3.1415926535897"));
1047            }
1048            other => panic!("expected char array, got {other:?}"),
1049        }
1050    }
1051
1052    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1053    #[test]
1054    fn num2str_precision_argument() {
1055        let value = Value::Num(std::f64::consts::PI);
1056        let out = num2str_builtin(value, vec![Value::Int(IntValue::I32(4))]).expect("num2str");
1057        match out {
1058            Value::CharArray(ca) => {
1059                let text: String = ca.data.iter().collect();
1060                assert_eq!(text.trim(), "3.142");
1061            }
1062            other => panic!("expected char array, got {other:?}"),
1063        }
1064    }
1065
1066    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1067    #[test]
1068    fn num2str_matrix_alignment() {
1069        let tensor =
1070            Tensor::new(vec![1.0, 78.0, 23.0, 9.0, 456.0, 10.0], vec![2, 3]).expect("tensor");
1071        let out = num2str_builtin(Value::Tensor(tensor), Vec::new()).expect("num2str");
1072        match out {
1073            Value::CharArray(ca) => {
1074                assert_eq!(ca.rows, 2);
1075                assert_eq!(ca.cols, 11);
1076                let rows: Vec<String> = ca
1077                    .data
1078                    .chunks(ca.cols)
1079                    .map(|chunk| chunk.iter().collect())
1080                    .collect();
1081                assert_eq!(rows[0], " 1  23  456");
1082                assert_eq!(rows[1], "78   9   10");
1083            }
1084            other => panic!("expected char array, got {other:?}"),
1085        }
1086    }
1087
1088    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1089    #[test]
1090    fn num2str_custom_format() {
1091        let tensor = Tensor::new(vec![1.234, 5.678], vec![1, 2]).expect("tensor");
1092        let fmt = Value::String("%.2f".to_string());
1093        let out = num2str_builtin(Value::Tensor(tensor), vec![fmt]).expect("num2str");
1094        match out {
1095            Value::CharArray(ca) => {
1096                let text: String = ca.data.iter().collect();
1097                assert_eq!(text, "1.23  5.68");
1098            }
1099            other => panic!("expected char array, got {other:?}"),
1100        }
1101    }
1102
1103    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1104    #[test]
1105    fn num2str_complex_values() {
1106        let complex = ComplexTensor::new(vec![(3.0, 4.0), (5.0, -6.0)], vec![1, 2]).expect("cplx");
1107        let out = num2str_builtin(Value::ComplexTensor(complex), Vec::new()).expect("num2str");
1108        match out {
1109            Value::CharArray(ca) => {
1110                let text: String = ca.data.iter().collect();
1111                assert_eq!(text, "3 + 4i  5 - 6i");
1112            }
1113            other => panic!("expected char array, got {other:?}"),
1114        }
1115    }
1116
1117    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1118    #[test]
1119    fn num2str_local_decimal() {
1120        std::env::set_var("RUNMAT_DECIMAL_SEPARATOR", ",");
1121        let out =
1122            num2str_builtin(Value::Num(0.5), vec![Value::String("local".into())]).expect("num2str");
1123        std::env::remove_var("RUNMAT_DECIMAL_SEPARATOR");
1124        match out {
1125            Value::CharArray(ca) => {
1126                let text: String = ca.data.iter().collect();
1127                assert_eq!(text, "0,5");
1128            }
1129            other => panic!("expected char array, got {other:?}"),
1130        }
1131    }
1132
1133    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1134    #[test]
1135    fn num2str_logical_array() {
1136        let logical = LogicalArray::new(vec![1, 0, 1], vec![1, 3]).expect("logical");
1137        let out = num2str_builtin(Value::LogicalArray(logical), Vec::new()).expect("num2str");
1138        match out {
1139            Value::CharArray(ca) => {
1140                let text: String = ca.data.iter().collect();
1141                assert_eq!(text, "1  0  1");
1142            }
1143            other => panic!("expected char array, got {other:?}"),
1144        }
1145    }
1146
1147    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1148    #[test]
1149    fn num2str_gpu_tensor_roundtrip() {
1150        test_support::with_test_provider(|provider| {
1151            let tensor = Tensor::new(vec![10.5, 20.5], vec![1, 2]).expect("tensor");
1152            let view = runmat_accelerate_api::HostTensorView {
1153                data: &tensor.data,
1154                shape: &tensor.shape,
1155            };
1156            let handle = provider.upload(&view).expect("upload");
1157            let out = num2str_builtin(Value::GpuTensor(handle), vec![Value::String("%.1f".into())])
1158                .expect("num2str");
1159            match out {
1160                Value::CharArray(ca) => {
1161                    let text: String = ca.data.iter().collect();
1162                    assert_eq!(text, "10.5  20.5");
1163                }
1164                other => panic!("expected char array, got {other:?}"),
1165            }
1166        });
1167    }
1168
1169    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1170    #[test]
1171    fn num2str_invalid_input_type() {
1172        let err =
1173            error_message(num2str_builtin(Value::String("hello".into()), Vec::new()).unwrap_err());
1174        assert!(err.contains("unsupported input type"));
1175    }
1176
1177    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1178    #[test]
1179    fn num2str_invalid_format_string() {
1180        let err = error_message(
1181            num2str_builtin(Value::Num(1.0), vec![Value::String("%q".into())]).unwrap_err(),
1182        );
1183        assert!(err.contains("unsupported format string"));
1184    }
1185
1186    #[test]
1187    fn num2str_type_is_string_scalar() {
1188        assert_eq!(
1189            string_scalar_type(&[Type::Num], &ResolveContext::new(Vec::new())),
1190            Type::String
1191        );
1192    }
1193}