runmat_runtime/builtins/common/
shape.rs1use runmat_builtins::{Tensor, Value};
2
3fn normalize_shape(shape: &[usize]) -> Vec<usize> {
5 match shape.len() {
6 0 => vec![1, 1],
7 1 => vec![1, shape[0]],
8 _ => shape.to_vec(),
9 }
10}
11
12pub fn value_dimensions(value: &Value) -> Vec<usize> {
14 match value {
15 Value::Tensor(t) => normalize_shape(&t.shape),
16 Value::ComplexTensor(t) => normalize_shape(&t.shape),
17 Value::LogicalArray(la) => normalize_shape(&la.shape),
18 Value::StringArray(sa) => normalize_shape(&sa.shape),
19 Value::CharArray(ca) => vec![ca.rows, ca.cols],
20 Value::Cell(ca) => normalize_shape(&ca.shape),
21 Value::GpuTensor(handle) => {
22 if handle.shape.is_empty() {
23 if let Some(provider) = runmat_accelerate_api::provider() {
24 if let Ok(host) = provider.download(handle) {
25 return normalize_shape(&host.shape);
26 }
27 }
28 vec![1, 1]
29 } else {
30 normalize_shape(&handle.shape)
31 }
32 }
33 _ => vec![1, 1],
34 }
35}
36
37pub fn value_numel(value: &Value) -> usize {
39 match value {
40 Value::Tensor(t) => t.data.len(),
41 Value::ComplexTensor(t) => t.data.len(),
42 Value::LogicalArray(la) => la.data.len(),
43 Value::StringArray(sa) => sa.data.len(),
44 Value::CharArray(ca) => ca.rows * ca.cols,
45 Value::Cell(ca) => ca.data.len(),
46 Value::GpuTensor(handle) => {
47 if handle.shape.is_empty() {
48 if let Some(provider) = runmat_accelerate_api::provider() {
49 if let Ok(host) = provider.download(handle) {
50 return host.data.len();
51 }
52 }
53 1
54 } else {
55 handle
56 .shape
57 .iter()
58 .copied()
59 .fold(1usize, |acc, dim| acc.saturating_mul(dim))
60 }
61 }
62 _ => 1,
63 }
64}
65
66pub fn value_ndims(value: &Value) -> usize {
68 let dims = value_dimensions(value);
69 if dims.len() < 2 {
70 2
71 } else {
72 dims.len()
73 }
74}
75
76pub fn dims_to_row_tensor(dims: &[usize]) -> Result<Tensor, String> {
78 let len = dims.len();
79 let data: Vec<f64> = dims.iter().map(|&d| d as f64).collect();
80 let shape = if len == 0 { vec![1, 0] } else { vec![1, len] };
81 Tensor::new(data, shape).map_err(|e| format!("shape::dims_to_row_tensor: {e}"))
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87
88 #[test]
89 fn dims_scalar_defaults_to_one_by_one() {
90 assert_eq!(value_dimensions(&Value::Num(5.0)), vec![1, 1]);
91 }
92
93 #[test]
94 fn dims_tensor_preserves_rank() {
95 let tensor = Tensor::new(vec![0.0; 12], vec![2, 3, 2]).unwrap();
96 assert_eq!(value_dimensions(&Value::Tensor(tensor)), vec![2, 3, 2]);
97 }
98
99 #[test]
100 fn numel_gpu_uses_shape_product() {
101 let handle = runmat_accelerate_api::GpuTensorHandle {
102 shape: vec![4, 5, 6],
103 device_id: 0,
104 buffer_id: 1,
105 };
106 assert_eq!(value_numel(&Value::GpuTensor(handle)), 120);
107 }
108
109 #[test]
110 fn dims_to_row_tensor_converts() {
111 let tensor = dims_to_row_tensor(&[2, 4, 6]).unwrap();
112 assert_eq!(tensor.shape, vec![1, 3]);
113 assert_eq!(tensor.data, vec![2.0, 4.0, 6.0]);
114 }
115}