use std::collections::BTreeMap;
use std::sync::Arc;
use crate::autograd::autocast_ops::autocast_guard;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug, Clone)]
struct ParsedEquation {
input_subscripts: Vec<Vec<char>>,
output_subscripts: Vec<char>,
}
fn parse_equation(equation: &str, n_inputs: usize) -> FerrotorchResult<ParsedEquation> {
let equation = equation.replace(' ', "");
let (lhs, output_subscripts) = if let Some((lhs, rhs)) = equation.split_once("->") {
let out: Vec<char> = rhs.chars().collect();
for &c in &out {
if !c.is_ascii_lowercase() {
return Err(FerrotorchError::InvalidArgument {
message: format!("einsum: invalid character '{c}' in output subscripts"),
});
}
}
(lhs.to_string(), out)
} else {
let lhs = equation.clone();
let mut counts: BTreeMap<char, usize> = BTreeMap::new();
for c in lhs.chars() {
if c == ',' {
continue;
}
if !c.is_ascii_lowercase() {
return Err(FerrotorchError::InvalidArgument {
message: format!("einsum: invalid character '{c}' in subscripts"),
});
}
*counts.entry(c).or_insert(0) += 1;
}
let out: Vec<char> = counts
.into_iter()
.filter(|&(_, count)| count == 1)
.map(|(c, _)| c)
.collect();
(lhs, out)
};
let input_parts: Vec<&str> = lhs.split(',').collect();
if input_parts.len() != n_inputs {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einsum: equation has {} input subscripts but {} tensors were provided",
input_parts.len(),
n_inputs
),
});
}
let input_subscripts: Vec<Vec<char>> = input_parts
.iter()
.map(|part| {
let chars: Vec<char> = part.chars().collect();
for &c in &chars {
if !c.is_ascii_lowercase() {
return Err(FerrotorchError::InvalidArgument {
message: format!("einsum: invalid character '{c}' in input subscripts"),
});
}
}
Ok(chars)
})
.collect::<FerrotorchResult<Vec<_>>>()?;
Ok(ParsedEquation {
input_subscripts,
output_subscripts,
})
}
fn build_dim_map<T: Float>(
parsed: &ParsedEquation,
inputs: &[&Tensor<T>],
) -> FerrotorchResult<BTreeMap<char, usize>> {
let mut dim_map: BTreeMap<char, usize> = BTreeMap::new();
for (i, (subs, tensor)) in parsed
.input_subscripts
.iter()
.zip(inputs.iter())
.enumerate()
{
if subs.len() != tensor.ndim() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einsum: input {} has {} subscripts but tensor has {} dimensions",
i,
subs.len(),
tensor.ndim()
),
});
}
for (axis, &c) in subs.iter().enumerate() {
let size = tensor.shape()[axis];
if let Some(&existing) = dim_map.get(&c) {
if existing != size {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"einsum: index '{c}' has inconsistent sizes: {} vs {}",
existing, size
),
});
}
} else {
dim_map.insert(c, size);
}
}
}
for &c in &parsed.output_subscripts {
if !dim_map.contains_key(&c) {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einsum: output index '{c}' does not appear in any input subscripts"
),
});
}
}
Ok(dim_map)
}
fn einsum_single<T: Float>(
parsed: &ParsedEquation,
input: &Tensor<T>,
dim_map: &BTreeMap<char, usize>,
) -> FerrotorchResult<Tensor<T>> {
let in_subs = &parsed.input_subscripts[0];
let out_subs = &parsed.output_subscripts;
let out_shape: Vec<usize> = out_subs.iter().map(|c| dim_map[c]).collect();
let out_numel: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let data = input.data_vec()?;
let in_shape = input.shape();
let summed_indices: Vec<char> = in_subs
.iter()
.filter(|c| !out_subs.contains(c))
.copied()
.collect::<Vec<_>>();
let summed_unique: Vec<char> = {
let mut v = summed_indices.clone();
v.sort();
v.dedup();
v.into_iter().filter(|c| !out_subs.contains(c)).collect()
};
let in_strides: Vec<usize> = {
let mut strides = vec![1usize; in_shape.len()];
for i in (0..in_shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * in_shape[i + 1];
}
strides
};
let summed_sizes: Vec<usize> = summed_unique.iter().map(|c| dim_map[c]).collect();
let summed_numel: usize = if summed_sizes.is_empty() {
1
} else {
summed_sizes.iter().product()
};
let mut result = vec![<T as num_traits::Zero>::zero(); out_numel];
for (out_idx, result_elem) in result.iter_mut().enumerate() {
let mut out_multi = vec![0usize; out_subs.len()];
{
let mut remainder = out_idx;
for i in (0..out_subs.len()).rev() {
let size = dim_map[&out_subs[i]];
out_multi[i] = remainder % size;
remainder /= size;
}
}
let mut idx_vals: BTreeMap<char, usize> = BTreeMap::new();
for (i, &c) in out_subs.iter().enumerate() {
idx_vals.insert(c, out_multi[i]);
}
let mut acc = <T as num_traits::Zero>::zero();
for s_idx in 0..summed_numel {
let mut remainder = s_idx;
let mut valid = true;
for i in (0..summed_unique.len()).rev() {
let val = remainder % summed_sizes[i];
remainder /= summed_sizes[i];
idx_vals.insert(summed_unique[i], val);
}
let mut first_occurrence: BTreeMap<char, Option<usize>> = BTreeMap::new();
for &c in in_subs.iter() {
let val = idx_vals[&c];
match first_occurrence.get(&c) {
Some(Some(prev_val)) => {
if *prev_val != val {
valid = false;
break;
}
}
_ => {
first_occurrence.insert(c, Some(val));
}
}
}
if !valid {
continue;
}
let mut flat_idx = 0usize;
for (axis, &c) in in_subs.iter().enumerate() {
flat_idx += idx_vals[&c] * in_strides[axis];
}
acc += data[flat_idx];
}
*result_elem = acc;
}
Tensor::from_storage(TensorStorage::cpu(result), out_shape, false)
}
fn einsum_two<T: Float>(
parsed: &ParsedEquation,
a: &Tensor<T>,
b: &Tensor<T>,
dim_map: &BTreeMap<char, usize>,
) -> FerrotorchResult<Tensor<T>> {
let a_subs = &parsed.input_subscripts[0];
let b_subs = &parsed.input_subscripts[1];
let out_subs = &parsed.output_subscripts;
let mut batch_chars: Vec<char> = Vec::new();
let mut free_a_chars: Vec<char> = Vec::new();
let mut free_b_chars: Vec<char> = Vec::new();
let mut contract_chars: Vec<char> = Vec::new();
let a_unique: Vec<char> = {
let mut v = a_subs.clone();
v.sort();
v.dedup();
v
};
let b_unique: Vec<char> = {
let mut v = b_subs.clone();
v.sort();
v.dedup();
v
};
for &c in &a_unique {
let in_b = b_unique.contains(&c);
let in_out = out_subs.contains(&c);
match (in_b, in_out) {
(true, true) => batch_chars.push(c),
(true, false) => contract_chars.push(c),
(false, true) => free_a_chars.push(c),
(false, false) => {
free_a_chars.push(c); }
}
}
for &c in &b_unique {
if !a_unique.contains(&c) && out_subs.contains(&c) {
free_b_chars.push(c);
}
}
let batch_sizes: Vec<usize> = batch_chars.iter().map(|c| dim_map[c]).collect();
let free_a_sizes: Vec<usize> = free_a_chars.iter().map(|c| dim_map[c]).collect();
let free_b_sizes: Vec<usize> = free_b_chars.iter().map(|c| dim_map[c]).collect();
let contract_sizes: Vec<usize> = contract_chars.iter().map(|c| dim_map[c]).collect();
let batch_total: usize = batch_sizes.iter().product::<usize>().max(1);
let free_a_total: usize = free_a_sizes.iter().product::<usize>().max(1);
let free_b_total: usize = free_b_sizes.iter().product::<usize>().max(1);
let contract_total: usize = contract_sizes.iter().product::<usize>().max(1);
let a_data = a.data_vec()?;
let b_data = b.data_vec()?;
let a_shape = a.shape();
let b_shape = b.shape();
let a_strides = row_major_strides(a_shape);
let b_strides = row_major_strides(b_shape);
let a_char_to_axis: BTreeMap<char, Vec<usize>> = {
let mut m: BTreeMap<char, Vec<usize>> = BTreeMap::new();
for (axis, &c) in a_subs.iter().enumerate() {
m.entry(c).or_default().push(axis);
}
m
};
let b_char_to_axis: BTreeMap<char, Vec<usize>> = {
let mut m: BTreeMap<char, Vec<usize>> = BTreeMap::new();
for (axis, &c) in b_subs.iter().enumerate() {
m.entry(c).or_default().push(axis);
}
m
};
fn decode_multi(flat: usize, sizes: &[usize]) -> Vec<usize> {
let mut result = vec![0usize; sizes.len()];
let mut remainder = flat;
for i in (0..sizes.len()).rev() {
result[i] = remainder % sizes[i];
remainder /= sizes[i];
}
result
}
#[inline]
#[allow(clippy::too_many_arguments)]
fn compute_a_flat(
batch_chars: &[char],
batch_vals: &[usize],
free_a_chars: &[char],
free_a_vals: &[usize],
contract_chars: &[char],
contract_vals: &[usize],
a_char_to_axis: &BTreeMap<char, Vec<usize>>,
a_strides: &[usize],
) -> usize {
let mut flat = 0usize;
for (i, &c) in batch_chars.iter().enumerate() {
if let Some(axes) = a_char_to_axis.get(&c) {
for &ax in axes {
flat += batch_vals[i] * a_strides[ax];
}
}
}
for (i, &c) in free_a_chars.iter().enumerate() {
if let Some(axes) = a_char_to_axis.get(&c) {
for &ax in axes {
flat += free_a_vals[i] * a_strides[ax];
}
}
}
for (i, &c) in contract_chars.iter().enumerate() {
if let Some(axes) = a_char_to_axis.get(&c) {
for &ax in axes {
flat += contract_vals[i] * a_strides[ax];
}
}
}
flat
}
#[inline]
#[allow(clippy::too_many_arguments)]
fn compute_b_flat(
batch_chars: &[char],
batch_vals: &[usize],
free_b_chars: &[char],
free_b_vals: &[usize],
contract_chars: &[char],
contract_vals: &[usize],
b_char_to_axis: &BTreeMap<char, Vec<usize>>,
b_strides: &[usize],
) -> usize {
let mut flat = 0usize;
for (i, &c) in batch_chars.iter().enumerate() {
if let Some(axes) = b_char_to_axis.get(&c) {
for &ax in axes {
flat += batch_vals[i] * b_strides[ax];
}
}
}
for (i, &c) in contract_chars.iter().enumerate() {
if let Some(axes) = b_char_to_axis.get(&c) {
for &ax in axes {
flat += contract_vals[i] * b_strides[ax];
}
}
}
for (i, &c) in free_b_chars.iter().enumerate() {
if let Some(axes) = b_char_to_axis.get(&c) {
for &ax in axes {
flat += free_b_vals[i] * b_strides[ax];
}
}
}
flat
}
let gemm_size = batch_total * free_a_total * free_b_total;
let mut gemm_result = vec![<T as num_traits::Zero>::zero(); gemm_size];
for bi in 0..batch_total {
let batch_vals = decode_multi(bi, &batch_sizes);
for fa in 0..free_a_total {
let free_a_vals = decode_multi(fa, &free_a_sizes);
for fb in 0..free_b_total {
let free_b_vals = decode_multi(fb, &free_b_sizes);
let mut acc = <T as num_traits::Zero>::zero();
for ci in 0..contract_total {
let contract_vals = decode_multi(ci, &contract_sizes);
let a_flat = compute_a_flat(
&batch_chars,
&batch_vals,
&free_a_chars,
&free_a_vals,
&contract_chars,
&contract_vals,
&a_char_to_axis,
&a_strides,
);
let b_flat = compute_b_flat(
&batch_chars,
&batch_vals,
&free_b_chars,
&free_b_vals,
&contract_chars,
&contract_vals,
&b_char_to_axis,
&b_strides,
);
acc += a_data[a_flat] * b_data[b_flat];
}
gemm_result[bi * (free_a_total * free_b_total) + fa * free_b_total + fb] = acc;
}
}
}
let intermediate_chars: Vec<char> = batch_chars
.iter()
.chain(free_a_chars.iter())
.chain(free_b_chars.iter())
.copied()
.collect();
let intermediate_sizes: Vec<usize> = batch_sizes
.iter()
.chain(free_a_sizes.iter())
.chain(free_b_sizes.iter())
.copied()
.collect();
if intermediate_chars == *out_subs {
let out_shape: Vec<usize> = out_subs.iter().map(|c| dim_map[c]).collect();
return Tensor::from_storage(TensorStorage::cpu(gemm_result), out_shape, false);
}
let out_shape: Vec<usize> = out_subs.iter().map(|c| dim_map[c]).collect();
let out_numel: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let perm: Vec<usize> = out_subs
.iter()
.map(|c| {
intermediate_chars
.iter()
.position(|ic| ic == c)
.expect("output char must exist in intermediate")
})
.collect();
let inter_strides = row_major_strides(&intermediate_sizes);
let mut result = vec![<T as num_traits::Zero>::zero(); out_numel];
for (out_flat, result_elem) in result.iter_mut().enumerate() {
let out_multi = decode_multi(out_flat, &out_shape);
let mut inter_flat = 0usize;
for (out_axis, &inter_axis) in perm.iter().enumerate() {
inter_flat += out_multi[out_axis] * inter_strides[inter_axis];
}
*result_elem = gemm_result[inter_flat];
}
Tensor::from_storage(TensorStorage::cpu(result), out_shape, false)
}
fn row_major_strides(shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
if ndim == 0 {
return vec![];
}
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
pub fn einsum<T: Float>(equation: &str, inputs: &[&Tensor<T>]) -> FerrotorchResult<Tensor<T>> {
if inputs.is_empty() || inputs.len() > 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"einsum: expected 1 or 2 input tensors, got {}",
inputs.len()
),
});
}
let parsed = parse_equation(equation, inputs.len())?;
let dim_map = build_dim_map(&parsed, inputs)?;
let result = match inputs.len() {
1 => einsum_single(&parsed, inputs[0], &dim_map)?,
2 => einsum_two(&parsed, inputs[0], inputs[1], &dim_map)?,
_ => unreachable!(),
};
Ok(result)
}
pub fn einsum_differentiable<T: Float>(
equation: &str,
inputs: &[&Tensor<T>],
) -> FerrotorchResult<Tensor<T>> {
autocast_guard("einsum");
let result = einsum(equation, inputs)?;
let any_requires_grad = inputs.iter().any(|t| t.requires_grad());
if is_grad_enabled() && any_requires_grad {
let device = result.device();
let wrapped = match inputs.len() {
1 => {
let grad_fn = Arc::new(EinsumBackwardSingle {
equation: equation.to_string(),
input: inputs[0].clone(),
});
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
}
2 => {
let grad_fn = Arc::new(EinsumBackwardTwo {
equation: equation.to_string(),
a: inputs[0].clone(),
b: inputs[1].clone(),
});
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
}
_ => Ok(result),
}?;
Ok(wrapped)
} else {
Ok(result)
}
}
#[derive(Debug)]
struct EinsumBackwardSingle<T: Float> {
equation: String,
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for EinsumBackwardSingle<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let (lhs, rhs) = self
.equation
.split_once("->")
.unwrap_or((&self.equation, ""));
let in_subs: Vec<char> = lhs.chars().filter(|c| c.is_ascii_lowercase()).collect();
let out_subs: Vec<char> = rhs.chars().collect();
let has_repeated = {
let mut seen = std::collections::HashSet::new();
in_subs.iter().any(|c| !seen.insert(c))
};
if has_repeated {
let in_shape: Vec<usize> = self.input.shape().to_vec();
let in_numel = self.input.numel();
let mut grad_data = vec![<T as num_traits::Zero>::zero(); in_numel];
let grad_out_data = grad_output.data_vec()?;
let out_strides = row_major_strides(grad_output.shape());
for (flat, grad_elem) in grad_data.iter_mut().enumerate().take(in_numel) {
let mut multi = vec![0usize; in_subs.len()];
{
let mut rem = flat;
for i in (0..in_subs.len()).rev() {
multi[i] = rem % in_shape[i];
rem /= in_shape[i];
}
}
let mut char_val: BTreeMap<char, usize> = BTreeMap::new();
let mut valid = true;
for (axis, &c) in in_subs.iter().enumerate() {
match char_val.get(&c) {
Some(&prev) if prev != multi[axis] => {
valid = false;
break;
}
_ => {
char_val.insert(c, multi[axis]);
}
}
}
if !valid {
continue; }
let mut out_flat = 0usize;
for (oi, &oc) in out_subs.iter().enumerate() {
out_flat += char_val[&oc] * out_strides[oi];
}
*grad_elem = if out_subs.is_empty() {
grad_out_data[0]
} else {
grad_out_data[out_flat]
};
}
let grad_tensor = Tensor::from_storage(TensorStorage::cpu(grad_data), in_shape, false)?;
return Ok(vec![Some(grad_tensor)]);
}
if out_subs.is_empty() {
let scalar_val = grad_output.item()?;
let grad_data = vec![scalar_val; self.input.numel()];
let grad_tensor = Tensor::from_storage(
TensorStorage::cpu(grad_data),
self.input.shape().to_vec(),
false,
)?;
return Ok(vec![Some(grad_tensor)]);
}
let reverse_eq = format!("{}->{}", rhs, lhs);
let grad_a = einsum(&reverse_eq, &[grad_output])?;
Ok(vec![Some(grad_a)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"EinsumBackward"
}
}
#[derive(Debug)]
struct EinsumBackwardTwo<T: Float> {
equation: String,
a: Tensor<T>,
b: Tensor<T>,
}
impl<T: Float> EinsumBackwardTwo<T> {
fn backward_equation(&self, target: usize) -> (String, usize, usize) {
let (lhs, rhs) = self
.equation
.split_once("->")
.unwrap_or((&self.equation, ""));
let parts: Vec<&str> = lhs.split(',').collect();
let a_subs = parts[0];
let b_subs = parts[1];
let out_subs = rhs;
if target == 0 {
let eq = format!("{},{}->{}", out_subs, b_subs, a_subs);
(eq, 0, 1) } else {
let eq = format!("{},{}->{}", a_subs, out_subs, b_subs);
(eq, 1, 0) }
}
}
impl<T: Float> GradFn<T> for EinsumBackwardTwo<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_a = if self.a.requires_grad() {
let (eq, _, _) = self.backward_equation(0);
Some(einsum(&eq, &[grad_output, &self.b])?)
} else {
None
};
let grad_b = if self.b.requires_grad() {
let (eq, _, _) = self.backward_equation(1);
Some(einsum(&eq, &[&self.a, grad_output])?)
} else {
None
};
Ok(vec![grad_a, grad_b])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"EinsumBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
}
fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: {} vs {}",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"index {i}: {a} vs {e} (diff {})",
(a - e).abs()
);
}
}
#[test]
fn test_einsum_mm() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = einsum("ij,jk->ik", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[2, 2]);
assert_close(c.data().unwrap(), &[19.0, 22.0, 43.0, 50.0], 1e-6);
}
#[test]
fn test_einsum_bmm() {
#[rustfmt::skip]
let a_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0,
1.0, 0.0, 0.0, 1.0,
];
#[rustfmt::skip]
let b_data: Vec<f32> = vec![
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
];
let a = t(&a_data, &[2, 2, 2]);
let b = t(&b_data, &[2, 2, 2]);
let c = einsum("bij,bjk->bik", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[2, 2, 2]);
let d = c.data().unwrap();
assert_close(&d[0..4], &[19.0, 22.0, 43.0, 50.0], 1e-6);
assert_close(&d[4..8], &[9.0, 10.0, 11.0, 12.0], 1e-6);
}
#[test]
fn test_einsum_trace() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let c = einsum("ii->", &[&a]).unwrap();
assert!(c.is_scalar());
assert!((c.item().unwrap() - 5.0).abs() < 1e-6);
}
#[test]
fn test_einsum_outer_product() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let b = t(&[4.0, 5.0], &[2]);
let c = einsum("i,j->ij", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[3, 2]);
assert_close(c.data().unwrap(), &[4.0, 5.0, 8.0, 10.0, 12.0, 15.0], 1e-6);
}
#[test]
fn test_einsum_transpose() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let c = einsum("ij->ji", &[&a]).unwrap();
assert_eq!(c.shape(), &[3, 2]);
assert_close(c.data().unwrap(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0], 1e-6);
}
#[test]
fn test_einsum_sum_all() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let c = einsum("ij->", &[&a]).unwrap();
assert!(c.is_scalar());
assert!((c.item().unwrap() - 21.0).abs() < 1e-6);
}
#[test]
fn test_einsum_sum_axis() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let c = einsum("ij->i", &[&a]).unwrap();
assert_eq!(c.shape(), &[2]);
assert_close(c.data().unwrap(), &[6.0, 15.0], 1e-6);
}
#[test]
fn test_einsum_implicit_mm() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = einsum("ij,jk", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[2, 2]);
assert_close(c.data().unwrap(), &[19.0, 22.0, 43.0, 50.0], 1e-6);
}
#[test]
fn test_einsum_backward_mm() {
let a = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = leaf(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = einsum_differentiable("ij,jk->ik", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[2, 2]);
let c_data = c.data().unwrap();
let loss_val: f32 = c_data.iter().sum();
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(
&self,
_grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let ones = vec![<T as num_traits::One>::one(); self.input.numel()];
let g = Tensor::from_storage(
TensorStorage::cpu(ones),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(g)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![loss_val]),
vec![],
Arc::new(SumBackward { input: c }),
)
.unwrap();
loss.backward().unwrap();
let a_grad = a.grad().unwrap().expect("a should have grad");
let b_grad = b.grad().unwrap().expect("b should have grad");
assert_eq!(a_grad.shape(), &[2, 2]);
assert_eq!(b_grad.shape(), &[2, 2]);
assert_close(a_grad.data().unwrap(), &[11.0, 15.0, 11.0, 15.0], 1e-5);
assert_close(b_grad.data().unwrap(), &[4.0, 4.0, 6.0, 6.0], 1e-5);
}
#[test]
fn test_einsum_invalid_equation() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
assert!(einsum("ij,jk,kl->il", &[&a, &b]).is_err());
assert!(einsum("ijk,jk->ik", &[&a, &b]).is_err());
let c = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
assert!(einsum("ij,jk->ik", &[&c, &a]).is_err());
assert!(einsum("i1,1j->ij", &[&a, &b]).is_err());
}
#[test]
fn test_einsum_diagonal() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let c = einsum("ii->i", &[&a]).unwrap();
assert_eq!(c.shape(), &[2]);
assert_close(c.data().unwrap(), &[1.0, 4.0], 1e-6);
}
#[test]
fn test_einsum_dot() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let b = t(&[4.0, 5.0, 6.0], &[3]);
let c = einsum("i,i->", &[&a, &b]).unwrap();
assert!(c.is_scalar());
assert!((c.item().unwrap() - 32.0).abs() < 1e-6);
}
#[test]
fn test_einsum_non_square_mm() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let b = t(
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
&[3, 4],
);
let c = einsum("ij,jk->ik", &[&a, &b]).unwrap();
assert_eq!(c.shape(), &[2, 4]);
assert_close(
c.data().unwrap(),
&[38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0],
1e-5,
);
}
#[test]
fn test_einsum_differentiable_fires_autocast_guard() {
use crate::autograd::autocast::{AutocastDtype, autocast, set_autocast_debug};
use crate::autograd::autocast_ops::{AutocastCategory, drain_autocast_events};
set_autocast_debug(true);
let a = t(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
drain_autocast_events();
let _ = einsum_differentiable("ij,jk->ik", &[&a, &b]).unwrap();
assert!(drain_autocast_events().is_empty());
autocast(AutocastDtype::F16, || {
drain_autocast_events();
let _ = einsum_differentiable("ij,jk->ik", &[&a, &b]).unwrap();
let events = drain_autocast_events();
assert_eq!(events.len(), 1);
assert_eq!(events[0].op, "einsum");
assert_eq!(events[0].category, AutocastCategory::ReducedPrecision);
});
}
}