1use std::collections::HashMap;
21
22use crate::dtype::Float;
23use crate::error::{FerrotorchError, FerrotorchResult};
24use crate::storage::TensorStorage;
25use crate::tensor::Tensor;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum EinopsReduction {
34 Mean,
36 Sum,
38 Max,
40 Min,
42}
43
44#[derive(Debug, Clone, PartialEq)]
51enum AxisSpec {
52 Single(String),
54 Group(Vec<String>),
56}
57
58#[derive(Debug)]
60struct ParsedPattern {
61 left: Vec<AxisSpec>,
62 right: Vec<AxisSpec>,
63}
64
65fn flatten_axes(specs: &[AxisSpec]) -> Vec<String> {
67 let mut out = Vec::new();
68 for spec in specs {
69 match spec {
70 AxisSpec::Single(name) => out.push(name.clone()),
71 AxisSpec::Group(names) => out.extend(names.iter().cloned()),
72 }
73 }
74 out
75}
76
77fn parse_side(s: &str) -> FerrotorchResult<Vec<AxisSpec>> {
80 let s = s.trim();
81 let mut specs = Vec::new();
82 let mut chars = s.chars().peekable();
83
84 while let Some(&c) = chars.peek() {
85 if c.is_whitespace() {
86 chars.next();
87 continue;
88 }
89
90 if c == '(' {
91 chars.next();
93 let mut group = Vec::new();
94 loop {
95 while let Some(&c2) = chars.peek() {
97 if c2.is_whitespace() {
98 chars.next();
99 } else {
100 break;
101 }
102 }
103 match chars.peek() {
104 None => {
105 return Err(FerrotorchError::InvalidArgument {
106 message: "einops: unmatched '(' in pattern".into(),
107 });
108 }
109 Some(&')') => {
110 chars.next();
111 break;
112 }
113 _ => {}
114 }
115 let name = read_axis_name(&mut chars)?;
117 if name.is_empty() {
118 return Err(FerrotorchError::InvalidArgument {
119 message: "einops: empty axis name inside parentheses".into(),
120 });
121 }
122 group.push(name);
123 }
124 if group.is_empty() {
125 return Err(FerrotorchError::InvalidArgument {
126 message: "einops: empty parenthesized group".into(),
127 });
128 }
129 specs.push(AxisSpec::Group(group));
130 } else if c.is_ascii_alphanumeric() || c == '_' {
131 let name = read_axis_name(&mut chars)?;
132 specs.push(AxisSpec::Single(name));
133 } else {
134 return Err(FerrotorchError::InvalidArgument {
135 message: format!("einops: unexpected character '{c}' in pattern"),
136 });
137 }
138 }
139
140 Ok(specs)
141}
142
143fn read_axis_name(chars: &mut std::iter::Peekable<std::str::Chars<'_>>) -> FerrotorchResult<String> {
145 let mut name = String::new();
146 while let Some(&c) = chars.peek() {
147 if c.is_ascii_alphanumeric() || c == '_' {
148 name.push(c);
149 chars.next();
150 } else {
151 break;
152 }
153 }
154 Ok(name)
155}
156
157fn parse_pattern(pattern: &str) -> FerrotorchResult<ParsedPattern> {
159 let pattern = pattern.trim();
160 let (left_str, right_str) = pattern.split_once("->").ok_or_else(|| {
161 FerrotorchError::InvalidArgument {
162 message: format!("einops: pattern must contain '->', got: \"{pattern}\""),
163 }
164 })?;
165
166 let left = parse_side(left_str)?;
167 let right = parse_side(right_str)?;
168
169 let left_names = flatten_axes(&left);
171 let right_names = flatten_axes(&right);
172
173 let mut seen = HashMap::new();
174 for name in &left_names {
175 if seen.insert(name.as_str(), "left").is_some() {
176 return Err(FerrotorchError::InvalidArgument {
177 message: format!("einops: duplicate axis name '{name}' on left side of pattern"),
178 });
179 }
180 }
181 seen.clear();
182 for name in &right_names {
183 if seen.insert(name.as_str(), "right").is_some() {
184 return Err(FerrotorchError::InvalidArgument {
185 message: format!("einops: duplicate axis name '{name}' on right side of pattern"),
186 });
187 }
188 }
189
190 Ok(ParsedPattern { left, right })
191}
192
193fn resolve_sizes(
208 pattern: &ParsedPattern,
209 input_shape: &[usize],
210 axes_lengths: &[(&str, usize)],
211) -> FerrotorchResult<HashMap<String, usize>> {
212 let left_flat = flatten_axes(&pattern.left);
213 let right_flat = flatten_axes(&pattern.right);
214
215 let left_dim_count = pattern.left.len();
217 if left_dim_count != input_shape.len() {
218 return Err(FerrotorchError::InvalidArgument {
219 message: format!(
220 "einops: left side of pattern has {} axes but input tensor has {} dimensions",
221 left_dim_count,
222 input_shape.len()
223 ),
224 });
225 }
226
227 let user_sizes: HashMap<&str, usize> = axes_lengths.iter().copied().collect();
228 let mut sizes: HashMap<String, usize> = HashMap::new();
229
230 for (dim_idx, spec) in pattern.left.iter().enumerate() {
232 let dim_size = input_shape[dim_idx];
233 match spec {
234 AxisSpec::Single(name) => {
235 sizes.insert(name.clone(), dim_size);
236 }
237 AxisSpec::Group(names) => {
238 let mut unknown_idx: Option<usize> = None;
242 let mut known_product: usize = 1;
243
244 for (i, name) in names.iter().enumerate() {
245 if let Some(&sz) = user_sizes.get(name.as_str()) {
246 sizes.insert(name.clone(), sz);
247 known_product *= sz;
248 } else if let Some(&sz) = sizes.get(name) {
249 known_product *= sz;
252 } else {
253 if unknown_idx.is_some() {
254 return Err(FerrotorchError::InvalidArgument {
255 message: format!(
256 "einops: cannot infer sizes for split '({})' — \
257 provide sizes for all but one sub-axis via axes_lengths",
258 names.join(" ")
259 ),
260 });
261 }
262 unknown_idx = Some(i);
263 }
264 }
265
266 if let Some(ui) = unknown_idx {
267 if known_product == 0 || dim_size % known_product != 0 {
268 return Err(FerrotorchError::InvalidArgument {
269 message: format!(
270 "einops: dimension {} (size {}) is not divisible by \
271 known product {} for split '({})'",
272 dim_idx, dim_size, known_product,
273 names.join(" ")
274 ),
275 });
276 }
277 sizes.insert(names[ui].clone(), dim_size / known_product);
278 } else {
279 if known_product != dim_size {
281 return Err(FerrotorchError::ShapeMismatch {
282 message: format!(
283 "einops: split '({})' product {} does not match dimension {} size {}",
284 names.join(" "), known_product, dim_idx, dim_size
285 ),
286 });
287 }
288 }
289 }
290 }
291 }
292
293 for name in &right_flat {
296 if !sizes.contains_key(name) {
297 if let Some(&sz) = user_sizes.get(name.as_str()) {
298 sizes.insert(name.clone(), sz);
299 } else if !left_flat.contains(name) {
300 return Err(FerrotorchError::InvalidArgument {
301 message: format!(
302 "einops: axis '{name}' appears on the right but not the left \
303 and has no size in axes_lengths"
304 ),
305 });
306 }
307 }
308 }
309
310 Ok(sizes)
311}
312
313fn output_shape(right: &[AxisSpec], sizes: &HashMap<String, usize>) -> Vec<usize> {
320 right
321 .iter()
322 .map(|spec| match spec {
323 AxisSpec::Single(name) => *sizes.get(name).unwrap(),
324 AxisSpec::Group(names) => names.iter().map(|n| sizes.get(n).unwrap()).product(),
325 })
326 .collect()
327}
328
329fn flat_to_coords(mut flat: usize, shape: &[usize]) -> Vec<usize> {
331 let ndim = shape.len();
332 let mut coords = vec![0usize; ndim];
333 for d in (0..ndim).rev() {
334 coords[d] = flat % shape[d];
335 flat /= shape[d];
336 }
337 coords
338}
339
340fn coords_to_flat(coords: &[usize], shape: &[usize]) -> usize {
342 let mut flat = 0usize;
343 let mut stride = 1usize;
344 for d in (0..shape.len()).rev() {
345 flat += coords[d] * stride;
346 stride *= shape[d];
347 }
348 flat
349}
350
351fn elementary_shape(specs: &[AxisSpec], sizes: &HashMap<String, usize>) -> Vec<usize> {
354 let mut shape = Vec::new();
355 for spec in specs {
356 match spec {
357 AxisSpec::Single(name) => shape.push(*sizes.get(name).unwrap()),
358 AxisSpec::Group(names) => {
359 for n in names {
360 shape.push(*sizes.get(n).unwrap());
361 }
362 }
363 }
364 }
365 shape
366}
367
368fn rearrange_impl<T: Float>(
379 data: &[T],
380 _input_shape: &[usize],
381 pattern: &ParsedPattern,
382 sizes: &HashMap<String, usize>,
383 _output_shape: &[usize],
384) -> FerrotorchResult<Vec<T>> {
385 let left_names = flatten_axes(&pattern.left);
386 let right_names = flatten_axes(&pattern.right);
387 let left_elem_shape = elementary_shape(&pattern.left, sizes);
388 let right_elem_shape = elementary_shape(&pattern.right, sizes);
389
390 let perm: Vec<usize> = right_names
394 .iter()
395 .map(|name| {
396 left_names
397 .iter()
398 .position(|n| n == name)
399 .unwrap_or(usize::MAX)
400 })
401 .collect();
402
403 let elem_numel: usize = left_elem_shape.iter().product();
413 let mut transposed = vec![<T as num_traits::Zero>::zero(); elem_numel];
414
415 for src_flat in 0..elem_numel {
416 let src_coords = flat_to_coords(src_flat, &left_elem_shape);
417 let mut dst_coords = vec![0usize; right_elem_shape.len()];
418 for (dst_dim, &src_dim) in perm.iter().enumerate() {
419 dst_coords[dst_dim] = src_coords[src_dim];
420 }
421 let dst_flat = coords_to_flat(&dst_coords, &right_elem_shape);
422 transposed[dst_flat] = data[src_flat];
423 }
424
425 Ok(transposed)
429}
430
431pub fn rearrange<T: Float>(input: &Tensor<T>, pattern: &str) -> FerrotorchResult<Tensor<T>> {
449 rearrange_with(input, pattern, &[])
450}
451
452pub fn rearrange_with<T: Float>(
460 input: &Tensor<T>,
461 pattern: &str,
462 axes_lengths: &[(&str, usize)],
463) -> FerrotorchResult<Tensor<T>> {
464 let parsed = parse_pattern(pattern)?;
465 let sizes = resolve_sizes(&parsed, input.shape(), axes_lengths)?;
466
467 let left_names = flatten_axes(&parsed.left);
468 let right_names = flatten_axes(&parsed.right);
469
470 let mut left_sorted = left_names.clone();
472 left_sorted.sort();
473 let mut right_sorted = right_names.clone();
474 right_sorted.sort();
475 if left_sorted != right_sorted {
476 return Err(FerrotorchError::InvalidArgument {
477 message: format!(
478 "einops rearrange: left axes {:?} and right axes {:?} must name \
479 the same set of axes (use `repeat` for new axes, `reduce` for removed axes)",
480 left_names, right_names
481 ),
482 });
483 }
484
485 let out_shape = output_shape(&parsed.right, &sizes);
486 let data = input.data()?;
487 let result_data = rearrange_impl(data, input.shape(), &parsed, &sizes, &out_shape)?;
488
489 Tensor::from_storage(TensorStorage::cpu(result_data), out_shape, false)
490}
491
492pub fn repeat<T: Float>(
510 input: &Tensor<T>,
511 pattern: &str,
512 axes_lengths: &[(&str, usize)],
513) -> FerrotorchResult<Tensor<T>> {
514 let parsed = parse_pattern(pattern)?;
515 let sizes = resolve_sizes(&parsed, input.shape(), axes_lengths)?;
516
517 let left_names = flatten_axes(&parsed.left);
518 let right_names = flatten_axes(&parsed.right);
519
520 for name in &left_names {
522 if !right_names.contains(name) {
523 return Err(FerrotorchError::InvalidArgument {
524 message: format!(
525 "einops repeat: left axis '{name}' does not appear on the right — \
526 use `reduce` to remove axes"
527 ),
528 });
529 }
530 }
531
532 let _new_axes: Vec<&String> = right_names
534 .iter()
535 .filter(|n| !left_names.contains(n))
536 .collect();
537
538 let right_elem_shape = elementary_shape(&parsed.right, &sizes);
540 let out_shape = output_shape(&parsed.right, &sizes);
541
542 let out_numel: usize = right_elem_shape.iter().product();
545 let left_elem_shape = elementary_shape(&parsed.left, &sizes);
546 let data = input.data()?;
547
548 let mut result = Vec::with_capacity(out_numel);
549 for dst_flat in 0..out_numel {
550 let dst_coords = flat_to_coords(dst_flat, &right_elem_shape);
551 let mut src_coords = Vec::with_capacity(left_elem_shape.len());
553 for (i, name) in right_names.iter().enumerate() {
554 if left_names.contains(name) {
555 src_coords.push(dst_coords[i]);
556 }
557 }
559 let src_flat = coords_to_flat(&src_coords, &left_elem_shape);
560 result.push(data[src_flat]);
561 }
562
563 Tensor::from_storage(TensorStorage::cpu(result), out_shape, false)
566}
567
568pub fn reduce<T: Float>(
583 input: &Tensor<T>,
584 pattern: &str,
585 reduction: EinopsReduction,
586) -> FerrotorchResult<Tensor<T>> {
587 let parsed = parse_pattern(pattern)?;
588 let sizes = resolve_sizes(&parsed, input.shape(), &[])?;
589
590 let left_names = flatten_axes(&parsed.left);
591 let right_names = flatten_axes(&parsed.right);
592
593 for name in &right_names {
595 if !left_names.contains(name) {
596 return Err(FerrotorchError::InvalidArgument {
597 message: format!(
598 "einops reduce: right axis '{name}' does not appear on the left — \
599 use `repeat` to add new axes"
600 ),
601 });
602 }
603 }
604
605 let reduced_axes: Vec<&String> = left_names
607 .iter()
608 .filter(|n| !right_names.contains(n))
609 .collect();
610
611 if reduced_axes.is_empty() {
612 return Err(FerrotorchError::InvalidArgument {
613 message: "einops reduce: no axes are being reduced — use `rearrange` instead".into(),
614 });
615 }
616
617 let left_elem_shape = elementary_shape(&parsed.left, &sizes);
619 let right_elem_shape = elementary_shape(&parsed.right, &sizes);
620 let out_shape = output_shape(&parsed.right, &sizes);
621
622 let out_numel: usize = right_elem_shape.iter().product();
623 let data = input.data()?;
624 let in_numel: usize = left_elem_shape.iter().product();
625
626 let reduce_count: usize = reduced_axes
628 .iter()
629 .map(|name| sizes.get(name.as_str()).unwrap())
630 .product();
631
632 let init_val = match reduction {
636 EinopsReduction::Sum | EinopsReduction::Mean => <T as num_traits::Zero>::zero(),
637 EinopsReduction::Max => T::neg_infinity(),
638 EinopsReduction::Min => T::infinity(),
639 };
640 let mut accum = vec![init_val; out_numel];
641
642 for src_flat in 0..in_numel {
643 let src_coords = flat_to_coords(src_flat, &left_elem_shape);
644 let mut dst_coords = Vec::with_capacity(right_elem_shape.len());
646 for (i, name) in left_names.iter().enumerate() {
647 if right_names.contains(name) {
648 dst_coords.push(src_coords[i]);
649 }
650 }
651 let dst_flat = coords_to_flat(&dst_coords, &right_elem_shape);
652
653 let val = data[src_flat];
654 match reduction {
655 EinopsReduction::Sum | EinopsReduction::Mean => {
656 accum[dst_flat] = accum[dst_flat] + val;
657 }
658 EinopsReduction::Max => {
659 if val > accum[dst_flat] {
660 accum[dst_flat] = val;
661 }
662 }
663 EinopsReduction::Min => {
664 if val < accum[dst_flat] {
665 accum[dst_flat] = val;
666 }
667 }
668 }
669 }
670
671 if reduction == EinopsReduction::Mean {
673 let n = T::from(reduce_count).unwrap();
674 for v in &mut accum {
675 *v = *v / n;
676 }
677 }
678
679 Tensor::from_storage(TensorStorage::cpu(accum), out_shape, false)
680}
681
682#[cfg(test)]
687mod tests {
688 use super::*;
689
690 fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
692 Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
693 }
694
695 #[test]
700 fn test_rearrange_identity() {
701 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
703 let t = leaf(&data, &[2, 3, 2, 2]);
704 let r = rearrange(&t, "b c h w -> b c h w").unwrap();
705 assert_eq!(r.shape(), &[2, 3, 2, 2]);
706 assert_eq!(r.data().unwrap(), data.as_slice());
707 }
708
709 #[test]
710 fn test_rearrange_flatten() {
711 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
713 let t = leaf(&data, &[2, 3, 2, 2]); let r = rearrange(&t, "b c h w -> b (c h w)").unwrap();
715 assert_eq!(r.shape(), &[2, 12]);
716 assert_eq!(r.data().unwrap(), data.as_slice());
717 }
718
719 #[test]
720 fn test_rearrange_transpose_nhwc_to_nchw() {
721 let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
725 let t = leaf(&data, &[1, 2, 2, 3]);
726 let r = rearrange(&t, "b h w c -> b c h w").unwrap();
727 assert_eq!(r.shape(), &[1, 3, 2, 2]);
728
729 let out = r.data().unwrap();
740 assert_eq!(out[0], 0.0); assert_eq!(out[1], 3.0); assert_eq!(out[2], 6.0); assert_eq!(out[3], 9.0); assert_eq!(out[4], 1.0); assert_eq!(out[5], 4.0); }
747
748 #[test]
749 fn test_rearrange_split_with_axes_lengths() {
750 let data: Vec<f32> = (0..48).map(|i| i as f32).collect();
753 let t = leaf(&data, &[2, 6, 4]);
754 let r = rearrange_with(&t, "b (c h) w -> b c h w", &[("c", 3)]).unwrap();
755 assert_eq!(r.shape(), &[2, 3, 2, 4]);
756
757 assert_eq!(r.data().unwrap(), data.as_slice());
760 }
761
762 #[test]
763 fn test_rearrange_merge_dims() {
764 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
767 let t = leaf(&data, &[1, 2, 3, 4]);
768 let r = rearrange(&t, "b h w c -> b (h w) c").unwrap();
769 assert_eq!(r.shape(), &[1, 6, 4]);
770 assert_eq!(r.data().unwrap(), data.as_slice());
772 }
773
774 #[test]
779 fn test_repeat_new_batch_dim() {
780 let data = vec![1.0f32, 2.0, 3.0, 4.0];
782 let t = leaf(&data, &[2, 2]);
783 let r = repeat(&t, "h w -> b h w", &[("b", 3)]).unwrap();
784 assert_eq!(r.shape(), &[3, 2, 2]);
785
786 let out = r.data().unwrap();
787 assert_eq!(&out[0..4], &[1.0, 2.0, 3.0, 4.0]);
789 assert_eq!(&out[4..8], &[1.0, 2.0, 3.0, 4.0]);
790 assert_eq!(&out[8..12], &[1.0, 2.0, 3.0, 4.0]);
791 }
792
793 #[test]
794 fn test_repeat_tile() {
795 let data = vec![10.0f32, 20.0, 30.0];
797 let t = leaf(&data, &[3]);
798 let r = repeat(&t, "c -> c n", &[("n", 2)]).unwrap();
799 assert_eq!(r.shape(), &[3, 2]);
800
801 let out = r.data().unwrap();
802 assert_eq!(out, &[10.0, 10.0, 20.0, 20.0, 30.0, 30.0]);
803 }
804
805 #[test]
810 fn test_reduce_mean_spatial() {
811 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
816 let t = leaf(&data, &[1, 2, 2, 2]);
817 let r = reduce(&t, "b c h w -> b c", EinopsReduction::Mean).unwrap();
818 assert_eq!(r.shape(), &[1, 2]);
819 let out = r.data().unwrap();
820 assert!((out[0] - 2.5).abs() < 1e-6, "expected 2.5, got {}", out[0]);
821 assert!((out[1] - 6.5).abs() < 1e-6, "expected 6.5, got {}", out[1]);
822 }
823
824 #[test]
825 fn test_reduce_sum_batch() {
826 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
828 let t = leaf(&data, &[3, 2]); let r = reduce(&t, "b c -> c", EinopsReduction::Sum).unwrap();
830 assert_eq!(r.shape(), &[2]);
831 let out = r.data().unwrap();
832 assert!((out[0] - 9.0).abs() < 1e-6);
835 assert!((out[1] - 12.0).abs() < 1e-6);
836 }
837
838 #[test]
839 fn test_reduce_max() {
840 let data = vec![1.0f32, 5.0, 3.0, 2.0, 4.0, 6.0];
842 let t = leaf(&data, &[3, 2]);
843 let r = reduce(&t, "b c -> c", EinopsReduction::Max).unwrap();
844 assert_eq!(r.shape(), &[2]);
845 let out = r.data().unwrap();
846 assert!((out[0] - 4.0).abs() < 1e-6); assert!((out[1] - 6.0).abs() < 1e-6); }
849
850 #[test]
851 fn test_reduce_min() {
852 let data = vec![1.0f32, 5.0, 3.0, 2.0, 4.0, 6.0];
854 let t = leaf(&data, &[3, 2]);
855 let r = reduce(&t, "b c -> c", EinopsReduction::Min).unwrap();
856 assert_eq!(r.shape(), &[2]);
857 let out = r.data().unwrap();
858 assert!((out[0] - 1.0).abs() < 1e-6); assert!((out[1] - 2.0).abs() < 1e-6); }
861
862 #[test]
867 fn test_invalid_pattern_no_arrow() {
868 let t = leaf(&[1.0, 2.0, 3.0], &[3]);
869 assert!(rearrange(&t, "a b c").is_err());
870 }
871
872 #[test]
873 fn test_mismatched_axis_count() {
874 let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
875 assert!(rearrange(&t, "a b c -> a b c").is_err());
877 }
878
879 #[test]
880 fn test_rearrange_missing_axis_on_right() {
881 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
883 let t = leaf(&data, &[2, 3, 2, 2]);
884 assert!(rearrange(&t, "b c h w -> b c").is_err());
885 }
886
887 #[test]
888 fn test_rearrange_extra_axis_on_right() {
889 let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
891 assert!(rearrange(&t, "b c -> b c n").is_err());
892 }
893
894 #[test]
895 fn test_repeat_missing_new_axis_size() {
896 let t = leaf(&[1.0, 2.0], &[2]);
897 assert!(repeat(&t, "c -> c n", &[]).is_err());
899 }
900
901 #[test]
902 fn test_reduce_no_reduction() {
903 let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
905 assert!(reduce(&t, "b c -> b c", EinopsReduction::Sum).is_err());
906 }
907
908 #[test]
909 fn test_unmatched_paren() {
910 let t = leaf(&[1.0, 2.0], &[2]);
911 assert!(rearrange(&t, "(a -> a").is_err());
912 }
913
914 #[test]
915 fn test_duplicate_axis_name() {
916 let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
917 assert!(rearrange(&t, "a a -> a a").is_err());
918 }
919
920 #[test]
925 fn test_parse_simple() {
926 let p = parse_pattern("b c h w -> b c h w").unwrap();
927 assert_eq!(flatten_axes(&p.left), vec!["b", "c", "h", "w"]);
928 assert_eq!(flatten_axes(&p.right), vec!["b", "c", "h", "w"]);
929 }
930
931 #[test]
932 fn test_parse_groups() {
933 let p = parse_pattern("b c h w -> b (c h w)").unwrap();
934 assert_eq!(p.right.len(), 2); match &p.right[1] {
936 AxisSpec::Group(names) => assert_eq!(names, &["c", "h", "w"]),
937 _ => panic!("expected Group"),
938 }
939 }
940
941 #[test]
942 fn test_parse_left_group() {
943 let p = parse_pattern("b (c h) w -> b c h w").unwrap();
944 assert_eq!(p.left.len(), 3); match &p.left[1] {
946 AxisSpec::Group(names) => assert_eq!(names, &["c", "h"]),
947 _ => panic!("expected Group"),
948 }
949 }
950}