Skip to main content

runmat_runtime/builtins/common/
random_args.rs

1use crate::builtins::common::tensor;
2use runmat_builtins::{ComplexTensor, Value};
3
4/// Extract a lowercased keyword from runtime values such as strings or
5/// single-row char arrays.
6pub(crate) fn keyword_of(value: &Value) -> Option<String> {
7    match value {
8        Value::String(s) => Some(s.to_ascii_lowercase()),
9        Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].to_ascii_lowercase()),
10        Value::CharArray(ca) if ca.rows == 1 => {
11            let text: String = ca.data.iter().collect();
12            Some(text.to_ascii_lowercase())
13        }
14        _ => None,
15    }
16}
17
18/// Attempt to parse a dimension argument. Returns `Ok(Some(Vec))` when the
19/// value encodes dimensions, `Ok(None)` when the value is not a dimension
20/// argument, and `Err` when the value is dimension-like but invalid.
21pub(crate) async fn extract_dims(value: &Value, label: &str) -> Result<Option<Vec<usize>>, String> {
22    if matches!(value, Value::LogicalArray(_)) {
23        return Ok(None);
24    }
25    let gpu_scalar = match value {
26        Value::GpuTensor(handle) => tensor::element_count(&handle.shape) == 1,
27        _ => false,
28    };
29    match tensor::dims_from_value_async(value).await {
30        Ok(dims) => Ok(dims),
31        Err(err) => {
32            if matches!(value, Value::Tensor(_))
33                || (matches!(value, Value::GpuTensor(_)) && !gpu_scalar)
34            {
35                Ok(None)
36            } else {
37                Err(format!("{label}: {err}"))
38            }
39        }
40    }
41}
42
43/// Determine the output shape encoded by a runtime value.
44pub(crate) fn shape_from_value(value: &Value, label: &str) -> Result<Vec<usize>, String> {
45    match value {
46        Value::Tensor(t) => Ok(t.shape.clone()),
47        Value::ComplexTensor(t) => Ok(t.shape.clone()),
48        Value::LogicalArray(l) => Ok(l.shape.clone()),
49        Value::GpuTensor(h) => Ok(h.shape.clone()),
50        Value::CharArray(ca) => Ok(vec![ca.rows, ca.cols]),
51        Value::Cell(cell) => Ok(vec![cell.rows, cell.cols]),
52        Value::Num(_)
53        | Value::Int(_)
54        | Value::Bool(_)
55        | Value::Complex(_, _)
56        | Value::String(_)
57        | Value::StringArray(_) => Ok(vec![1, 1]),
58        other => Err(format!("{label}: unsupported prototype {other:?}")),
59    }
60}
61
62/// Convert a complex tensor back into an appropriate runtime value.
63pub(crate) fn complex_tensor_into_value(tensor: ComplexTensor) -> Value {
64    if tensor.data.len() == 1 {
65        let (re, im) = tensor.data[0];
66        Value::Complex(re, im)
67    } else {
68        Value::ComplexTensor(tensor)
69    }
70}