1use super::indexing::ravel_index;
4use alloc::format;
5use alloc::string::{String, ToString};
6use alloc::vec::Vec;
7use core::fmt::{Debug, Display, Formatter};
8use core::str::FromStr;
9use core::{
10 ops::{Deref, DerefMut, Index, IndexMut, Range},
11 slice::{Iter, IterMut, SliceIndex},
12};
13use serde::{Deserialize, Serialize};
14use smallvec::{SmallVec, smallvec};
15
16pub use crate::errors::ExpressionError;
17use crate::{
18 INLINE_DIMS,
19 indexing::{AsIndex, AsSize},
20};
21
22#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
24pub struct Shape {
25 dims: SmallVec<[usize; INLINE_DIMS]>,
27}
28
29#[allow(missing_docs)]
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum MetadataError {
33 RankMismatch { left: usize, right: usize },
35 IncompatibleDims {
37 left: usize,
38 right: usize,
39 dim: usize,
40 },
41 OutOfBounds { dim: usize, rank: usize },
43 IncompatibleShapes { left: Shape, right: Shape },
45 Invalid { reason: String },
47}
48
49impl MetadataError {
50 fn empty() -> Self {
51 Self::Invalid {
52 reason: "Shape is empty.".into(),
53 }
54 }
55}
56
57impl Shape {
58 pub fn new<const D: usize>(dims: [usize; D]) -> Self {
60 Self {
62 dims: SmallVec::from_slice(&dims),
63 }
64 }
65
66 pub fn new_raw(dims: SmallVec<[usize; INLINE_DIMS]>) -> Self {
68 Self { dims }
69 }
70
71 pub fn num_elements(&self) -> usize {
73 self.dims.iter().product()
74 }
75
76 pub fn num_dims(&self) -> usize {
80 self.dims.len()
81 }
82
83 pub fn rank(&self) -> usize {
87 self.num_dims()
88 }
89
90 pub fn dims<const D: usize>(&self) -> [usize; D] {
93 let mut dims = [1; D];
94 dims[..D].copy_from_slice(&self.dims[..D]);
95 dims
96 }
97
98 pub fn flatten(mut self) -> Self {
100 self.dims = SmallVec::from_slice(&[self.num_elements()]);
101 self
102 }
103
104 pub fn flatten_dims(self, start_dim: impl AsIndex, end_dim: impl AsIndex) -> Self {
134 let rank = self.rank();
135 let start = start_dim.expect_dim_index(rank);
136 let end = end_dim.expect_dim_index(rank);
137
138 assert!(
139 start <= end,
140 "start_dim ({start}) must be <= than end_dim ({end})"
141 );
142
143 let existing = self.dims;
144
145 let flattened_size = existing[start..=end].iter().product();
146
147 let new_rank = rank - (end - start);
148 let mut dims = smallvec![0; new_rank];
149 dims[..start].copy_from_slice(&existing[..start]);
150 dims[start] = flattened_size;
151 dims[start + 1..].copy_from_slice(&existing[end + 1..]);
152
153 Self { dims }
154 }
155
156 pub fn ravel_index<I: AsIndex>(&self, indices: &[I]) -> usize {
170 ravel_index(indices, &self.dims)
171 }
172
173 pub fn into_ranges(self) -> Vec<Range<usize>> {
175 self.iter().map(|&d| 0..d).collect()
176 }
177
178 pub fn to_vec(&self) -> Vec<usize> {
180 self.dims.to_vec()
181 }
182
183 pub fn iter(&self) -> Iter<'_, usize> {
185 self.dims.iter()
186 }
187
188 pub fn iter_mut(&mut self) -> IterMut<'_, usize> {
190 self.dims.iter_mut()
191 }
192
193 pub fn as_slice(&self) -> &[usize] {
195 &self.dims
196 }
197
198 pub fn as_mut_slice(&mut self) -> &mut [usize] {
200 &mut self.dims
201 }
202
203 pub fn insert(&mut self, index: usize, size: usize) {
205 self.dims.insert(index, size);
206 }
207
208 pub fn remove(&mut self, index: usize) -> usize {
210 self.dims.remove(index)
211 }
212
213 pub fn push(&mut self, size: usize) {
215 self.dims.push(size)
216 }
217
218 pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
220 self.dims.extend(iter)
221 }
222
223 pub fn swapped(mut self, dim1: usize, dim2: usize) -> Result<Self, MetadataError> {
225 if dim1 >= self.rank() {
226 return Err(MetadataError::OutOfBounds {
227 dim: dim1,
228 rank: self.rank(),
229 });
230 }
231 if dim2 >= self.rank() {
232 return Err(MetadataError::OutOfBounds {
233 dim: dim2,
234 rank: self.rank(),
235 });
236 }
237 self.dims.swap(dim1, dim2);
238 Ok(self)
239 }
240
241 pub fn permute(&mut self, axes: &[usize]) -> Result<(), MetadataError> {
243 if axes.len() != self.rank() {
244 return Err(MetadataError::RankMismatch {
245 left: self.rank(),
246 right: axes.len(),
247 });
248 }
249 debug_assert!(axes.iter().all(|i| i < &self.rank()));
250
251 self.dims = axes.iter().map(|&i| self.dims[i]).collect();
252 Ok(())
253 }
254
255 pub fn permuted(mut self, axes: &[usize]) -> Result<Self, MetadataError> {
257 self.permute(axes)?;
258 Ok(self)
259 }
260
261 pub fn repeat(mut self, dim: usize, times: usize) -> Result<Shape, MetadataError> {
263 if dim >= self.rank() {
264 return Err(MetadataError::OutOfBounds {
265 dim,
266 rank: self.rank(),
267 });
268 }
269
270 self.dims[dim] *= times;
271 Ok(self)
272 }
273
274 pub fn reduce(mut self, dim: usize) -> Result<Shape, MetadataError> {
276 if dim >= self.rank() {
277 return Err(MetadataError::OutOfBounds {
278 dim,
279 rank: self.rank(),
280 });
281 }
282
283 self.dims[dim] = 1;
284 Ok(self)
285 }
286
287 pub fn cat<'a, I>(shapes: I, dim: usize) -> Result<Self, MetadataError>
289 where
290 I: IntoIterator<Item = &'a Shape>,
291 {
292 let mut iter = shapes.into_iter();
293
294 let first = iter.next().ok_or(MetadataError::empty())?;
295
296 if dim >= first.rank() {
297 return Err(MetadataError::OutOfBounds {
298 dim,
299 rank: first.rank(),
300 });
301 }
302
303 let mut shape = first.clone();
304
305 for s in iter {
306 if s.rank() != shape.rank() {
307 return Err(MetadataError::RankMismatch {
308 left: shape.rank(),
309 right: s.rank(),
310 });
311 }
312
313 if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] {
314 return Err(MetadataError::IncompatibleShapes {
315 left: shape.clone(),
316 right: s.clone(),
317 });
318 }
319
320 shape[dim] += s[dim];
321 }
322
323 Ok(shape)
324 }
325
326 pub fn broadcast(&self, other: &Self) -> Result<Self, MetadataError> {
335 Self::broadcast_many([self, other])
336 }
337
338 pub fn broadcast_many<'a, I>(shapes: I) -> Result<Self, MetadataError>
342 where
343 I: IntoIterator<Item = &'a Shape>,
344 {
345 let mut iter = shapes.into_iter();
346 let mut broadcasted = iter.next().ok_or(MetadataError::empty())?.clone();
347 let rank = broadcasted.rank();
348
349 for shape in iter {
350 if shape.rank() != rank {
351 return Err(MetadataError::RankMismatch {
352 left: rank,
353 right: shape.rank(),
354 });
355 }
356
357 for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() {
358 match (*d_lhs, d_rhs) {
359 (a, b) if a == b => {} (1, b) => *d_lhs = b, (_a, 1) => {} _ => {
363 return Err(MetadataError::IncompatibleDims {
364 left: *d_lhs,
365 right: d_rhs,
366 dim,
367 });
368 }
369 }
370 }
371 }
372
373 Ok(broadcasted)
374 }
375
376 pub fn expand(&self, target: Shape) -> Result<Shape, MetadataError> {
378 let target_rank = target.rank();
379 if self.rank() > target_rank {
380 return Err(MetadataError::RankMismatch {
381 left: self.rank(),
382 right: target_rank,
383 });
384 }
385
386 for (i, (dim_target, dim_self)) in target.iter().rev().zip(self.iter().rev()).enumerate() {
387 if dim_self != dim_target && *dim_self != 1 {
388 return Err(MetadataError::IncompatibleDims {
389 left: *dim_self,
390 right: *dim_target,
391 dim: target_rank - i - 1,
392 });
393 }
394 }
395
396 Ok(target)
397 }
398
399 pub fn reshape<A, T>(&self, args: A) -> Result<Shape, MetadataError>
401 where
402 A: AsRef<[T]> + Debug,
403 T: AsIndex,
404 {
405 let args = args.as_ref();
406 let mut infer_index = None;
407 let mut dims = Vec::new();
408
409 let mut new_size = 1;
410
411 for (idx, &s) in args.iter().enumerate() {
412 let s = s.as_index();
413 if s > 0 {
414 let s = s as usize;
415 new_size *= s;
416 dims.push(s);
417 } else if s == 0 {
418 let s = self.dims[idx];
421 new_size *= s;
422 dims.push(s);
423 } else if s == -1 {
424 match infer_index {
425 None => {
426 infer_index = Some(idx);
427 dims.push(1);
429 }
430 Some(_) => {
431 return Err(MetadataError::Invalid {
432 reason: "Repeated -1 in reshape".to_string(),
433 });
434 }
435 }
436 } else {
437 return Err(MetadataError::Invalid {
438 reason: "The given shape cannot contain negative dimensions (other than -1)."
439 .to_string(),
440 });
441 }
442 }
443
444 let source_size = self.num_elements();
445 match infer_index {
446 None => {
447 if source_size != new_size {
448 return Err(MetadataError::Invalid {
449 reason: format!(
450 "The given shape doesn't have the same number of elements as the current shape. Current shape: {self}, target shape: {dims:?}.",
451 ),
452 });
453 }
454 }
455 Some(idx) => {
456 if !source_size.is_multiple_of(new_size) {
457 return Err(MetadataError::Invalid {
458 reason: format!(
459 "Cannot infer a valid target shape. Current shape: {self}, target dimensions: {args:?}."
460 ),
461 });
462 }
463 dims[idx] = source_size / new_size;
464 }
465 }
466
467 Ok(dims.into())
468 }
469}
470
471#[macro_export]
472macro_rules! shape {
473 (@one $x:expr) => (1usize);
474 () => (
475 $crate::Shape::new_raw($crate::SmallVec::new())
476 );
477 ($elem:expr; $n:expr) => ({
478 $crate::Shape::new_raw($crate::smallvec!($elem; $n))
479 });
480 ($($x:expr),+$(,)?) => ({
481 $crate::Shape::new_raw($crate::smallvec!($($x),*))
482 });
483}
484
485pub fn calculate_matmul_output(lhs: &Shape, rhs: &Shape) -> Result<Shape, MetadataError> {
490 let rank = lhs.rank();
491 if rank != rhs.rank() {
492 return Err(MetadataError::RankMismatch {
493 left: rank,
494 right: rhs.rank(),
495 });
496 }
497
498 if lhs[rank - 1] != rhs[rank - 2] {
499 return Err(MetadataError::IncompatibleShapes {
500 left: lhs.clone(),
501 right: rhs.clone(),
502 });
503 }
504
505 let mut shape = if rank > 2 {
506 Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))?
508 } else {
509 Shape::new([])
510 };
511 shape.extend([lhs[rank - 2], rhs[rank - 1]]);
512
513 Ok(shape)
514}
515
516impl Display for Shape {
517 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
518 self.dims.fmt(f)
519 }
520}
521
522impl FromStr for Shape {
523 type Err = ExpressionError;
524
525 fn from_str(source: &str) -> Result<Self, Self::Err> {
526 let mut s = source.trim();
527
528 const DELIMS: [(&str, &str); 2] = [("[", "]"), ("(", ")")];
529
530 for (open, close) in DELIMS {
531 if let Some(p) = s.strip_prefix(open) {
532 if let Some(p) = p.strip_suffix(close) {
533 s = p.trim();
534 break;
535 } else {
536 return Err(ExpressionError::ParseError {
537 message: "Unbalanced delimiters".to_string(),
538 source: source.to_string(),
539 });
540 }
541 }
542 }
543
544 if s.is_empty() {
545 return Ok(Shape::new([]));
546 }
547
548 let dims = s
549 .split(',')
550 .map(|dim_str| {
551 dim_str
552 .trim()
553 .parse::<usize>()
554 .map_err(|_| ExpressionError::ParseError {
555 message: "Unable to parse shape".to_string(),
556 source: source.to_string(),
557 })
558 })
559 .collect::<Result<SmallVec<_>, ExpressionError>>()?;
560
561 if dims.is_empty() {
562 unreachable!("Split should have returned at least one element");
563 }
564
565 Ok(Shape { dims })
566 }
567}
568
569impl<Idx> Index<Idx> for Shape
570where
571 Idx: SliceIndex<[usize]>,
572{
573 type Output = Idx::Output;
574
575 fn index(&self, index: Idx) -> &Self::Output {
576 &self.dims[index]
577 }
578}
579
580impl<Idx> IndexMut<Idx> for Shape
581where
582 Idx: SliceIndex<[usize]>,
583{
584 fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
585 &mut self.dims[index]
586 }
587}
588
589impl Deref for Shape {
591 type Target = [usize];
592
593 fn deref(&self) -> &Self::Target {
594 &self.dims
595 }
596}
597
598impl DerefMut for Shape {
600 fn deref_mut(&mut self) -> &mut Self::Target {
601 &mut self.dims
602 }
603}
604impl AsRef<[usize]> for Shape {
609 fn as_ref(&self) -> &[usize] {
610 &self.dims
611 }
612}
613
614impl From<Shape> for Vec<usize> {
615 fn from(shape: Shape) -> Self {
616 shape.dims.to_vec()
617 }
618}
619
620impl<T, I> From<T> for Shape
621where
622 T: IntoIterator<Item = I>,
623 I: AsSize,
624{
625 fn from(dims: T) -> Self {
626 Shape {
627 dims: dims.into_iter().map(|d| d.as_size()).collect(),
628 }
629 }
630}
631
632impl From<&Shape> for Shape {
633 fn from(value: &Shape) -> Self {
634 value.clone()
635 }
636}
637
638impl<I: AsSize> FromIterator<I> for Shape {
639 fn from_iter<T: IntoIterator<Item = I>>(iter: T) -> Self {
640 Shape {
641 dims: iter.into_iter().map(|it| it.as_size()).collect(),
642 }
643 }
644}
645
646#[cfg(test)]
647#[allow(clippy::identity_op, reason = "useful for clarity")]
648mod tests {
649 use super::*;
650 use alloc::string::ToString;
651 use alloc::vec;
652
653 #[test]
654 fn test_shape_to_str() {
655 let shape = Shape::new([2, 3, 4, 5]);
656 assert_eq!(shape.to_string(), "[2, 3, 4, 5]");
657 }
658
659 #[test]
660 fn test_shape_from_str() {
661 assert_eq!(
662 "[2, 3, 4, 5]".parse::<Shape>().unwrap(),
663 Shape::new([2, 3, 4, 5])
664 );
665 assert_eq!(
666 "(2, 3, 4, 5)".parse::<Shape>().unwrap(),
667 Shape::new([2, 3, 4, 5])
668 );
669 assert_eq!(
670 "2, 3, 4, 5".parse::<Shape>().unwrap(),
671 Shape::new([2, 3, 4, 5])
672 );
673
674 assert_eq!("[2]".parse::<Shape>().unwrap(), Shape::new([2]));
675 assert_eq!("(2)".parse::<Shape>().unwrap(), Shape::new([2]));
676 assert_eq!("2".parse::<Shape>().unwrap(), Shape::new([2]));
677
678 assert_eq!("[]".parse::<Shape>().unwrap(), Shape::new([]));
679 assert_eq!("".parse::<Shape>().unwrap(), Shape::new([]));
680
681 assert_eq!(
682 "[".parse::<Shape>(),
683 Err(ExpressionError::ParseError {
684 message: "Unbalanced delimiters".to_string(),
685 source: "[".to_string()
686 })
687 );
688
689 assert_eq!(
690 "[[1]".parse::<Shape>(),
691 Err(ExpressionError::ParseError {
692 message: "Unable to parse shape".to_string(),
693 source: "[[1]".to_string()
694 })
695 );
696 assert_eq!(
697 "[[1]]".parse::<Shape>(),
698 Err(ExpressionError::ParseError {
699 message: "Unable to parse shape".to_string(),
700 source: "[[1]]".to_string()
701 })
702 );
703 assert_eq!(
704 "[1)".parse::<Shape>(),
705 Err(ExpressionError::ParseError {
706 message: "Unbalanced delimiters".to_string(),
707 source: "[1)".to_string()
708 })
709 );
710
711 assert_eq!(
712 "]".parse::<Shape>(),
713 Err(ExpressionError::ParseError {
714 message: "Unable to parse shape".to_string(),
715 source: "]".to_string()
716 })
717 );
718
719 assert_eq!(
720 "[a]".parse::<Shape>(),
721 Err(ExpressionError::ParseError {
722 message: "Unable to parse shape".to_string(),
723 source: "[a]".to_string()
724 })
725 );
726 }
727
728 #[test]
729 fn num_dims_and_rank() {
730 let dims = [2, 3, 4, 5];
731 let shape = Shape::new(dims);
732 assert_eq!(4, shape.num_dims());
733 assert_eq!(4, shape.rank());
734 }
735
736 #[test]
737 fn num_elements() {
738 let dims = [2, 3, 4, 5];
739 let shape = Shape::new(dims);
740 assert_eq!(120, shape.num_elements());
741 }
742
743 #[test]
744 #[allow(clippy::into_iter_on_ref)]
745 fn test_shape_into_iter() {
746 let dims = [2, 3, 4, 5];
747 let shape = Shape::new(dims);
748
749 assert_eq!(shape.into_iter().sum::<usize>(), 14);
750 }
751
752 #[test]
753 fn test_into_ranges() {
754 let dims = [2, 3, 4, 5];
755 let shape = Shape::new(dims);
756 assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
757 }
758
759 #[test]
760 fn test_to_vec() {
761 let dims = [2, 3, 4, 5];
762 let shape = Shape::new(dims);
763 assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]);
764 }
765
766 #[test]
767 fn test_shape_index() {
768 let shape = Shape::new([2, 3, 4, 5]);
769
770 assert_eq!(shape[0], 2);
771 assert_eq!(shape[1], 3);
772 assert_eq!(shape[2], 4);
773 assert_eq!(shape[3], 5);
774
775 assert_eq!(shape[1..3], *&[3, 4]);
777 assert_eq!(shape[1..=2], *&[3, 4]);
778 assert_eq!(shape[..], *&[2, 3, 4, 5]);
779 }
780
781 #[test]
782 fn test_shape_slice_methods() {
783 let shape = Shape::new([2, 3, 4, 5]);
784
785 let dim = shape.first();
786 assert_eq!(dim, Some(&2));
787 let dim = shape.last();
788 assert_eq!(dim, Some(&5));
789
790 assert!(!shape.is_empty());
791 let shape = Shape::new([]);
792 assert!(shape.is_empty());
793 }
794
795 #[test]
796 fn test_shape_iter() {
797 let dims = [2, 3, 4, 5];
798 let shape = Shape::new(dims);
799
800 for (d, sd) in dims.iter().zip(shape.iter()) {
801 assert_eq!(d, sd);
802 }
803 }
804
805 #[test]
806 fn test_shape_iter_mut() {
807 let mut shape = Shape::new([2, 3, 4, 5]);
808
809 for d in shape.iter_mut() {
810 *d += 1;
811 }
812
813 assert_eq!(shape.as_slice(), &[3, 4, 5, 6]);
814 }
815
816 #[test]
817 fn test_shape_as_slice() {
818 let dims = [2, 3, 4, 5];
819 let shape = Shape::new(dims);
820
821 assert_eq!(shape.as_slice(), dims.as_slice());
822
823 let shape_slice: &[usize] = &shape;
825 assert_eq!(shape_slice, *&[2, 3, 4, 5]);
826 }
827
828 #[test]
829 fn test_shape_as_mut_slice() {
830 let mut dims = [2, 3, 4, 5];
831 let mut shape = Shape::new(dims);
832
833 let shape_mut = shape.as_mut_slice();
834 assert_eq!(shape_mut, dims.as_mut_slice());
835 shape_mut[1] = 6;
836
837 assert_eq!(shape_mut, &[2, 6, 4, 5]);
838
839 let mut shape = Shape::new(dims);
840 let shape = &mut shape[..];
841 shape[1] = 6;
842
843 assert_eq!(shape, shape_mut)
844 }
845
846 #[test]
847 fn test_shape_flatten() {
848 let shape = Shape::new([2, 3, 4, 5]);
849 assert_eq!(shape.num_elements(), 120);
850
851 let shape = shape.flatten();
852 assert_eq!(shape.num_elements(), 120);
853 assert_eq!(shape.as_slice(), &[120]);
854 }
855
856 #[test]
857 fn test_ravel() {
858 let shape = Shape::new([2, 3, 4, 5]);
859
860 assert_eq!(shape.ravel_index(&[0, 0, 0, 0]), 0);
861 assert_eq!(
862 shape.ravel_index(&[1, 2, 3, 4]),
863 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
864 );
865 }
866
867 #[test]
868 fn test_shape_insert_remove_push() {
869 let dims = [2, 3, 4, 5];
870 let mut shape = Shape::new(dims);
871 let size = 6;
872 shape.insert(1, size);
873
874 assert_eq!(shape, Shape::new([2, 6, 3, 4, 5]));
875
876 let removed = shape.remove(1);
877 assert_eq!(removed, size);
878 assert_eq!(shape, Shape::new(dims));
879
880 shape.push(6);
881 assert_eq!(shape, Shape::new([2, 3, 4, 5, 6]));
882 }
883
884 #[test]
885 fn test_shape_swap_permute() {
886 let dims = [2, 3, 4, 5];
887 let shape = Shape::new(dims);
888 let shape = shape.swapped(1, 2).unwrap();
889
890 assert_eq!(shape.as_slice(), &[2, 4, 3, 5]);
891
892 let shape = shape.permuted(&[0, 2, 1, 3]).unwrap();
893 assert_eq!(shape, Shape::new(dims));
894 }
895
896 #[test]
897 #[should_panic]
898 fn test_shape_swap_out_of_bounds() {
899 let shape = Shape::new([2, 3, 4, 5]);
900
901 shape.swapped(0, 4).unwrap();
902 }
903
904 #[test]
905 #[should_panic]
906 fn test_shape_permute_incomplete() {
907 let shape = Shape::new([2, 3, 4, 5]);
908
909 shape.permuted(&[0, 2, 1]).unwrap();
910 }
911
912 #[test]
913 fn test_shape_repeat() {
914 let shape = Shape::new([2, 3, 4, 5]);
915
916 let out = shape.repeat(2, 3).unwrap();
917 assert_eq!(out, Shape::new([2, 3, 12, 5]));
918 }
919
920 #[test]
921 fn test_shape_repeat_invalid() {
922 let shape = Shape::new([2, 3, 4, 5]);
923
924 let out = shape.repeat(5, 3);
925 assert_eq!(out, Err(MetadataError::OutOfBounds { dim: 5, rank: 4 }));
926 }
927
928 #[test]
929 fn test_shape_reduce() {
930 let shape = Shape::new([2, 3, 4, 5]);
931
932 let out = shape.reduce(2).unwrap();
933 assert_eq!(out, Shape::new([2, 3, 1, 5]));
934 }
935
936 #[test]
937 fn test_shape_reduce_invalid() {
938 let shape = Shape::new([2, 3, 4, 5]);
939
940 let out = shape.reduce(5);
941 assert_eq!(out, Err(MetadataError::OutOfBounds { dim: 5, rank: 4 }));
942 }
943
944 #[test]
945 fn test_shape_broadcast_binary() {
946 let lhs = Shape::new([1, 1, 2, 4]);
947 let rhs = Shape::new([7, 6, 2, 1]);
948
949 let out = lhs.broadcast(&rhs).unwrap();
950 assert_eq!(out, Shape::new([7, 6, 2, 4]));
951 }
952
953 #[test]
954 fn test_shape_broadcast_rank_mismatch() {
955 let lhs = Shape::new([1, 2, 4]);
956 let rhs = Shape::new([7, 6, 2, 4]);
957
958 let out = lhs.broadcast(&rhs);
959 assert_eq!(out, Err(MetadataError::RankMismatch { left: 3, right: 4 }));
960 }
961
962 #[test]
963 fn test_shape_broadcast_incompatible_dims() {
964 let lhs = Shape::new([1, 2, 2, 4]);
965 let rhs = Shape::new([7, 6, 2, 1]);
966
967 let out = lhs.broadcast(&rhs);
968 assert_eq!(
969 out,
970 Err(MetadataError::IncompatibleDims {
971 left: 2,
972 right: 6,
973 dim: 1
974 })
975 );
976 }
977
978 #[test]
979 fn test_shape_broadcast_many() {
980 let s1 = Shape::new([1, 1, 2, 4]);
981 let s2 = Shape::new([7, 1, 2, 1]);
982 let s3 = Shape::new([7, 6, 1, 1]);
983
984 let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap();
985 assert_eq!(out, Shape::new([7, 6, 2, 4]));
986 }
987
988 #[test]
989 fn test_shape_broadcast_many_rank_mismatch() {
990 let s1 = Shape::new([1, 1, 2, 4]);
991 let s2 = Shape::new([7, 1, 2, 1]);
992 let s3 = Shape::new([1, 6, 1]);
993
994 let out = Shape::broadcast_many([&s1, &s2, &s3]);
995 assert_eq!(out, Err(MetadataError::RankMismatch { left: 4, right: 3 }));
996 }
997
998 #[test]
999 fn test_shape_broadcast_many_incompatible_dims() {
1000 let s1 = Shape::new([1, 1, 2, 4]);
1001 let s2 = Shape::new([7, 1, 2, 1]);
1002 let s3 = Shape::new([4, 6, 1, 1]);
1003
1004 let out = Shape::broadcast_many([&s1, &s2, &s3]);
1005 assert_eq!(
1006 out,
1007 Err(MetadataError::IncompatibleDims {
1008 left: 7,
1009 right: 4,
1010 dim: 0
1011 })
1012 );
1013 }
1014
1015 #[test]
1016 fn test_shape_broadcast_many_empty() {
1017 let out = Shape::broadcast_many(&[]);
1018 assert_eq!(out, Err(MetadataError::empty()));
1019 }
1020
1021 #[test]
1022 fn test_shape_matmul_2d() {
1023 let lhs = Shape::new([2, 4]);
1024 let rhs = Shape::new([4, 2]);
1025 let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1026 assert_eq!(out, Shape::new([2, 2]));
1027 }
1028
1029 #[test]
1030 fn test_shape_matmul_4d_broadcasted() {
1031 let lhs = Shape::new([1, 3, 2, 4]);
1032 let rhs = Shape::new([2, 1, 4, 2]);
1033 let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1034 assert_eq!(out, Shape::new([2, 3, 2, 2]));
1035 }
1036
1037 #[test]
1038 fn test_shape_matmul_invalid_rank() {
1039 let lhs = Shape::new([3, 2, 4]);
1040 let rhs = Shape::new([2, 1, 4, 2]);
1041 let out = calculate_matmul_output(&lhs, &rhs);
1042 assert_eq!(out, Err(MetadataError::RankMismatch { left: 3, right: 4 }));
1043 }
1044
1045 #[test]
1046 fn test_shape_matmul_invalid_shape() {
1047 let lhs = Shape::new([1, 3, 2, 4]);
1048 let rhs = Shape::new([2, 1, 3, 2]);
1049 let out = calculate_matmul_output(&lhs, &rhs);
1050 assert_eq!(
1051 out,
1052 Err(MetadataError::IncompatibleShapes {
1053 left: lhs,
1054 right: rhs
1055 })
1056 );
1057 }
1058
1059 #[test]
1060 fn test_shape_matmul_invalid_broadcast() {
1061 let lhs = Shape::new([1, 3, 2, 4]);
1062 let rhs = Shape::new([2, 2, 4, 2]);
1063 let out = calculate_matmul_output(&lhs, &rhs);
1064 assert_eq!(
1065 out,
1066 Err(MetadataError::IncompatibleDims {
1067 left: 3,
1068 right: 2,
1069 dim: 1
1070 })
1071 );
1072 }
1073
1074 #[test]
1075 fn test_shape_cat() {
1076 let s1 = Shape::new([2, 3, 4, 5]);
1077 let s2 = Shape::new([1, 3, 4, 5]);
1078 let s3 = Shape::new([4, 3, 4, 5]);
1079
1080 let out = Shape::cat(&[s1, s2, s3], 0).unwrap();
1081 assert_eq!(out, Shape::new([7, 3, 4, 5]));
1082
1083 let s1 = Shape::new([2, 3, 4, 5]);
1084 let s2 = Shape::new([2, 3, 2, 5]);
1085 let s3 = Shape::new([2, 3, 1, 5]);
1086
1087 let out = Shape::cat(&[s1, s2, s3], 2).unwrap();
1088 assert_eq!(out, Shape::new([2, 3, 7, 5]));
1089 }
1090
1091 #[test]
1092 fn test_shape_cat_empty() {
1093 let out = Shape::cat(&[], 0);
1094 assert_eq!(out, Err(MetadataError::empty()));
1095 }
1096
1097 #[test]
1098 fn test_shape_cat_dim_out_of_bounds() {
1099 let s1 = Shape::new([2, 3, 4, 5]);
1100 let s2 = Shape::new([2, 3, 4, 5]);
1101 let out = Shape::cat(&[s1, s2], 4);
1102 assert_eq!(out, Err(MetadataError::OutOfBounds { dim: 4, rank: 4 }));
1103 }
1104
1105 #[test]
1106 fn test_shape_cat_rank_mismatch() {
1107 let s1 = Shape::new([2, 3, 4, 5]);
1108 let s2 = Shape::new([2, 3, 4, 5, 6]);
1109 let out = Shape::cat(&[s1, s2], 0);
1110 assert_eq!(out, Err(MetadataError::RankMismatch { left: 4, right: 5 }));
1111 }
1112
1113 #[test]
1114 fn test_shape_cat_incompatible_shapes() {
1115 let s1 = Shape::new([2, 3, 4, 5]);
1116 let s2 = Shape::new([1, 3, 4, 5]);
1117 let out = Shape::cat(&[s1.clone(), s2.clone()], 1);
1118
1119 assert_eq!(
1120 out,
1121 Err(MetadataError::IncompatibleShapes {
1122 left: s1,
1123 right: s2
1124 })
1125 );
1126 }
1127
1128 #[test]
1129 fn test_shape_expand() {
1130 let shape = Shape::new([1, 3, 1]);
1131 let expanded = Shape::new([2, 3, 4]);
1132 let out = shape.expand(expanded.clone()).unwrap();
1133 assert_eq!(out, expanded);
1134 }
1135
1136 #[test]
1137 fn test_shape_expand_higher_rank() {
1138 let shape = Shape::new([1, 4]);
1139 let expanded = Shape::new([2, 3, 4]);
1140 let out = shape.expand(expanded.clone()).unwrap();
1141 assert_eq!(out, expanded);
1142 }
1143
1144 #[test]
1145 fn test_shape_expand_invalid_rank() {
1146 let shape = Shape::new([1, 3, 1]);
1147 let expanded = Shape::new([3, 4]);
1148 let out = shape.expand(expanded);
1149 assert_eq!(out, Err(MetadataError::RankMismatch { left: 3, right: 2 }));
1150 }
1151
1152 #[test]
1153 fn test_shape_expand_incompatible_dims() {
1154 let shape = Shape::new([1, 3, 2]);
1155 let expanded = Shape::new([2, 3, 4]);
1156 let out = shape.expand(expanded);
1157 assert_eq!(
1158 out,
1159 Err(MetadataError::IncompatibleDims {
1160 left: 2,
1161 right: 4,
1162 dim: 2
1163 })
1164 );
1165 }
1166
1167 #[test]
1168 fn test_shape_reshape() {
1169 let shape = Shape::new([2, 3, 4, 5]);
1170 let reshaped = Shape::new([1, 2, 12, 5]);
1171 let out = shape.reshape(reshaped.clone()).unwrap();
1172 assert_eq!(out, reshaped);
1173 }
1174
1175 #[test]
1176 fn test_shape_reshape_invalid() {
1177 let shape = Shape::new([2, 3, 4, 5]);
1178 let reshaped = Shape::new([2, 2, 12, 5]);
1179 let out = shape.reshape(reshaped.clone());
1180 assert_eq!(
1181 out,
1182 Err(MetadataError::Invalid {
1183 reason: "The given shape doesn't have the same number of elements as the current shape. Current shape: [2, 3, 4, 5], target shape: [2, 2, 12, 5].".into(),
1184 })
1185 );
1186 }
1187
1188 #[test]
1189 fn test_shape_reshape_invalid_inferred() {
1190 let shape = Shape::new([2, 4]);
1191 let out = shape.reshape([-1, 3]);
1192 assert_eq!(
1193 out,
1194 Err(MetadataError::Invalid {
1195 reason: "Cannot infer a valid target shape. Current shape: [2, 4], target dimensions: [-1, 3].".into(),
1196 })
1197 );
1198 }
1199
1200 #[test]
1201 fn test_flatten_dims() {
1202 let shape = Shape::new([2, 3, 4, 5]);
1203 let flattened = shape.flatten_dims(-2, 3);
1204 assert_eq!(flattened, Shape::new([2, 3, 20]));
1205 }
1206}