runmat_runtime/builtins/common/
shape.rs1use runmat_builtins::{Tensor, Value};
2
3use crate::dispatcher::gather_if_needed_async;
4use crate::RuntimeError;
5
6pub fn is_scalar_shape(shape: &[usize]) -> bool {
8 shape.is_empty()
9 || (shape.len() == 1 && shape[0] == 1)
10 || (shape.len() == 2 && shape[0] == 1 && shape[1] == 1)
11}
12
13pub fn canonical_scalar_shape() -> Vec<usize> {
15 vec![1, 1]
16}
17
18pub fn normalize_scalar_shape(shape: &[usize]) -> Vec<usize> {
20 if is_scalar_shape(shape) {
21 canonical_scalar_shape()
22 } else {
23 shape.to_vec()
24 }
25}
26
27fn normalize_shape(shape: &[usize]) -> Vec<usize> {
29 if shape.len() == 1 && shape[0] != 1 {
30 return vec![1, shape[0]];
31 }
32 if is_scalar_shape(shape) {
33 return canonical_scalar_shape();
34 }
35 shape.to_vec()
36}
37
38#[async_recursion::async_recursion(?Send)]
40pub async fn value_dimensions(value: &Value) -> Result<Vec<usize>, RuntimeError> {
41 let dims = match value {
42 Value::Tensor(t) => normalize_shape(&t.shape),
43 Value::SparseTensor(t) => normalize_shape(&[t.rows, t.cols]),
44 Value::ComplexTensor(t) => normalize_shape(&t.shape),
45 Value::LogicalArray(la) => normalize_shape(&la.shape),
46 Value::StringArray(sa) => normalize_shape(&sa.shape),
47 Value::CharArray(ca) => vec![ca.rows, ca.cols],
48 Value::Cell(ca) => normalize_shape(&ca.shape),
49 Value::GpuTensor(handle) => {
50 if handle.shape.is_empty() {
51 let gathered = gather_if_needed_async(&Value::GpuTensor(handle.clone())).await?;
52 return value_dimensions(&gathered).await;
53 }
54 normalize_shape(&handle.shape)
55 }
56 _ => vec![1, 1],
57 };
58 Ok(dims)
59}
60
61#[async_recursion::async_recursion(?Send)]
63pub async fn value_numel(value: &Value) -> Result<usize, RuntimeError> {
64 let numel = match value {
65 Value::Tensor(t) => t.data.len(),
66 Value::SparseTensor(t) => t.rows.saturating_mul(t.cols),
67 Value::ComplexTensor(t) => t.data.len(),
68 Value::LogicalArray(la) => la.data.len(),
69 Value::StringArray(sa) => sa.data.len(),
70 Value::CharArray(ca) => ca.rows * ca.cols,
71 Value::Cell(ca) => ca.data.len(),
72 Value::GpuTensor(handle) => {
73 if handle.shape.is_empty() {
74 let gathered = gather_if_needed_async(&Value::GpuTensor(handle.clone())).await?;
75 return value_numel(&gathered).await;
76 }
77 handle
78 .shape
79 .iter()
80 .copied()
81 .fold(1usize, |acc, dim| acc.saturating_mul(dim))
82 }
83 _ => 1,
84 };
85 Ok(numel)
86}
87
88pub async fn value_ndims(value: &Value) -> Result<usize, RuntimeError> {
90 let dims = value_dimensions(value).await?;
91 if dims.len() < 2 {
92 Ok(2)
93 } else {
94 Ok(dims.len())
95 }
96}
97
98pub fn dims_to_row_tensor(dims: &[usize]) -> Result<Tensor, String> {
100 let len = dims.len();
101 let data: Vec<f64> = dims.iter().map(|&d| d as f64).collect();
102 let shape = if len == 0 { vec![1, 0] } else { vec![1, len] };
103 Tensor::new(data, shape).map_err(|e| format!("shape::dims_to_row_tensor: {e}"))
104}
105
106#[cfg(test)]
107pub(crate) mod tests {
108 use super::*;
109 use futures::executor::block_on;
110
111 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
112 #[test]
113 fn dims_scalar_defaults_to_one_by_one() {
114 assert_eq!(
115 block_on(value_dimensions(&Value::Num(5.0))).unwrap(),
116 vec![1, 1]
117 );
118 }
119
120 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
121 #[test]
122 fn dims_tensor_preserves_rank() {
123 let tensor = Tensor::new(vec![0.0; 12], vec![2, 3, 2]).unwrap();
124 assert_eq!(
125 block_on(value_dimensions(&Value::Tensor(tensor))).unwrap(),
126 vec![2, 3, 2]
127 );
128 }
129
130 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
131 #[test]
132 fn numel_gpu_uses_shape_product() {
133 let handle = runmat_accelerate_api::GpuTensorHandle {
134 shape: vec![4, 5, 6],
135 device_id: 0,
136 buffer_id: 1,
137 };
138 assert_eq!(
139 block_on(value_numel(&Value::GpuTensor(handle))).unwrap(),
140 120
141 );
142 }
143
144 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
145 #[test]
146 fn dims_to_row_tensor_converts() {
147 let tensor = dims_to_row_tensor(&[2, 4, 6]).unwrap();
148 assert_eq!(tensor.shape, vec![1, 3]);
149 assert_eq!(tensor.data, vec![2.0, 4.0, 6.0]);
150 }
151}