use std::collections::HashMap;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EinopsReduction {
Mean,
Sum,
Max,
Min,
}
#[derive(Debug, Clone, PartialEq)]
enum AxisSpec {
Single(String),
Group(Vec<String>),
}
#[derive(Debug)]
struct ParsedPattern {
left: Vec<AxisSpec>,
right: Vec<AxisSpec>,
}
fn flatten_axes(specs: &[AxisSpec]) -> Vec<String> {
let mut out = Vec::new();
for spec in specs {
match spec {
AxisSpec::Single(name) => out.push(name.clone()),
AxisSpec::Group(names) => out.extend(names.iter().cloned()),
}
}
out
}
fn parse_side(s: &str) -> FerrotorchResult<Vec<AxisSpec>> {
let s = s.trim();
let mut specs = Vec::new();
let mut chars = s.chars().peekable();
while let Some(&c) = chars.peek() {
if c.is_whitespace() {
chars.next();
continue;
}
if c == '(' {
chars.next();
let mut group = Vec::new();
loop {
while let Some(&c2) = chars.peek() {
if c2.is_whitespace() {
chars.next();
} else {
break;
}
}
match chars.peek() {
None => {
return Err(FerrotorchError::InvalidArgument {
message: "einops: unmatched '(' in pattern".into(),
});
}
Some(&')') => {
chars.next();
break;
}
_ => {}
}
let name = read_axis_name(&mut chars)?;
if name.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "einops: empty axis name inside parentheses".into(),
});
}
group.push(name);
}
if group.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "einops: empty parenthesized group".into(),
});
}
specs.push(AxisSpec::Group(group));
} else if c.is_ascii_alphanumeric() || c == '_' {
let name = read_axis_name(&mut chars)?;
specs.push(AxisSpec::Single(name));
} else {
return Err(FerrotorchError::InvalidArgument {
message: format!("einops: unexpected character '{c}' in pattern"),
});
}
}
Ok(specs)
}
fn read_axis_name(
chars: &mut std::iter::Peekable<std::str::Chars<'_>>,
) -> FerrotorchResult<String> {
let mut name = String::new();
while let Some(&c) = chars.peek() {
if c.is_ascii_alphanumeric() || c == '_' {
name.push(c);
chars.next();
} else {
break;
}
}
Ok(name)
}
fn parse_pattern(pattern: &str) -> FerrotorchResult<ParsedPattern> {
let pattern = pattern.trim();
let (left_str, right_str) =
pattern
.split_once("->")
.ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!("einops: pattern must contain '->', got: \"{pattern}\""),
})?;
let left = parse_side(left_str)?;
let right = parse_side(right_str)?;
let left_names = flatten_axes(&left);
let right_names = flatten_axes(&right);
let mut seen = HashMap::new();
for name in &left_names {
if seen.insert(name.as_str(), "left").is_some() {
return Err(FerrotorchError::InvalidArgument {
message: format!("einops: duplicate axis name '{name}' on left side of pattern"),
});
}
}
seen.clear();
for name in &right_names {
if seen.insert(name.as_str(), "right").is_some() {
return Err(FerrotorchError::InvalidArgument {
message: format!("einops: duplicate axis name '{name}' on right side of pattern"),
});
}
}
Ok(ParsedPattern { left, right })
}
fn resolve_sizes(
pattern: &ParsedPattern,
input_shape: &[usize],
axes_lengths: &[(&str, usize)],
) -> FerrotorchResult<HashMap<String, usize>> {
let left_flat = flatten_axes(&pattern.left);
let right_flat = flatten_axes(&pattern.right);
let left_dim_count = pattern.left.len();
if left_dim_count != input_shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einops: left side of pattern has {} axes but input tensor has {} dimensions",
left_dim_count,
input_shape.len()
),
});
}
let user_sizes: HashMap<&str, usize> = axes_lengths.iter().copied().collect();
let mut sizes: HashMap<String, usize> = HashMap::new();
for (dim_idx, spec) in pattern.left.iter().enumerate() {
let dim_size = input_shape[dim_idx];
match spec {
AxisSpec::Single(name) => {
sizes.insert(name.clone(), dim_size);
}
AxisSpec::Group(names) => {
let mut unknown_idx: Option<usize> = None;
let mut known_product: usize = 1;
for (i, name) in names.iter().enumerate() {
if let Some(&sz) = user_sizes.get(name.as_str()) {
sizes.insert(name.clone(), sz);
known_product *= sz;
} else if let Some(&sz) = sizes.get(name) {
known_product *= sz;
} else {
if unknown_idx.is_some() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einops: cannot infer sizes for split '({})' — \
provide sizes for all but one sub-axis via axes_lengths",
names.join(" ")
),
});
}
unknown_idx = Some(i);
}
}
if let Some(ui) = unknown_idx {
if known_product == 0 || dim_size % known_product != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einops: dimension {} (size {}) is not divisible by \
known product {} for split '({})'",
dim_idx,
dim_size,
known_product,
names.join(" ")
),
});
}
sizes.insert(names[ui].clone(), dim_size / known_product);
} else {
if known_product != dim_size {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"einops: split '({})' product {} does not match dimension {} size {}",
names.join(" "),
known_product,
dim_idx,
dim_size
),
});
}
}
}
}
}
for name in &right_flat {
if !sizes.contains_key(name) {
if let Some(&sz) = user_sizes.get(name.as_str()) {
sizes.insert(name.clone(), sz);
} else if !left_flat.contains(name) {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einops: axis '{name}' appears on the right but not the left \
and has no size in axes_lengths"
),
});
}
}
}
Ok(sizes)
}
fn output_shape(right: &[AxisSpec], sizes: &HashMap<String, usize>) -> Vec<usize> {
right
.iter()
.map(|spec| match spec {
AxisSpec::Single(name) => *sizes.get(name).unwrap(),
AxisSpec::Group(names) => names.iter().map(|n| sizes.get(n).unwrap()).product(),
})
.collect()
}
fn flat_to_coords(mut flat: usize, shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
let mut coords = vec![0usize; ndim];
for d in (0..ndim).rev() {
coords[d] = flat % shape[d];
flat /= shape[d];
}
coords
}
fn coords_to_flat(coords: &[usize], shape: &[usize]) -> usize {
let mut flat = 0usize;
let mut stride = 1usize;
for d in (0..shape.len()).rev() {
flat += coords[d] * stride;
stride *= shape[d];
}
flat
}
fn elementary_shape(specs: &[AxisSpec], sizes: &HashMap<String, usize>) -> Vec<usize> {
let mut shape = Vec::new();
for spec in specs {
match spec {
AxisSpec::Single(name) => shape.push(*sizes.get(name).unwrap()),
AxisSpec::Group(names) => {
for n in names {
shape.push(*sizes.get(n).unwrap());
}
}
}
}
shape
}
fn rearrange_impl<T: Float>(
data: &[T],
_input_shape: &[usize],
pattern: &ParsedPattern,
sizes: &HashMap<String, usize>,
_output_shape: &[usize],
) -> FerrotorchResult<Vec<T>> {
let left_names = flatten_axes(&pattern.left);
let right_names = flatten_axes(&pattern.right);
let left_elem_shape = elementary_shape(&pattern.left, sizes);
let right_elem_shape = elementary_shape(&pattern.right, sizes);
let perm: Vec<usize> = right_names
.iter()
.map(|name| {
left_names
.iter()
.position(|n| n == name)
.unwrap_or(usize::MAX)
})
.collect();
let elem_numel: usize = left_elem_shape.iter().product();
let mut transposed = vec![<T as num_traits::Zero>::zero(); elem_numel];
for (src_flat, &val) in data.iter().enumerate().take(elem_numel) {
let src_coords = flat_to_coords(src_flat, &left_elem_shape);
let mut dst_coords = vec![0usize; right_elem_shape.len()];
for (dst_dim, &src_dim) in perm.iter().enumerate() {
dst_coords[dst_dim] = src_coords[src_dim];
}
let dst_flat = coords_to_flat(&dst_coords, &right_elem_shape);
transposed[dst_flat] = val;
}
Ok(transposed)
}
pub fn rearrange<T: Float>(input: &Tensor<T>, pattern: &str) -> FerrotorchResult<Tensor<T>> {
rearrange_with(input, pattern, &[])
}
pub fn rearrange_with<T: Float>(
input: &Tensor<T>,
pattern: &str,
axes_lengths: &[(&str, usize)],
) -> FerrotorchResult<Tensor<T>> {
let parsed = parse_pattern(pattern)?;
let sizes = resolve_sizes(&parsed, input.shape(), axes_lengths)?;
let left_names = flatten_axes(&parsed.left);
let right_names = flatten_axes(&parsed.right);
let mut left_sorted = left_names.clone();
left_sorted.sort();
let mut right_sorted = right_names.clone();
right_sorted.sort();
if left_sorted != right_sorted {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einops rearrange: left axes {:?} and right axes {:?} must name \
the same set of axes (use `repeat` for new axes, `reduce` for removed axes)",
left_names, right_names
),
});
}
let out_shape = output_shape(&parsed.right, &sizes);
let device = input.device();
let data = input.data_vec()?;
let result_data = rearrange_impl(&data, input.shape(), &parsed, &sizes, &out_shape)?;
let t = Tensor::from_storage(TensorStorage::cpu(result_data), out_shape, false)?;
Ok(if device.is_cuda() { t.to(device)? } else { t })
}
pub fn repeat<T: Float>(
input: &Tensor<T>,
pattern: &str,
axes_lengths: &[(&str, usize)],
) -> FerrotorchResult<Tensor<T>> {
let parsed = parse_pattern(pattern)?;
let sizes = resolve_sizes(&parsed, input.shape(), axes_lengths)?;
let left_names = flatten_axes(&parsed.left);
let right_names = flatten_axes(&parsed.right);
for name in &left_names {
if !right_names.contains(name) {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einops repeat: left axis '{name}' does not appear on the right — \
use `reduce` to remove axes"
),
});
}
}
let _new_axes: Vec<&String> = right_names
.iter()
.filter(|n| !left_names.contains(n))
.collect();
let right_elem_shape = elementary_shape(&parsed.right, &sizes);
let out_shape = output_shape(&parsed.right, &sizes);
let out_numel: usize = right_elem_shape.iter().product();
let left_elem_shape = elementary_shape(&parsed.left, &sizes);
let device = input.device();
let data = input.data_vec()?;
let mut result = Vec::with_capacity(out_numel);
for dst_flat in 0..out_numel {
let dst_coords = flat_to_coords(dst_flat, &right_elem_shape);
let mut src_coords = Vec::with_capacity(left_elem_shape.len());
for (i, name) in right_names.iter().enumerate() {
if left_names.contains(name) {
src_coords.push(dst_coords[i]);
}
}
let src_flat = coords_to_flat(&src_coords, &left_elem_shape);
result.push(data[src_flat]);
}
let t = Tensor::from_storage(TensorStorage::cpu(result), out_shape, false)?;
Ok(if device.is_cuda() { t.to(device)? } else { t })
}
pub fn reduce<T: Float>(
input: &Tensor<T>,
pattern: &str,
reduction: EinopsReduction,
) -> FerrotorchResult<Tensor<T>> {
let parsed = parse_pattern(pattern)?;
let sizes = resolve_sizes(&parsed, input.shape(), &[])?;
let left_names = flatten_axes(&parsed.left);
let right_names = flatten_axes(&parsed.right);
for name in &right_names {
if !left_names.contains(name) {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einops reduce: right axis '{name}' does not appear on the left — \
use `repeat` to add new axes"
),
});
}
}
let reduced_axes: Vec<&String> = left_names
.iter()
.filter(|n| !right_names.contains(n))
.collect();
if reduced_axes.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "einops reduce: no axes are being reduced — use `rearrange` instead".into(),
});
}
let left_elem_shape = elementary_shape(&parsed.left, &sizes);
let right_elem_shape = elementary_shape(&parsed.right, &sizes);
let out_shape = output_shape(&parsed.right, &sizes);
let out_numel: usize = right_elem_shape.iter().product();
let device = input.device();
let data = input.data_vec()?;
let in_numel: usize = left_elem_shape.iter().product();
let reduce_count: usize = reduced_axes
.iter()
.map(|name| sizes.get(name.as_str()).unwrap())
.product();
let init_val = match reduction {
EinopsReduction::Sum | EinopsReduction::Mean => <T as num_traits::Zero>::zero(),
EinopsReduction::Max => T::neg_infinity(),
EinopsReduction::Min => T::infinity(),
};
let mut accum = vec![init_val; out_numel];
for (src_flat, &val) in data.iter().enumerate().take(in_numel) {
let src_coords = flat_to_coords(src_flat, &left_elem_shape);
let mut dst_coords = Vec::with_capacity(right_elem_shape.len());
for (i, name) in left_names.iter().enumerate() {
if right_names.contains(name) {
dst_coords.push(src_coords[i]);
}
}
let dst_flat = coords_to_flat(&dst_coords, &right_elem_shape);
match reduction {
EinopsReduction::Sum | EinopsReduction::Mean => {
accum[dst_flat] += val;
}
EinopsReduction::Max => {
if val > accum[dst_flat] {
accum[dst_flat] = val;
}
}
EinopsReduction::Min => {
if val < accum[dst_flat] {
accum[dst_flat] = val;
}
}
}
}
if reduction == EinopsReduction::Mean {
let n = T::from(reduce_count).unwrap();
for v in &mut accum {
*v = *v / n;
}
}
let t = Tensor::from_storage(TensorStorage::cpu(accum), out_shape, false)?;
Ok(if device.is_cuda() { t.to(device)? } else { t })
}
#[cfg(test)]
mod tests {
use super::*;
fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
#[test]
fn test_rearrange_identity() {
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let t = leaf(&data, &[2, 3, 2, 2]);
let r = rearrange(&t, "b c h w -> b c h w").unwrap();
assert_eq!(r.shape(), &[2, 3, 2, 2]);
assert_eq!(r.data().unwrap(), data.as_slice());
}
#[test]
fn test_rearrange_flatten() {
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let t = leaf(&data, &[2, 3, 2, 2]); let r = rearrange(&t, "b c h w -> b (c h w)").unwrap();
assert_eq!(r.shape(), &[2, 12]);
assert_eq!(r.data().unwrap(), data.as_slice());
}
#[test]
fn test_rearrange_transpose_nhwc_to_nchw() {
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
let t = leaf(&data, &[1, 2, 2, 3]);
let r = rearrange(&t, "b h w c -> b c h w").unwrap();
assert_eq!(r.shape(), &[1, 3, 2, 2]);
let out = r.data().unwrap();
assert_eq!(out[0], 0.0); assert_eq!(out[1], 3.0); assert_eq!(out[2], 6.0); assert_eq!(out[3], 9.0); assert_eq!(out[4], 1.0); assert_eq!(out[5], 4.0); }
#[test]
fn test_rearrange_split_with_axes_lengths() {
let data: Vec<f32> = (0..48).map(|i| i as f32).collect();
let t = leaf(&data, &[2, 6, 4]);
let r = rearrange_with(&t, "b (c h) w -> b c h w", &[("c", 3)]).unwrap();
assert_eq!(r.shape(), &[2, 3, 2, 4]);
assert_eq!(r.data().unwrap(), data.as_slice());
}
#[test]
fn test_rearrange_merge_dims() {
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let t = leaf(&data, &[1, 2, 3, 4]);
let r = rearrange(&t, "b h w c -> b (h w) c").unwrap();
assert_eq!(r.shape(), &[1, 6, 4]);
assert_eq!(r.data().unwrap(), data.as_slice());
}
#[test]
fn test_repeat_new_batch_dim() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let t = leaf(&data, &[2, 2]);
let r = repeat(&t, "h w -> b h w", &[("b", 3)]).unwrap();
assert_eq!(r.shape(), &[3, 2, 2]);
let out = r.data().unwrap();
assert_eq!(&out[0..4], &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(&out[4..8], &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(&out[8..12], &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_repeat_tile() {
let data = vec![10.0f32, 20.0, 30.0];
let t = leaf(&data, &[3]);
let r = repeat(&t, "c -> c n", &[("n", 2)]).unwrap();
assert_eq!(r.shape(), &[3, 2]);
let out = r.data().unwrap();
assert_eq!(out, &[10.0, 10.0, 20.0, 20.0, 30.0, 30.0]);
}
#[test]
fn test_reduce_mean_spatial() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let t = leaf(&data, &[1, 2, 2, 2]);
let r = reduce(&t, "b c h w -> b c", EinopsReduction::Mean).unwrap();
assert_eq!(r.shape(), &[1, 2]);
let out = r.data().unwrap();
assert!((out[0] - 2.5).abs() < 1e-6, "expected 2.5, got {}", out[0]);
assert!((out[1] - 6.5).abs() < 1e-6, "expected 6.5, got {}", out[1]);
}
#[test]
fn test_reduce_sum_batch() {
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = leaf(&data, &[3, 2]); let r = reduce(&t, "b c -> c", EinopsReduction::Sum).unwrap();
assert_eq!(r.shape(), &[2]);
let out = r.data().unwrap();
assert!((out[0] - 9.0).abs() < 1e-6);
assert!((out[1] - 12.0).abs() < 1e-6);
}
#[test]
fn test_reduce_max() {
let data = vec![1.0f32, 5.0, 3.0, 2.0, 4.0, 6.0];
let t = leaf(&data, &[3, 2]);
let r = reduce(&t, "b c -> c", EinopsReduction::Max).unwrap();
assert_eq!(r.shape(), &[2]);
let out = r.data().unwrap();
assert!((out[0] - 4.0).abs() < 1e-6); assert!((out[1] - 6.0).abs() < 1e-6); }
#[test]
fn test_reduce_min() {
let data = vec![1.0f32, 5.0, 3.0, 2.0, 4.0, 6.0];
let t = leaf(&data, &[3, 2]);
let r = reduce(&t, "b c -> c", EinopsReduction::Min).unwrap();
assert_eq!(r.shape(), &[2]);
let out = r.data().unwrap();
assert!((out[0] - 1.0).abs() < 1e-6); assert!((out[1] - 2.0).abs() < 1e-6); }
#[test]
fn test_invalid_pattern_no_arrow() {
let t = leaf(&[1.0, 2.0, 3.0], &[3]);
assert!(rearrange(&t, "a b c").is_err());
}
#[test]
fn test_mismatched_axis_count() {
let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
assert!(rearrange(&t, "a b c -> a b c").is_err());
}
#[test]
fn test_rearrange_missing_axis_on_right() {
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let t = leaf(&data, &[2, 3, 2, 2]);
assert!(rearrange(&t, "b c h w -> b c").is_err());
}
#[test]
fn test_rearrange_extra_axis_on_right() {
let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
assert!(rearrange(&t, "b c -> b c n").is_err());
}
#[test]
fn test_repeat_missing_new_axis_size() {
let t = leaf(&[1.0, 2.0], &[2]);
assert!(repeat(&t, "c -> c n", &[]).is_err());
}
#[test]
fn test_reduce_no_reduction() {
let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
assert!(reduce(&t, "b c -> b c", EinopsReduction::Sum).is_err());
}
#[test]
fn test_unmatched_paren() {
let t = leaf(&[1.0, 2.0], &[2]);
assert!(rearrange(&t, "(a -> a").is_err());
}
#[test]
fn test_duplicate_axis_name() {
let t = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
assert!(rearrange(&t, "a a -> a a").is_err());
}
#[test]
fn test_parse_simple() {
let p = parse_pattern("b c h w -> b c h w").unwrap();
assert_eq!(flatten_axes(&p.left), vec!["b", "c", "h", "w"]);
assert_eq!(flatten_axes(&p.right), vec!["b", "c", "h", "w"]);
}
#[test]
fn test_parse_groups() {
let p = parse_pattern("b c h w -> b (c h w)").unwrap();
assert_eq!(p.right.len(), 2); match &p.right[1] {
AxisSpec::Group(names) => assert_eq!(names, &["c", "h", "w"]),
_ => panic!("expected Group"),
}
}
#[test]
fn test_parse_left_group() {
let p = parse_pattern("b (c h) w -> b c h w").unwrap();
assert_eq!(p.left.len(), 3); match &p.left[1] {
AxisSpec::Group(names) => assert_eq!(names, &["c", "h"]),
_ => panic!("expected Group"),
}
}
}