use torsh_core::Result as TorshResult;
use torsh_tensor::Tensor;
pub fn tensordot(a: &Tensor, b: &Tensor, axes: TensorDotAxes) -> TorshResult<Tensor> {
match axes {
TensorDotAxes::Int(n) => {
let a_shape = a.shape();
let b_shape = b.shape();
if n > a_shape.ndim() || n > b_shape.ndim() {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Number of axes to contract exceeds tensor dimensions",
"tensordot",
));
}
for i in 0..n {
let a_axis = a_shape.ndim() - n + i;
let b_axis = i;
if a_shape.dims()[a_axis] != b_shape.dims()[b_axis] {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Axes to contract must have the same size",
"tensordot",
));
}
}
let a_free_dims = a_shape.ndim() - n;
let _b_free_dims = b_shape.ndim() - n;
let a_free_size: usize = a_shape.dims()[..a_free_dims].iter().product();
let b_free_size: usize = b_shape.dims()[n..].iter().product();
let contract_size: usize = a_shape.dims()[a_free_dims..].iter().product();
let a_reshaped = a.view(&[a_free_size as i32, contract_size as i32])?;
let b_reshaped = b.view(&[contract_size as i32, b_free_size as i32])?;
let result = a_reshaped.matmul(&b_reshaped)?;
let mut result_shape: Vec<usize> = Vec::new();
result_shape.extend(&a_shape.dims()[..a_free_dims]);
result_shape.extend(&b_shape.dims()[n..]);
if result_shape.is_empty() {
result.view(&[])
} else {
let result_shape_i32: Vec<i32> = result_shape.iter().map(|&x| x as i32).collect();
result.view(&result_shape_i32)
}
}
TensorDotAxes::Explicit(a_axes, b_axes) => {
if a_axes.len() != b_axes.len() {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Number of axes to contract must be equal",
"tensordot",
));
}
let a_shape = a.shape();
let b_shape = b.shape();
for (&a_axis, &b_axis) in a_axes.iter().zip(b_axes.iter()) {
if a_axis >= a_shape.ndim() || b_axis >= b_shape.ndim() {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Axis index out of range",
"tensordot",
));
}
if a_shape.dims()[a_axis] != b_shape.dims()[b_axis] {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Contracted axes must have the same size",
"tensordot",
));
}
}
let contract_size: usize = a_axes.iter().map(|&axis| a_shape.dims()[axis]).product();
let a_free_dims: Vec<usize> = (0..a_shape.ndim())
.filter(|i| !a_axes.contains(i))
.collect();
let b_free_dims: Vec<usize> = (0..b_shape.ndim())
.filter(|i| !b_axes.contains(i))
.collect();
let a_free_size: usize = a_free_dims.iter().map(|&i| a_shape.dims()[i]).product();
let b_free_size: usize = b_free_dims.iter().map(|&i| b_shape.dims()[i]).product();
let a_reshaped = a.view(&[a_free_size as i32, contract_size as i32])?;
let b_reshaped = b.view(&[contract_size as i32, b_free_size as i32])?;
let result = a_reshaped.matmul(&b_reshaped)?;
let mut result_shape: Vec<usize> = Vec::new();
for &dim in &a_free_dims {
result_shape.push(a_shape.dims()[dim]);
}
for &dim in &b_free_dims {
result_shape.push(b_shape.dims()[dim]);
}
if result_shape.is_empty() {
result.view(&[])
} else {
let result_shape_i32: Vec<i32> = result_shape.iter().map(|&x| x as i32).collect();
result.view(&result_shape_i32)
}
}
TensorDotAxes::Arrays(a_axes, b_axes) => {
tensordot(a, b, TensorDotAxes::Explicit(a_axes, b_axes))
}
}
}
#[derive(Debug, Clone)]
pub enum TensorDotAxes {
Int(usize),
Explicit(Vec<usize>, Vec<usize>),
Arrays(Vec<usize>, Vec<usize>),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::random_ops::randn;
#[test]
fn test_tensordot_simple_matrix_multiplication() -> TorshResult<()> {
let a = randn(&[3, 4], None, None, None)?;
let b = randn(&[4, 5], None, None, None)?;
let result = tensordot(&a, &b, TensorDotAxes::Int(1))?;
assert_eq!(result.shape().dims(), &[3, 5]);
Ok(())
}
#[test]
fn test_tensordot_with_explicit_axes() -> TorshResult<()> {
let a = randn(&[3, 4], None, None, None)?;
let b = randn(&[4, 5], None, None, None)?;
let result = tensordot(&a, &b, TensorDotAxes::Explicit(vec![1], vec![0]))?;
assert_eq!(result.shape().dims(), &[3, 5]);
Ok(())
}
#[test]
fn test_tensordot_with_arrays() -> TorshResult<()> {
let a = randn(&[3, 4], None, None, None)?;
let b = randn(&[4, 5], None, None, None)?;
let result = tensordot(&a, &b, TensorDotAxes::Arrays(vec![1], vec![0]))?;
assert_eq!(result.shape().dims(), &[3, 5]);
Ok(())
}
#[test]
fn test_tensordot_higher_order() -> TorshResult<()> {
let a = randn(&[2, 3, 4, 5], None, None, None)?;
let b = randn(&[4, 5, 6], None, None, None)?;
let result = tensordot(&a, &b, TensorDotAxes::Int(2))?;
assert_eq!(result.shape().dims(), &[2, 3, 6]);
Ok(())
}
#[test]
fn test_tensordot_multiple_explicit_axes() -> TorshResult<()> {
let a = randn(&[2, 3, 4, 5], None, None, None)?;
let b = randn(&[4, 6, 3, 7], None, None, None)?;
let result = tensordot(&a, &b, TensorDotAxes::Explicit(vec![2, 1], vec![0, 2]))?;
assert_eq!(result.shape().dims(), &[2, 5, 6, 7]);
Ok(())
}
#[test]
fn test_tensordot_scalar_result() -> TorshResult<()> {
let a = randn(&[3, 4], None, None, None)?;
let b = randn(&[3, 4], None, None, None)?;
let result = tensordot(&a, &b, TensorDotAxes::Explicit(vec![0, 1], vec![0, 1]))?;
assert_eq!(result.shape().ndim(), 0);
Ok(())
}
#[test]
fn test_tensordot_error_mismatched_axes_lengths() {
let a = randn(&[3, 4], None, None, None).expect("randn should succeed");
let b = randn(&[4, 5], None, None, None).expect("randn should succeed");
let result = tensordot(&a, &b, TensorDotAxes::Explicit(vec![1], vec![0, 1]));
assert!(result.is_err());
}
#[test]
fn test_tensordot_error_axis_out_of_bounds() {
let a = randn(&[3, 4], None, None, None).expect("randn should succeed");
let b = randn(&[4, 5], None, None, None).expect("randn should succeed");
let result = tensordot(&a, &b, TensorDotAxes::Explicit(vec![2], vec![0]));
assert!(result.is_err());
}
#[test]
fn test_tensordot_error_incompatible_sizes() {
let a = randn(&[3, 4], None, None, None).expect("randn should succeed");
let b = randn(&[5, 6], None, None, None).expect("randn should succeed");
let result = tensordot(&a, &b, TensorDotAxes::Explicit(vec![1], vec![0]));
assert!(result.is_err());
}
#[test]
fn test_tensordot_error_too_many_axes() {
let a = randn(&[3, 4], None, None, None).expect("randn should succeed");
let b = randn(&[4, 5], None, None, None).expect("randn should succeed");
let result = tensordot(&a, &b, TensorDotAxes::Int(3));
assert!(result.is_err());
}
#[test]
fn test_tensordot_batch_operations() -> TorshResult<()> {
let batch_a = randn(&[10, 32, 64], None, None, None)?;
let batch_b = randn(&[10, 64, 128], None, None, None)?;
let result = tensordot(
&batch_a,
&batch_b,
TensorDotAxes::Explicit(vec![2], vec![1]),
)?;
assert_eq!(result.shape().dims(), &[10, 32, 10, 128]);
Ok(())
}
#[test]
fn test_tensordot_edge_case_1d_tensors() -> TorshResult<()> {
let a = randn(&[5], None, None, None)?;
let b = randn(&[5], None, None, None)?;
let result = tensordot(&a, &b, TensorDotAxes::Int(1))?;
assert_eq!(result.shape().ndim(), 0); Ok(())
}
}