use std::collections::BTreeMap;
use std::sync::Arc;
use crate::autograd::autocast_ops::autocast_guard;
use crate::autograd::no_grad::{is_grad_enabled, no_grad};
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug, Clone)]
struct ParsedEquation {
input_subscripts: Vec<Vec<char>>,
output_subscripts: Vec<char>,
}
fn parse_equation(equation: &str, n_inputs: usize) -> FerrotorchResult<ParsedEquation> {
let equation = equation.replace(' ', "");
let (lhs, output_subscripts) = if let Some((lhs, rhs)) = equation.split_once("->") {
let out: Vec<char> = rhs.chars().collect();
for &c in &out {
if !c.is_ascii_lowercase() {
return Err(FerrotorchError::InvalidArgument {
message: format!("einsum: invalid character '{c}' in output subscripts"),
});
}
}
(lhs.to_string(), out)
} else {
let lhs = equation.clone();
let mut counts: BTreeMap<char, usize> = BTreeMap::new();
for c in lhs.chars() {
if c == ',' {
continue;
}
if !c.is_ascii_lowercase() {
return Err(FerrotorchError::InvalidArgument {
message: format!("einsum: invalid character '{c}' in subscripts"),
});
}
*counts.entry(c).or_insert(0) += 1;
}
let out: Vec<char> = counts
.into_iter()
.filter(|&(_, count)| count == 1)
.map(|(c, _)| c)
.collect();
(lhs, out)
};
let input_parts: Vec<&str> = lhs.split(',').collect();
if input_parts.len() != n_inputs {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einsum: equation has {} input subscripts but {} tensors were provided",
input_parts.len(),
n_inputs
),
});
}
let input_subscripts: Vec<Vec<char>> = input_parts
.iter()
.map(|part| {
let chars: Vec<char> = part.chars().collect();
for &c in &chars {
if !c.is_ascii_lowercase() {
return Err(FerrotorchError::InvalidArgument {
message: format!("einsum: invalid character '{c}' in input subscripts"),
});
}
}
Ok(chars)
})
.collect::<FerrotorchResult<Vec<_>>>()?;
Ok(ParsedEquation {
input_subscripts,
output_subscripts,
})
}
fn build_dim_map<T: Float>(
parsed: &ParsedEquation,
inputs: &[&Tensor<T>],
) -> FerrotorchResult<BTreeMap<char, usize>> {
let mut dim_map: BTreeMap<char, usize> = BTreeMap::new();
for (i, (subs, tensor)) in parsed
.input_subscripts
.iter()
.zip(inputs.iter())
.enumerate()
{
if subs.len() != tensor.ndim() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einsum: input {} has {} subscripts but tensor has {} dimensions",
i,
subs.len(),
tensor.ndim()
),
});
}
for (axis, &c) in subs.iter().enumerate() {
let size = tensor.shape()[axis];
if let Some(&existing) = dim_map.get(&c) {
if existing != size {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"einsum: index '{c}' has inconsistent sizes: {existing} vs {size}"
),
});
}
} else {
dim_map.insert(c, size);
}
}
}
for &c in &parsed.output_subscripts {
if !dim_map.contains_key(&c) {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einsum: output index '{c}' does not appear in any input subscripts"
),
});
}
}
Ok(dim_map)
}
fn einsum_single_gpu<T: Float>(
parsed: &ParsedEquation,
input: &Tensor<T>,
dim_map: &BTreeMap<char, usize>,
) -> FerrotorchResult<Tensor<T>> {
let in_subs = &parsed.input_subscripts[0];
let out_subs = &parsed.output_subscripts;
if has_duplicate_chars(in_subs) {
return einsum_single_repeated_gpu(in_subs, out_subs, input, dim_map);
}
for &c in out_subs {
if !in_subs.contains(&c) {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einsum: output index '{c}' does not appear in any input subscripts"
),
});
}
}
no_grad(|| {
let mut keep_chars: Vec<char> = in_subs.clone();
let mut current = input.clone();
let mut axis = in_subs.len();
for &c in in_subs.iter().rev() {
axis -= 1;
if !out_subs.contains(&c) {
current = crate::grad_fns::reduction::sum_dim(¤t, axis as i64, false)?;
keep_chars.remove(axis);
}
}
if keep_chars == *out_subs {
let out_shape: Vec<usize> = out_subs.iter().map(|c| dim_map[c]).collect();
if current.shape() != out_shape.as_slice() {
return Err(FerrotorchError::Internal {
message: format!(
"einsum_single_gpu: shape mismatch after reduction: got {:?} expected {:?}",
current.shape(),
out_shape
),
});
}
return Ok(current);
}
let perm: Vec<usize> = out_subs
.iter()
.map(|c| {
keep_chars
.iter()
.position(|kc| kc == c)
.expect("out_subs char must exist in keep_chars (validated above)")
})
.collect();
let permuted = crate::methods::permute_t(¤t, &perm)?;
let materialised = crate::methods::contiguous_t(&permuted)?;
Ok(materialised)
})
}
fn einsum_single_repeated_gpu<T: Float>(
in_subs: &[char],
out_subs: &[char],
input: &Tensor<T>,
dim_map: &BTreeMap<char, usize>,
) -> FerrotorchResult<Tensor<T>> {
if in_subs.len() < 2 {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "einsum_repeated_index",
});
}
let (new_subs, diag) = diagonalize_repeats_gpu(in_subs, input)?;
for &c in out_subs {
if !new_subs.contains(&c) {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einsum: output index '{c}' does not appear in any input subscripts"
),
});
}
}
let new_parsed = ParsedEquation {
input_subscripts: vec![new_subs],
output_subscripts: out_subs.to_vec(),
};
einsum_single_gpu(&new_parsed, &diag, dim_map)
}
fn diagonalize_repeats_gpu<T: Float>(
in_subs: &[char],
input: &Tensor<T>,
) -> FerrotorchResult<(Vec<char>, Tensor<T>)> {
if !has_duplicate_chars(in_subs) {
return Ok((in_subs.to_vec(), input.clone()));
}
let in_strides = input.strides();
let in_shape = input.shape();
if in_strides.len() != in_subs.len() || in_shape.len() != in_subs.len() {
return Err(FerrotorchError::Internal {
message: format!(
"diagonalize_repeats_gpu: subs/shape/strides length mismatch: \
{} vs {} vs {}",
in_subs.len(),
in_shape.len(),
in_strides.len()
),
});
}
let mut new_subs: Vec<char> = Vec::with_capacity(in_subs.len());
let mut new_sizes: Vec<usize> = Vec::with_capacity(in_subs.len());
let mut new_strides: Vec<isize> = Vec::with_capacity(in_subs.len());
for (axis, &c) in in_subs.iter().enumerate() {
if let Some(pos) = new_subs.iter().position(|&nc| nc == c) {
if new_sizes[pos] != in_shape[axis] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"einsum: repeated index '{c}' addresses incompatible sizes \
{} vs {}",
new_sizes[pos], in_shape[axis]
),
});
}
let add = in_strides[axis];
new_strides[pos] = new_strides[pos].checked_add(add).ok_or_else(|| {
FerrotorchError::InvalidArgument {
message: "einsum diagonalisation: stride sum overflowed".into(),
}
})?;
} else {
new_subs.push(c);
new_sizes.push(in_shape[axis]);
new_strides.push(in_strides[axis]);
}
}
let view = input.as_strided(&new_sizes, &new_strides, None)?;
let materialised = view.as_strided_copy(&new_sizes, &new_strides, None)?;
Ok((new_subs, materialised))
}
fn has_duplicate_chars(chars: &[char]) -> bool {
let mut seen = std::collections::HashSet::new();
for &c in chars {
if !seen.insert(c) {
return true;
}
}
false
}
fn einsum_single<T: Float>(
parsed: &ParsedEquation,
input: &Tensor<T>,
dim_map: &BTreeMap<char, usize>,
) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() {
return einsum_single_gpu(parsed, input, dim_map);
}
let in_subs = &parsed.input_subscripts[0];
let out_subs = &parsed.output_subscripts;
let out_shape: Vec<usize> = out_subs.iter().map(|c| dim_map[c]).collect();
let out_numel: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let data = input.data_vec()?;
let in_shape = input.shape();
let summed_indices: Vec<char> = in_subs
.iter()
.filter(|c| !out_subs.contains(c))
.copied()
.collect::<Vec<_>>();
let summed_unique: Vec<char> = {
let mut v = summed_indices.clone();
v.sort_unstable();
v.dedup();
v.into_iter().filter(|c| !out_subs.contains(c)).collect()
};
let in_strides: Vec<usize> = {
let mut strides = vec![1usize; in_shape.len()];
for i in (0..in_shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * in_shape[i + 1];
}
strides
};
let summed_sizes: Vec<usize> = summed_unique.iter().map(|c| dim_map[c]).collect();
let summed_numel: usize = if summed_sizes.is_empty() {
1
} else {
summed_sizes.iter().product()
};
let mut result = vec![<T as num_traits::Zero>::zero(); out_numel];
for (out_idx, result_elem) in result.iter_mut().enumerate() {
let mut out_multi = vec![0usize; out_subs.len()];
{
let mut remainder = out_idx;
for i in (0..out_subs.len()).rev() {
let size = dim_map[&out_subs[i]];
out_multi[i] = remainder % size;
remainder /= size;
}
}
let mut idx_vals: BTreeMap<char, usize> = BTreeMap::new();
for (i, &c) in out_subs.iter().enumerate() {
idx_vals.insert(c, out_multi[i]);
}
let mut acc = <T as num_traits::Zero>::zero();
for s_idx in 0..summed_numel {
let mut remainder = s_idx;
let mut valid = true;
for i in (0..summed_unique.len()).rev() {
let val = remainder % summed_sizes[i];
remainder /= summed_sizes[i];
idx_vals.insert(summed_unique[i], val);
}
let mut first_occurrence: BTreeMap<char, Option<usize>> = BTreeMap::new();
for &c in in_subs {
let val = idx_vals[&c];
match first_occurrence.get(&c) {
Some(Some(prev_val)) => {
if *prev_val != val {
valid = false;
break;
}
}
_ => {
first_occurrence.insert(c, Some(val));
}
}
}
if !valid {
continue;
}
let mut flat_idx = 0usize;
for (axis, &c) in in_subs.iter().enumerate() {
flat_idx += idx_vals[&c] * in_strides[axis];
}
acc += data[flat_idx];
}
*result_elem = acc;
}
Tensor::from_storage(TensorStorage::cpu(result), out_shape, false)
}
fn einsum_two_gpu<T: Float>(
parsed: &ParsedEquation,
a: &Tensor<T>,
b: &Tensor<T>,
dim_map: &BTreeMap<char, usize>,
) -> FerrotorchResult<Tensor<T>> {
let a_subs_orig = &parsed.input_subscripts[0];
let b_subs_orig = &parsed.input_subscripts[1];
let out_subs = &parsed.output_subscripts;
let (a_subs_owned, a_diagonalised) = diagonalize_repeats_gpu(a_subs_orig, a)?;
let (b_subs_owned, b_diagonalised) = diagonalize_repeats_gpu(b_subs_orig, b)?;
let a_subs = &a_subs_owned;
let b_subs = &b_subs_owned;
let a = &a_diagonalised;
let b = &b_diagonalised;
if has_duplicate_chars(a_subs) || has_duplicate_chars(b_subs) {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "einsum_repeated_index",
});
}
no_grad(|| {
if a_subs.len() == 2
&& b_subs.len() == 2
&& out_subs.len() == 2
&& a_subs[0] != a_subs[1]
&& b_subs[0] != b_subs[1]
&& out_subs[0] != out_subs[1]
{
let contracted: Option<char> = a_subs
.iter()
.copied()
.find(|c| b_subs.contains(c) && !out_subs.contains(c));
if let Some(c) = contracted {
let a_other = if a_subs[0] == c { a_subs[1] } else { a_subs[0] };
let b_other = if b_subs[0] == c { b_subs[1] } else { b_subs[0] };
if a_other != b_other && out_subs.contains(&a_other) && out_subs.contains(&b_other)
{
let a_oriented = if a_subs[1] == c {
a.clone()
} else {
let permuted = crate::methods::permute_t(a, &[1, 0])?;
crate::methods::contiguous_t(&permuted)?
};
let b_oriented = if b_subs[0] == c {
b.clone()
} else {
let permuted = crate::methods::permute_t(b, &[1, 0])?;
crate::methods::contiguous_t(&permuted)?
};
let mm =
crate::grad_fns::linalg::matmul_differentiable(&a_oriented, &b_oriented)?;
if out_subs[0] == a_other && out_subs[1] == b_other {
return Ok(mm);
}
let permuted = crate::methods::permute_t(&mm, &[1, 0])?;
return crate::methods::contiguous_t(&permuted);
}
}
}
if a_subs.len() == 3
&& b_subs.len() == 3
&& out_subs.len() == 3
&& a_subs[0] == b_subs[0]
&& a_subs[0] == out_subs[0]
{
let bat = a_subs[0];
let a_uniq = a_subs[0] != a_subs[1] && a_subs[1] != a_subs[2] && a_subs[0] != a_subs[2];
let b_uniq = b_subs[0] != b_subs[1] && b_subs[1] != b_subs[2] && b_subs[0] != b_subs[2];
if a_uniq
&& b_uniq
&& bat != out_subs[1]
&& bat != out_subs[2]
&& out_subs[1] != out_subs[2]
{
let a_non_batch = [a_subs[1], a_subs[2]];
let b_non_batch = [b_subs[1], b_subs[2]];
let contracted: Option<char> = a_non_batch
.iter()
.copied()
.find(|c| b_non_batch.contains(c) && !out_subs.contains(c));
if let Some(c) = contracted {
let a_other = if a_subs[1] == c { a_subs[2] } else { a_subs[1] };
let b_other = if b_subs[1] == c { b_subs[2] } else { b_subs[1] };
if a_other != b_other
&& out_subs.contains(&a_other)
&& out_subs.contains(&b_other)
{
let a_oriented = if a_subs[2] == c {
a.clone()
} else {
let permuted = crate::methods::permute_t(a, &[0, 2, 1])?;
crate::methods::contiguous_t(&permuted)?
};
let b_oriented = if b_subs[1] == c {
b.clone()
} else {
let permuted = crate::methods::permute_t(b, &[0, 2, 1])?;
crate::methods::contiguous_t(&permuted)?
};
let result = crate::grad_fns::linalg::bmm(&a_oriented, &b_oriented)?;
if out_subs[1] == a_other && out_subs[2] == b_other {
return Ok(result);
}
let permuted = crate::methods::permute_t(&result, &[0, 2, 1])?;
return crate::methods::contiguous_t(&permuted);
}
}
}
}
if a_subs == b_subs && b_subs.as_slice() == out_subs.as_slice() {
return crate::grad_fns::arithmetic::mul(a, b);
}
if a_subs.len() == 1 && b_subs.as_slice() == a_subs.as_slice() && out_subs.is_empty() {
let prod = crate::grad_fns::arithmetic::mul(a, b)?;
return crate::grad_fns::reduction::sum(&prod);
}
if a_subs.len() == 1
&& b_subs.len() == 1
&& a_subs[0] != b_subs[0]
&& out_subs.len() == 2
&& out_subs[0] == a_subs[0]
&& out_subs[1] == b_subs[0]
{
let a_unsq = crate::grad_fns::shape::unsqueeze(a, 1)?;
let b_unsq = crate::grad_fns::shape::unsqueeze(b, 0)?;
return crate::grad_fns::arithmetic::mul(&a_unsq, &b_unsq);
}
if a_subs.is_empty() && b_subs.len() == 1 && out_subs.as_slice() == b_subs.as_slice() {
return crate::grad_fns::arithmetic::mul(a, b);
}
if b_subs.is_empty() && a_subs.len() == 1 && out_subs.as_slice() == a_subs.as_slice() {
return crate::grad_fns::arithmetic::mul(a, b);
}
if a_subs.len() == 2
&& b_subs.len() == 1
&& out_subs.len() == 1
&& a_subs[1] == b_subs[0]
&& a_subs[0] == out_subs[0]
&& a_subs[0] != a_subs[1]
{
let b_unsq = crate::grad_fns::shape::unsqueeze(b, 1)?; let mm_result = crate::grad_fns::linalg::matmul_differentiable(a, &b_unsq)?; return crate::grad_fns::shape::squeeze(&mm_result, 1);
}
if a_subs.len() == 1
&& b_subs.len() == 2
&& out_subs.len() == 1
&& a_subs[0] == b_subs[0]
&& b_subs[1] == out_subs[0]
&& b_subs[0] != b_subs[1]
{
let a_unsq = crate::grad_fns::shape::unsqueeze(a, 0)?; let mm_result = crate::grad_fns::linalg::matmul_differentiable(&a_unsq, b)?; return crate::grad_fns::shape::squeeze(&mm_result, 0);
}
einsum_two_gpu_general(a_subs, b_subs, out_subs, a, b, dim_map)
})
}
fn einsum_two_gpu_general<T: Float>(
a_subs: &[char],
b_subs: &[char],
out_subs: &[char],
a: &Tensor<T>,
b: &Tensor<T>,
dim_map: &BTreeMap<char, usize>,
) -> FerrotorchResult<Tensor<T>> {
if has_duplicate_chars(a_subs) || has_duplicate_chars(b_subs) {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "einsum_general",
});
}
let a_only_lone: Vec<char> = a_subs
.iter()
.copied()
.filter(|c| !b_subs.contains(c) && !out_subs.contains(c))
.collect();
let b_only_lone: Vec<char> = b_subs
.iter()
.copied()
.filter(|c| !a_subs.contains(c) && !out_subs.contains(c))
.collect();
let (a_reduced_subs, a_reduced) = reduce_lone_axes(a_subs, &a_only_lone, a)?;
let (b_reduced_subs, b_reduced) = reduce_lone_axes(b_subs, &b_only_lone, b)?;
let mut batch_chars: Vec<char> = Vec::new();
let mut free_a_chars: Vec<char> = Vec::new();
let mut free_b_chars: Vec<char> = Vec::new();
let mut contract_chars: Vec<char> = Vec::new();
for &c in &a_reduced_subs {
let in_b = b_reduced_subs.contains(&c);
let in_out = out_subs.contains(&c);
match (in_b, in_out) {
(true, true) => {
if !batch_chars.contains(&c) {
batch_chars.push(c);
}
}
(true, false) => {
if !contract_chars.contains(&c) {
contract_chars.push(c);
}
}
(false, true) => {
if !free_a_chars.contains(&c) {
free_a_chars.push(c);
}
}
(false, false) => {
return Err(FerrotorchError::Internal {
message: format!(
"einsum_two_gpu_general: lone-A char '{c}' survived reduction"
),
});
}
}
}
for &c in &b_reduced_subs {
if !a_reduced_subs.contains(&c) && out_subs.contains(&c) && !free_b_chars.contains(&c) {
free_b_chars.push(c);
}
}
let a_axis_of = |c: char| -> Option<usize> { a_reduced_subs.iter().position(|&x| x == c) };
let b_axis_of = |c: char| -> Option<usize> { b_reduced_subs.iter().position(|&x| x == c) };
let mut a_perm: Vec<usize> = Vec::with_capacity(a_reduced_subs.len());
for &c in &batch_chars {
a_perm.push(a_axis_of(c).ok_or_else(|| FerrotorchError::Internal {
message: format!("einsum_two_gpu_general: batch char '{c}' missing from A"),
})?);
}
for &c in &free_a_chars {
a_perm.push(a_axis_of(c).ok_or_else(|| FerrotorchError::Internal {
message: format!("einsum_two_gpu_general: free-A char '{c}' missing from A"),
})?);
}
for &c in &contract_chars {
a_perm.push(a_axis_of(c).ok_or_else(|| FerrotorchError::Internal {
message: format!("einsum_two_gpu_general: contract char '{c}' missing from A"),
})?);
}
if a_perm.len() != a_reduced_subs.len() {
return Err(FerrotorchError::Internal {
message: format!(
"einsum_two_gpu_general: A permutation has {} axes, expected {}",
a_perm.len(),
a_reduced_subs.len()
),
});
}
let mut b_perm: Vec<usize> = Vec::with_capacity(b_reduced_subs.len());
for &c in &batch_chars {
b_perm.push(b_axis_of(c).ok_or_else(|| FerrotorchError::Internal {
message: format!("einsum_two_gpu_general: batch char '{c}' missing from B"),
})?);
}
for &c in &contract_chars {
b_perm.push(b_axis_of(c).ok_or_else(|| FerrotorchError::Internal {
message: format!("einsum_two_gpu_general: contract char '{c}' missing from B"),
})?);
}
for &c in &free_b_chars {
b_perm.push(b_axis_of(c).ok_or_else(|| FerrotorchError::Internal {
message: format!("einsum_two_gpu_general: free-B char '{c}' missing from B"),
})?);
}
if b_perm.len() != b_reduced_subs.len() {
return Err(FerrotorchError::Internal {
message: format!(
"einsum_two_gpu_general: B permutation has {} axes, expected {}",
b_perm.len(),
b_reduced_subs.len()
),
});
}
let a_perm_view = crate::methods::permute_t(&a_reduced, &a_perm)?;
let a_permuted = crate::methods::contiguous_t(&a_perm_view)?;
let b_perm_view = crate::methods::permute_t(&b_reduced, &b_perm)?;
let b_permuted = crate::methods::contiguous_t(&b_perm_view)?;
let batch_sizes: Vec<usize> = batch_chars.iter().map(|c| dim_map[c]).collect();
let free_a_sizes: Vec<usize> = free_a_chars.iter().map(|c| dim_map[c]).collect();
let free_b_sizes: Vec<usize> = free_b_chars.iter().map(|c| dim_map[c]).collect();
let contract_sizes: Vec<usize> = contract_chars.iter().map(|c| dim_map[c]).collect();
let batch_total: usize = batch_sizes.iter().product::<usize>().max(1);
let free_a_total: usize = free_a_sizes.iter().product::<usize>().max(1);
let free_b_total: usize = free_b_sizes.iter().product::<usize>().max(1);
let contract_total: usize = contract_sizes.iter().product::<usize>().max(1);
let a_3d = crate::grad_fns::shape::reshape(
&a_permuted,
&[
batch_total as isize,
free_a_total as isize,
contract_total as isize,
],
)?;
let b_3d = crate::grad_fns::shape::reshape(
&b_permuted,
&[
batch_total as isize,
contract_total as isize,
free_b_total as isize,
],
)?;
let bmm_result = crate::grad_fns::linalg::bmm(&a_3d, &b_3d)?;
let mut intermediate_shape: Vec<isize> =
Vec::with_capacity(batch_sizes.len() + free_a_sizes.len() + free_b_sizes.len());
intermediate_shape.extend(batch_sizes.iter().map(|&n| n as isize));
intermediate_shape.extend(free_a_sizes.iter().map(|&n| n as isize));
intermediate_shape.extend(free_b_sizes.iter().map(|&n| n as isize));
let intermediate = if intermediate_shape.is_empty() {
crate::grad_fns::shape::reshape(&bmm_result, &[])?
} else {
crate::grad_fns::shape::reshape(&bmm_result, &intermediate_shape)?
};
let intermediate_chars: Vec<char> = batch_chars
.iter()
.chain(free_a_chars.iter())
.chain(free_b_chars.iter())
.copied()
.collect();
if intermediate_chars == *out_subs {
return Ok(intermediate);
}
if intermediate_chars.len() != out_subs.len() {
return Err(FerrotorchError::Internal {
message: format!(
"einsum_two_gpu_general: intermediate has {} axes, output has {}",
intermediate_chars.len(),
out_subs.len()
),
});
}
let out_perm: Vec<usize> = out_subs
.iter()
.map(|c| {
intermediate_chars
.iter()
.position(|ic| ic == c)
.ok_or_else(|| FerrotorchError::Internal {
message: format!(
"einsum_two_gpu_general: out char '{c}' missing from intermediate"
),
})
})
.collect::<FerrotorchResult<Vec<_>>>()?;
let permuted_view = crate::methods::permute_t(&intermediate, &out_perm)?;
crate::methods::contiguous_t(&permuted_view)
}
fn reduce_lone_axes<T: Float>(
subs: &[char],
lone_chars: &[char],
tensor: &Tensor<T>,
) -> FerrotorchResult<(Vec<char>, Tensor<T>)> {
if lone_chars.is_empty() {
return Ok((subs.to_vec(), tensor.clone()));
}
let mut current_subs = subs.to_vec();
let mut current = tensor.clone();
for axis in (0..subs.len()).rev() {
if lone_chars.contains(&subs[axis]) {
current = crate::grad_fns::reduction::sum_dim(¤t, axis as i64, false)?;
current_subs.remove(axis);
}
}
Ok((current_subs, current))
}
fn einsum_two<T: Float>(
parsed: &ParsedEquation,
a: &Tensor<T>,
b: &Tensor<T>,
dim_map: &BTreeMap<char, usize>,
) -> FerrotorchResult<Tensor<T>> {
if a.is_cuda() || b.is_cuda() {
if a.device() != b.device() {
return Err(FerrotorchError::DeviceMismatch {
expected: a.device(),
got: b.device(),
});
}
return einsum_two_gpu(parsed, a, b, dim_map);
}
let a_subs = &parsed.input_subscripts[0];
let b_subs = &parsed.input_subscripts[1];
let out_subs = &parsed.output_subscripts;
let mut batch_chars: Vec<char> = Vec::new();
let mut free_a_chars: Vec<char> = Vec::new();
let mut free_b_chars: Vec<char> = Vec::new();
let mut contract_chars: Vec<char> = Vec::new();
let a_unique: Vec<char> = {
let mut v = a_subs.clone();
v.sort_unstable();
v.dedup();
v
};
let b_unique: Vec<char> = {
let mut v = b_subs.clone();
v.sort_unstable();
v.dedup();
v
};
for &c in &a_unique {
let in_b = b_unique.contains(&c);
let in_out = out_subs.contains(&c);
match (in_b, in_out) {
(true, true) => batch_chars.push(c),
(true, false) => contract_chars.push(c),
(false, true) => free_a_chars.push(c),
(false, false) => {
free_a_chars.push(c); }
}
}
for &c in &b_unique {
if !a_unique.contains(&c) && out_subs.contains(&c) {
free_b_chars.push(c);
}
}
let batch_sizes: Vec<usize> = batch_chars.iter().map(|c| dim_map[c]).collect();
let free_a_sizes: Vec<usize> = free_a_chars.iter().map(|c| dim_map[c]).collect();
let free_b_sizes: Vec<usize> = free_b_chars.iter().map(|c| dim_map[c]).collect();
let contract_sizes: Vec<usize> = contract_chars.iter().map(|c| dim_map[c]).collect();
let out_shape_empty: Vec<usize> = out_subs.iter().map(|c| dim_map[c]).collect();
if out_shape_empty.contains(&0) {
return Tensor::from_storage(TensorStorage::cpu(Vec::new()), out_shape_empty, false);
}
let batch_total: usize = batch_sizes.iter().product::<usize>().max(1);
let free_a_total: usize = free_a_sizes.iter().product::<usize>().max(1);
let free_b_total: usize = free_b_sizes.iter().product::<usize>().max(1);
let contract_total: usize = if contract_sizes.is_empty() {
1
} else {
contract_sizes.iter().product::<usize>()
};
let a_data = a.data_vec()?;
let b_data = b.data_vec()?;
let a_shape = a.shape();
let b_shape = b.shape();
let a_strides = row_major_strides(a_shape);
let b_strides = row_major_strides(b_shape);
let a_char_to_axis: BTreeMap<char, Vec<usize>> = {
let mut m: BTreeMap<char, Vec<usize>> = BTreeMap::new();
for (axis, &c) in a_subs.iter().enumerate() {
m.entry(c).or_default().push(axis);
}
m
};
let b_char_to_axis: BTreeMap<char, Vec<usize>> = {
let mut m: BTreeMap<char, Vec<usize>> = BTreeMap::new();
for (axis, &c) in b_subs.iter().enumerate() {
m.entry(c).or_default().push(axis);
}
m
};
fn decode_multi(flat: usize, sizes: &[usize]) -> Vec<usize> {
let mut result = vec![0usize; sizes.len()];
let mut remainder = flat;
for i in (0..sizes.len()).rev() {
result[i] = remainder % sizes[i];
remainder /= sizes[i];
}
result
}
#[inline]
#[allow(clippy::too_many_arguments)]
fn compute_a_flat(
batch_chars: &[char],
batch_vals: &[usize],
free_a_chars: &[char],
free_a_vals: &[usize],
contract_chars: &[char],
contract_vals: &[usize],
a_char_to_axis: &BTreeMap<char, Vec<usize>>,
a_strides: &[usize],
) -> usize {
let mut flat = 0usize;
for (i, &c) in batch_chars.iter().enumerate() {
if let Some(axes) = a_char_to_axis.get(&c) {
for &ax in axes {
flat += batch_vals[i] * a_strides[ax];
}
}
}
for (i, &c) in free_a_chars.iter().enumerate() {
if let Some(axes) = a_char_to_axis.get(&c) {
for &ax in axes {
flat += free_a_vals[i] * a_strides[ax];
}
}
}
for (i, &c) in contract_chars.iter().enumerate() {
if let Some(axes) = a_char_to_axis.get(&c) {
for &ax in axes {
flat += contract_vals[i] * a_strides[ax];
}
}
}
flat
}
#[inline]
#[allow(clippy::too_many_arguments)]
fn compute_b_flat(
batch_chars: &[char],
batch_vals: &[usize],
free_b_chars: &[char],
free_b_vals: &[usize],
contract_chars: &[char],
contract_vals: &[usize],
b_char_to_axis: &BTreeMap<char, Vec<usize>>,
b_strides: &[usize],
) -> usize {
let mut flat = 0usize;
for (i, &c) in batch_chars.iter().enumerate() {
if let Some(axes) = b_char_to_axis.get(&c) {
for &ax in axes {
flat += batch_vals[i] * b_strides[ax];
}
}
}
for (i, &c) in contract_chars.iter().enumerate() {
if let Some(axes) = b_char_to_axis.get(&c) {
for &ax in axes {
flat += contract_vals[i] * b_strides[ax];
}
}
}
for (i, &c) in free_b_chars.iter().enumerate() {
if let Some(axes) = b_char_to_axis.get(&c) {
for &ax in axes {
flat += free_b_vals[i] * b_strides[ax];
}
}
}
flat
}
let gemm_size = batch_total * free_a_total * free_b_total;
let mut gemm_result = vec![<T as num_traits::Zero>::zero(); gemm_size];
for bi in 0..batch_total {
let batch_vals = decode_multi(bi, &batch_sizes);
for fa in 0..free_a_total {
let free_a_vals = decode_multi(fa, &free_a_sizes);
for fb in 0..free_b_total {
let free_b_vals = decode_multi(fb, &free_b_sizes);
let mut acc = <T as num_traits::Zero>::zero();
for ci in 0..contract_total {
let contract_vals = decode_multi(ci, &contract_sizes);
let a_flat = compute_a_flat(
&batch_chars,
&batch_vals,
&free_a_chars,
&free_a_vals,
&contract_chars,
&contract_vals,
&a_char_to_axis,
&a_strides,
);
let b_flat = compute_b_flat(
&batch_chars,
&batch_vals,
&free_b_chars,
&free_b_vals,
&contract_chars,
&contract_vals,
&b_char_to_axis,
&b_strides,
);
acc += a_data[a_flat] * b_data[b_flat];
}
gemm_result[bi * (free_a_total * free_b_total) + fa * free_b_total + fb] = acc;
}
}
}
let intermediate_chars: Vec<char> = batch_chars
.iter()
.chain(free_a_chars.iter())
.chain(free_b_chars.iter())
.copied()
.collect();
let intermediate_sizes: Vec<usize> = batch_sizes
.iter()
.chain(free_a_sizes.iter())
.chain(free_b_sizes.iter())
.copied()
.collect();
if intermediate_chars == *out_subs {
let out_shape: Vec<usize> = out_subs.iter().map(|c| dim_map[c]).collect();
return Tensor::from_storage(TensorStorage::cpu(gemm_result), out_shape, false);
}
let out_shape: Vec<usize> = out_subs.iter().map(|c| dim_map[c]).collect();
let out_numel: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let perm: Vec<usize> = out_subs
.iter()
.map(|c| {
intermediate_chars
.iter()
.position(|ic| ic == c)
.expect("output char must exist in intermediate")
})
.collect();
let inter_strides = row_major_strides(&intermediate_sizes);
let mut result = vec![<T as num_traits::Zero>::zero(); out_numel];
for (out_flat, result_elem) in result.iter_mut().enumerate() {
let out_multi = decode_multi(out_flat, &out_shape);
let mut inter_flat = 0usize;
for (out_axis, &inter_axis) in perm.iter().enumerate() {
inter_flat += out_multi[out_axis] * inter_strides[inter_axis];
}
*result_elem = gemm_result[inter_flat];
}
Tensor::from_storage(TensorStorage::cpu(result), out_shape, false)
}
fn row_major_strides(shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
if ndim == 0 {
return vec![];
}
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
pub fn einsum<T: Float>(equation: &str, inputs: &[&Tensor<T>]) -> FerrotorchResult<Tensor<T>> {
if inputs.is_empty() || inputs.len() > 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einsum: expected 1 or 2 input tensors, got {}",
inputs.len()
),
});
}
let parsed = parse_equation(equation, inputs.len())?;
let dim_map = build_dim_map(&parsed, inputs)?;
let result = match inputs.len() {
1 => einsum_single(&parsed, inputs[0], &dim_map)?,
2 => einsum_two(&parsed, inputs[0], inputs[1], &dim_map)?,
_ => unreachable!(),
};
Ok(result)
}
pub fn einsum_differentiable<T: Float>(
equation: &str,
inputs: &[&Tensor<T>],
) -> FerrotorchResult<Tensor<T>> {
autocast_guard("einsum");
let result = einsum(equation, inputs)?;
let any_requires_grad = inputs.iter().any(|t| t.requires_grad());
if is_grad_enabled() && any_requires_grad {
let wrapped = match inputs.len() {
1 => {
let grad_fn = Arc::new(EinsumBackwardSingle {
equation: equation.to_string(),
input: inputs[0].clone(),
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
}
2 => {
let grad_fn = Arc::new(EinsumBackwardTwo {
equation: equation.to_string(),
a: inputs[0].clone(),
b: inputs[1].clone(),
});
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
}
_ => Ok(result),
}?;
Ok(wrapped)
} else {
Ok(result)
}
}
#[derive(Debug)]
struct EinsumBackwardSingle<T: Float> {
equation: String,
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for EinsumBackwardSingle<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let (lhs, rhs) = self
.equation
.split_once("->")
.unwrap_or((&self.equation, ""));
let in_subs: Vec<char> = lhs.chars().filter(|c| c.is_ascii_lowercase()).collect();
let out_subs: Vec<char> = rhs.chars().collect();
if has_duplicate_chars(&in_subs) {
return self.backward_repeated_index(grad_output, &in_subs, &out_subs);
}
for &c in &out_subs {
if !in_subs.contains(&c) {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einsum backward: output index '{c}' does not appear in input subscripts"
),
});
}
}
let in_shape = self.input.shape();
let dropped_chars: Vec<char> = in_subs
.iter()
.filter(|c| !out_subs.contains(c))
.copied()
.collect();
let intermediate_chars: Vec<char> = out_subs
.iter()
.chain(dropped_chars.iter())
.copied()
.collect();
let dim_size = |c: char| -> usize {
for (axis, &ic) in in_subs.iter().enumerate() {
if ic == c {
return in_shape[axis];
}
}
unreachable!("dim_size called for char not in in_subs")
};
let intermediate_shape: Vec<usize> =
intermediate_chars.iter().map(|&c| dim_size(c)).collect();
let unsqueezed_shape: Vec<usize> = (0..intermediate_chars.len())
.map(|i| {
if i < out_subs.len() {
intermediate_shape[i]
} else {
1
}
})
.collect();
let grad_unsq = if grad_output.shape() == unsqueezed_shape.as_slice() {
grad_output.clone()
} else if grad_output.is_contiguous() {
grad_output.view_reshape(unsqueezed_shape.clone())?
} else {
grad_output
.contiguous()?
.view_reshape(unsqueezed_shape.clone())?
};
let grad_expanded = if intermediate_shape.is_empty()
|| grad_unsq.shape() == intermediate_shape.as_slice()
{
grad_unsq
} else {
crate::grad_fns::shape::expand(&grad_unsq, &intermediate_shape)?
};
if intermediate_chars == in_subs {
return Ok(vec![Some(crate::methods::contiguous_t(&grad_expanded)?)]);
}
let perm: Vec<usize> = in_subs
.iter()
.map(|c| {
intermediate_chars
.iter()
.position(|ic| ic == c)
.expect("in_subs char must exist in intermediate_chars")
})
.collect();
let permuted = crate::methods::permute_t(&grad_expanded, &perm)?;
let grad_input = crate::methods::contiguous_t(&permuted)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"EinsumBackward"
}
}
impl<T: Float> EinsumBackwardSingle<T> {
fn backward_repeated_index(
&self,
grad_output: &Tensor<T>,
in_subs: &[char],
out_subs: &[char],
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let in_shape: Vec<usize> = self.input.shape().to_vec();
let in_numel = self.input.numel();
let mut grad_data = vec![<T as num_traits::Zero>::zero(); in_numel];
let grad_out_data = grad_output.data_vec()?;
let out_strides = row_major_strides(grad_output.shape());
for (flat, grad_elem) in grad_data.iter_mut().enumerate().take(in_numel) {
let mut multi = vec![0usize; in_subs.len()];
{
let mut rem = flat;
for i in (0..in_subs.len()).rev() {
multi[i] = rem % in_shape[i];
rem /= in_shape[i];
}
}
let mut char_val: BTreeMap<char, usize> = BTreeMap::new();
let mut valid = true;
for (axis, &c) in in_subs.iter().enumerate() {
match char_val.get(&c) {
Some(&prev) if prev != multi[axis] => {
valid = false;
break;
}
_ => {
char_val.insert(c, multi[axis]);
}
}
}
if !valid {
continue;
}
let mut out_flat = 0usize;
for (oi, &oc) in out_subs.iter().enumerate() {
out_flat += char_val[&oc] * out_strides[oi];
}
*grad_elem = if out_subs.is_empty() {
grad_out_data[0]
} else {
grad_out_data[out_flat]
};
}
let grad_tensor = Tensor::from_storage(TensorStorage::cpu(grad_data), in_shape, false)?;
Ok(vec![Some(grad_tensor)])
}
}
#[derive(Debug)]
struct EinsumBackwardTwo<T: Float> {
equation: String,
a: Tensor<T>,
b: Tensor<T>,
}
impl<T: Float> EinsumBackwardTwo<T> {
fn backward_equation(&self, target: usize) -> (String, usize, usize) {
let (lhs, rhs) = self
.equation
.split_once("->")
.unwrap_or((&self.equation, ""));
let parts: Vec<&str> = lhs.split(',').collect();
let a_subs = parts[0];
let b_subs = parts[1];
let out_subs = rhs;
if target == 0 {
let eq = format!("{out_subs},{b_subs}->{a_subs}");
(eq, 0, 1) } else {
let eq = format!("{a_subs},{out_subs}->{b_subs}");
(eq, 1, 0) }
}
}
impl<T: Float> GradFn<T> for EinsumBackwardTwo<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_a = if self.a.requires_grad() {
let (eq, _, _) = self.backward_equation(0);
Some(einsum(&eq, &[grad_output, &self.b])?)
} else {
None
};
let grad_b = if self.b.requires_grad() {
let (eq, _, _) = self.backward_equation(1);
Some(einsum(&eq, &[&self.a, grad_output])?)
} else {
None
};
Ok(vec![grad_a, grad_b])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"EinsumBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
}
fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: {} vs {}",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"index {i}: {a} vs {e} (diff {})",
(a - e).abs()
);
}
}
#[test]
fn test_einsum_zero_batch_bilinear_eq() {
let x1 = t(&[], &[0, 3]); let w = t(
&(0..(4 * 3 * 2)).map(|i| i as f32).collect::<Vec<_>>(),
&[4, 3, 2],
);
let out = einsum("bi,oij->boj", &[&x1, &w]).unwrap();
assert_eq!(out.shape(), &[0, 4, 2]);
assert_eq!(out.data().unwrap().len(), 0);
}
#[test]
fn test_einsum_zero_contracted_dim_zero_filled() {
let a = t(&[], &[2, 0]); let b = t(&[], &[0, 3]); let c = einsum("ij,jk->ik", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[2, 3]);
assert_close(c.data().unwrap(), &[0.0; 6], 1e-6);
}
#[test]
fn test_einsum_mm() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = einsum("ij,jk->ik", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[2, 2]);
assert_close(c.data().unwrap(), &[19.0, 22.0, 43.0, 50.0], 1e-6);
}
#[test]
fn test_einsum_bmm() {
#[rustfmt::skip]
let a_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0,
1.0, 0.0, 0.0, 1.0,
];
#[rustfmt::skip]
let b_data: Vec<f32> = vec![
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
];
let a = t(&a_data, &[2, 2, 2]);
let b = t(&b_data, &[2, 2, 2]);
let c = einsum("bij,bjk->bik", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[2, 2, 2]);
let d = c.data().unwrap();
assert_close(&d[0..4], &[19.0, 22.0, 43.0, 50.0], 1e-6);
assert_close(&d[4..8], &[9.0, 10.0, 11.0, 12.0], 1e-6);
}
#[test]
fn test_einsum_trace() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let c = einsum("ii->", &[&a]).unwrap();
assert!(c.is_scalar());
assert!((c.item().unwrap() - 5.0).abs() < 1e-6);
}
#[test]
fn test_einsum_outer_product() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let b = t(&[4.0, 5.0], &[2]);
let c = einsum("i,j->ij", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[3, 2]);
assert_close(c.data().unwrap(), &[4.0, 5.0, 8.0, 10.0, 12.0, 15.0], 1e-6);
}
#[test]
fn test_einsum_transpose() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let c = einsum("ij->ji", &[&a]).unwrap();
assert_eq!(c.shape(), &[3, 2]);
assert_close(c.data().unwrap(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 1e-6);
}
#[test]
fn test_einsum_sum_all() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let c = einsum("ij->", &[&a]).unwrap();
assert!(c.is_scalar());
assert!((c.item().unwrap() - 21.0).abs() < 1e-6);
}
#[test]
fn test_einsum_sum_axis() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let c = einsum("ij->i", &[&a]).unwrap();
assert_eq!(c.shape(), &[2]);
assert_close(c.data().unwrap(), &[6.0, 15.0], 1e-6);
}
#[test]
fn test_einsum_implicit_mm() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = einsum("ij,jk", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[2, 2]);
assert_close(c.data().unwrap(), &[19.0, 22.0, 43.0, 50.0], 1e-6);
}
#[test]
fn test_einsum_backward_mm() {
let a = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = leaf(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = einsum_differentiable("ij,jk->ik", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[2, 2]);
let c_data = c.data().unwrap();
let loss_val: f32 = c_data.iter().sum();
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(
&self,
_grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let ones = vec![<T as num_traits::One>::one(); self.input.numel()];
let g = Tensor::from_storage(
TensorStorage::cpu(ones),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(g)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![loss_val]),
vec![],
Arc::new(SumBackward { input: c }),
)
.unwrap();
loss.backward().unwrap();
let a_grad = a.grad().unwrap().expect("a should have grad");
let b_grad = b.grad().unwrap().expect("b should have grad");
assert_eq!(a_grad.shape(), &[2, 2]);
assert_eq!(b_grad.shape(), &[2, 2]);
assert_close(a_grad.data().unwrap(), &[11.0, 15.0, 11.0, 15.0], 1e-5);
assert_close(b_grad.data().unwrap(), &[4.0, 4.0, 6.0, 6.0], 1e-5);
}
#[test]
fn test_einsum_invalid_equation() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
assert!(einsum("ij,jk,kl->il", &[&a, &b]).is_err());
assert!(einsum("ijk,jk->ik", &[&a, &b]).is_err());
let c = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
assert!(einsum("ij,jk->ik", &[&c, &a]).is_err());
assert!(einsum("i1,1j->ij", &[&a, &b]).is_err());
}
#[test]
fn test_einsum_diagonal() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let c = einsum("ii->i", &[&a]).unwrap();
assert_eq!(c.shape(), &[2]);
assert_close(c.data().unwrap(), &[1.0, 4.0], 1e-6);
}
#[test]
fn test_einsum_dot() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let b = t(&[4.0, 5.0, 6.0], &[3]);
let c = einsum("i,i->", &[&a, &b]).unwrap();
assert!(c.is_scalar());
assert!((c.item().unwrap() - 32.0).abs() < 1e-6);
}
#[test]
fn test_einsum_non_square_mm() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let b = t(
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
&[3, 4],
);
let c = einsum("ij,jk->ik", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[2, 4]);
assert_close(
c.data().unwrap(),
&[38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0],
1e-5,
);
}
#[test]
fn test_einsum_differentiable_fires_autocast_guard() {
use crate::autograd::autocast::{AutocastDtype, autocast, set_autocast_debug};
use crate::autograd::autocast_ops::{AutocastCategory, drain_autocast_events};
set_autocast_debug(true);
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
drain_autocast_events();
let _ = einsum_differentiable("ij,jk->ik", &[&a, &b]).unwrap();
assert!(drain_autocast_events().is_empty());
autocast(AutocastDtype::F16, || {
drain_autocast_events();
let _ = einsum_differentiable("ij,jk->ik", &[&a, &b]).unwrap();
let events = drain_autocast_events();
assert_eq!(events.len(), 1);
assert_eq!(events[0].op, "einsum");
assert_eq!(events[0].category, AutocastCategory::ReducedPrecision);
});
}
}