use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
use super::core::NormOrd;
pub fn chain_matmul(matrices: &[Tensor]) -> TorshResult<Tensor> {
if matrices.is_empty() {
return Err(TorshError::invalid_argument_with_context(
"chain_matmul requires at least one matrix",
"chain_matmul",
));
}
if matrices.len() == 1 {
return Ok(matrices[0].clone());
}
for (i, mat) in matrices.iter().enumerate() {
if mat.shape().ndim() != 2 {
return Err(TorshError::invalid_argument_with_context(
&format!("Matrix {} is not 2D", i),
"chain_matmul",
));
}
}
for i in 0..matrices.len() - 1 {
let m1_cols = matrices[i].shape().dims()[1];
let m2_rows = matrices[i + 1].shape().dims()[0];
if m1_cols != m2_rows {
return Err(TorshError::invalid_argument_with_context(
&format!(
"Matrix dimensions incompatible: [{} x {}] @ [{} x {}]",
matrices[i].shape().dims()[0],
m1_cols,
m2_rows,
matrices[i + 1].shape().dims()[1]
),
"chain_matmul",
));
}
}
let mut result = matrices[0].clone();
for i in 1..matrices.len() {
result = result.matmul(&matrices[i])?;
}
Ok(result)
}
pub fn norm(
tensor: &Tensor,
ord: Option<NormOrd>,
dim: Option<Vec<isize>>,
keepdim: bool,
) -> TorshResult<Tensor> {
let ord = ord.unwrap_or(NormOrd::Fro);
match ord {
NormOrd::Fro => {
let squared = tensor.pow(2.0)?;
let sum = if let Some(dims) = dim {
let mut result = squared;
for &d in dims.iter() {
result = result.sum_dim(&[d as i32], keepdim)?;
}
result
} else {
squared.sum()?
};
sum.sqrt()
}
NormOrd::Nuclear => {
crate::reduction::norm_nuclear(tensor)
}
NormOrd::Inf => {
if let Some(dims) = dim {
let abs_tensor = tensor.abs()?;
let mut result = abs_tensor;
for &d in dims.iter() {
result = result.sum_dim(&[d as i32], keepdim)?;
}
result.max(None, false)
} else {
tensor.abs()?.max(None, false)
}
}
NormOrd::NegInf => {
if let Some(dims) = dim {
let abs_tensor = tensor.abs()?;
let mut result = abs_tensor;
for &d in dims.iter() {
result = result.sum_dim(&[d as i32], keepdim)?;
}
result.min()
} else {
tensor.abs()?.min()
}
}
NormOrd::P(p) => {
let abs_p = tensor.abs()?.pow(p)?;
let sum = if let Some(dims) = dim {
let mut result = abs_p;
for &d in dims.iter() {
result = result.sum_dim(&[d as i32], keepdim)?;
}
result
} else {
abs_p.sum()?
};
sum.pow(1.0 / p)
}
}
}
pub fn bmm(input: &Tensor, mat2: &Tensor) -> TorshResult<Tensor> {
if input.shape().ndim() != 3 || mat2.shape().ndim() != 3 {
return Err(TorshError::invalid_argument_with_context(
"Batch matrix multiplication requires 3D tensors (batch, rows, cols)",
"bmm",
));
}
let input_binding = input.shape();
let input_dims = input_binding.dims();
let mat2_binding = mat2.shape();
let mat2_dims = mat2_binding.dims();
if input_dims[0] != mat2_dims[0] {
return Err(TorshError::invalid_argument_with_context(
&format!(
"Batch sizes don't match: {} vs {}",
input_dims[0], mat2_dims[0]
),
"bmm",
));
}
if input_dims[2] != mat2_dims[1] {
return Err(TorshError::invalid_argument_with_context(
&format!(
"Matrix dimensions incompatible: [{} x {}] @ [{} x {}]",
input_dims[1], input_dims[2], mat2_dims[1], mat2_dims[2]
),
"bmm",
));
}
let batch_size = input_dims[0];
let out_rows = input_dims[1];
let out_cols = mat2_dims[2];
let mut result_data = vec![0.0f32; batch_size * out_rows * out_cols];
let input_data = input.to_vec()?;
let mat2_data = mat2.to_vec()?;
for b in 0..batch_size {
for i in 0..out_rows {
for j in 0..out_cols {
let mut sum = 0.0f32;
for k in 0..input_dims[2] {
let input_idx = b * input_dims[1] * input_dims[2] + i * input_dims[2] + k;
let mat2_idx = b * mat2_dims[1] * mat2_dims[2] + k * mat2_dims[2] + j;
sum += input_data[input_idx] * mat2_data[mat2_idx];
}
let result_idx = b * out_rows * out_cols + i * out_cols + j;
result_data[result_idx] = sum;
}
}
}
Tensor::from_data(
result_data,
vec![batch_size, out_rows, out_cols],
input.device(),
)
}
pub fn baddbmm(
input: &Tensor,
batch1: &Tensor,
batch2: &Tensor,
beta: f32,
alpha: f32,
) -> TorshResult<Tensor> {
let mm_result = bmm(batch1, batch2)?;
let scaled_input = input.mul_scalar(beta)?;
let scaled_mm = mm_result.mul_scalar(alpha)?;
scaled_input.add_op(&scaled_mm)
}