1use crate::error::{GhostError, Result};
4use smallvec::SmallVec;
5
6const MAX_INLINE_DIMS: usize = 6;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct Shape(SmallVec<[usize; MAX_INLINE_DIMS]>);
12
13impl Shape {
14 pub fn new(dims: &[usize]) -> Self {
16 Shape(SmallVec::from_slice(dims))
17 }
18
19 pub fn scalar() -> Self {
21 Shape(SmallVec::new())
22 }
23
24 pub fn ndim(&self) -> usize {
26 self.0.len()
27 }
28
29 pub fn numel(&self) -> usize {
31 self.0.iter().product()
32 }
33
34 pub fn dim(&self, idx: usize) -> Option<usize> {
36 self.0.get(idx).copied()
37 }
38
39 pub fn dims(&self) -> &[usize] {
41 &self.0
42 }
43
44 pub fn is_scalar(&self) -> bool {
46 self.0.is_empty()
47 }
48
49 pub fn broadcast_with(&self, other: &Shape) -> Result<Shape> {
51 let max_ndim = self.ndim().max(other.ndim());
52 let mut result = SmallVec::with_capacity(max_ndim);
53
54 for i in 0..max_ndim {
55 let a = if i < self.ndim() {
56 self.0[self.ndim() - 1 - i]
57 } else {
58 1
59 };
60 let b = if i < other.ndim() {
61 other.0[other.ndim() - 1 - i]
62 } else {
63 1
64 };
65
66 if a == b {
67 result.push(a);
68 } else if a == 1 {
69 result.push(b);
70 } else if b == 1 {
71 result.push(a);
72 } else {
73 return Err(GhostError::BroadcastError {
74 a: self.0.to_vec(),
75 b: other.0.to_vec(),
76 });
77 }
78 }
79
80 result.reverse();
81 Ok(Shape(result))
82 }
83
84 pub fn default_strides(&self) -> Strides {
86 if self.is_scalar() {
87 return Strides::new(&[]);
88 }
89
90 let mut strides = SmallVec::with_capacity(self.ndim());
91 let mut stride = 1usize;
92
93 for &dim in self.0.iter().rev() {
94 strides.push(stride);
95 stride *= dim;
96 }
97
98 strides.reverse();
99 Strides(strides)
100 }
101}
102
103impl From<&[usize]> for Shape {
104 fn from(dims: &[usize]) -> Self {
105 Shape::new(dims)
106 }
107}
108
109impl From<Vec<usize>> for Shape {
110 fn from(dims: Vec<usize>) -> Self {
111 Shape(SmallVec::from_vec(dims))
112 }
113}
114
115impl std::fmt::Display for Shape {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 write!(f, "[")?;
118 for (i, d) in self.0.iter().enumerate() {
119 if i > 0 {
120 write!(f, ", ")?;
121 }
122 write!(f, "{}", d)?;
123 }
124 write!(f, "]")
125 }
126}
127
128#[derive(Debug, Clone, PartialEq, Eq, Hash)]
130pub struct Strides(SmallVec<[usize; MAX_INLINE_DIMS]>);
131
132impl Strides {
133 pub fn new(strides: &[usize]) -> Self {
135 Strides(SmallVec::from_slice(strides))
136 }
137
138 pub fn stride(&self, idx: usize) -> Option<usize> {
140 self.0.get(idx).copied()
141 }
142
143 pub fn as_slice(&self) -> &[usize] {
145 &self.0
146 }
147
148 pub fn is_contiguous(&self, shape: &Shape) -> bool {
150 if shape.is_scalar() {
151 return true;
152 }
153
154 let expected = shape.default_strides();
155 self.0 == expected.0
156 }
157
158 pub fn offset(&self, indices: &[usize]) -> usize {
160 indices
161 .iter()
162 .zip(self.0.iter())
163 .map(|(&idx, &stride)| idx * stride)
164 .sum()
165 }
166}
167
168impl From<&[usize]> for Strides {
169 fn from(strides: &[usize]) -> Self {
170 Strides::new(strides)
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[test]
179 fn test_shape_numel() {
180 assert_eq!(Shape::new(&[2, 3, 4]).numel(), 24);
181 assert_eq!(Shape::new(&[1]).numel(), 1);
182 assert_eq!(Shape::scalar().numel(), 1);
183 }
184
185 #[test]
186 fn test_broadcast() {
187 let a = Shape::new(&[3, 1]);
188 let b = Shape::new(&[1, 4]);
189 let c = a.broadcast_with(&b).unwrap();
190 assert_eq!(c.dims(), &[3, 4]);
191 }
192
193 #[test]
194 fn test_strides() {
195 let shape = Shape::new(&[2, 3, 4]);
196 let strides = shape.default_strides();
197 assert_eq!(strides.as_slice(), &[12, 4, 1]);
198 }
199}