use scirs2_core::ndarray::{Array, ArrayD, ArrayViewD, IxDyn};
use std::collections::{HashMap, HashSet};
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum EinsumError {
ParseError(String),
ShapeMismatch(String),
UnknownIndex(String),
EllipsisNotSupported(String),
}
impl fmt::Display for EinsumError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
EinsumError::ParseError(msg) => write!(f, "einsum parse error: {msg}"),
EinsumError::ShapeMismatch(msg) => write!(f, "einsum shape mismatch: {msg}"),
EinsumError::UnknownIndex(msg) => write!(f, "einsum unknown index: {msg}"),
EinsumError::EllipsisNotSupported(msg) => {
write!(f, "einsum ellipsis not supported: {msg}")
}
}
}
}
impl std::error::Error for EinsumError {}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum IndexSpec {
Label(char),
Ellipsis,
}
#[derive(Debug, Clone)]
struct ParsedEq {
input_specs: Vec<Vec<IndexSpec>>,
output_spec: Vec<IndexSpec>,
ellipsis_rank: Option<usize>,
}
fn parse_subscript(s: &str) -> Result<Vec<IndexSpec>, EinsumError> {
let mut specs = Vec::new();
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
if c == '.' {
let d1 = chars.next();
let d2 = chars.next();
if d1 != Some('.') || d2 != Some('.') {
return Err(EinsumError::ParseError(format!(
"malformed ellipsis in subscript '{s}'"
)));
}
if specs.contains(&IndexSpec::Ellipsis) {
return Err(EinsumError::ParseError(format!(
"multiple ellipses in subscript '{s}'"
)));
}
specs.push(IndexSpec::Ellipsis);
} else if c.is_ascii_alphabetic() {
specs.push(IndexSpec::Label(c));
} else {
return Err(EinsumError::ParseError(format!(
"unexpected character '{c}' in subscript '{s}'"
)));
}
}
Ok(specs)
}
fn parse_einsum(eq: &str) -> Result<ParsedEq, EinsumError> {
let (inputs_str, output_str_opt) = if let Some(arrow) = eq.find("->") {
(&eq[..arrow], Some(&eq[arrow + 2..]))
} else {
(eq, None)
};
let input_parts: Vec<&str> = inputs_str.split(',').collect();
if input_parts.is_empty() {
return Err(EinsumError::ParseError(
"equation has no input subscripts".to_owned(),
));
}
let mut input_specs: Vec<Vec<IndexSpec>> = Vec::with_capacity(input_parts.len());
for part in &input_parts {
input_specs.push(parse_subscript(part.trim())?);
}
let output_spec = if let Some(out) = output_str_opt {
parse_subscript(out.trim())?
} else {
let mut counts: HashMap<char, usize> = HashMap::new();
for specs in &input_specs {
for s in specs {
if let IndexSpec::Label(c) = s {
*counts.entry(*c).or_insert(0) += 1;
}
}
}
let mut singles: Vec<char> = counts
.iter()
.filter(|(_, &v)| v == 1)
.map(|(&k, _)| k)
.collect();
singles.sort_unstable();
singles.into_iter().map(IndexSpec::Label).collect()
};
Ok(ParsedEq {
input_specs,
output_spec,
ellipsis_rank: None,
})
}
fn resolve_sizes(
parsed: &mut ParsedEq,
ops: &[ArrayViewD<f64>],
) -> Result<HashMap<char, usize>, EinsumError> {
if parsed.input_specs.len() != ops.len() {
return Err(EinsumError::ParseError(format!(
"equation has {} input subscripts but {} operands were supplied",
parsed.input_specs.len(),
ops.len()
)));
}
let mut ell_rank: Option<usize> = None;
for (i, specs) in parsed.input_specs.iter().enumerate() {
let explicit_count = specs.iter().filter(|s| **s != IndexSpec::Ellipsis).count();
if specs.contains(&IndexSpec::Ellipsis) {
let op_rank = ops[i].ndim();
if op_rank < explicit_count {
return Err(EinsumError::ShapeMismatch(format!(
"operand {i} has rank {op_rank} but subscript has {explicit_count} explicit indices"
)));
}
let this_ell = op_rank - explicit_count;
match ell_rank {
None => ell_rank = Some(this_ell),
Some(prev) if prev != this_ell => {
return Err(EinsumError::EllipsisNotSupported(format!(
"operand {i} gives ellipsis rank {this_ell} but earlier operand gave {prev}; \
broadcasting ellipsis across different batch ranks is not supported"
)));
}
_ => {}
}
}
}
parsed.ellipsis_rank = ell_rank;
let mut sizes: HashMap<char, usize> = HashMap::new();
let ell_rank_val = ell_rank.unwrap_or(0);
let ell_chars: Vec<char> = (0..ell_rank_val)
.map(|k| char::from_u32(0xE000 + k as u32).unwrap_or('_'))
.collect();
for (i, specs) in parsed.input_specs.iter().enumerate() {
let expanded = expand_ellipsis(specs, &ell_chars);
let op_shape = ops[i].shape();
if expanded.len() != op_shape.len() {
return Err(EinsumError::ShapeMismatch(format!(
"operand {i}: subscript has {} indices after expansion but operand has rank {}",
expanded.len(),
op_shape.len()
)));
}
for (j, label) in expanded.iter().enumerate() {
let dim_size = op_shape[j];
match sizes.entry(*label) {
std::collections::hash_map::Entry::Occupied(e) => {
if *e.get() != dim_size {
return Err(EinsumError::ShapeMismatch(format!(
"index '{label}' has size {} from one operand but {} from operand {i}",
e.get(),
dim_size
)));
}
}
std::collections::hash_map::Entry::Vacant(e) => {
e.insert(dim_size);
}
}
}
}
let expanded_out = expand_ellipsis(&parsed.output_spec, &ell_chars);
for label in &expanded_out {
if !sizes.contains_key(label) {
return Err(EinsumError::UnknownIndex(format!(
"output index '{label}' does not appear in any input"
)));
}
}
Ok(sizes)
}
fn expand_ellipsis(specs: &[IndexSpec], ell_chars: &[char]) -> Vec<char> {
let mut out = Vec::with_capacity(specs.len() + ell_chars.len().saturating_sub(1));
for s in specs {
match s {
IndexSpec::Label(c) => out.push(*c),
IndexSpec::Ellipsis => out.extend_from_slice(ell_chars),
}
}
out
}
fn einsum_general(
input_expanded: &[Vec<char>],
output_expanded: &[char],
sizes: &HashMap<char, usize>,
ops: &[ArrayViewD<f64>],
) -> Result<ArrayD<f64>, EinsumError> {
let all_labels: Vec<char> = {
let mut set: HashSet<char> = HashSet::new();
for spec in input_expanded {
set.extend(spec);
}
set.extend(output_expanded);
let mut v: Vec<char> = set.into_iter().collect();
v.sort_unstable();
v
};
let output_set: HashSet<char> = output_expanded.iter().copied().collect();
let sum_labels: Vec<char> = all_labels
.iter()
.filter(|c| !output_set.contains(c))
.copied()
.collect();
let out_shape: Vec<usize> = output_expanded
.iter()
.map(|c| *sizes.get(c).unwrap_or(&1))
.collect();
let iter_labels: Vec<char> = output_expanded
.iter()
.copied()
.chain(sum_labels.iter().copied())
.collect();
let iter_sizes: Vec<usize> = iter_labels
.iter()
.map(|c| *sizes.get(c).unwrap_or(&1))
.collect();
let label_pos: HashMap<char, usize> = iter_labels
.iter()
.enumerate()
.map(|(i, &c)| (c, i))
.collect();
let op_pos_maps: Vec<Vec<usize>> = input_expanded
.iter()
.map(|spec| {
spec.iter()
.map(|c| *label_pos.get(c).unwrap_or(&0))
.collect()
})
.collect();
let out_len: usize = out_shape.iter().product();
let mut result = Array::zeros(IxDyn(&out_shape));
if out_len == 0 {
return Ok(result);
}
let total_iters: usize = iter_sizes.iter().product();
let n_dims = iter_sizes.len();
let mut strides = vec![1usize; n_dims];
for k in (0..n_dims.saturating_sub(1)).rev() {
strides[k] = strides[k + 1] * iter_sizes[k + 1];
}
for flat in 0..total_iters {
let mut multi = vec![0usize; n_dims];
let mut rem = flat;
for k in 0..n_dims {
multi[k] = rem / strides[k];
rem %= strides[k];
}
let mut prod = 1.0_f64;
for (op_idx, op) in ops.iter().enumerate() {
let op_index: Vec<usize> = op_pos_maps[op_idx].iter().map(|&p| multi[p]).collect();
prod *= op[IxDyn(&op_index)];
}
let out_index: Vec<usize> = (0..output_expanded.len()).map(|k| multi[k]).collect();
result[IxDyn(&out_index)] += prod;
}
Ok(result)
}
fn shortcut_single(
input: &[char],
output: &[char],
sizes: &HashMap<char, usize>,
op: &ArrayViewD<f64>,
) -> Option<ArrayD<f64>> {
if input.len() == output.len() && {
let in_set: HashSet<char> = input.iter().copied().collect();
let out_set: HashSet<char> = output.iter().copied().collect();
in_set == out_set
} {
let perm: Vec<usize> = output
.iter()
.map(|c| input.iter().position(|x| x == c).unwrap_or(0))
.collect();
let axes: Vec<usize> = perm;
let view = op.view();
if axes.len() == view.ndim() {
let transposed = view.permuted_axes(IxDyn(&axes));
return Some(transposed.to_owned());
}
}
if output.is_empty() && input.len() == 2 && input[0] == input[1] {
let n = *sizes.get(&input[0]).unwrap_or(&0);
let mut acc = 0.0_f64;
for i in 0..n {
acc += op[[i, i].as_ref()];
}
return Some(Array::from_elem(IxDyn(&[]), acc));
}
if output.len() == 1 && input.len() == 2 && input[0] == input[1] && output[0] == input[0] {
let n = *sizes.get(&input[0]).unwrap_or(&0);
let diag: Vec<f64> = (0..n).map(|i| op[[i, i].as_ref()]).collect();
return Some(
Array::from_shape_vec(IxDyn(&[n]), diag).unwrap_or_else(|_| Array::zeros(IxDyn(&[n]))),
);
}
None
}
fn shortcut_double(
in0: &[char],
in1: &[char],
output: &[char],
sizes: &HashMap<char, usize>,
a: &ArrayViewD<f64>,
b: &ArrayViewD<f64>,
) -> Option<ArrayD<f64>> {
if in0.len() == 2 && in1.len() == 2 && output.len() == 2 {
let (ai, aj) = (in0[0], in0[1]);
let (bj, bk) = (in1[0], in1[1]);
let (oi, ok) = (output[0], output[1]);
if aj == bj && ai == oi && bk == ok {
let m = *sizes.get(&ai).unwrap_or(&0);
let k = *sizes.get(&aj).unwrap_or(&0);
let n = *sizes.get(&bk).unwrap_or(&0);
if a.shape() == [m, k] && b.shape() == [k, n] {
let a2 = a
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.ok()?;
let b2 = b
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.ok()?;
let c = a2.dot(&b2);
return Some(c.into_dyn());
}
}
}
if in0 == in1 && output.is_empty() {
let sum: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
return Some(Array::from_elem(IxDyn(&[]), sum));
}
if in0.len() == 1 && in1.len() == 1 && output.len() == 2 {
let (li, lj) = (in0[0], in1[0]);
if li != lj && output[0] == li && output[1] == lj {
let ni = *sizes.get(&li).unwrap_or(&0);
let nj = *sizes.get(&lj).unwrap_or(&0);
let mut c = Array::zeros(IxDyn(&[ni, nj]));
for i in 0..ni {
for j in 0..nj {
c[[i, j]] = a[[i]] * b[[j]];
}
}
return Some(c);
}
}
None
}
pub fn einsum(eq: &str, ops: &[ArrayViewD<f64>]) -> Result<ArrayD<f64>, EinsumError> {
let mut parsed = parse_einsum(eq)?;
let sizes = resolve_sizes(&mut parsed, ops)?;
let ell_rank = parsed.ellipsis_rank.unwrap_or(0);
let ell_chars: Vec<char> = (0..ell_rank)
.map(|k| char::from_u32(0xE000 + k as u32).unwrap_or('_'))
.collect();
let input_expanded: Vec<Vec<char>> = parsed
.input_specs
.iter()
.map(|s| expand_ellipsis(s, &ell_chars))
.collect();
let output_expanded: Vec<char> = expand_ellipsis(&parsed.output_spec, &ell_chars);
if ops.len() == 1 {
if let Some(r) = shortcut_single(&input_expanded[0], &output_expanded, &sizes, &ops[0]) {
return Ok(r);
}
}
if ops.len() == 2 {
if let Some(r) = shortcut_double(
&input_expanded[0],
&input_expanded[1],
&output_expanded,
&sizes,
&ops[0],
&ops[1],
) {
return Ok(r);
}
}
if ops.len() == 1 || ops.len() == 2 {
einsum_general(&input_expanded, &output_expanded, &sizes, ops)
} else {
einsum_multi(&input_expanded, &output_expanded, &sizes, ops, &ell_chars)
}
}
fn einsum_multi(
input_expanded: &[Vec<char>],
output_expanded: &[char],
sizes: &HashMap<char, usize>,
ops: &[ArrayViewD<f64>],
_ell_chars: &[char],
) -> Result<ArrayD<f64>, EinsumError> {
let needed_later: HashSet<char> = {
let mut set: HashSet<char> = output_expanded.iter().copied().collect();
for spec in input_expanded.iter().skip(2) {
set.extend(spec.iter().copied());
}
set
};
let first_out_labels: Vec<char> = {
let both: HashSet<char> = input_expanded[0]
.iter()
.chain(input_expanded[1].iter())
.copied()
.collect();
let mut v: Vec<char> = both
.into_iter()
.filter(|c| needed_later.contains(c))
.collect();
v.sort_unstable();
v
};
let mut acc = einsum_general(
&[input_expanded[0].clone(), input_expanded[1].clone()],
&first_out_labels,
sizes,
&[ops[0].view(), ops[1].view()],
)?;
let mut current_labels = first_out_labels;
for step in 2..ops.len() {
let is_last = step == ops.len() - 1;
let next_out_labels: Vec<char> = if is_last {
output_expanded.to_vec()
} else {
let needed: HashSet<char> = {
let mut s: HashSet<char> = output_expanded.iter().copied().collect();
for spec in input_expanded.iter().skip(step + 1) {
s.extend(spec.iter().copied());
}
s
};
let both: HashSet<char> = current_labels
.iter()
.chain(input_expanded[step].iter())
.copied()
.collect();
let mut v: Vec<char> = both.into_iter().filter(|c| needed.contains(c)).collect();
v.sort_unstable();
v
};
acc = einsum_general(
&[current_labels.clone(), input_expanded[step].clone()],
&next_out_labels,
sizes,
&[acc.view(), ops[step].view()],
)?;
current_labels = next_out_labels;
}
Ok(acc)
}
pub fn einsum_grad(
eq: &str,
grad_out: ArrayViewD<f64>,
ops: &[ArrayViewD<f64>],
) -> Vec<ArrayD<f64>> {
let mut parsed = parse_einsum(eq).expect("einsum_grad: failed to parse equation");
let sizes = resolve_sizes(&mut parsed, ops)
.expect("einsum_grad: failed to resolve sizes from operands");
let ell_rank = parsed.ellipsis_rank.unwrap_or(0);
let ell_chars: Vec<char> = (0..ell_rank)
.map(|k| char::from_u32(0xE000 + k as u32).unwrap_or('_'))
.collect();
let input_expanded: Vec<Vec<char>> = parsed
.input_specs
.iter()
.map(|s| expand_ellipsis(s, &ell_chars))
.collect();
let output_expanded: Vec<char> = expand_ellipsis(&parsed.output_spec, &ell_chars);
let mut grads = Vec::with_capacity(ops.len());
for k in 0..ops.len() {
let grad_input_specs: Vec<Vec<char>> = std::iter::once(output_expanded.clone())
.chain(
input_expanded
.iter()
.enumerate()
.filter(|(i, _)| *i != k)
.map(|(_, s)| s.clone()),
)
.collect();
let grad_output_spec: Vec<char> = input_expanded[k].clone();
let grad_ops: Vec<ArrayViewD<f64>> = std::iter::once(grad_out.view())
.chain(
ops.iter()
.enumerate()
.filter(|(i, _)| *i != k)
.map(|(_, op)| op.view()),
)
.collect();
let gk = einsum_general(&grad_input_specs, &grad_output_spec, &sizes, &grad_ops)
.expect("einsum_grad: failed to compute gradient");
grads.push(gk);
}
grads
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array1, Array2, Array3};
fn approx_eq(a: &ArrayD<f64>, b: &ArrayD<f64>, tol: f64) -> bool {
if a.shape() != b.shape() {
return false;
}
a.iter().zip(b.iter()).all(|(x, y)| (x - y).abs() < tol)
}
#[test]
fn test_matmul() {
let a: Array2<f64> = array![[1.0, 2.0], [3.0, 4.0]];
let b: Array2<f64> = array![[5.0, 6.0], [7.0, 8.0]];
let c = einsum("ij,jk->ik", &[a.view().into_dyn(), b.view().into_dyn()]).unwrap();
let expected: Array2<f64> = array![[19.0, 22.0], [43.0, 50.0]];
assert!(approx_eq(&c, &expected.into_dyn(), 1e-10));
}
#[test]
fn test_trace() {
let a: Array2<f64> = array![[1.0, 2.0], [3.0, 4.0]];
let result = einsum("ii->", &[a.view().into_dyn()]).unwrap();
assert!((result[[]] - 5.0).abs() < 1e-10);
}
#[test]
fn test_diagonal() {
let a: Array2<f64> = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let d = einsum("ii->i", &[a.view().into_dyn()]).unwrap();
let expected: Array1<f64> = array![1.0, 5.0, 9.0];
assert!(approx_eq(&d, &expected.into_dyn(), 1e-10));
}
#[test]
fn test_transpose() {
let a: Array2<f64> = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let t = einsum("ij->ji", &[a.view().into_dyn()]).unwrap();
assert_eq!(t.shape(), &[3, 2]);
assert!((t[[0, 0]] - 1.0).abs() < 1e-10);
assert!((t[[0, 1]] - 4.0).abs() < 1e-10);
assert!((t[[1, 0]] - 2.0).abs() < 1e-10);
assert!((t[[2, 1]] - 6.0).abs() < 1e-10);
}
#[test]
fn test_outer_product() {
let u: Array1<f64> = array![1.0, 2.0, 3.0];
let v: Array1<f64> = array![4.0, 5.0];
let outer = einsum("i,j->ij", &[u.view().into_dyn(), v.view().into_dyn()]).unwrap();
assert_eq!(outer.shape(), &[3, 2]);
assert!((outer[[0, 0]] - 4.0).abs() < 1e-10);
assert!((outer[[2, 1]] - 15.0).abs() < 1e-10);
}
#[test]
fn test_inner_product() {
let u: Array1<f64> = array![1.0, 2.0, 3.0];
let v: Array1<f64> = array![4.0, 5.0, 6.0];
let s = einsum("i,i->", &[u.view().into_dyn(), v.view().into_dyn()]).unwrap();
assert!((s[[]] - 32.0).abs() < 1e-10);
}
#[test]
fn test_frobenius() {
let a: Array2<f64> = array![[1.0, 2.0], [3.0, 4.0]];
let b: Array2<f64> = array![[1.0, 0.0], [0.0, 1.0]];
let s = einsum("ij,ij->", &[a.view().into_dyn(), b.view().into_dyn()]).unwrap();
assert!((s[[]] - 5.0).abs() < 1e-10);
}
#[test]
fn test_batched_matmul_explicit() {
let a: Array3<f64> =
Array3::from_shape_fn((2, 2, 2), |(b, i, j)| ((b * 4 + i * 2 + j) as f64) + 1.0);
let b: Array3<f64> =
Array3::from_shape_fn((2, 2, 3), |(bb, j, k)| ((bb * 6 + j * 3 + k) as f64) + 1.0);
let c = einsum("bij,bjk->bik", &[a.view().into_dyn(), b.view().into_dyn()]).unwrap();
assert_eq!(c.shape(), &[2, 2, 3]);
let a0 = a
.slice(scirs2_core::ndarray::s![0, .., ..])
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.unwrap();
let b0 = b
.slice(scirs2_core::ndarray::s![0, .., ..])
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.unwrap();
let expected0: Array2<f64> = a0.dot(&b0);
for i in 0..2 {
for k in 0..3 {
assert!(
(c[[0, i, k]] - expected0[[i, k]]).abs() < 1e-8,
"mismatch at [0,{i},{k}]: {} vs {}",
c[[0, i, k]],
expected0[[i, k]]
);
}
}
}
#[test]
fn test_ellipsis_batched_matmul() {
let a: Array3<f64> =
Array3::from_shape_fn((3, 2, 4), |(b, i, j)| (b * 8 + i * 4 + j) as f64);
let b: Array3<f64> =
Array3::from_shape_fn((3, 4, 5), |(bb, j, k)| (bb * 20 + j * 5 + k) as f64);
let c = einsum(
"...ij,...jk->...ik",
&[a.view().into_dyn(), b.view().into_dyn()],
)
.unwrap();
assert_eq!(c.shape(), &[3, 2, 5]);
let a1 = a
.slice(scirs2_core::ndarray::s![1, .., ..])
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.unwrap();
let b1 = b
.slice(scirs2_core::ndarray::s![1, .., ..])
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.unwrap();
let expected1: Array2<f64> = a1.dot(&b1);
for i in 0..2 {
for k in 0..5 {
assert!(
(c[[1, i, k]] - expected1[[i, k]]).abs() < 1e-8,
"mismatch at [1,{i},{k}]"
);
}
}
}
#[test]
fn test_matmul_gradient() {
let a: Array2<f64> = array![[1.0, 2.0], [3.0, 4.0]];
let b: Array2<f64> = array![[5.0, 6.0], [7.0, 8.0]];
let grad_out: Array2<f64> = Array2::ones((2, 2));
let grads = einsum_grad(
"ij,jk->ik",
grad_out.view().into_dyn(),
&[a.view().into_dyn(), b.view().into_dyn()],
);
assert_eq!(grads.len(), 2);
assert_eq!(grads[0].shape(), &[2, 2]); assert_eq!(grads[1].shape(), &[2, 2]);
let expected_da: Array2<f64> = array![[11.0, 15.0], [11.0, 15.0]];
assert!(
approx_eq(&grads[0], &expected_da.into_dyn(), 1e-8),
"dA mismatch: {:?}",
grads[0]
);
let expected_db: Array2<f64> = array![[4.0, 4.0], [6.0, 6.0]];
assert!(
approx_eq(&grads[1], &expected_db.into_dyn(), 1e-8),
"dB mismatch: {:?}",
grads[1]
);
}
#[test]
fn test_shape_mismatch_error() {
let a: Array2<f64> = array![[1.0, 2.0], [3.0, 4.0]]; let b: Array2<f64> = array![[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]; let result = einsum("ij,jk->ik", &[b.view().into_dyn(), a.view().into_dyn()]);
assert!(matches!(result, Err(EinsumError::ShapeMismatch(_))));
}
#[test]
fn test_parse_error() {
let a: Array2<f64> = array![[1.0, 2.0], [3.0, 4.0]];
let result = einsum("ij!k->ik", &[a.view().into_dyn()]);
assert!(matches!(result, Err(EinsumError::ParseError(_))));
}
#[test]
fn test_three_operand() {
let a: Array2<f64> = array![[1.0, 2.0], [3.0, 4.0]];
let b: Array2<f64> = array![[1.0, 0.0], [0.0, 1.0]]; let c: Array2<f64> = array![[2.0, 0.0], [0.0, 2.0]]; let result = einsum(
"ij,jk,kl->il",
&[
a.view().into_dyn(),
b.view().into_dyn(),
c.view().into_dyn(),
],
)
.unwrap();
let expected: Array2<f64> = array![[2.0, 4.0], [6.0, 8.0]];
assert!(
approx_eq(&result, &expected.into_dyn(), 1e-8),
"3-op result: {:?}",
result
);
}
}