Skip to main content

ferrotorch_core/
einops.rs

1//! Einops-style tensor rearrangement operations.
2//!
3//! Provides `rearrange`, `repeat`, and `reduce` with readable string patterns
4//! for expressing tensor shape transformations declaratively.
5//!
6//! # Pattern syntax
7//!
8//! A pattern has the form `"left -> right"` where `left` and `right` are
9//! space-separated axis names. Parenthesized groups denote merged/split
10//! dimensions:
11//!
12//! - `"b c h w -> b (c h w)"` merges `c`, `h`, `w` into one axis
13//! - `"b (c h) w -> b c h w"` splits a dimension (requires `axes_lengths`)
14//! - `"b h w c -> b c h w"` transposes (reorders) axes
15//!
16//! Axes present on the left but absent on the right are reduced (for `reduce`)
17//! or must be size-1 (for `rearrange`). Axes present on the right but absent
18//! on the left are new axes (for `repeat`).
19
20use std::collections::HashMap;
21
22use crate::dtype::Float;
23use crate::error::{FerrotorchError, FerrotorchResult};
24use crate::storage::TensorStorage;
25use crate::tensor::Tensor;
26
27// ---------------------------------------------------------------------------
28// Public API — Reduction enum
29// ---------------------------------------------------------------------------
30
31/// Reduction operation for [`reduce`].
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum EinopsReduction {
34    /// Arithmetic mean along reduced axes.
35    Mean,
36    /// Sum along reduced axes.
37    Sum,
38    /// Element-wise maximum along reduced axes.
39    Max,
40    /// Element-wise minimum along reduced axes.
41    Min,
42}
43
44// ---------------------------------------------------------------------------
45// Pattern parser
46// ---------------------------------------------------------------------------
47
48/// A single axis on one side of the pattern. Either a bare name or a
49/// parenthesized group of names (representing a merged/split dimension).
50#[derive(Debug, Clone, PartialEq)]
51enum AxisSpec {
52    /// A single named axis, e.g. `b`.
53    Single(String),
54    /// A parenthesized group of axes, e.g. `(c h w)`.
55    Group(Vec<String>),
56}
57
58/// Parsed einops pattern.
59#[derive(Debug)]
60struct ParsedPattern {
61    left: Vec<AxisSpec>,
62    right: Vec<AxisSpec>,
63}
64
65/// Flatten an `AxisSpec` list into individual axis names in order.
66fn flatten_axes(specs: &[AxisSpec]) -> Vec<String> {
67    let mut out = Vec::new();
68    for spec in specs {
69        match spec {
70            AxisSpec::Single(name) => out.push(name.clone()),
71            AxisSpec::Group(names) => out.extend(names.iter().cloned()),
72        }
73    }
74    out
75}
76
77/// Parse one side of the pattern (e.g. `"b (c h) w"`) into a list of
78/// `AxisSpec` entries.
79fn parse_side(s: &str) -> FerrotorchResult<Vec<AxisSpec>> {
80    let s = s.trim();
81    let mut specs = Vec::new();
82    let mut chars = s.chars().peekable();
83
84    while let Some(&c) = chars.peek() {
85        if c.is_whitespace() {
86            chars.next();
87            continue;
88        }
89
90        if c == '(' {
91            // Consume the opening paren.
92            chars.next();
93            let mut group = Vec::new();
94            loop {
95                // Skip whitespace inside parens.
96                while let Some(&c2) = chars.peek() {
97                    if c2.is_whitespace() {
98                        chars.next();
99                    } else {
100                        break;
101                    }
102                }
103                match chars.peek() {
104                    None => {
105                        return Err(FerrotorchError::InvalidArgument {
106                            message: "einops: unmatched '(' in pattern".into(),
107                        });
108                    }
109                    Some(&')') => {
110                        chars.next();
111                        break;
112                    }
113                    _ => {}
114                }
115                // Read an axis name.
116                let name = read_axis_name(&mut chars)?;
117                if name.is_empty() {
118                    return Err(FerrotorchError::InvalidArgument {
119                        message: "einops: empty axis name inside parentheses".into(),
120                    });
121                }
122                group.push(name);
123            }
124            if group.is_empty() {
125                return Err(FerrotorchError::InvalidArgument {
126                    message: "einops: empty parenthesized group".into(),
127                });
128            }
129            specs.push(AxisSpec::Group(group));
130        } else if c.is_ascii_alphanumeric() || c == '_' {
131            let name = read_axis_name(&mut chars)?;
132            specs.push(AxisSpec::Single(name));
133        } else {
134            return Err(FerrotorchError::InvalidArgument {
135                message: format!("einops: unexpected character '{c}' in pattern"),
136            });
137        }
138    }
139
140    Ok(specs)
141}
142
143/// Read an axis name: a run of alphanumeric / underscore characters.
144fn read_axis_name(chars: &mut std::iter::Peekable<std::str::Chars<'_>>) -> FerrotorchResult<String> {
145    let mut name = String::new();
146    while let Some(&c) = chars.peek() {
147        if c.is_ascii_alphanumeric() || c == '_' {
148            name.push(c);
149            chars.next();
150        } else {
151            break;
152        }
153    }
154    Ok(name)
155}
156
157/// Parse a full pattern like `"b c h w -> b (c h) w"`.
158fn parse_pattern(pattern: &str) -> FerrotorchResult<ParsedPattern> {
159    let pattern = pattern.trim();
160    let (left_str, right_str) = pattern.split_once("->").ok_or_else(|| {
161        FerrotorchError::InvalidArgument {
162            message: format!("einops: pattern must contain '->', got: \"{pattern}\""),
163        }
164    })?;
165
166    let left = parse_side(left_str)?;
167    let right = parse_side(right_str)?;
168
169    // Validate: no duplicate axis names within a side.
170    let left_names = flatten_axes(&left);
171    let right_names = flatten_axes(&right);
172
173    let mut seen = HashMap::new();
174    for name in &left_names {
175        if seen.insert(name.as_str(), "left").is_some() {
176            return Err(FerrotorchError::InvalidArgument {
177                message: format!("einops: duplicate axis name '{name}' on left side of pattern"),
178            });
179        }
180    }
181    seen.clear();
182    for name in &right_names {
183        if seen.insert(name.as_str(), "right").is_some() {
184            return Err(FerrotorchError::InvalidArgument {
185                message: format!("einops: duplicate axis name '{name}' on right side of pattern"),
186            });
187        }
188    }
189
190    Ok(ParsedPattern { left, right })
191}
192
193// ---------------------------------------------------------------------------
194// Axis-size resolution
195// ---------------------------------------------------------------------------
196
197/// Resolve the size of every named axis. Returns a map from axis name to
198/// its size.
199///
200/// - Axes that appear as `Single` on the left get their size from the
201///   corresponding input dimension.
202/// - Axes inside a `Group` on the left come from splitting an input dim.
203///   If there are N sub-axes and all but one have known sizes (from
204///   `axes_lengths`), the remaining one is inferred.
205/// - Axes that only appear on the right (new axes) must have their size
206///   supplied in `axes_lengths`.
207fn resolve_sizes(
208    pattern: &ParsedPattern,
209    input_shape: &[usize],
210    axes_lengths: &[(&str, usize)],
211) -> FerrotorchResult<HashMap<String, usize>> {
212    let left_flat = flatten_axes(&pattern.left);
213    let right_flat = flatten_axes(&pattern.right);
214
215    // Count how many input dimensions the left side represents.
216    let left_dim_count = pattern.left.len();
217    if left_dim_count != input_shape.len() {
218        return Err(FerrotorchError::InvalidArgument {
219            message: format!(
220                "einops: left side of pattern has {} axes but input tensor has {} dimensions",
221                left_dim_count,
222                input_shape.len()
223            ),
224        });
225    }
226
227    let user_sizes: HashMap<&str, usize> = axes_lengths.iter().copied().collect();
228    let mut sizes: HashMap<String, usize> = HashMap::new();
229
230    // First pass: assign sizes from the left side.
231    for (dim_idx, spec) in pattern.left.iter().enumerate() {
232        let dim_size = input_shape[dim_idx];
233        match spec {
234            AxisSpec::Single(name) => {
235                sizes.insert(name.clone(), dim_size);
236            }
237            AxisSpec::Group(names) => {
238                // This is a split: one input dim is being decomposed into
239                // multiple named axes. We need axes_lengths for all but
240                // (at most) one of them.
241                let mut unknown_idx: Option<usize> = None;
242                let mut known_product: usize = 1;
243
244                for (i, name) in names.iter().enumerate() {
245                    if let Some(&sz) = user_sizes.get(name.as_str()) {
246                        sizes.insert(name.clone(), sz);
247                        known_product *= sz;
248                    } else if let Some(&sz) = sizes.get(name) {
249                        // Already known from a previous occurrence (shouldn't happen
250                        // since we check duplicates, but be defensive).
251                        known_product *= sz;
252                    } else {
253                        if unknown_idx.is_some() {
254                            return Err(FerrotorchError::InvalidArgument {
255                                message: format!(
256                                    "einops: cannot infer sizes for split '({})' — \
257                                     provide sizes for all but one sub-axis via axes_lengths",
258                                    names.join(" ")
259                                ),
260                            });
261                        }
262                        unknown_idx = Some(i);
263                    }
264                }
265
266                if let Some(ui) = unknown_idx {
267                    if known_product == 0 || dim_size % known_product != 0 {
268                        return Err(FerrotorchError::InvalidArgument {
269                            message: format!(
270                                "einops: dimension {} (size {}) is not divisible by \
271                                 known product {} for split '({})'",
272                                dim_idx, dim_size, known_product,
273                                names.join(" ")
274                            ),
275                        });
276                    }
277                    sizes.insert(names[ui].clone(), dim_size / known_product);
278                } else {
279                    // All sub-axes are known; verify the product matches.
280                    if known_product != dim_size {
281                        return Err(FerrotorchError::ShapeMismatch {
282                            message: format!(
283                                "einops: split '({})' product {} does not match dimension {} size {}",
284                                names.join(" "), known_product, dim_idx, dim_size
285                            ),
286                        });
287                    }
288                }
289            }
290        }
291    }
292
293    // Second pass: axes that only appear on the right (new axes) must come
294    // from axes_lengths.
295    for name in &right_flat {
296        if !sizes.contains_key(name) {
297            if let Some(&sz) = user_sizes.get(name.as_str()) {
298                sizes.insert(name.clone(), sz);
299            } else if !left_flat.contains(name) {
300                return Err(FerrotorchError::InvalidArgument {
301                    message: format!(
302                        "einops: axis '{name}' appears on the right but not the left \
303                         and has no size in axes_lengths"
304                    ),
305                });
306            }
307        }
308    }
309
310    Ok(sizes)
311}
312
313// ---------------------------------------------------------------------------
314// Core implementation helpers
315// ---------------------------------------------------------------------------
316
317/// Compute the output shape from the right side of the pattern and the
318/// resolved axis sizes.
319fn output_shape(right: &[AxisSpec], sizes: &HashMap<String, usize>) -> Vec<usize> {
320    right
321        .iter()
322        .map(|spec| match spec {
323            AxisSpec::Single(name) => *sizes.get(name).unwrap(),
324            AxisSpec::Group(names) => names.iter().map(|n| sizes.get(n).unwrap()).product(),
325        })
326        .collect()
327}
328
329/// Convert a flat index to per-axis coordinates.
330fn flat_to_coords(mut flat: usize, shape: &[usize]) -> Vec<usize> {
331    let ndim = shape.len();
332    let mut coords = vec![0usize; ndim];
333    for d in (0..ndim).rev() {
334        coords[d] = flat % shape[d];
335        flat /= shape[d];
336    }
337    coords
338}
339
340/// Convert per-axis coordinates to a flat index.
341fn coords_to_flat(coords: &[usize], shape: &[usize]) -> usize {
342    let mut flat = 0usize;
343    let mut stride = 1usize;
344    for d in (0..shape.len()).rev() {
345        flat += coords[d] * stride;
346        stride *= shape[d];
347    }
348    flat
349}
350
351/// Build the "elementary" shape from a pattern side: each `AxisSpec::Group`
352/// is expanded into its individual sub-axis sizes.
353fn elementary_shape(specs: &[AxisSpec], sizes: &HashMap<String, usize>) -> Vec<usize> {
354    let mut shape = Vec::new();
355    for spec in specs {
356        match spec {
357            AxisSpec::Single(name) => shape.push(*sizes.get(name).unwrap()),
358            AxisSpec::Group(names) => {
359                for n in names {
360                    shape.push(*sizes.get(n).unwrap());
361                }
362            }
363        }
364    }
365    shape
366}
367
368/// Perform the general rearrange operation. This is the core engine used by
369/// `rearrange`, `repeat`, and `reduce`.
370///
371/// The algorithm:
372/// 1. Compute the "elementary" shapes for left and right (fully expanded).
373/// 2. Reshape input from `input_shape` to `left_elementary_shape` (splits).
374/// 3. Determine the permutation from left-elementary to right-elementary
375///    axis ordering.
376/// 4. Transpose data according to the permutation.
377/// 5. Reshape from `right_elementary_shape` to `output_shape` (merges).
378fn rearrange_impl<T: Float>(
379    data: &[T],
380    _input_shape: &[usize],
381    pattern: &ParsedPattern,
382    sizes: &HashMap<String, usize>,
383    _output_shape: &[usize],
384) -> FerrotorchResult<Vec<T>> {
385    let left_names = flatten_axes(&pattern.left);
386    let right_names = flatten_axes(&pattern.right);
387    let left_elem_shape = elementary_shape(&pattern.left, sizes);
388    let right_elem_shape = elementary_shape(&pattern.right, sizes);
389
390    // The right_names should be a permutation of left_names (for rearrange).
391    // Build the permutation: for each axis in right_names, find its index
392    // in left_names.
393    let perm: Vec<usize> = right_names
394        .iter()
395        .map(|name| {
396            left_names
397                .iter()
398                .position(|n| n == name)
399                .unwrap_or(usize::MAX)
400        })
401        .collect();
402
403    // If there are axes only on the right (repeat) or only on the left (reduce),
404    // they won't have a valid permutation entry. For a pure rearrange, every
405    // entry should be valid.
406
407    // Step 1: Reshape from input_shape to left_elem_shape. Since both have
408    // the same total number of elements and the data is C-contiguous, this
409    // is a no-op on the buffer — only the interpretation changes.
410
411    // Step 2: Transpose from left_elem_shape to right_elem_shape order.
412    let elem_numel: usize = left_elem_shape.iter().product();
413    let mut transposed = vec![<T as num_traits::Zero>::zero(); elem_numel];
414
415    for src_flat in 0..elem_numel {
416        let src_coords = flat_to_coords(src_flat, &left_elem_shape);
417        let mut dst_coords = vec![0usize; right_elem_shape.len()];
418        for (dst_dim, &src_dim) in perm.iter().enumerate() {
419            dst_coords[dst_dim] = src_coords[src_dim];
420        }
421        let dst_flat = coords_to_flat(&dst_coords, &right_elem_shape);
422        transposed[dst_flat] = data[src_flat];
423    }
424
425    // Step 3: The transposed buffer now has right_elem_shape layout.
426    // Reshape to the output_shape (merging groups). This is again a
427    // reinterpretation — same buffer, different shape.
428    Ok(transposed)
429}
430
431// ---------------------------------------------------------------------------
432// Public API — rearrange
433// ---------------------------------------------------------------------------
434
435/// Rearrange tensor dimensions using an einops-style pattern.
436///
437/// # Examples
438/// ```ignore
439/// // Flatten spatial dims: [B, C, H, W] -> [B, C*H*W]
440/// rearrange(&t, "b c h w -> b (c h w)")?;
441///
442/// // Transpose: [B, H, W, C] -> [B, C, H, W]
443/// rearrange(&t, "b h w c -> b c h w")?;
444///
445/// // Merge dims: [B, H, W, C] -> [B, H*W, C]
446/// rearrange(&t, "b h w c -> b (h w) c")?;
447/// ```
448pub fn rearrange<T: Float>(input: &Tensor<T>, pattern: &str) -> FerrotorchResult<Tensor<T>> {
449    rearrange_with(input, pattern, &[])
450}
451
452/// Rearrange with explicit axis sizes for ambiguous splits.
453///
454/// # Examples
455/// ```ignore
456/// // Split a dimension: [B, C*H, W] -> [B, C, H, W] with C=3
457/// rearrange_with(&t, "b (c h) w -> b c h w", &[("c", 3)])?;
458/// ```
459pub fn rearrange_with<T: Float>(
460    input: &Tensor<T>,
461    pattern: &str,
462    axes_lengths: &[(&str, usize)],
463) -> FerrotorchResult<Tensor<T>> {
464    let parsed = parse_pattern(pattern)?;
465    let sizes = resolve_sizes(&parsed, input.shape(), axes_lengths)?;
466
467    let left_names = flatten_axes(&parsed.left);
468    let right_names = flatten_axes(&parsed.right);
469
470    // For rearrange, left and right must name exactly the same set of axes.
471    let mut left_sorted = left_names.clone();
472    left_sorted.sort();
473    let mut right_sorted = right_names.clone();
474    right_sorted.sort();
475    if left_sorted != right_sorted {
476        return Err(FerrotorchError::InvalidArgument {
477            message: format!(
478                "einops rearrange: left axes {:?} and right axes {:?} must name \
479                 the same set of axes (use `repeat` for new axes, `reduce` for removed axes)",
480                left_names, right_names
481            ),
482        });
483    }
484
485    let out_shape = output_shape(&parsed.right, &sizes);
486    let data = input.data()?;
487    let result_data = rearrange_impl(data, input.shape(), &parsed, &sizes, &out_shape)?;
488
489    Tensor::from_storage(TensorStorage::cpu(result_data), out_shape, false)
490}
491
492// ---------------------------------------------------------------------------
493// Public API — repeat
494// ---------------------------------------------------------------------------
495
496/// Repeat tensor elements along new or existing axes.
497///
498/// Axes on the right that do not appear on the left are new dimensions and
499/// must have their size specified in `axes_lengths`.
500///
501/// # Examples
502/// ```ignore
503/// // Add a batch dim by repeating: [H, W] -> [B, H, W]
504/// repeat(&t, "h w -> b h w", &[("b", 4)])?;
505///
506/// // Tile: [C] -> [C, 3]
507/// repeat(&t, "c -> c n", &[("n", 3)])?;
508/// ```
509pub fn repeat<T: Float>(
510    input: &Tensor<T>,
511    pattern: &str,
512    axes_lengths: &[(&str, usize)],
513) -> FerrotorchResult<Tensor<T>> {
514    let parsed = parse_pattern(pattern)?;
515    let sizes = resolve_sizes(&parsed, input.shape(), axes_lengths)?;
516
517    let left_names = flatten_axes(&parsed.left);
518    let right_names = flatten_axes(&parsed.right);
519
520    // Every left axis must appear on the right.
521    for name in &left_names {
522        if !right_names.contains(name) {
523            return Err(FerrotorchError::InvalidArgument {
524                message: format!(
525                    "einops repeat: left axis '{name}' does not appear on the right — \
526                     use `reduce` to remove axes"
527                ),
528            });
529        }
530    }
531
532    // Identify new axes (on right but not left).
533    let _new_axes: Vec<&String> = right_names
534        .iter()
535        .filter(|n| !left_names.contains(n))
536        .collect();
537
538    // Build the right elementary shape and the output shape.
539    let right_elem_shape = elementary_shape(&parsed.right, &sizes);
540    let out_shape = output_shape(&parsed.right, &sizes);
541
542    // Strategy: iterate over every element of the output (in right
543    // elementary order), map its coordinates back to the input.
544    let out_numel: usize = right_elem_shape.iter().product();
545    let left_elem_shape = elementary_shape(&parsed.left, &sizes);
546    let data = input.data()?;
547
548    let mut result = Vec::with_capacity(out_numel);
549    for dst_flat in 0..out_numel {
550        let dst_coords = flat_to_coords(dst_flat, &right_elem_shape);
551        // Map each right-elementary coordinate to a left-elementary coordinate.
552        let mut src_coords = Vec::with_capacity(left_elem_shape.len());
553        for (i, name) in right_names.iter().enumerate() {
554            if left_names.contains(name) {
555                src_coords.push(dst_coords[i]);
556            }
557            // New axes are simply ignored (they tile/repeat).
558        }
559        let src_flat = coords_to_flat(&src_coords, &left_elem_shape);
560        result.push(data[src_flat]);
561    }
562
563    // The result buffer is in right-elementary order. Reshape to out_shape
564    // (which merges groups). Since it's the same total size, just reinterpret.
565    Tensor::from_storage(TensorStorage::cpu(result), out_shape, false)
566}
567
568// ---------------------------------------------------------------------------
569// Public API — reduce
570// ---------------------------------------------------------------------------
571
572/// Reduce along axes that appear on the left but not the right.
573///
574/// # Examples
575/// ```ignore
576/// // Global average pool: [B, C, H, W] -> [B, C]
577/// reduce(&t, "b c h w -> b c", EinopsReduction::Mean)?;
578///
579/// // Sum over batch: [B, C] -> [C]
580/// reduce(&t, "b c -> c", EinopsReduction::Sum)?;
581/// ```
582pub fn reduce<T: Float>(
583    input: &Tensor<T>,
584    pattern: &str,
585    reduction: EinopsReduction,
586) -> FerrotorchResult<Tensor<T>> {
587    let parsed = parse_pattern(pattern)?;
588    let sizes = resolve_sizes(&parsed, input.shape(), &[])?;
589
590    let left_names = flatten_axes(&parsed.left);
591    let right_names = flatten_axes(&parsed.right);
592
593    // Every right axis must appear on the left.
594    for name in &right_names {
595        if !left_names.contains(name) {
596            return Err(FerrotorchError::InvalidArgument {
597                message: format!(
598                    "einops reduce: right axis '{name}' does not appear on the left — \
599                     use `repeat` to add new axes"
600                ),
601            });
602        }
603    }
604
605    // Identify reduced axes (on left but not right).
606    let reduced_axes: Vec<&String> = left_names
607        .iter()
608        .filter(|n| !right_names.contains(n))
609        .collect();
610
611    if reduced_axes.is_empty() {
612        return Err(FerrotorchError::InvalidArgument {
613            message: "einops reduce: no axes are being reduced — use `rearrange` instead".into(),
614        });
615    }
616
617    // Build the output elementary shape.
618    let left_elem_shape = elementary_shape(&parsed.left, &sizes);
619    let right_elem_shape = elementary_shape(&parsed.right, &sizes);
620    let out_shape = output_shape(&parsed.right, &sizes);
621
622    let out_numel: usize = right_elem_shape.iter().product();
623    let data = input.data()?;
624    let in_numel: usize = left_elem_shape.iter().product();
625
626    // Compute how many elements are reduced per output element.
627    let reduce_count: usize = reduced_axes
628        .iter()
629        .map(|name| sizes.get(name.as_str()).unwrap())
630        .product();
631
632    // Accumulate: for each input element, figure out which output element
633    // it contributes to.
634    // Initialize accumulators.
635    let init_val = match reduction {
636        EinopsReduction::Sum | EinopsReduction::Mean => <T as num_traits::Zero>::zero(),
637        EinopsReduction::Max => T::neg_infinity(),
638        EinopsReduction::Min => T::infinity(),
639    };
640    let mut accum = vec![init_val; out_numel];
641
642    for src_flat in 0..in_numel {
643        let src_coords = flat_to_coords(src_flat, &left_elem_shape);
644        // Map to output coordinates: keep only the axes that survive.
645        let mut dst_coords = Vec::with_capacity(right_elem_shape.len());
646        for (i, name) in left_names.iter().enumerate() {
647            if right_names.contains(name) {
648                dst_coords.push(src_coords[i]);
649            }
650        }
651        let dst_flat = coords_to_flat(&dst_coords, &right_elem_shape);
652
653        let val = data[src_flat];
654        match reduction {
655            EinopsReduction::Sum | EinopsReduction::Mean => {
656                accum[dst_flat] = accum[dst_flat] + val;
657            }
658            EinopsReduction::Max => {
659                if val > accum[dst_flat] {
660                    accum[dst_flat] = val;
661                }
662            }
663            EinopsReduction::Min => {
664                if val < accum[dst_flat] {
665                    accum[dst_flat] = val;
666                }
667            }
668        }
669    }
670
671    // For mean, divide by the number of reduced elements.
672    if reduction == EinopsReduction::Mean {
673        let n = T::from(reduce_count).unwrap();
674        for v in &mut accum {
675            *v = *v / n;
676        }
677    }
678
679    Tensor::from_storage(TensorStorage::cpu(accum), out_shape, false)
680}
681
682// ---------------------------------------------------------------------------
683// Tests
684// ---------------------------------------------------------------------------
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689
690    /// Helper: create a leaf tensor.
691    fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
692        Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
693    }
694
695    // -----------------------------------------------------------------------
696    // rearrange tests
697    // -----------------------------------------------------------------------
698
699    #[test]
700    fn test_rearrange_identity() {
701        // "b c h w -> b c h w" should be a no-op.
702        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
703        let t = leaf(&data, &[2, 3, 2, 2]);
704        let r = rearrange(&t, "b c h w -> b c h w").unwrap();
705        assert_eq!(r.shape(), &[2, 3, 2, 2]);
706        assert_eq!(r.data().unwrap(), data.as_slice());
707    }
708
709    #[test]
710    fn test_rearrange_flatten() {
711        // "b c h w -> b (c h w)" merges c, h, w.
712        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
713        let t = leaf(&data, &[2, 3, 2, 2]); // B=2, C=3, H=2, W=2
714        let r = rearrange(&t, "b c h w -> b (c h w)").unwrap();
715        assert_eq!(r.shape(), &[2, 12]);
716        assert_eq!(r.data().unwrap(), data.as_slice());
717    }
718
719    #[test]
720    fn test_rearrange_transpose_nhwc_to_nchw() {
721        // "b h w c -> b c h w" transposes.
722        // Input shape: [1, 2, 2, 3] (B=1, H=2, W=2, C=3)
723        // Output shape: [1, 3, 2, 2]
724        let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
725        let t = leaf(&data, &[1, 2, 2, 3]);
726        let r = rearrange(&t, "b h w c -> b c h w").unwrap();
727        assert_eq!(r.shape(), &[1, 3, 2, 2]);
728
729        // Verify specific elements.
730        // Input[0,0,0,:] = [0,1,2], Input[0,0,1,:] = [3,4,5]
731        // Input[0,1,0,:] = [6,7,8], Input[0,1,1,:] = [9,10,11]
732        // Output[0,c,h,w] = Input[0,h,w,c]
733        // Output[0,0,0,0] = Input[0,0,0,0] = 0
734        // Output[0,0,0,1] = Input[0,0,1,0] = 3
735        // Output[0,0,1,0] = Input[0,1,0,0] = 6
736        // Output[0,0,1,1] = Input[0,1,1,0] = 9
737        // Output[0,1,0,0] = Input[0,0,0,1] = 1
738        // etc.
739        let out = r.data().unwrap();
740        assert_eq!(out[0], 0.0);  // [0,0,0,0]
741        assert_eq!(out[1], 3.0);  // [0,0,0,1]
742        assert_eq!(out[2], 6.0);  // [0,0,1,0]
743        assert_eq!(out[3], 9.0);  // [0,0,1,1]
744        assert_eq!(out[4], 1.0);  // [0,1,0,0]
745        assert_eq!(out[5], 4.0);  // [0,1,0,1]
746    }
747
748    #[test]
749    fn test_rearrange_split_with_axes_lengths() {
750        // "b (c h) w -> b c h w" with c=3 splits dimension 1.
751        // Input: [2, 6, 4] -> Output: [2, 3, 2, 4]
752        let data: Vec<f32> = (0..48).map(|i| i as f32).collect();
753        let t = leaf(&data, &[2, 6, 4]);
754        let r = rearrange_with(&t, "b (c h) w -> b c h w", &[("c", 3)]).unwrap();
755        assert_eq!(r.shape(), &[2, 3, 2, 4]);
756
757        // The data should be the same since (c h) is already in order and
758        // we're just splitting.
759        assert_eq!(r.data().unwrap(), data.as_slice());
760    }
761
762    #[test]
763    fn test_rearrange_merge_dims() {
764        // "b h w c -> b (h w) c" merges h and w.
765        // Input: [1, 2, 3, 4] -> Output: [1, 6, 4]
766        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
767        let t = leaf(&data, &[1, 2, 3, 4]);
768        let r = rearrange(&t, "b h w c -> b (h w) c").unwrap();
769        assert_eq!(r.shape(), &[1, 6, 4]);
770        // Data stays the same since h and w are adjacent and in order.
771        assert_eq!(r.data().unwrap(), data.as_slice());
772    }
773
774    // -----------------------------------------------------------------------
775    // repeat tests
776    // -----------------------------------------------------------------------
777
778    #[test]
779    fn test_repeat_new_batch_dim() {
780        // "h w -> b h w" adds a batch dimension.
781        let data = vec![1.0f32, 2.0, 3.0, 4.0];
782        let t = leaf(&data, &[2, 2]);
783        let r = repeat(&t, "h w -> b h w", &[("b", 3)]).unwrap();
784        assert_eq!(r.shape(), &[3, 2, 2]);
785
786        let out = r.data().unwrap();
787        // Each batch should be a copy of the original.
788        assert_eq!(&out[0..4], &[1.0, 2.0, 3.0, 4.0]);
789        assert_eq!(&out[4..8], &[1.0, 2.0, 3.0, 4.0]);
790        assert_eq!(&out[8..12], &[1.0, 2.0, 3.0, 4.0]);
791    }
792
793    #[test]
794    fn test_repeat_tile() {
795        // "c -> c n" tiles a 1-D tensor.
796        let data = vec![10.0f32, 20.0, 30.0];
797        let t = leaf(&data, &[3]);
798        let r = repeat(&t, "c -> c n", &[("n", 2)]).unwrap();
799        assert_eq!(r.shape(), &[3, 2]);
800
801        let out = r.data().unwrap();
802        assert_eq!(out, &[10.0, 10.0, 20.0, 20.0, 30.0, 30.0]);
803    }
804
805    // -----------------------------------------------------------------------
806    // reduce tests
807    // -----------------------------------------------------------------------
808
809    #[test]
810    fn test_reduce_mean_spatial() {
811        // "b c h w -> b c" — global average pool.
812        // B=1, C=2, H=2, W=2
813        // Channel 0: [1, 2, 3, 4] mean = 2.5
814        // Channel 1: [5, 6, 7, 8] mean = 6.5
815        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
816        let t = leaf(&data, &[1, 2, 2, 2]);
817        let r = reduce(&t, "b c h w -> b c", EinopsReduction::Mean).unwrap();
818        assert_eq!(r.shape(), &[1, 2]);
819        let out = r.data().unwrap();
820        assert!((out[0] - 2.5).abs() < 1e-6, "expected 2.5, got {}", out[0]);
821        assert!((out[1] - 6.5).abs() < 1e-6, "expected 6.5, got {}", out[1]);
822    }
823
824    #[test]
825    fn test_reduce_sum_batch() {
826        // "b c -> c" — sum over batch.
827        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
828        let t = leaf(&data, &[3, 2]); // B=3, C=2
829        let r = reduce(&t, "b c -> c", EinopsReduction::Sum).unwrap();
830        assert_eq!(r.shape(), &[2]);
831        let out = r.data().unwrap();
832        // c=0: 1 + 3 + 5 = 9
833        // c=1: 2 + 4 + 6 = 12
834        assert!((out[0] - 9.0).abs() < 1e-6);
835        assert!((out[1] - 12.0).abs() < 1e-6);
836    }
837
838    #[test]
839    fn test_reduce_max() {
840        // "b c -> c" — max over batch.
841        let data = vec![1.0f32, 5.0, 3.0, 2.0, 4.0, 6.0];
842        let t = leaf(&data, &[3, 2]);
843        let r = reduce(&t, "b c -> c", EinopsReduction::Max).unwrap();
844        assert_eq!(r.shape(), &[2]);
845        let out = r.data().unwrap();
846        assert!((out[0] - 4.0).abs() < 1e-6); // max(1, 3, 4)
847        assert!((out[1] - 6.0).abs() < 1e-6); // max(5, 2, 6)
848    }
849
850    #[test]
851    fn test_reduce_min() {
852        // "b c -> c" — min over batch.
853        let data = vec![1.0f32, 5.0, 3.0, 2.0, 4.0, 6.0];
854        let t = leaf(&data, &[3, 2]);
855        let r = reduce(&t, "b c -> c", EinopsReduction::Min).unwrap();
856        assert_eq!(r.shape(), &[2]);
857        let out = r.data().unwrap();
858        assert!((out[0] - 1.0).abs() < 1e-6); // min(1, 3, 4)
859        assert!((out[1] - 2.0).abs() < 1e-6); // min(5, 2, 6)
860    }
861
862    // -----------------------------------------------------------------------
863    // Error tests
864    // -----------------------------------------------------------------------
865
866    #[test]
867    fn test_invalid_pattern_no_arrow() {
868        let t = leaf(&[1.0, 2.0, 3.0], &[3]);
869        assert!(rearrange(&t, "a b c").is_err());
870    }
871
872    #[test]
873    fn test_mismatched_axis_count() {
874        let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
875        // Left side has 3 axes but tensor has 2 dims.
876        assert!(rearrange(&t, "a b c -> a b c").is_err());
877    }
878
879    #[test]
880    fn test_rearrange_missing_axis_on_right() {
881        // "b c h w -> b c" would be a reduce, not a rearrange.
882        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
883        let t = leaf(&data, &[2, 3, 2, 2]);
884        assert!(rearrange(&t, "b c h w -> b c").is_err());
885    }
886
887    #[test]
888    fn test_rearrange_extra_axis_on_right() {
889        // "b c -> b c n" would be a repeat, not a rearrange.
890        let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
891        assert!(rearrange(&t, "b c -> b c n").is_err());
892    }
893
894    #[test]
895    fn test_repeat_missing_new_axis_size() {
896        let t = leaf(&[1.0, 2.0], &[2]);
897        // "c -> c n" but no size given for n.
898        assert!(repeat(&t, "c -> c n", &[]).is_err());
899    }
900
901    #[test]
902    fn test_reduce_no_reduction() {
903        // "b c -> b c" reduces nothing — should error.
904        let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
905        assert!(reduce(&t, "b c -> b c", EinopsReduction::Sum).is_err());
906    }
907
908    #[test]
909    fn test_unmatched_paren() {
910        let t = leaf(&[1.0, 2.0], &[2]);
911        assert!(rearrange(&t, "(a -> a").is_err());
912    }
913
914    #[test]
915    fn test_duplicate_axis_name() {
916        let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
917        assert!(rearrange(&t, "a a -> a a").is_err());
918    }
919
920    // -----------------------------------------------------------------------
921    // Parser tests
922    // -----------------------------------------------------------------------
923
924    #[test]
925    fn test_parse_simple() {
926        let p = parse_pattern("b c h w -> b c h w").unwrap();
927        assert_eq!(flatten_axes(&p.left), vec!["b", "c", "h", "w"]);
928        assert_eq!(flatten_axes(&p.right), vec!["b", "c", "h", "w"]);
929    }
930
931    #[test]
932    fn test_parse_groups() {
933        let p = parse_pattern("b c h w -> b (c h w)").unwrap();
934        assert_eq!(p.right.len(), 2); // b, (c h w)
935        match &p.right[1] {
936            AxisSpec::Group(names) => assert_eq!(names, &["c", "h", "w"]),
937            _ => panic!("expected Group"),
938        }
939    }
940
941    #[test]
942    fn test_parse_left_group() {
943        let p = parse_pattern("b (c h) w -> b c h w").unwrap();
944        assert_eq!(p.left.len(), 3); // b, (c h), w
945        match &p.left[1] {
946            AxisSpec::Group(names) => assert_eq!(names, &["c", "h"]),
947            _ => panic!("expected Group"),
948        }
949    }
950}