1#[cfg(feature = "nightly")]
2use alloc::alloc::Allocator;
3#[cfg(not(feature = "std"))]
4use alloc::boxed::Box;
5#[cfg(not(feature = "std"))]
6use alloc::vec::Vec;
7
8use core::cmp::Ordering;
9use core::fmt::Debug;
10use core::hash::{Hash, Hasher};
11use core::slice;
12
13#[cfg(not(feature = "nightly"))]
14use crate::allocator::Allocator;
15use crate::buffer::{DynBuffer, Owned, StaticBuffer};
16use crate::dim::{Const, Dim, Dims, Dyn};
17use crate::layout::{Layout, Strided};
18
19pub trait Shape: Clone + Debug + Default + Hash + Ord + Send + Sync {
21 type Head: Dim;
23
24 type Tail: Shape;
26
27 type Reverse: Shape;
29
30 type Prepend<D: Dim>: Shape;
32
33 type Dyn: Shape;
35
36 type Merge<S: Shape>: Shape;
39
40 type Buffer<T, A: Allocator>: Owned<Item = T, Shape = Self, Alloc = A>;
42
43 type Layout<L: Layout>: Layout;
45
46 #[doc(hidden)]
47 type Dims<T: Copy + Debug + Default + Hash + Ord + Send + Sync>: Dims<T>;
48
49 const RANK: Option<usize>;
51
52 #[inline]
58 fn dim(&self, index: usize) -> usize {
59 assert!(index < self.rank(), "invalid dimension");
60
61 self.with_dims(|dims| dims[index])
62 }
63
64 #[inline]
70 fn from_dims(dims: &[usize]) -> Self {
71 let mut shape = Self::new(dims.len());
72
73 shape.with_mut_dims(|dst| dst.copy_from_slice(dims));
74 shape
75 }
76
77 #[inline]
79 fn is_empty(&self) -> bool {
80 self.len() == 0
81 }
82
83 #[inline]
85 fn len(&self) -> usize {
86 self.with_dims(|dims| dims.iter().product())
87 }
88
89 #[inline]
91 fn rank(&self) -> usize {
92 self.with_dims(|dims| dims.len())
93 }
94
95 #[doc(hidden)]
96 fn new(rank: usize) -> Self;
97
98 #[doc(hidden)]
99 fn with_dims<T, F: FnOnce(&[usize]) -> T>(&self, f: F) -> T;
100
101 #[doc(hidden)]
102 fn with_mut_dims<T, F: FnOnce(&mut [usize]) -> T>(&mut self, f: F) -> T;
103
104 #[doc(hidden)]
105 #[inline]
106 fn checked_len(&self) -> Option<usize> {
107 self.with_dims(|dims| dims.iter().try_fold(1usize, |acc, &x| acc.checked_mul(x)))
108 }
109
110 #[doc(hidden)]
111 #[inline]
112 fn prepend_dim<S: Shape>(&self, size: usize) -> S {
113 let mut shape = S::new(self.rank() + 1);
114
115 shape.with_mut_dims(|dims| {
116 dims[0] = size;
117 self.with_dims(|src| dims[1..].copy_from_slice(src));
118 });
119
120 shape
121 }
122
123 #[doc(hidden)]
124 #[inline]
125 fn remove_dim<S: Shape>(&self, index: usize) -> S {
126 assert!(index < self.rank(), "invalid dimension");
127
128 let mut shape = S::new(self.rank() - 1);
129
130 shape.with_mut_dims(|dims| {
131 self.with_dims(|src| {
132 dims[..index].copy_from_slice(&src[..index]);
133 dims[index..].copy_from_slice(&src[index + 1..]);
134 });
135 });
136
137 shape
138 }
139
140 #[doc(hidden)]
141 #[inline]
142 fn reshape<S: Shape>(&self, mut new_shape: S) -> S {
143 let mut inferred = None;
144
145 new_shape.with_mut_dims(|dims| {
146 for i in 0..dims.len() {
147 if dims[i] == usize::MAX {
148 assert!(inferred.is_none(), "at most one dimension can be inferred");
149
150 dims[i] = 1;
151 inferred = Some(i);
152 }
153 }
154 });
155
156 let old_len = self.len();
157 let new_len = new_shape.checked_len().expect("invalid length");
158
159 if let Some(i) = inferred {
160 assert!(old_len.is_multiple_of(new_len), "length not divisible by the new dimensions");
161
162 new_shape.with_mut_dims(|dims| dims[i] = old_len / new_len);
163 } else {
164 assert!(new_len == old_len, "length must not change");
165 }
166
167 new_shape
168 }
169
170 #[doc(hidden)]
171 #[inline]
172 fn resize_dim<S: Shape>(&self, index: usize, new_size: usize) -> S {
173 assert!(index < self.rank(), "invalid dimension");
174
175 let mut shape = S::new(self.rank());
176
177 shape.with_mut_dims(|dims| {
178 self.with_dims(|src| dims[..].copy_from_slice(src));
179 dims[index] = new_size;
180 });
181
182 shape
183 }
184
185 #[doc(hidden)]
186 #[inline]
187 fn reverse(&self) -> Self::Reverse {
188 let mut shape = Self::Reverse::new(self.rank());
189
190 shape.with_mut_dims(|dims| {
191 self.with_dims(|src| dims.copy_from_slice(src));
192 dims.reverse();
193 });
194
195 shape
196 }
197}
198
199pub trait ConstShape: Copy + Shape {
201 #[doc(hidden)]
202 type Inner<T>;
203
204 #[doc(hidden)]
205 type WithConst<T, const N: usize, A: Allocator>: Owned<Item = T, Shape = Self::Prepend<Const<N>>, Alloc = A>;
206}
207
208pub trait IntoShape {
210 type IntoShape: Shape;
212
213 fn into_shape(self) -> Self::IntoShape;
215
216 #[doc(hidden)]
217 fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T;
218}
219
220pub enum DynRank {
225 Dyn(Box<[usize]>),
227 One(usize),
229}
230
231pub type Rank<const N: usize> = <[usize; N] as IntoShape>::IntoShape;
233
234impl DynRank {
235 #[inline]
237 pub fn dims(&self) -> &[usize] {
238 match self {
239 Self::Dyn(dims) => dims,
240 Self::One(size) => slice::from_ref(size),
241 }
242 }
243}
244
245impl Clone for DynRank {
246 #[inline]
247 fn clone(&self) -> Self {
248 match self {
249 Self::Dyn(dims) => {
250 if dims.len() == 1 {
251 Self::One(dims[0])
252 } else {
253 Self::Dyn(dims.clone())
254 }
255 }
256 Self::One(dim) => Self::One(*dim),
257 }
258 }
259
260 #[inline]
261 fn clone_from(&mut self, source: &Self) {
262 if let Self::Dyn(dims) = self
263 && let Self::Dyn(src) = source
264 && dims.len() == src.len()
265 {
266 dims.clone_from_slice(src);
267
268 return;
269 }
270
271 *self = source.clone();
272 }
273}
274
275impl Debug for DynRank {
276 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
277 self.with_dims(|dims| f.debug_tuple("DynRank").field(&dims).finish())
278 }
279}
280
281impl Default for DynRank {
282 #[inline]
283 fn default() -> Self {
284 Self::One(0)
285 }
286}
287
288impl Eq for DynRank {}
289
290impl Hash for DynRank {
291 #[inline]
292 fn hash<H: Hasher>(&self, state: &mut H) {
293 self.with_dims(|dims| dims.hash(state))
294 }
295}
296
297impl Ord for DynRank {
298 #[inline]
299 fn cmp(&self, other: &Self) -> Ordering {
300 self.with_dims(|dims| other.with_dims(|other| dims.cmp(other)))
301 }
302}
303
304impl PartialEq for DynRank {
305 #[inline]
306 fn eq(&self, other: &Self) -> bool {
307 self.with_dims(|dims| other.with_dims(|other| dims.eq(other)))
308 }
309}
310
311impl PartialOrd for DynRank {
312 #[inline]
313 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
314 Some(self.cmp(other))
315 }
316}
317
318impl Shape for DynRank {
319 type Head = Dyn;
320 type Tail = Self;
321
322 type Reverse = Self;
323 type Prepend<D: Dim> = Self;
324
325 type Dyn = Self;
326 type Merge<S: Shape> = Self;
327
328 type Buffer<T, A: Allocator> = DynBuffer<T, Self, A>;
329 type Layout<L: Layout> = Strided;
330
331 type Dims<T: Copy + Debug + Default + Hash + Ord + Send + Sync> = Box<[T]>;
332
333 const RANK: Option<usize> = None;
334
335 #[inline]
336 fn new(rank: usize) -> Self {
337 if rank == 1 { Self::One(0) } else { Self::Dyn(Dims::new(rank)) }
338 }
339
340 #[inline]
341 fn with_dims<T, F: FnOnce(&[usize]) -> T>(&self, f: F) -> T {
342 let dims = match self {
343 Self::Dyn(dims) => dims,
344 Self::One(size) => slice::from_ref(size),
345 };
346
347 f(dims)
348 }
349
350 #[inline]
351 fn with_mut_dims<T, F: FnOnce(&mut [usize]) -> T>(&mut self, f: F) -> T {
352 let dims = match self {
353 Self::Dyn(dims) => dims,
354 Self::One(size) => slice::from_mut(size),
355 };
356
357 f(dims)
358 }
359}
360
361impl Shape for () {
362 type Head = Dyn;
363 type Tail = Self;
364
365 type Reverse = Self;
366 type Prepend<D: Dim> = (D,);
367
368 type Dyn = Self;
369 type Merge<S: Shape> = S;
370
371 type Buffer<T, A: Allocator> = StaticBuffer<T, Self, A>;
372 type Layout<L: Layout> = L;
373
374 type Dims<T: Copy + Debug + Default + Hash + Ord + Send + Sync> = [T; 0];
375
376 const RANK: Option<usize> = Some(0);
377
378 #[inline]
379 fn new(rank: usize) {
380 assert!(rank == 0, "invalid rank");
381 }
382
383 #[inline]
384 fn with_dims<T, F: FnOnce(&[usize]) -> T>(&self, f: F) -> T {
385 f(&[])
386 }
387
388 #[inline]
389 fn with_mut_dims<T, F: FnOnce(&mut [usize]) -> T>(&mut self, f: F) -> T {
390 f(&mut [])
391 }
392}
393
394impl<X: Dim> Shape for (X,) {
395 type Head = X;
396 type Tail = ();
397
398 type Reverse = Self;
399 type Prepend<D: Dim> = (D, X);
400
401 type Dyn = (Dyn,);
402 type Merge<S: Shape> = <S::Tail as Shape>::Prepend<X::Merge<S::Head>>;
403
404 type Buffer<T, A: Allocator> = X::Buffer<T, (), A>;
405 type Layout<L: Layout> = Strided;
406
407 type Dims<T: Copy + Debug + Default + Hash + Ord + Send + Sync> = [T; 1];
408
409 const RANK: Option<usize> = Some(1);
410
411 #[inline]
412 fn new(rank: usize) -> Self {
413 assert!(rank == 1, "invalid rank");
414
415 Self::default()
416 }
417
418 #[inline]
419 fn with_dims<T, F: FnOnce(&[usize]) -> T>(&self, f: F) -> T {
420 f(&[self.0.size()])
421 }
422
423 #[inline]
424 fn with_mut_dims<T, F: FnOnce(&mut [usize]) -> T>(&mut self, f: F) -> T {
425 let mut dims = [self.0.size()];
426 let value = f(&mut dims);
427
428 *self = (X::from_size(dims[0]),);
429
430 value
431 }
432}
433
434#[cfg(not(feature = "nightly"))]
435macro_rules! dyn_shape {
436 ($($yz:tt),+) => {
437 <<Self::Tail as Shape>::Dyn as Shape>::Prepend<Dyn>
438 };
439}
440
441#[cfg(feature = "nightly")]
442macro_rules! dyn_shape {
443 ($($yz:tt),+) => {
444 (Dyn $(,${ignore($yz)} Dyn)+)
445 };
446}
447
448macro_rules! impl_shape {
449 ($n:tt, ($($jk:tt),+), ($($yz:tt),+), $reverse:tt, $prepend:tt) => {
450 impl<X: Dim $(,$yz: Dim)+> Shape for (X $(,$yz)+) {
451 type Head = X;
452 type Tail = ($($yz,)+);
453
454 type Reverse = $reverse;
455 type Prepend<D: Dim> = $prepend;
456
457 type Dyn = dyn_shape!($($yz),+);
458 type Merge<S: Shape> =
459 <<Self::Tail as Shape>::Merge<S::Tail> as Shape>::Prepend<X::Merge<S::Head>>;
460
461 type Buffer<T, A: Allocator> = X::Buffer<T, Self::Tail, A>;
462 type Layout<L: Layout> = Strided;
463
464 type Dims<T: Copy + Debug + Default + Hash + Ord + Send + Sync> = [T; $n];
465
466 const RANK: Option<usize> = Some($n);
467
468 #[inline]
469 fn new(rank: usize) -> Self {
470 assert!(rank == $n, "invalid rank");
471
472 Self::default()
473 }
474
475 #[inline]
476 fn with_dims<T, F: FnOnce(&[usize]) -> T>(&self, f: F) -> T {
477 f(&[self.0.size() $(,self.$jk.size())+])
478 }
479
480 #[inline]
481 fn with_mut_dims<T, F: FnOnce(&mut [usize]) -> T>(&mut self, f: F) -> T {
482 let mut dims = [self.0.size() $(,self.$jk.size())+];
483 let value = f(&mut dims);
484
485 *self = (X::from_size(dims[0]) $(,$yz::from_size(dims[$jk]))+);
486
487 value
488 }
489 }
490 };
491}
492
493impl_shape!(2, (1), (Y), (Y, X), (D, X, Y));
494impl_shape!(3, (1, 2), (Y, Z), (Z, Y, X), (D, X, Y, Z));
495impl_shape!(4, (1, 2, 3), (Y, Z, W), (W, Z, Y, X), (D, X, Y, Z, W));
496impl_shape!(5, (1, 2, 3, 4), (Y, Z, W, U), (U, W, Z, Y, X), (D, X, Y, Z, W, U));
497impl_shape!(6, (1, 2, 3, 4, 5), (Y, Z, W, U, V), (V, U, W, Z, Y, X), DynRank);
498
499macro_rules! impl_const_shape {
500 (($($xyz:tt),*), $inner:ty, $with_const:tt) => {
501 impl<$(const $xyz: usize),*> ConstShape for ($(Const<$xyz>,)*) {
502 type Inner<T> = $inner;
503 type WithConst<T, const N: usize, A: Allocator> =
504 $with_const<T, Self::Prepend<Const<N>>, A>;
505 }
506 };
507}
508
509impl_const_shape!((), T, StaticBuffer);
510impl_const_shape!((X), [T; X], StaticBuffer);
511impl_const_shape!((X, Y), [[T; Y]; X], StaticBuffer);
512impl_const_shape!((X, Y, Z), [[[T; Z]; Y]; X], StaticBuffer);
513impl_const_shape!((X, Y, Z, W), [[[[T; W]; Z]; Y]; X], StaticBuffer);
514impl_const_shape!((X, Y, Z, W, U), [[[[[T; U]; W]; Z]; Y]; X], StaticBuffer);
515impl_const_shape!((X, Y, Z, W, U, V), [[[[[[T; V]; U]; W]; Z]; Y]; X], DynBuffer);
516
517impl<S: Shape> IntoShape for S {
518 type IntoShape = S;
519
520 #[inline]
521 fn into_shape(self) -> S {
522 self
523 }
524
525 #[inline]
526 fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
527 self.with_dims(f)
528 }
529}
530
531impl<const N: usize> IntoShape for &[usize; N] {
532 type IntoShape = DynRank;
533
534 #[inline]
535 fn into_shape(self) -> DynRank {
536 Shape::from_dims(self)
537 }
538
539 #[inline]
540 fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
541 f(self)
542 }
543}
544
545impl IntoShape for &[usize] {
546 type IntoShape = DynRank;
547
548 #[inline]
549 fn into_shape(self) -> DynRank {
550 Shape::from_dims(self)
551 }
552
553 #[inline]
554 fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
555 f(self)
556 }
557}
558
559impl IntoShape for Box<[usize]> {
560 type IntoShape = DynRank;
561
562 #[inline]
563 fn into_shape(self) -> DynRank {
564 DynRank::Dyn(self)
565 }
566
567 #[inline]
568 fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
569 f(&self)
570 }
571}
572
573impl<const N: usize> IntoShape for Const<N> {
574 type IntoShape = (Self,);
575
576 #[inline]
577 fn into_shape(self) -> Self::IntoShape {
578 (self,)
579 }
580
581 #[inline]
582 fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
583 f(&[N])
584 }
585}
586
587impl IntoShape for Dyn {
588 type IntoShape = (Self,);
589
590 #[inline]
591 fn into_shape(self) -> Self::IntoShape {
592 (self,)
593 }
594
595 #[inline]
596 fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
597 f(&[self])
598 }
599}
600
601impl IntoShape for Vec<usize> {
602 type IntoShape = DynRank;
603
604 #[inline]
605 fn into_shape(self) -> DynRank {
606 DynRank::Dyn(self.into())
607 }
608
609 #[inline]
610 fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
611 f(&self)
612 }
613}
614
615macro_rules! impl_into_shape {
616 ($n:tt, $shape:ty) => {
617 impl IntoShape for [usize; $n] {
618 type IntoShape = $shape;
619
620 #[inline]
621 fn into_shape(self) -> Self::IntoShape {
622 Shape::from_dims(&self)
623 }
624
625 #[inline]
626 fn into_dims<T, F: FnOnce(&[usize]) -> T>(self, f: F) -> T {
627 f(&self)
628 }
629 }
630 };
631}
632
633impl_into_shape!(0, ());
634impl_into_shape!(1, (Dyn,));
635impl_into_shape!(2, (Dyn, Dyn));
636impl_into_shape!(3, (Dyn, Dyn, Dyn));
637impl_into_shape!(4, (Dyn, Dyn, Dyn, Dyn));
638impl_into_shape!(5, (Dyn, Dyn, Dyn, Dyn, Dyn));
639impl_into_shape!(6, (Dyn, Dyn, Dyn, Dyn, Dyn, Dyn));