Skip to main content

runmat_runtime/builtins/common/
tensor.rs

1use std::convert::TryFrom;
2
3use runmat_builtins::{LogicalArray, NumericDType, Tensor, Value};
4
5use crate::dispatcher::gather_if_needed_async;
6
7/// Return the total number of elements for a given shape.
8pub fn element_count(shape: &[usize]) -> usize {
9    let mut acc: u128 = 1;
10    for &dim in shape {
11        let dim128 = dim as u128;
12        acc = acc
13            .checked_mul(dim128)
14            .expect("tensor::element_count: overflow computing element count");
15    }
16    usize::try_from(acc).expect("tensor::element_count: overflow converting to usize")
17}
18
19/// Construct a zero-filled tensor with the provided shape.
20pub fn zeros(shape: &[usize]) -> Result<Tensor, String> {
21    Tensor::new(vec![0.0; element_count(shape)], shape.to_vec())
22        .map_err(|e| format!("tensor zeros: {e}"))
23}
24
25/// Construct an one-filled tensor with the provided shape.
26pub fn ones(shape: &[usize]) -> Result<Tensor, String> {
27    Tensor::new(vec![1.0; element_count(shape)], shape.to_vec())
28        .map_err(|e| format!("tensor ones: {e}"))
29}
30
31/// Construct a zero-filled tensor with an explicit dtype flag.
32pub fn zeros_with_dtype(shape: &[usize], dtype: NumericDType) -> Result<Tensor, String> {
33    Tensor::new_with_dtype(vec![0.0; element_count(shape)], shape.to_vec(), dtype)
34        .map_err(|e| format!("tensor zeros: {e}"))
35}
36
37/// Construct a one-filled tensor with an explicit dtype flag.
38pub fn ones_with_dtype(shape: &[usize], dtype: NumericDType) -> Result<Tensor, String> {
39    Tensor::new_with_dtype(vec![1.0; element_count(shape)], shape.to_vec(), dtype)
40        .map_err(|e| format!("tensor ones: {e}"))
41}
42
43/// Convert a logical array (0/1 bytes) into a numeric tensor.
44pub fn logical_to_tensor(logical: &LogicalArray) -> Result<Tensor, String> {
45    let data: Vec<f64> = logical
46        .data
47        .iter()
48        .map(|&b| if b != 0 { 1.0 } else { 0.0 })
49        .collect();
50    Tensor::new(data, logical.shape.clone()).map_err(|e| format!("logical->tensor: {e}"))
51}
52
53fn value_into_tensor_impl(name: &str, value: Value) -> Result<Tensor, String> {
54    match value {
55        Value::Tensor(t) => Ok(t),
56        Value::LogicalArray(logical) => logical_to_tensor(&logical),
57        Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).map_err(|e| format!("tensor: {e}")),
58        Value::Int(i) => {
59            Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|e| format!("tensor: {e}"))
60        }
61        Value::Bool(b) => Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
62            .map_err(|e| format!("tensor: {e}")),
63        other => Err(format!(
64            "{name}: unsupported input type {:?}; expected numeric or logical values",
65            other
66        )),
67    }
68}
69
70/// Convert a `Value` into an owned `Tensor`, defaulting error messages to `"sum"`.
71pub fn value_into_tensor(value: Value) -> Result<Tensor, String> {
72    value_into_tensor_impl("sum", value)
73}
74
75/// Convert a `Value` into a tensor while customising the builtin name in error messages.
76pub fn value_into_tensor_for(name: &str, value: Value) -> Result<Tensor, String> {
77    value_into_tensor_impl(name, value)
78}
79
80/// Clone a `Value` and coerce it into a tensor.
81pub fn value_to_tensor(value: &Value) -> Result<Tensor, String> {
82    value_into_tensor(value.clone())
83}
84
85/// Convert a `Tensor` back into a runtime value.
86///
87/// Scalars (exactly one element) become `Value::Num`, all other tensors
88/// remain as dense tensor variants.
89pub fn tensor_into_value(tensor: Tensor) -> Value {
90    if tensor.data.len() == 1 {
91        Value::Num(tensor.data[0])
92    } else {
93        Value::Tensor(tensor)
94    }
95}
96
97/// Return true when a tensor contains exactly one scalar element.
98pub fn is_scalar_tensor(tensor: &Tensor) -> bool {
99    tensor.data.len() == 1
100}
101
102fn scalar_f64_from_host_value(value: &Value) -> Result<Option<f64>, String> {
103    match value {
104        Value::Num(n) => Ok(Some(*n)),
105        Value::Int(i) => Ok(Some(i.to_f64())),
106        Value::Bool(b) => Ok(Some(if *b { 1.0 } else { 0.0 })),
107        Value::Tensor(t) => {
108            if t.data.len() == 1 {
109                Ok(Some(t.data[0]))
110            } else {
111                Err(format!(
112                    "expected scalar tensor, got tensor of size {}",
113                    t.data.len()
114                ))
115            }
116        }
117        Value::LogicalArray(la) => {
118            if la.data.len() == 1 {
119                Ok(Some(if la.data[0] != 0 { 1.0 } else { 0.0 }))
120            } else {
121                Err(format!(
122                    "expected scalar logical array, got array of size {}",
123                    la.data.len()
124                ))
125            }
126        }
127        _ => Ok(None),
128    }
129}
130
131/// Attempt to extract a scalar f64 from a runtime value asynchronously.
132pub async fn scalar_f64_from_value_async(value: &Value) -> Result<Option<f64>, String> {
133    match value {
134        Value::GpuTensor(handle) => {
135            if !handle.shape.is_empty() {
136                let len = element_count(&handle.shape);
137                if len != 1 {
138                    return Err(format!("expected scalar gpuArray, got array of size {len}"));
139                }
140            }
141            let gathered = gather_if_needed_async(&Value::GpuTensor(handle.clone()))
142                .await
143                .map_err(|e| format!("scalar: {e}"))?;
144            scalar_f64_from_host_value(&gathered)
145        }
146        _ => scalar_f64_from_host_value(value),
147    }
148}
149
150/// Attempt to parse a dimension index from a scalar-like runtime value.
151pub async fn dimension_from_value_async(
152    value: &Value,
153    name: &str,
154    allow_zero: bool,
155) -> Result<Option<usize>, String> {
156    let Some(raw) = scalar_f64_from_value_async(value).await? else {
157        return Ok(None);
158    };
159    if !raw.is_finite() {
160        return Err(format!("{name}: dimension must be finite"));
161    }
162    let rounded = raw.round();
163    if (rounded - raw).abs() > 1e-6 {
164        return Err(format!("{name}: dimension must be an integer"));
165    }
166    let min = if allow_zero { 0.0 } else { 1.0 };
167    if rounded < min {
168        let bound = if allow_zero { 0 } else { 1 };
169        return Err(format!("{name}: dimension must be >= {bound}"));
170    }
171    Ok(Some(rounded as usize))
172}
173
174fn parse_numeric_dimension(value: f64) -> Result<usize, String> {
175    if !value.is_finite() {
176        return Err("dimensions must be finite".to_string());
177    }
178    if value < 0.0 {
179        return Err("matrix dimensions must be non-negative".to_string());
180    }
181    let rounded = value.round();
182    if (rounded - value).abs() > f64::EPSILON {
183        return Err("dimensions must be integers".to_string());
184    }
185    Ok(rounded as usize)
186}
187
188fn dims_from_tensor_values(values: &[f64], shape: &[usize]) -> Result<Option<Vec<usize>>, String> {
189    let len = values.len();
190    if len == 0 {
191        return Ok(Some(Vec::new()));
192    }
193    let is_scalar = len == 1;
194    let is_row = shape.len() >= 2 && shape[0] == 1;
195    let is_column = shape.len() >= 2 && shape[1] == 1;
196    if !(is_row || is_column || is_scalar || shape.len() == 1) {
197        return Ok(None);
198    }
199    let mut dims = Vec::with_capacity(len);
200    for &value in values {
201        dims.push(parse_numeric_dimension(value)?);
202    }
203    Ok(Some(dims))
204}
205
206/// Attempt to parse a dimension vector from a runtime value asynchronously.
207pub async fn dims_from_value_async(value: &Value) -> Result<Option<Vec<usize>>, String> {
208    match value {
209        Value::Num(n) => parse_numeric_dimension(*n).map(|dim| Some(vec![dim])),
210        Value::Int(i) => parse_numeric_dimension(i.to_f64()).map(|dim| Some(vec![dim])),
211        Value::Tensor(t) => dims_from_tensor_values(&t.data, &t.shape),
212        Value::LogicalArray(la) => {
213            let values: Vec<f64> = la
214                .data
215                .iter()
216                .map(|&b| if b != 0 { 1.0 } else { 0.0 })
217                .collect();
218            dims_from_tensor_values(&values, &la.shape)
219        }
220        Value::GpuTensor(handle) => {
221            let gathered = gather_if_needed_async(&Value::GpuTensor(handle.clone()))
222                .await
223                .map_err(|e| format!("dimensions: {e}"))?;
224            match gathered {
225                Value::Tensor(t) => {
226                    if t.data.is_empty() {
227                        tracing::warn!(
228                            gpu_shape = ?handle.shape,
229                            "dims_from_value_async: gathered GPU tensor has no data"
230                        );
231                    }
232                    tracing::trace!(
233                        "dims_from_value_async: GPU tensor values gpu_shape={:?} host_shape={:?} values={:?}",
234                        handle.shape,
235                        t.shape,
236                        t.data
237                    );
238                    let dims = dims_from_tensor_values(&t.data, &t.shape)?;
239                    if dims.is_none() {
240                        tracing::debug!(
241                            gpu_shape = ?handle.shape,
242                            host_shape = ?t.shape,
243                            "dims_from_value_async: GPU tensor not interpretable as dims"
244                        );
245                    }
246                    Ok(dims)
247                }
248                Value::LogicalArray(la) => {
249                    let values: Vec<f64> = la
250                        .data
251                        .iter()
252                        .map(|&b| if b != 0 { 1.0 } else { 0.0 })
253                        .collect();
254                    let dims = dims_from_tensor_values(&values, &la.shape)?;
255                    if dims.is_none() {
256                        tracing::debug!(
257                            gpu_shape = ?handle.shape,
258                            host_shape = ?la.shape,
259                            "dims_from_value_async: GPU logical not interpretable as dims"
260                        );
261                    }
262                    Ok(dims)
263                }
264                Value::Num(n) => parse_numeric_dimension(n).map(|dim| Some(vec![dim])),
265                Value::Int(i) => parse_numeric_dimension(i.to_f64()).map(|dim| Some(vec![dim])),
266                _ => Ok(None),
267            }
268        }
269        _ => Ok(None),
270    }
271}
272
273/// Convert an argument into a dimension index (1-based) if possible.
274pub fn parse_dimension(value: &Value, name: &str) -> Result<usize, String> {
275    match value {
276        Value::Int(i) => {
277            let raw = i.to_i64();
278            if raw < 1 {
279                return Err(format!("{name}: dimension must be >= 1"));
280            }
281            Ok(raw as usize)
282        }
283        Value::Num(n) => {
284            if !n.is_finite() {
285                return Err(format!("{name}: dimension must be finite"));
286            }
287            let rounded = n.round();
288            // Allow small floating error tolerance when users pass float-typed dims
289            if (rounded - n).abs() > 1e-6 {
290                return Err(format!("{name}: dimension must be an integer"));
291            }
292            if rounded < 1.0 {
293                return Err(format!("{name}: dimension must be >= 1"));
294            }
295            Ok(rounded as usize)
296        }
297        other => Err(format!(
298            "{name}: dimension must be numeric, got {:?}",
299            other
300        )),
301    }
302}
303
304/// Attempt to extract a string from a runtime value.
305pub fn value_to_string(value: &Value) -> Option<String> {
306    String::try_from(value).ok()
307}
308
309/// Return a canonical 2-D shape for a tensor given its shape slice and element count.
310///
311/// * Empty data (`len == 0`) → `[0, 1]` (MATLAB convention for empty arrays).
312/// * No shape info (`shape.is_empty()`) → `[1, 1]` (scalar).
313/// * Otherwise → the tensor's own shape.
314pub fn default_shape_for(shape: &[usize], len: usize) -> Vec<usize> {
315    if len == 0 {
316        vec![0, 1]
317    } else if shape.is_empty() {
318        vec![1, 1]
319    } else {
320        shape.to_vec()
321    }
322}
323
324/// Align two numeric tensors for a binary element-wise operation with scalar broadcasting.
325///
326/// Returns `(lhs_data, rhs_data, output_shape)`.  If either operand is a
327/// single element it is broadcast to the other's length.  `builtin` names the
328/// calling builtin and is embedded in the error message when the shapes are
329/// incompatible.
330pub fn binary_numeric_tensors(
331    lhs: &Tensor,
332    rhs: &Tensor,
333    context: &str,
334    builtin: &str,
335) -> crate::BuiltinResult<(Vec<f64>, Vec<f64>, Vec<usize>)> {
336    let lhs_shape = default_shape_for(&lhs.shape, lhs.data.len());
337    let rhs_shape = default_shape_for(&rhs.shape, rhs.data.len());
338    match (lhs.data.len(), rhs.data.len()) {
339        (1, 1) => Ok((vec![lhs.data[0]], vec![rhs.data[0]], vec![1, 1])),
340        (1, len) => Ok((vec![lhs.data[0]; len], rhs.data.clone(), rhs_shape)),
341        (len, 1) => Ok((lhs.data.clone(), vec![rhs.data[0]; len], lhs_shape)),
342        (left, right) if left == right && lhs_shape == rhs_shape => {
343            Ok((lhs.data.clone(), rhs.data.clone(), lhs_shape))
344        }
345        _ => Err(crate::build_runtime_error(format!(
346            "{context}: operands must be scalar or have matching sizes"
347        ))
348        .with_builtin(builtin)
349        .build()),
350    }
351}