runmat_runtime/builtins/common/
random_args.rs

1use runmat_builtins::{ComplexTensor, Tensor, Value};
2
3/// Extract a lowercased keyword from runtime values such as strings or
4/// single-row char arrays.
5pub(crate) fn keyword_of(value: &Value) -> Option<String> {
6    match value {
7        Value::String(s) => Some(s.to_ascii_lowercase()),
8        Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].to_ascii_lowercase()),
9        Value::CharArray(ca) if ca.rows == 1 => {
10            let text: String = ca.data.iter().collect();
11            Some(text.to_ascii_lowercase())
12        }
13        _ => None,
14    }
15}
16
17/// Attempt to parse a dimension argument. Returns `Ok(Some(Vec))` when the
18/// value encodes dimensions, `Ok(None)` when the value is not a dimension
19/// argument, and `Err` when the value is dimension-like but invalid.
20pub(crate) fn extract_dims(value: &Value, label: &str) -> Result<Option<Vec<usize>>, String> {
21    match value {
22        Value::Int(i) => {
23            let dim = i.to_i64();
24            if dim < 0 {
25                return Err(format!("{label}: matrix dimensions must be non-negative"));
26            }
27            Ok(Some(vec![dim as usize]))
28        }
29        Value::Num(n) => parse_numeric_dimension(label, *n).map(|d| Some(vec![d])),
30        Value::Tensor(t) => dims_from_tensor(label, t),
31        Value::LogicalArray(_) => Ok(None),
32        _ => Ok(None),
33    }
34}
35
36/// Parse a numeric dimension, ensuring it aligns with MATLAB semantics.
37pub(crate) fn parse_numeric_dimension(label: &str, n: f64) -> Result<usize, String> {
38    if !n.is_finite() {
39        return Err(format!("{label}: dimensions must be finite"));
40    }
41    if n < 0.0 {
42        return Err(format!("{label}: matrix dimensions must be non-negative"));
43    }
44    let rounded = n.round();
45    if (rounded - n).abs() > f64::EPSILON {
46        return Err(format!("{label}: dimensions must be integers"));
47    }
48    Ok(rounded as usize)
49}
50
51/// Parse dimensions from a tensor representing a size vector.
52pub(crate) fn dims_from_tensor(label: &str, tensor: &Tensor) -> Result<Option<Vec<usize>>, String> {
53    let is_row = tensor.rows() == 1;
54    let is_column = tensor.cols() == 1;
55    let is_scalar = tensor.data.len() == 1;
56    if !(is_row || is_column || is_scalar || tensor.shape.len() == 1) {
57        return Ok(None);
58    }
59    let mut dims = Vec::with_capacity(tensor.data.len());
60    for &v in &tensor.data {
61        match parse_numeric_dimension(label, v) {
62            Ok(dim) => dims.push(dim),
63            Err(_) => return Ok(None),
64        }
65    }
66    Ok(Some(dims))
67}
68
69/// Determine the output shape encoded by a runtime value.
70pub(crate) fn shape_from_value(value: &Value, label: &str) -> Result<Vec<usize>, String> {
71    match value {
72        Value::Tensor(t) => Ok(t.shape.clone()),
73        Value::ComplexTensor(t) => Ok(t.shape.clone()),
74        Value::LogicalArray(l) => Ok(l.shape.clone()),
75        Value::GpuTensor(h) => Ok(h.shape.clone()),
76        Value::CharArray(ca) => Ok(vec![ca.rows, ca.cols]),
77        Value::Cell(cell) => Ok(vec![cell.rows, cell.cols]),
78        Value::Num(_)
79        | Value::Int(_)
80        | Value::Bool(_)
81        | Value::Complex(_, _)
82        | Value::String(_)
83        | Value::StringArray(_) => Ok(vec![1, 1]),
84        other => Err(format!("{label}: unsupported prototype {other:?}")),
85    }
86}
87
88/// Convert a complex tensor back into an appropriate runtime value.
89pub(crate) fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
90    if tensor.data.len() == 1 {
91        let (re, im) = tensor.data[0];
92        Value::Complex(re, im)
93    } else {
94        Value::ComplexTensor(tensor)
95    }
96}