Skip to main content

runmat_runtime/builtins/common/
shape.rs

1use runmat_builtins::{Tensor, Value};
2
3use crate::dispatcher::gather_if_needed_async;
4use crate::RuntimeError;
5
6/// Return true if a shape should be treated as a scalar.
7pub fn is_scalar_shape(shape: &[usize]) -> bool {
8    shape.is_empty()
9        || (shape.len() == 1 && shape[0] == 1)
10        || (shape.len() == 2 && shape[0] == 1 && shape[1] == 1)
11}
12
13/// Return the canonical scalar shape.
14pub fn canonical_scalar_shape() -> Vec<usize> {
15    vec![1, 1]
16}
17
18/// Normalize scalar-like shapes to the canonical scalar shape.
19pub fn normalize_scalar_shape(shape: &[usize]) -> Vec<usize> {
20    if is_scalar_shape(shape) {
21        canonical_scalar_shape()
22    } else {
23        shape.to_vec()
24    }
25}
26
27/// Normalize a raw shape vector into MATLAB-compatible dimension metadata.
28fn normalize_shape(shape: &[usize]) -> Vec<usize> {
29    if shape.len() == 1 && shape[0] != 1 {
30        return vec![1, shape[0]];
31    }
32    if is_scalar_shape(shape) {
33        return canonical_scalar_shape();
34    }
35    shape.to_vec()
36}
37
38/// Return the MATLAB-visible dimension vector for a runtime value.
39#[async_recursion::async_recursion(?Send)]
40pub async fn value_dimensions(value: &Value) -> Result<Vec<usize>, RuntimeError> {
41    let dims = match value {
42        Value::Tensor(t) => normalize_shape(&t.shape),
43        Value::ComplexTensor(t) => normalize_shape(&t.shape),
44        Value::LogicalArray(la) => normalize_shape(&la.shape),
45        Value::StringArray(sa) => normalize_shape(&sa.shape),
46        Value::CharArray(ca) => vec![ca.rows, ca.cols],
47        Value::Cell(ca) => normalize_shape(&ca.shape),
48        Value::GpuTensor(handle) => {
49            if handle.shape.is_empty() {
50                let gathered = gather_if_needed_async(&Value::GpuTensor(handle.clone())).await?;
51                return value_dimensions(&gathered).await;
52            }
53            normalize_shape(&handle.shape)
54        }
55        _ => vec![1, 1],
56    };
57    Ok(dims)
58}
59
60/// Compute the total number of elements contained in a runtime value.
61#[async_recursion::async_recursion(?Send)]
62pub async fn value_numel(value: &Value) -> Result<usize, RuntimeError> {
63    let numel = match value {
64        Value::Tensor(t) => t.data.len(),
65        Value::ComplexTensor(t) => t.data.len(),
66        Value::LogicalArray(la) => la.data.len(),
67        Value::StringArray(sa) => sa.data.len(),
68        Value::CharArray(ca) => ca.rows * ca.cols,
69        Value::Cell(ca) => ca.data.len(),
70        Value::GpuTensor(handle) => {
71            if handle.shape.is_empty() {
72                let gathered = gather_if_needed_async(&Value::GpuTensor(handle.clone())).await?;
73                return value_numel(&gathered).await;
74            }
75            handle
76                .shape
77                .iter()
78                .copied()
79                .fold(1usize, |acc, dim| acc.saturating_mul(dim))
80        }
81        _ => 1,
82    };
83    Ok(numel)
84}
85
86/// Compute the dimensionality (NDIMS) of a runtime value, with MATLAB semantics.
87pub async fn value_ndims(value: &Value) -> Result<usize, RuntimeError> {
88    let dims = value_dimensions(value).await?;
89    if dims.len() < 2 {
90        Ok(2)
91    } else {
92        Ok(dims.len())
93    }
94}
95
96/// Convert a dimension vector into a 1×N tensor encoded as `f64`.
97pub fn dims_to_row_tensor(dims: &[usize]) -> Result<Tensor, String> {
98    let len = dims.len();
99    let data: Vec<f64> = dims.iter().map(|&d| d as f64).collect();
100    let shape = if len == 0 { vec![1, 0] } else { vec![1, len] };
101    Tensor::new(data, shape).map_err(|e| format!("shape::dims_to_row_tensor: {e}"))
102}
103
104#[cfg(test)]
105pub(crate) mod tests {
106    use super::*;
107    use futures::executor::block_on;
108
109    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
110    #[test]
111    fn dims_scalar_defaults_to_one_by_one() {
112        assert_eq!(
113            block_on(value_dimensions(&Value::Num(5.0))).unwrap(),
114            vec![1, 1]
115        );
116    }
117
118    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
119    #[test]
120    fn dims_tensor_preserves_rank() {
121        let tensor = Tensor::new(vec![0.0; 12], vec![2, 3, 2]).unwrap();
122        assert_eq!(
123            block_on(value_dimensions(&Value::Tensor(tensor))).unwrap(),
124            vec![2, 3, 2]
125        );
126    }
127
128    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
129    #[test]
130    fn numel_gpu_uses_shape_product() {
131        let handle = runmat_accelerate_api::GpuTensorHandle {
132            shape: vec![4, 5, 6],
133            device_id: 0,
134            buffer_id: 1,
135        };
136        assert_eq!(
137            block_on(value_numel(&Value::GpuTensor(handle))).unwrap(),
138            120
139        );
140    }
141
142    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
143    #[test]
144    fn dims_to_row_tensor_converts() {
145        let tensor = dims_to_row_tensor(&[2, 4, 6]).unwrap();
146        assert_eq!(tensor.shape, vec![1, 3]);
147        assert_eq!(tensor.data, vec![2.0, 4.0, 6.0]);
148    }
149}