runmat_runtime/builtins/common/
tensor.rs

1use std::convert::TryFrom;
2
3use runmat_builtins::{LogicalArray, NumericDType, Tensor, Value};
4
5/// Return the total number of elements for a given shape.
6pub fn element_count(shape: &[usize]) -> usize {
7    let mut acc: u128 = 1;
8    for &dim in shape {
9        let dim128 = dim as u128;
10        acc = acc
11            .checked_mul(dim128)
12            .expect("tensor::element_count: overflow computing element count");
13    }
14    usize::try_from(acc).expect("tensor::element_count: overflow converting to usize")
15}
16
17/// Construct a zero-filled tensor with the provided shape.
18pub fn zeros(shape: &[usize]) -> Result<Tensor, String> {
19    Tensor::new(vec![0.0; element_count(shape)], shape.to_vec())
20        .map_err(|e| format!("tensor zeros: {e}"))
21}
22
23/// Construct an one-filled tensor with the provided shape.
24pub fn ones(shape: &[usize]) -> Result<Tensor, String> {
25    Tensor::new(vec![1.0; element_count(shape)], shape.to_vec())
26        .map_err(|e| format!("tensor ones: {e}"))
27}
28
29/// Construct a zero-filled tensor with an explicit dtype flag.
30pub fn zeros_with_dtype(shape: &[usize], dtype: NumericDType) -> Result<Tensor, String> {
31    Tensor::new_with_dtype(vec![0.0; element_count(shape)], shape.to_vec(), dtype)
32        .map_err(|e| format!("tensor zeros: {e}"))
33}
34
35/// Construct a one-filled tensor with an explicit dtype flag.
36pub fn ones_with_dtype(shape: &[usize], dtype: NumericDType) -> Result<Tensor, String> {
37    Tensor::new_with_dtype(vec![1.0; element_count(shape)], shape.to_vec(), dtype)
38        .map_err(|e| format!("tensor ones: {e}"))
39}
40
41/// Convert a logical array (0/1 bytes) into a numeric tensor.
42pub fn logical_to_tensor(logical: &LogicalArray) -> Result<Tensor, String> {
43    let data: Vec<f64> = logical
44        .data
45        .iter()
46        .map(|&b| if b != 0 { 1.0 } else { 0.0 })
47        .collect();
48    Tensor::new(data, logical.shape.clone()).map_err(|e| format!("logical->tensor: {e}"))
49}
50
51fn value_into_tensor_impl(name: &str, value: Value) -> Result<Tensor, String> {
52    match value {
53        Value::Tensor(t) => Ok(t),
54        Value::LogicalArray(logical) => logical_to_tensor(&logical),
55        Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).map_err(|e| format!("tensor: {e}")),
56        Value::Int(i) => {
57            Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|e| format!("tensor: {e}"))
58        }
59        Value::Bool(b) => Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
60            .map_err(|e| format!("tensor: {e}")),
61        other => Err(format!(
62            "{name}: unsupported input type {:?}; expected numeric or logical values",
63            other
64        )),
65    }
66}
67
68/// Convert a `Value` into an owned `Tensor`, defaulting error messages to `"sum"`.
69pub fn value_into_tensor(value: Value) -> Result<Tensor, String> {
70    value_into_tensor_impl("sum", value)
71}
72
73/// Convert a `Value` into a tensor while customising the builtin name in error messages.
74pub fn value_into_tensor_for(name: &str, value: Value) -> Result<Tensor, String> {
75    value_into_tensor_impl(name, value)
76}
77
78/// Clone a `Value` and coerce it into a tensor.
79pub fn value_to_tensor(value: &Value) -> Result<Tensor, String> {
80    value_into_tensor(value.clone())
81}
82
83/// Convert a `Tensor` back into a runtime value.
84///
85/// Scalars (exactly one element) become `Value::Num`, all other tensors
86/// remain as dense tensor variants.
87pub fn tensor_into_value(tensor: Tensor) -> Value {
88    if tensor.data.len() == 1 {
89        Value::Num(tensor.data[0])
90    } else {
91        Value::Tensor(tensor)
92    }
93}
94
95/// Return true when a tensor contains exactly one scalar element.
96pub fn is_scalar_tensor(tensor: &Tensor) -> bool {
97    tensor.data.len() == 1
98}
99
100/// Convert an argument into a dimension index (1-based) if possible.
101pub fn parse_dimension(value: &Value, name: &str) -> Result<usize, String> {
102    match value {
103        Value::Int(i) => {
104            let raw = i.to_i64();
105            if raw < 1 {
106                return Err(format!("{name}: dimension must be >= 1"));
107            }
108            Ok(raw as usize)
109        }
110        Value::Num(n) => {
111            if !n.is_finite() {
112                return Err(format!("{name}: dimension must be finite"));
113            }
114            let rounded = n.round();
115            // Allow small floating error tolerance when users pass float-typed dims
116            if (rounded - n).abs() > 1e-6 {
117                return Err(format!("{name}: dimension must be an integer"));
118            }
119            if rounded < 1.0 {
120                return Err(format!("{name}: dimension must be >= 1"));
121            }
122            Ok(rounded as usize)
123        }
124        other => Err(format!(
125            "{name}: dimension must be numeric, got {:?}",
126            other
127        )),
128    }
129}
130
131/// Attempt to extract a string from a runtime value.
132pub fn value_to_string(value: &Value) -> Option<String> {
133    String::try_from(value).ok()
134}