runmat_runtime/builtins/common/
shape.rs1use runmat_builtins::{Tensor, Value};
2
3use crate::dispatcher::gather_if_needed_async;
4use crate::RuntimeError;
5
6pub fn is_scalar_shape(shape: &[usize]) -> bool {
8 shape.is_empty()
9 || (shape.len() == 1 && shape[0] == 1)
10 || (shape.len() == 2 && shape[0] == 1 && shape[1] == 1)
11}
12
13pub fn canonical_scalar_shape() -> Vec<usize> {
15 vec![1, 1]
16}
17
18pub fn normalize_scalar_shape(shape: &[usize]) -> Vec<usize> {
20 if is_scalar_shape(shape) {
21 canonical_scalar_shape()
22 } else {
23 shape.to_vec()
24 }
25}
26
27fn normalize_shape(shape: &[usize]) -> Vec<usize> {
29 if shape.len() == 1 && shape[0] != 1 {
30 return vec![1, shape[0]];
31 }
32 if is_scalar_shape(shape) {
33 return canonical_scalar_shape();
34 }
35 shape.to_vec()
36}
37
38#[async_recursion::async_recursion(?Send)]
40pub async fn value_dimensions(value: &Value) -> Result<Vec<usize>, RuntimeError> {
41 let dims = match value {
42 Value::Tensor(t) => normalize_shape(&t.shape),
43 Value::ComplexTensor(t) => normalize_shape(&t.shape),
44 Value::LogicalArray(la) => normalize_shape(&la.shape),
45 Value::StringArray(sa) => normalize_shape(&sa.shape),
46 Value::CharArray(ca) => vec![ca.rows, ca.cols],
47 Value::Cell(ca) => normalize_shape(&ca.shape),
48 Value::GpuTensor(handle) => {
49 if handle.shape.is_empty() {
50 let gathered = gather_if_needed_async(&Value::GpuTensor(handle.clone())).await?;
51 return value_dimensions(&gathered).await;
52 }
53 normalize_shape(&handle.shape)
54 }
55 _ => vec![1, 1],
56 };
57 Ok(dims)
58}
59
60#[async_recursion::async_recursion(?Send)]
62pub async fn value_numel(value: &Value) -> Result<usize, RuntimeError> {
63 let numel = match value {
64 Value::Tensor(t) => t.data.len(),
65 Value::ComplexTensor(t) => t.data.len(),
66 Value::LogicalArray(la) => la.data.len(),
67 Value::StringArray(sa) => sa.data.len(),
68 Value::CharArray(ca) => ca.rows * ca.cols,
69 Value::Cell(ca) => ca.data.len(),
70 Value::GpuTensor(handle) => {
71 if handle.shape.is_empty() {
72 let gathered = gather_if_needed_async(&Value::GpuTensor(handle.clone())).await?;
73 return value_numel(&gathered).await;
74 }
75 handle
76 .shape
77 .iter()
78 .copied()
79 .fold(1usize, |acc, dim| acc.saturating_mul(dim))
80 }
81 _ => 1,
82 };
83 Ok(numel)
84}
85
86pub async fn value_ndims(value: &Value) -> Result<usize, RuntimeError> {
88 let dims = value_dimensions(value).await?;
89 if dims.len() < 2 {
90 Ok(2)
91 } else {
92 Ok(dims.len())
93 }
94}
95
96pub fn dims_to_row_tensor(dims: &[usize]) -> Result<Tensor, String> {
98 let len = dims.len();
99 let data: Vec<f64> = dims.iter().map(|&d| d as f64).collect();
100 let shape = if len == 0 { vec![1, 0] } else { vec![1, len] };
101 Tensor::new(data, shape).map_err(|e| format!("shape::dims_to_row_tensor: {e}"))
102}
103
104#[cfg(test)]
105pub(crate) mod tests {
106 use super::*;
107 use futures::executor::block_on;
108
109 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
110 #[test]
111 fn dims_scalar_defaults_to_one_by_one() {
112 assert_eq!(
113 block_on(value_dimensions(&Value::Num(5.0))).unwrap(),
114 vec![1, 1]
115 );
116 }
117
118 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
119 #[test]
120 fn dims_tensor_preserves_rank() {
121 let tensor = Tensor::new(vec![0.0; 12], vec![2, 3, 2]).unwrap();
122 assert_eq!(
123 block_on(value_dimensions(&Value::Tensor(tensor))).unwrap(),
124 vec![2, 3, 2]
125 );
126 }
127
128 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
129 #[test]
130 fn numel_gpu_uses_shape_product() {
131 let handle = runmat_accelerate_api::GpuTensorHandle {
132 shape: vec![4, 5, 6],
133 device_id: 0,
134 buffer_id: 1,
135 };
136 assert_eq!(
137 block_on(value_numel(&Value::GpuTensor(handle))).unwrap(),
138 120
139 );
140 }
141
142 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
143 #[test]
144 fn dims_to_row_tensor_converts() {
145 let tensor = dims_to_row_tensor(&[2, 4, 6]).unwrap();
146 assert_eq!(tensor.shape, vec![1, 3]);
147 assert_eq!(tensor.data, vec![2.0, 4.0, 6.0]);
148 }
149}