1use super::{Axis, Rank, ShapeError, Stride};
6use crate::iter::zip;
7use crate::prelude::{Ixs, SwapAxes};
8#[cfg(not(feature = "std"))]
9use alloc::vec;
10use core::ops::{self, Deref};
11#[cfg(feature = "std")]
12use std::vec;
13
14#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
16#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize,))]
17pub struct Shape(Vec<usize>);
18
19impl Shape {
20 pub fn new(shape: Vec<usize>) -> Self {
21 Self(shape)
22 }
23 pub fn scalar() -> Self {
25 Self(Vec::new())
26 }
27
28 pub fn stride_offset(index: &[usize], strides: &Stride) -> Ixs {
29 let mut offset = 0;
30 for (&i, &s) in index.iter().zip(strides.as_slice()) {
31 offset += super::dim::stride_offset(i, s);
32 }
33 offset
34 }
35
36 pub fn with_capacity(capacity: usize) -> Self {
37 Self(Vec::with_capacity(capacity))
38 }
39 pub fn zeros(rank: usize) -> Self {
41 Self(vec![0; rank])
42 }
43 pub fn as_slice(&self) -> &[usize] {
45 &self.0
46 }
47 pub fn as_slice_mut(&mut self) -> &mut [usize] {
49 &mut self.0
50 }
51
52 pub fn check_size(&self) -> Result<usize, ShapeError> {
53 let size_nonzero = self
54 .as_slice()
55 .iter()
56 .filter(|&&d| d != 0)
57 .try_fold(1usize, |acc, &d| acc.checked_mul(d))
58 .ok_or(ShapeError::Overflow)?;
59 if size_nonzero > core::isize::MAX as usize {
60 Err(ShapeError::Overflow)
61 } else {
62 Ok(self.size())
63 }
64 }
65 pub fn dec(&self) -> Self {
68 let mut shape = self.clone();
69 shape.dec_inplace();
70 shape
71 }
72 pub fn dec_inplace(&mut self) {
74 for dim in self.iter_mut() {
75 *dim -= 1;
76 }
77 }
78 pub fn dec_axis(&mut self, axis: Axis) {
80 self[axis] -= 1;
81 }
82 pub fn diag(&self) -> Shape {
85 Self::new(i![self.nrows()])
86 }
87
88 pub fn first_index(&self) -> Option<Vec<usize>> {
89 if self.is_empty() {
90 return None;
91 }
92 Some(vec![0; *self.rank()])
93 }
94 pub fn get_final_position(&self) -> Vec<usize> {
95 self.dec().to_vec()
96 }
97 pub fn insert(&mut self, index: Axis, dim: usize) {
99 self.0.insert(*index, dim)
100 }
101 pub fn insert_axis(&self, index: Axis) -> Self {
103 let mut shape = self.clone();
104 shape.insert(index, 1);
105 shape
106 }
107 pub fn is_contiguous(&self, stride: &Stride) -> bool {
109 if self.0.len() != stride.len() {
110 return false;
111 }
112 let mut acc = 1;
113 for (&stride, &dim) in stride.iter().zip(self.iter()).rev() {
114 if stride != acc {
115 return false;
116 }
117 acc *= dim;
118 }
119 true
120 }
121 pub fn is_scalar(&self) -> bool {
123 self.0.is_empty()
124 }
125 pub fn is_square(&self) -> bool {
127 self.iter().all(|&dim| dim == self[0])
128 }
129 pub fn iter(&self) -> core::slice::Iter<usize> {
131 self.0.iter()
132 }
133 pub fn iter_mut(&mut self) -> core::slice::IterMut<usize> {
135 self.0.iter_mut()
136 }
137 pub fn ncols(&self) -> usize {
139 if self.len() >= 2 {
140 self[1]
141 } else if self.len() == 1 {
142 1
143 } else {
144 0
145 }
146 }
147 #[doc(hidden)]
148 #[inline]
152 pub fn next_for<D>(&self, index: D) -> Option<Vec<usize>>
153 where
154 D: AsRef<[usize]>,
155 {
156 let mut index = index.as_ref().to_vec();
157 let mut done = false;
158 for (&dim, ix) in zip(self.as_slice(), index.as_mut_slice()).rev() {
159 *ix += 1;
160 if *ix == dim {
161 *ix = 0;
162 } else {
163 done = true;
164 break;
165 }
166 }
167 if done {
168 Some(index)
169 } else {
170 None
171 }
172 }
173 pub fn nrows(&self) -> usize {
175 if self.len() >= 1 {
176 self[0]
177 } else {
178 0
179 }
180 }
181 pub fn pop(&mut self) -> Option<usize> {
183 self.0.pop()
184 }
185 pub fn push(&mut self, dim: usize) {
187 self.0.push(dim)
188 }
189 pub fn rank(&self) -> Rank {
191 self.0.len().into()
192 }
193 pub fn remove(&mut self, index: Axis) -> usize {
195 self.0.remove(*index)
196 }
197 pub fn remove_axis(&self, index: Axis) -> Shape {
199 let mut shape = self.clone();
200 shape.remove(index);
201 shape
202 }
203 pub fn reverse(&mut self) {
205 self.0.reverse()
206 }
207 pub fn set(&mut self, index: Axis, dim: usize) {
209 self[index] = dim
210 }
211 pub fn size(&self) -> usize {
213 self.0.iter().product()
214 }
215 pub fn swap(&mut self, a: Axis, b: Axis) {
217 self.0.swap(a.axis(), b.axis())
218 }
219 pub fn swap_axes(&self, swap: Axis, with: Axis) -> Self {
221 let mut shape = self.clone();
222 shape.swap(swap, with);
223 shape
224 }
225 pub fn to_vec(&self) -> Vec<usize> {
227 self.0.clone()
228 }
229}
230
231#[allow(dead_code)]
233#[doc(hidden)]
234impl Shape {
235 pub fn default_strides(&self) -> Stride {
236 let mut strides = Stride::zeros(self.rank());
239 if self.iter().all(|&d| d != 0) {
241 let mut it = strides.as_slice_mut().iter_mut().rev();
242 if let Some(rs) = it.next() {
244 *rs = 1;
245 }
246 let mut cum_prod = 1;
247 for (rs, dim) in it.zip(self.iter().rev()) {
248 cum_prod *= *dim;
249 *rs = cum_prod;
250 }
251 }
252 strides
253 }
254
255 pub(crate) fn matmul(&self, other: &Self) -> Result<Self, ShapeError> {
256 if self.rank() == 2 && other.rank() == 2 {
257 return Ok(Self::from((self[0], other[1])));
258 } else if self.rank() == 2 && other.rank() == 1 {
259 return Ok(Self::from(self[0]));
260 } else if self.rank() == 1 && other.rank() == 2 {
261 return Ok(Self::from(other[0]));
262 } else if self.rank() == 1 && other.rank() == 1 {
263 return Ok(Self::scalar());
264 }
265 Err(ShapeError::IncompatibleShapes)
266 }
267
268 pub(crate) fn matmul_shape(&self, other: &Self) -> Result<Self, ShapeError> {
269 if *self.rank() != 2 || *other.rank() != 2 || self[1] != other[0] {
270 return Err(ShapeError::IncompatibleShapes);
271 }
272 Ok(Self::from((self[0], other[1])))
273 }
274
275 pub(crate) fn stride_contiguous(&self) -> Stride {
276 let mut stride: Vec<_> = self
277 .0
278 .iter()
279 .rev()
280 .scan(1, |prod, u| {
281 let prod_pre_mult = *prod;
282 *prod *= u;
283 Some(prod_pre_mult)
284 })
285 .collect();
286 stride.reverse();
287 stride.into()
288 }
289
290 pub(crate) fn upcast(&self, to: &Shape, stride: &Stride) -> Option<Stride> {
291 let mut new_stride = to.as_slice().to_vec();
292 if to.rank() < self.rank() {
295 return None;
296 }
297
298 let mut iter = new_stride.as_mut_slice().iter_mut().rev();
299 for ((er, es), dr) in self
300 .as_slice()
301 .iter()
302 .rev()
303 .zip(stride.as_slice().iter().rev())
304 .zip(iter.by_ref())
305 {
306 if *dr == *er {
308 *dr = *es;
310 } else if *er == 1 {
311 *dr = 0
313 } else {
314 return None;
315 }
316 }
317
318 for dr in iter {
320 *dr = 0;
321 }
322
323 Some(new_stride.into())
324 }
325}
326
327impl AsRef<[usize]> for Shape {
328 fn as_ref(&self) -> &[usize] {
329 &self.0
330 }
331}
332
333impl AsMut<[usize]> for Shape {
334 fn as_mut(&mut self) -> &mut [usize] {
335 &mut self.0
336 }
337}
338
339impl Deref for Shape {
340 type Target = [usize];
341
342 fn deref(&self) -> &Self::Target {
343 &self.0
344 }
345}
346
347impl Extend<usize> for Shape {
348 fn extend<I: IntoIterator<Item = usize>>(&mut self, iter: I) {
349 self.0.extend(iter)
350 }
351}
352
353impl From<Shape> for Vec<usize> {
354 fn from(shape: Shape) -> Self {
355 shape.0
356 }
357}
358
359impl_partial_eq!(Shape -> 0: [[usize], Vec<usize>]);
360
361impl SwapAxes for Shape {
362 fn swap_axes(&self, a: Axis, b: Axis) -> Self {
363 self.swap_axes(a, b)
364 }
365}
366
367impl FromIterator<usize> for Shape {
368 fn from_iter<I: IntoIterator<Item = usize>>(iter: I) -> Self {
369 Self(Vec::from_iter(iter))
370 }
371}
372
373impl IntoIterator for Shape {
374 type Item = usize;
375 type IntoIter = vec::IntoIter<Self::Item>;
376
377 fn into_iter(self) -> Self::IntoIter {
378 self.0.into_iter()
379 }
380}
381
382impl<'a> IntoIterator for &'a mut Shape {
383 type Item = &'a mut usize;
384 type IntoIter = core::slice::IterMut<'a, usize>;
385
386 fn into_iter(self) -> Self::IntoIter {
387 self.0.iter_mut()
388 }
389}
390
391impl ops::Index<usize> for Shape {
392 type Output = usize;
393
394 fn index(&self, index: usize) -> &Self::Output {
395 &self.0[index]
396 }
397}
398
399impl ops::Index<Axis> for Shape {
400 type Output = usize;
401
402 fn index(&self, index: Axis) -> &Self::Output {
403 &self.0[*index]
404 }
405}
406
407impl ops::IndexMut<usize> for Shape {
408 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
409 &mut self.0[index]
410 }
411}
412
413impl ops::IndexMut<Axis> for Shape {
414 fn index_mut(&mut self, index: Axis) -> &mut Self::Output {
415 &mut self.0[*index]
416 }
417}
418
419impl ops::Index<ops::Range<usize>> for Shape {
420 type Output = [usize];
421
422 fn index(&self, index: ops::Range<usize>) -> &Self::Output {
423 &self.0[index]
424 }
425}
426
427impl ops::Index<ops::RangeTo<usize>> for Shape {
428 type Output = [usize];
429
430 fn index(&self, index: ops::RangeTo<usize>) -> &Self::Output {
431 &self.0[index]
432 }
433}
434
435impl ops::Index<ops::RangeFrom<usize>> for Shape {
436 type Output = [usize];
437
438 fn index(&self, index: ops::RangeFrom<usize>) -> &Self::Output {
439 &self.0[index]
440 }
441}
442
443impl ops::Index<ops::RangeFull> for Shape {
444 type Output = [usize];
445
446 fn index(&self, index: ops::RangeFull) -> &Self::Output {
447 &self.0[index]
448 }
449}
450
451impl ops::Index<ops::RangeInclusive<usize>> for Shape {
452 type Output = [usize];
453
454 fn index(&self, index: ops::RangeInclusive<usize>) -> &Self::Output {
455 &self.0[index]
456 }
457}
458
459impl ops::Index<ops::RangeToInclusive<usize>> for Shape {
460 type Output = [usize];
461
462 fn index(&self, index: ops::RangeToInclusive<usize>) -> &Self::Output {
463 &self.0[index]
464 }
465}
466
467unsafe impl Send for Shape {}
468
469unsafe impl Sync for Shape {}
470
471impl From<()> for Shape {
472 fn from(_: ()) -> Self {
473 Self::default()
474 }
475}
476
477impl From<usize> for Shape {
478 fn from(dim: usize) -> Self {
479 Self(vec![dim])
480 }
481}
482
483impl From<Vec<usize>> for Shape {
484 fn from(shape: Vec<usize>) -> Self {
485 Self(shape)
486 }
487}
488
489impl From<&[usize]> for Shape {
490 fn from(shape: &[usize]) -> Self {
491 Self(shape.to_vec())
492 }
493}
494
495impl<const N: usize> From<[usize; N]> for Shape {
496 fn from(shape: [usize; N]) -> Self {
497 Self(shape.to_vec())
498 }
499}
500
501impl From<(usize,)> for Shape {
502 fn from(shape: (usize,)) -> Self {
503 Self(vec![shape.0])
504 }
505}
506
507impl From<(usize, usize)> for Shape {
508 fn from(shape: (usize, usize)) -> Self {
509 Self(vec![shape.0, shape.1])
510 }
511}
512
513impl From<(usize, usize, usize)> for Shape {
514 fn from(shape: (usize, usize, usize)) -> Self {
515 Self(vec![shape.0, shape.1, shape.2])
516 }
517}
518
519impl From<(usize, usize, usize, usize)> for Shape {
520 fn from(shape: (usize, usize, usize, usize)) -> Self {
521 Self(vec![shape.0, shape.1, shape.2, shape.3])
522 }
523}
524
525impl From<(usize, usize, usize, usize, usize)> for Shape {
526 fn from(shape: (usize, usize, usize, usize, usize)) -> Self {
527 Self(vec![shape.0, shape.1, shape.2, shape.3, shape.4])
528 }
529}
530
531impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
532 fn from(shape: (usize, usize, usize, usize, usize, usize)) -> Self {
533 Self(vec![shape.0, shape.1, shape.2, shape.3, shape.4, shape.5])
534 }
535}
536
537