use super::utils::{batch_transpose, cache_friendly_trace, compute_outer_product};
use crate::{Result, Tensor};
use scirs2_core::numeric::{One, Zero};
#[cfg(feature = "gpu")]
use super::gpu::{
gpu_einsum_batched_matmul, gpu_einsum_diagonal, gpu_einsum_matmul, gpu_einsum_outer_product,
gpu_einsum_trace, gpu_einsum_transpose, gpu_einsum_vector_dot,
};
pub fn try_optimize_common_patterns<T>(
equation: &str,
operands: &[&Tensor<T>],
) -> Option<Result<Tensor<T>>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
#[cfg(feature = "gpu")]
let is_gpu = operands
.iter()
.any(|op| matches!(&op.storage, crate::tensor::TensorStorage::Gpu(_)));
#[cfg(not(feature = "gpu"))]
let is_gpu = false;
if is_gpu {
return try_optimize_gpu_patterns(equation, operands);
}
match equation {
"bij,bjk->bik" | "bik,bkj->bij" if operands.len() == 2 => {
Some(crate::ops::matmul(operands[0], operands[1]))
}
"ii->" if operands.len() == 1 => Some(cache_friendly_trace(operands[0])),
"i,i->" if operands.len() == 2 => Some(crate::ops::dot(operands[0], operands[1])),
"i,j->ij" if operands.len() == 2 => Some(compute_outer_product(operands[0], operands[1])),
"bij->bji" if operands.len() == 1 => Some(batch_transpose(operands[0])),
_ => None,
}
}
#[cfg(feature = "gpu")]
pub fn try_optimize_gpu_patterns<T>(
equation: &str,
operands: &[&Tensor<T>],
) -> Option<Result<Tensor<T>>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
match equation {
"ij,jk->ik" if operands.len() == 2 => Some(gpu_einsum_matmul(operands[0], operands[1])),
"bij,bjk->bik" if operands.len() == 2 => {
Some(gpu_einsum_batched_matmul(operands[0], operands[1]))
}
"ij->ji" if operands.len() == 1 => Some(gpu_einsum_transpose(operands[0])),
"ii->i" if operands.len() == 1 => Some(gpu_einsum_diagonal(operands[0])),
eq if operands.len() == 2 && eq.starts_with("ij,ij->ij") => {
Some(operands[0].mul(operands[1]))
}
eq if operands.len() == 2 && eq.starts_with("ij,ij->") && eq.ends_with("->") => {
let elementwise = operands[0].mul(operands[1]);
Some(elementwise.and_then(|t| crate::ops::sum(&t, None, false)))
}
"i,j->ij" if operands.len() == 2 => {
Some(gpu_einsum_outer_product(operands[0], operands[1]))
}
"i,i->" if operands.len() == 2 => Some(gpu_einsum_vector_dot(operands[0], operands[1])),
"ii->" if operands.len() == 1 => Some(gpu_einsum_trace(operands[0])),
_ => None,
}
}
#[cfg(not(feature = "gpu"))]
pub fn try_optimize_gpu_patterns<T>(
_equation: &str,
_operands: &[&Tensor<T>],
) -> Option<Result<Tensor<T>>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
None
}