runmat_runtime/builtins/common/
shape.rs

1use runmat_builtins::{Tensor, Value};
2
3/// Normalize a raw shape vector into MATLAB-compatible dimension metadata.
4fn normalize_shape(shape: &[usize]) -> Vec<usize> {
5    match shape.len() {
6        0 => vec![1, 1],
7        1 => vec![1, shape[0]],
8        _ => shape.to_vec(),
9    }
10}
11
12/// Return the MATLAB-visible dimension vector for a runtime value.
13pub fn value_dimensions(value: &Value) -> Vec<usize> {
14    match value {
15        Value::Tensor(t) => normalize_shape(&t.shape),
16        Value::ComplexTensor(t) => normalize_shape(&t.shape),
17        Value::LogicalArray(la) => normalize_shape(&la.shape),
18        Value::StringArray(sa) => normalize_shape(&sa.shape),
19        Value::CharArray(ca) => vec![ca.rows, ca.cols],
20        Value::Cell(ca) => normalize_shape(&ca.shape),
21        Value::GpuTensor(handle) => {
22            if handle.shape.is_empty() {
23                if let Some(provider) = runmat_accelerate_api::provider() {
24                    if let Ok(host) = provider.download(handle) {
25                        return normalize_shape(&host.shape);
26                    }
27                }
28                vec![1, 1]
29            } else {
30                normalize_shape(&handle.shape)
31            }
32        }
33        _ => vec![1, 1],
34    }
35}
36
37/// Compute the total number of elements contained in a runtime value.
38pub fn value_numel(value: &Value) -> usize {
39    match value {
40        Value::Tensor(t) => t.data.len(),
41        Value::ComplexTensor(t) => t.data.len(),
42        Value::LogicalArray(la) => la.data.len(),
43        Value::StringArray(sa) => sa.data.len(),
44        Value::CharArray(ca) => ca.rows * ca.cols,
45        Value::Cell(ca) => ca.data.len(),
46        Value::GpuTensor(handle) => {
47            if handle.shape.is_empty() {
48                if let Some(provider) = runmat_accelerate_api::provider() {
49                    if let Ok(host) = provider.download(handle) {
50                        return host.data.len();
51                    }
52                }
53                1
54            } else {
55                handle
56                    .shape
57                    .iter()
58                    .copied()
59                    .fold(1usize, |acc, dim| acc.saturating_mul(dim))
60            }
61        }
62        _ => 1,
63    }
64}
65
66/// Compute the dimensionality (NDIMS) of a runtime value, with MATLAB semantics.
67pub fn value_ndims(value: &Value) -> usize {
68    let dims = value_dimensions(value);
69    if dims.len() < 2 {
70        2
71    } else {
72        dims.len()
73    }
74}
75
76/// Convert a dimension vector into a 1×N tensor encoded as `f64`.
77pub fn dims_to_row_tensor(dims: &[usize]) -> Result<Tensor, String> {
78    let len = dims.len();
79    let data: Vec<f64> = dims.iter().map(|&d| d as f64).collect();
80    let shape = if len == 0 { vec![1, 0] } else { vec![1, len] };
81    Tensor::new(data, shape).map_err(|e| format!("shape::dims_to_row_tensor: {e}"))
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    #[test]
89    fn dims_scalar_defaults_to_one_by_one() {
90        assert_eq!(value_dimensions(&Value::Num(5.0)), vec![1, 1]);
91    }
92
93    #[test]
94    fn dims_tensor_preserves_rank() {
95        let tensor = Tensor::new(vec![0.0; 12], vec![2, 3, 2]).unwrap();
96        assert_eq!(value_dimensions(&Value::Tensor(tensor)), vec![2, 3, 2]);
97    }
98
99    #[test]
100    fn numel_gpu_uses_shape_product() {
101        let handle = runmat_accelerate_api::GpuTensorHandle {
102            shape: vec![4, 5, 6],
103            device_id: 0,
104            buffer_id: 1,
105        };
106        assert_eq!(value_numel(&Value::GpuTensor(handle)), 120);
107    }
108
109    #[test]
110    fn dims_to_row_tensor_converts() {
111        let tensor = dims_to_row_tensor(&[2, 4, 6]).unwrap();
112        assert_eq!(tensor.shape, vec![1, 3]);
113        assert_eq!(tensor.data, vec![2.0, 4.0, 6.0]);
114    }
115}