use crate::prelude_dev::*;
#[allow(clippy::needless_range_loop)]
pub fn tensordot_to_einsum_str(
dim_a: usize,
dim_b: usize,
axes: impl TryInto<AxesPairIndex<isize>, Error: Into<Error>>,
) -> Result<String> {
let mut axes = axes.try_into().map_err(Into::into)?;
if axes == AxesPairIndex::None {
axes = AxesPairIndex::Val(2);
}
let (axes_a, axes_b): (Vec<usize>, Vec<usize>) = match axes {
AxesPairIndex::None => unreachable!("this have been handled above"),
AxesPairIndex::Pair(axes_a, axes_b) => (
normalize_axes_index(axes_a, dim_a, false, false)?.into_iter().map(|x| x as usize).collect(),
normalize_axes_index(axes_b, dim_b, false, false)?.into_iter().map(|x| x as usize).collect(),
),
AxesPairIndex::Val(n) => {
if n < 0 {
return rstsr_raise!(InvalidValue, "n must be non-negative")?;
}
let n = n as usize;
if n > dim_a || n > dim_b {
return rstsr_raise!(
InvalidLayout,
"n must be less than or equal to the number of dimensions of both tensors"
)?;
}
let axes_a = (dim_a - n..dim_a).collect();
let axes_b = (0..n).collect();
(axes_a, axes_b)
},
};
assert_eq!(axes_a.len(), axes_b.len(), "axes must have same length");
for &i in &axes_a {
assert!(i < dim_a, "axis {} out of bounds for tensor A", i);
}
for &j in &axes_b {
assert!(j < dim_b, "axis {} out of bounds for tensor B", j);
}
let mut label_gen = (b'a'..=b'z').chain(b'A'..=b'Z').map(|c| c as char);
let mut a_labels = vec![None; dim_a];
let mut b_labels = vec![None; dim_b];
for (&i, &j) in axes_a.iter().zip(axes_b.iter()) {
let label = label_gen.next().ok_or_else(|| rstsr_error!(InvalidValue, "ran out of unique labels"))?;
a_labels[i] = Some(label);
b_labels[j] = Some(label);
}
for i in 0..dim_a {
if a_labels[i].is_none() {
let label = label_gen.next().ok_or_else(|| rstsr_error!(InvalidValue, "ran out of unique labels"))?;
a_labels[i] = Some(label);
}
}
for j in 0..dim_b {
if b_labels[j].is_none() {
let label = label_gen.next().ok_or_else(|| rstsr_error!(InvalidValue, "ran out of unique labels"))?;
b_labels[j] = Some(label);
}
}
let a_str: String = a_labels.iter().map(|opt| opt.unwrap()).collect();
let b_str: String = b_labels.iter().map(|opt| opt.unwrap()).collect();
let mut out_labels = Vec::new();
for i in 0..dim_a {
if !axes_a.contains(&i) {
out_labels.push(a_labels[i].unwrap());
}
}
for j in 0..dim_b {
if !axes_b.contains(&j) {
out_labels.push(b_labels[j].unwrap());
}
}
let out_str: String = out_labels.into_iter().collect();
Ok(format!("{}, {} -> {}", a_str, b_str, out_str))
}