1use std::fmt;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
9pub struct Shape {
10 dims: Vec<usize>,
11}
12
13impl Shape {
14 pub fn new(dims: Vec<usize>) -> Self {
25 Self { dims }
26 }
27
28 pub fn scalar() -> Self {
30 Self { dims: Vec::new() }
31 }
32
33 #[inline]
35 pub fn ndim(&self) -> usize {
36 self.dims.len()
37 }
38
39 pub fn size(&self) -> usize {
41 if self.dims.is_empty() {
42 1
43 } else {
44 self.dims.iter().product()
45 }
46 }
47
48 #[inline]
50 pub fn as_slice(&self) -> &[usize] {
51 &self.dims
52 }
53
54 #[inline]
56 pub fn is_scalar(&self) -> bool {
57 self.dims.is_empty()
58 }
59
60 pub fn get(&self, index: usize) -> Option<usize> {
62 self.dims.get(index).copied()
63 }
64
65 pub fn default_strides(&self) -> Vec<usize> {
76 let mut strides = vec![1; self.ndim()];
77 for i in (0..self.ndim().saturating_sub(1)).rev() {
78 strides[i] = strides[i + 1] * self.dims[i + 1];
79 }
80 strides
81 }
82
83 pub fn broadcast_with(&self, other: &Shape) -> Option<Shape> {
88 let ndim = self.ndim().max(other.ndim());
89 let mut result = Vec::with_capacity(ndim);
90
91 for i in 0..ndim {
92 let dim1 = if i < self.ndim() {
93 self.dims[self.ndim() - 1 - i]
94 } else {
95 1
96 };
97 let dim2 = if i < other.ndim() {
98 other.dims[other.ndim() - 1 - i]
99 } else {
100 1
101 };
102
103 if dim1 == dim2 || dim1 == 1 || dim2 == 1 {
104 result.push(dim1.max(dim2));
105 } else {
106 return None; }
108 }
109
110 result.reverse();
111 Some(Shape::new(result))
112 }
113}
114
115impl From<Vec<usize>> for Shape {
116 fn from(dims: Vec<usize>) -> Self {
117 Shape::new(dims)
118 }
119}
120
121impl From<&[usize]> for Shape {
122 fn from(dims: &[usize]) -> Self {
123 Shape::new(dims.to_vec())
124 }
125}
126
127impl fmt::Display for Shape {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 write!(f, "(")?;
130 for (i, dim) in self.dims.iter().enumerate() {
131 if i > 0 {
132 write!(f, ", ")?;
133 }
134 write!(f, "{}", dim)?;
135 }
136 if self.dims.len() == 1 {
137 write!(f, ",")?;
138 }
139 write!(f, ")")
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn test_shape_creation() {
149 let shape = Shape::new(vec![2, 3, 4]);
150 assert_eq!(shape.ndim(), 3);
151 assert_eq!(shape.size(), 24);
152 assert_eq!(shape.as_slice(), &[2, 3, 4]);
153 }
154
155 #[test]
156 fn test_scalar_shape() {
157 let shape = Shape::scalar();
158 assert_eq!(shape.ndim(), 0);
159 assert_eq!(shape.size(), 1);
160 assert!(shape.is_scalar());
161 }
162
163 #[test]
164 fn test_default_strides() {
165 let shape = Shape::new(vec![2, 3, 4]);
166 assert_eq!(shape.default_strides(), vec![12, 4, 1]);
167
168 let shape = Shape::new(vec![5]);
169 assert_eq!(shape.default_strides(), vec![1]);
170
171 let shape = Shape::scalar();
172 assert_eq!(shape.default_strides(), Vec::<usize>::new());
173 }
174
175 #[test]
176 fn test_broadcast() {
177 let s1 = Shape::new(vec![3, 1]);
178 let s2 = Shape::new(vec![1, 4]);
179 assert_eq!(s1.broadcast_with(&s2), Some(Shape::new(vec![3, 4])));
180
181 let s1 = Shape::new(vec![2, 3]);
182 let s2 = Shape::new(vec![3]);
183 assert_eq!(s1.broadcast_with(&s2), Some(Shape::new(vec![2, 3])));
184
185 let s1 = Shape::new(vec![2, 3]);
186 let s2 = Shape::new(vec![4]);
187 assert_eq!(s1.broadcast_with(&s2), None); }
189
190 #[test]
191 fn test_display() {
192 assert_eq!(Shape::new(vec![2, 3, 4]).to_string(), "(2, 3, 4)");
193 assert_eq!(Shape::new(vec![5]).to_string(), "(5,)");
194 assert_eq!(Shape::scalar().to_string(), "()");
195 }
196}