1use std::{
2 borrow::{Borrow, Cow},
3 fmt,
4 hash::Hash,
5 ops::{Deref, DerefMut, Index, RangeBounds},
6};
7
8use serde::*;
9use smallvec::SmallVec;
10
11#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default, Serialize, Deserialize)]
13#[serde(transparent)]
14pub struct Shape {
15 dims: SmallVec<[usize; INLINE_DIMS]>,
16}
17const INLINE_DIMS: usize = 2;
18
19impl Shape {
20 pub const SCALAR: Self = Shape {
22 dims: SmallVec::new_const(),
23 };
24 pub const EMPTY_LIST: Self = Shape {
26 dims: unsafe { SmallVec::from_const_with_len_unchecked([0; INLINE_DIMS], 1) },
27 };
28 pub fn with_capacity(capacity: usize) -> Self {
30 Shape {
31 dims: SmallVec::with_capacity(capacity),
32 }
33 }
34 pub fn drain(&mut self, range: impl RangeBounds<usize>) {
36 self.dims.drain(range);
37 }
38 pub fn prepend(&mut self, dim: usize) {
40 self.dims.insert(0, dim);
41 }
42 pub fn push(&mut self, dim: usize) {
44 self.dims.push(dim);
45 }
46 pub fn pop(&mut self) -> Option<usize> {
48 self.dims.pop()
49 }
50 pub fn insert(&mut self, index: usize, dim: usize) {
52 self.dims.insert(index, dim);
53 }
54 pub fn row_count_mut(&mut self) -> &mut usize {
56 if self.is_empty() {
57 self.push(1);
58 }
59 &mut self.dims[0]
60 }
61 pub fn remove(&mut self, index: usize) -> usize {
63 self.dims.remove(index)
64 }
65 #[inline(always)]
67 pub fn row_count(&self) -> usize {
68 self.dims.first().copied().unwrap_or(1)
69 }
70 pub fn row_len(&self) -> usize {
72 self.dims.iter().skip(1).product()
73 }
74 pub fn row(&self) -> Shape {
76 let mut shape = self.clone();
77 shape.make_row();
78 shape
79 }
80 pub fn row_slice(&self) -> &[usize] {
82 &self.dims[self.len().min(1)..]
83 }
84 pub fn subshape<R>(&self, range: R) -> Shape
86 where
87 [usize]: Index<R>,
88 Self: for<'a> From<&'a <[usize] as Index<R>>::Output>,
89 {
90 Shape::from(&self.dims.as_slice()[range])
91 }
92 pub fn elements(&self) -> usize {
94 self.iter().product()
95 }
96 pub fn make_row(&mut self) {
98 if !self.is_empty() {
99 self.dims.remove(0);
100 }
101 }
102 pub fn deshape(&mut self) {
104 if self.len() != 1 {
105 *self = self.elements().into();
106 }
107 }
108 pub fn fix(&mut self) {
110 self.fix_depth(0);
111 }
112 pub(crate) fn fix_depth(&mut self, depth: usize) -> usize {
113 let depth = depth.min(self.len());
114 self.insert(depth, 1);
115 depth
116 }
117 pub fn unfix(&mut self) -> Result<(), Cow<'static, str>> {
119 match self.unfix_inner() {
120 Some(1) => Ok(()),
121 Some(d) => Err(Cow::Owned(format!("Cannot unfix array with length {d}"))),
122 None if self.contains(&0) => Err("Cannot unfix empty array".into()),
123 None if self.is_empty() => Err("Cannot unfix scalar".into()),
124 None => Err(Cow::Owned(format!(
125 "Cannot unfix array with shape {self:?}"
126 ))),
127 }
128 }
129 pub fn undo_fix(&mut self) {
131 self.unfix_inner();
132 }
133 fn unfix_inner(&mut self) -> Option<usize> {
137 match &mut **self {
138 [1, ..] => Some(self.remove(0)),
139 [a, b, ..] => {
140 let new_first_dim = *a * *b;
141 *b = new_first_dim;
142 Some(self.remove(0))
143 }
144 _ => None,
145 }
146 }
147 pub fn extend_from_slice(&mut self, dims: &[usize]) {
149 self.dims.extend_from_slice(dims);
150 }
151 pub fn split_off(&mut self, at: usize) -> Self {
153 let (_, b) = self.dims.split_at(at);
154 let second = Shape::from(b);
155 self.dims.truncate(at);
156 second
157 }
158 pub fn dims_mut(&mut self) -> &mut [usize] {
160 &mut self.dims
161 }
162 #[track_caller]
164 pub fn truncate(&mut self, len: usize) {
165 self.dims.truncate(len);
166 }
167 pub(crate) fn flat_to_dims(&self, flat: usize, index: &mut Vec<usize>) {
168 index.clear();
169 let mut flat = flat;
170 for &dim in self.dims.iter().rev() {
171 index.push(flat % dim);
172 flat /= dim;
173 }
174 index.reverse();
175 }
176 pub(crate) fn dims_to_flat(
177 &self,
178 index: impl IntoIterator<Item = impl Borrow<usize>>,
179 ) -> Option<usize> {
180 let mut flat = 0;
181 for (&dim, i) in self.dims.iter().zip(index) {
182 let i = *i.borrow();
183 if i >= dim {
184 return None;
185 }
186 flat = flat * dim + i;
187 }
188 Some(flat)
189 }
190 pub(crate) fn i_dims_to_flat(
191 &self,
192 index: impl IntoIterator<Item = impl Borrow<isize>>,
193 ) -> Option<usize> {
194 let mut flat = 0;
195 for (&dim, i) in self.dims.iter().zip(index) {
196 let i = *i.borrow();
197 if i < 0 || i >= dim as isize {
198 return None;
199 }
200 flat = flat * dim + i as usize;
201 }
202 Some(flat)
203 }
204}
205
206impl fmt::Debug for Shape {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 write!(f, "[")?;
209 for (i, dim) in self.dims.iter().enumerate() {
210 if i > 0 {
211 write!(f, " × ")?;
212 }
213 write!(f, "{dim}")?;
214 }
215 write!(f, "]")
216 }
217}
218
219impl fmt::Display for Shape {
220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221 write!(f, "{self:?}")
222 }
223}
224
225impl From<usize> for Shape {
226 fn from(dim: usize) -> Self {
227 Self::from([dim])
228 }
229}
230
231impl From<&[usize]> for Shape {
232 fn from(dims: &[usize]) -> Self {
233 Self {
234 dims: dims.iter().copied().collect(),
235 }
236 }
237}
238
239impl From<Vec<usize>> for Shape {
240 fn from(dims: Vec<usize>) -> Self {
241 Self {
242 dims: SmallVec::from_vec(dims),
243 }
244 }
245}
246
247impl<const N: usize> From<[usize; N]> for Shape {
248 fn from(dims: [usize; N]) -> Self {
249 dims.as_slice().into()
250 }
251}
252
253impl Deref for Shape {
254 type Target = [usize];
255 fn deref(&self) -> &Self::Target {
256 &self.dims
257 }
258}
259
260impl DerefMut for Shape {
261 fn deref_mut(&mut self) -> &mut Self::Target {
262 &mut self.dims
263 }
264}
265
266impl IntoIterator for Shape {
267 type Item = usize;
268 type IntoIter = <SmallVec<[usize; INLINE_DIMS]> as IntoIterator>::IntoIter;
269 fn into_iter(self) -> Self::IntoIter {
270 self.dims.into_iter()
271 }
272}
273
274impl<'a> IntoIterator for &'a Shape {
275 type Item = &'a usize;
276 type IntoIter = <&'a [usize] as IntoIterator>::IntoIter;
277 fn into_iter(self) -> Self::IntoIter {
278 self.dims.iter()
279 }
280}
281
282impl FromIterator<usize> for Shape {
283 fn from_iter<I: IntoIterator<Item = usize>>(iter: I) -> Self {
284 Self {
285 dims: iter.into_iter().collect(),
286 }
287 }
288}
289
290impl Extend<usize> for Shape {
291 fn extend<I: IntoIterator<Item = usize>>(&mut self, iter: I) {
292 self.dims.extend(iter);
293 }
294}
295
296impl PartialEq<usize> for Shape {
297 fn eq(&self, other: &usize) -> bool {
298 self == [*other]
299 }
300}
301
302impl PartialEq<usize> for &Shape {
303 fn eq(&self, other: &usize) -> bool {
304 *self == [*other]
305 }
306}
307
308impl<const N: usize> PartialEq<[usize; N]> for Shape {
309 fn eq(&self, other: &[usize; N]) -> bool {
310 self == other.as_slice()
311 }
312}
313
314impl<const N: usize> PartialEq<[usize; N]> for &Shape {
315 fn eq(&self, other: &[usize; N]) -> bool {
316 *self == other.as_slice()
317 }
318}
319
320impl PartialEq<[usize]> for Shape {
321 fn eq(&self, other: &[usize]) -> bool {
322 self.dims.as_slice() == other
323 }
324}
325
326impl PartialEq<[usize]> for &Shape {
327 fn eq(&self, other: &[usize]) -> bool {
328 *self == other
329 }
330}
331
332impl PartialEq<&[usize]> for Shape {
333 fn eq(&self, other: &&[usize]) -> bool {
334 self.dims.as_slice() == *other
335 }
336}
337
338impl PartialEq<Shape> for &[usize] {
339 fn eq(&self, other: &Shape) -> bool {
340 other == self
341 }
342}
343
344impl PartialEq<Shape> for [usize] {
345 fn eq(&self, other: &Shape) -> bool {
346 other == self
347 }
348}