use crate::tensor::TensorStorage;
use crate::{Result, Tensor, TensorError};
use scirs2_core::numeric::{One, Zero};
use std::collections::HashMap;
#[cfg(any(
feature = "blas-openblas",
feature = "blas-oxiblas",
feature = "blas-mkl"
))]
use super::blas::try_blas_optimized_patterns;
use super::cache::execute_contraction_path;
use super::patterns::try_optimize_common_patterns;
use super::utils::compute_optimal_path;
pub fn einsum<T>(equation: &str, operands: &[&Tensor<T>]) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
if operands.is_empty() {
return Err(TensorError::invalid_argument(
"At least one operand is required for einsum".to_string(),
));
}
let (input_subscripts, output_subscript) = parse_einsum_equation(equation)?;
if input_subscripts.len() != operands.len() {
return Err(TensorError::invalid_argument(format!(
"Number of operands ({}) does not match equation ({})",
operands.len(),
input_subscripts.len()
)));
}
let first_device = operands[0].device();
for operand in operands {
if operand.device() != first_device {
return Err(TensorError::device_mismatch(
"einsum",
&first_device.to_string(),
&operand.device().to_string(),
));
}
match &operand.storage {
TensorStorage::Cpu(_) => {}
#[cfg(feature = "gpu")]
TensorStorage::Gpu(_) => {
}
}
}
#[cfg(any(
all(feature = "blas-openblas", feature = "std"),
all(feature = "blas-mkl", feature = "std"),
all(feature = "blas-accelerate", feature = "std")
))]
{
let all_cpu = operands.iter().all(|op| match &op.storage {
TensorStorage::Cpu(_) => true,
#[cfg(feature = "gpu")]
TensorStorage::Gpu(_) => false,
});
if all_cpu {
if let Some(blas_result) = try_blas_optimized_patterns(equation, operands) {
return blas_result;
}
}
}
if let Some(optimized_result) = try_optimize_common_patterns(equation, operands) {
return optimized_result;
}
match operands.len() {
1 => einsum_unary(&input_subscripts[0], &output_subscript, operands[0]),
2 => einsum_binary(
&input_subscripts[0],
&input_subscripts[1],
&output_subscript,
operands[0],
operands[1],
),
_ => {
let contraction_path = compute_optimal_path(&input_subscripts, &output_subscript)?;
execute_contraction_path(operands, &contraction_path)
}
}
}
pub fn parse_einsum_equation(equation: &str) -> Result<(Vec<String>, String)> {
let parts: Vec<&str> = equation.split("->").collect();
if parts.len() > 2 {
return Err(TensorError::invalid_argument(
"Invalid einsum equation: too many '->' separators".to_string(),
));
}
let input_part = parts[0];
let output_part = if parts.len() == 2 { parts[1] } else { "" };
let input_subscripts: Vec<String> = input_part
.split(',')
.map(|s| s.trim().to_string())
.collect();
if input_subscripts.is_empty() {
return Err(TensorError::invalid_argument(
"No input subscripts found in einsum equation".to_string(),
));
}
let output_subscript = if parts.len() == 2 {
output_part.trim().to_string()
} else {
infer_output_subscript(&input_subscripts)?
};
Ok((input_subscripts, output_subscript))
}
pub fn infer_output_subscript(input_subscripts: &[String]) -> Result<String> {
let mut char_counts: HashMap<char, usize> = HashMap::new();
for subscript in input_subscripts {
for c in subscript.chars() {
if c.is_alphabetic() {
*char_counts.entry(c).or_insert(0) += 1;
}
}
}
let mut output_chars: Vec<char> = char_counts
.iter()
.filter(|(_, &count)| count == 1)
.map(|(&c, _)| c)
.collect();
output_chars.sort();
Ok(output_chars.into_iter().collect())
}
pub(super) fn einsum_unary<T>(
input_subscript: &str,
output_subscript: &str,
operand: &Tensor<T>,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let input_chars: Vec<char> = input_subscript.chars().collect();
let output_chars: Vec<char> = output_subscript.chars().collect();
if input_chars.len() != operand.shape().rank() {
return Err(TensorError::invalid_argument(format!(
"Input subscript length ({}) does not match tensor rank ({})",
input_chars.len(),
operand.shape().rank()
)));
}
if input_chars.len() == output_chars.len()
&& input_chars.iter().all(|c| output_chars.contains(c))
{
let mut permutation = Vec::new();
for &output_char in &output_chars {
if let Some(pos) = input_chars.iter().position(|&c| c == output_char) {
permutation.push(pos);
} else {
return Err(TensorError::invalid_argument(format!(
"Output character '{output_char}' not found in input"
)));
}
}
return crate::ops::manipulation::transpose_axes(operand, Some(&permutation));
}
if output_chars.is_empty() {
return crate::ops::sum(operand, None, false);
}
let mut axes_to_reduce = Vec::new();
for (i, &input_char) in input_chars.iter().enumerate() {
if !output_chars.contains(&input_char) {
axes_to_reduce.push(i as i32);
}
}
if !axes_to_reduce.is_empty() {
return crate::ops::sum(operand, Some(&axes_to_reduce), false);
}
if input_chars.len() == 2
&& output_chars.len() == 1
&& input_chars[0] == input_chars[1]
&& input_chars[0] == output_chars[0]
{
return extract_diagonal(operand);
}
Err(TensorError::invalid_argument(format!(
"Unsupported unary einsum: {input_subscript} -> {output_subscript}"
)))
}
pub(super) fn einsum_binary<T>(
left_subscript: &str,
right_subscript: &str,
output_subscript: &str,
left: &Tensor<T>,
right: &Tensor<T>,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let left_chars: Vec<char> = left_subscript.chars().collect();
let right_chars: Vec<char> = right_subscript.chars().collect();
let output_chars: Vec<char> = output_subscript.chars().collect();
if left_chars.len() != left.shape().rank() {
return Err(TensorError::invalid_argument(format!(
"Left subscript length ({}) does not match tensor rank ({})",
left_chars.len(),
left.shape().rank()
)));
}
if right_chars.len() != right.shape().rank() {
return Err(TensorError::invalid_argument(format!(
"Right subscript length ({}) does not match tensor rank ({})",
right_chars.len(),
right.shape().rank()
)));
}
if left_chars.len() == 2
&& right_chars.len() == 2
&& output_chars.len() == 2
&& left_chars[1] == right_chars[0]
&& left_chars[0] == output_chars[0]
&& right_chars[1] == output_chars[1]
{
return crate::ops::matmul(left, right);
}
if left_subscript == right_subscript && left_subscript == output_subscript {
return left.mul(right);
}
if left_subscript == right_subscript && output_subscript.is_empty() {
let elementwise = left.mul(right)?;
return crate::ops::sum(&elementwise, None, false);
}
Err(TensorError::invalid_argument(format!(
"Unsupported binary einsum: {left_subscript},{right_subscript} -> {output_subscript}"
)))
}
pub fn extract_diagonal<T>(tensor: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone + Default + Zero + Send + Sync + 'static,
{
let shape = tensor.shape().dims();
if shape.len() != 2 {
return Err(TensorError::invalid_argument(
"Diagonal extraction requires 2D tensor".to_string(),
));
}
let min_dim = shape[0].min(shape[1]);
let mut diagonal_data = Vec::with_capacity(min_dim);
for i in 0..min_dim {
if let Some(val) = tensor.get(&[i, i]) {
diagonal_data.push(val);
} else {
return Err(TensorError::invalid_argument(
"Failed to extract diagonal element".to_string(),
));
}
}
Tensor::from_vec(diagonal_data, &[min_dim])
}