use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::{creation::zeros, Tensor};
pub fn cdist(x1: &Tensor, x2: &Tensor, p: f32) -> TorshResult<Tensor> {
if x1.shape().ndim() != 2 || x2.shape().ndim() != 2 {
return Err(TorshError::invalid_argument_with_context(
"Input tensors must be 2-dimensional",
"cdist",
));
}
if x1.shape().dims()[1] != x2.shape().dims()[1] {
return Err(TorshError::invalid_argument_with_context(
"Input tensors must have the same number of features",
"cdist",
));
}
let n1 = x1.shape().dims()[0];
let n2 = x2.shape().dims()[0];
let d = x1.shape().dims()[1];
let distances = zeros(&[n1, n2])?;
for i in 0..n1 {
for j in 0..n2 {
let mut dist = 0.0;
for k in 0..d {
let diff = x1.get(&[i, k])? - x2.get(&[j, k])?;
if p == 2.0 {
dist += diff * diff;
} else if p == 1.0 {
dist += diff.abs();
} else if p == f32::INFINITY {
dist = dist.max(diff.abs());
} else {
dist += diff.abs().powf(p);
}
}
let final_dist = if p == 2.0 {
dist.sqrt()
} else if p == 1.0 || p == f32::INFINITY {
dist
} else {
dist.powf(1.0 / p)
};
distances.set(&[i, j], final_dist)?;
}
}
Ok(distances)
}
pub fn einsum(equation: &str, operands: &[Tensor]) -> TorshResult<Tensor> {
let parts: Vec<&str> = equation.split("->").collect();
if parts.len() != 2 {
return Err(TorshError::invalid_argument_with_context(
"Invalid einsum equation format",
"einsum",
));
}
let input_specs = parts[0].split(',').collect::<Vec<_>>();
let _output_spec = parts[1].trim();
if input_specs.len() != operands.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"Expected {} operands, got {}",
input_specs.len(),
operands.len()
),
"einsum",
));
}
match equation {
"ij,jk->ik" if operands.len() == 2 => {
operands[0].matmul(&operands[1])
}
"ii->i" if operands.len() == 1 => {
extract_diagonal(&operands[0])
}
"ij->ji" if operands.len() == 1 => {
operands[0].transpose(-2, -1)
}
"ij->" if operands.len() == 1 => {
operands[0].sum()
}
"ij->i" if operands.len() == 1 => {
operands[0].sum_dim(&[1], false)
}
"ij->j" if operands.len() == 1 => {
operands[0].sum_dim(&[0], false)
}
"ij,ij->ij" if operands.len() == 2 => {
operands[0].mul(&operands[1])
}
"ij,ij->" if operands.len() == 2 => {
let prod = operands[0].mul(&operands[1])?;
prod.sum()
}
"i,i->" if operands.len() == 2 => {
let prod = operands[0].mul(&operands[1])?;
prod.sum()
}
"i,i->i" if operands.len() == 2 => {
operands[0].mul(&operands[1])
}
"i,j->ij" if operands.len() == 2 => {
let a = operands[0].view(&[-1, 1])?;
let b = operands[1].view(&[1, -1])?;
a.matmul(&b)
}
"i->" if operands.len() == 1 => {
operands[0].sum()
}
"ijk,ikl->ijl" if operands.len() == 2 => {
operands[0].matmul(&operands[1])
}
"ijk->ikj" if operands.len() == 1 => {
operands[0].transpose(-2, -1)
}
"ijk->ij" if operands.len() == 1 => {
operands[0].sum_dim(&[2], false)
}
"ijk->jk" if operands.len() == 1 => {
operands[0].sum_dim(&[0], false)
}
"ijk->ik" if operands.len() == 1 => {
operands[0].sum_dim(&[1], false)
}
"ijk->" if operands.len() == 1 => {
operands[0].sum()
}
"ijk,ijk->ijk" if operands.len() == 2 => {
operands[0].mul(&operands[1])
}
"ijk,ijk->" if operands.len() == 2 => {
let prod = operands[0].mul(&operands[1])?;
prod.sum()
}
"ii->" if operands.len() == 1 => {
let diag = extract_diagonal(&operands[0])?;
diag.sum()
}
"ijk->jik" if operands.len() == 1 => {
operands[0].permute(&[1, 0, 2])
}
"ijk->kji" if operands.len() == 1 => {
operands[0].permute(&[2, 1, 0])
}
"ijkl->ikjl" if operands.len() == 1 => {
operands[0].permute(&[0, 2, 1, 3])
}
_ => {
Err(TorshError::Other(format!(
"Einsum equation '{}' not yet implemented",
equation
)))
}
}
}
fn extract_diagonal(tensor: &Tensor) -> TorshResult<Tensor> {
let shape = tensor.shape();
if shape.ndim() != 2 {
return Err(TorshError::invalid_argument_with_context(
"Diagonal extraction requires 2D tensor",
"extract_diagonal",
));
}
let size = shape.dims()[0].min(shape.dims()[1]);
let diag = zeros(&[size])?;
for i in 0..size {
diag.set(&[i], tensor.get(&[i, i])?)?;
}
Ok(diag)
}
pub fn abs(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.abs()
}
pub fn exp(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.exp()
}
pub fn log(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.ln()
}
pub fn log2(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.log2()
}
pub fn log10(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.log10()
}
pub fn sin(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.sin()
}
pub fn cos(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.cos()
}
pub fn tan(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.tan()
}
pub fn asin(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.asin()
}
pub fn acos(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.acos()
}
pub fn atan(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.atan()
}
pub fn sinh(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.sinh()
}
pub fn cosh(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.cosh()
}
pub fn tanh(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.tanh()
}
pub fn sqrt(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.sqrt()
}
pub fn rsqrt(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.rsqrt()
}
pub fn square(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.square()
}
pub fn reciprocal(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.reciprocal()
}
pub fn pow(tensor: &Tensor, exponent: f32) -> TorshResult<Tensor> {
tensor.pow(exponent)
}
pub fn pow_tensor(base: &Tensor, exponent: &Tensor) -> TorshResult<Tensor> {
base.pow_tensor(exponent)
}
pub fn floor(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.floor()
}
pub fn ceil(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.ceil()
}
pub fn round(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.round()
}
pub fn trunc(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.trunc()
}
pub fn frac(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.fract()
}
pub fn sign(tensor: &Tensor) -> TorshResult<Tensor> {
tensor.sign()
}
pub fn eq(tensor1: &Tensor, tensor2: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor1.eq(tensor2)
}
pub fn ne(tensor1: &Tensor, tensor2: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor1.ne(tensor2)
}
pub fn lt(tensor1: &Tensor, tensor2: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor1.lt(tensor2)
}
pub fn le(tensor1: &Tensor, tensor2: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor1.le(tensor2)
}
pub fn gt(tensor1: &Tensor, tensor2: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor1.gt(tensor2)
}
pub fn ge(tensor1: &Tensor, tensor2: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor1.ge(tensor2)
}
pub fn eq_scalar(tensor: &Tensor, scalar: f32) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor.eq_scalar(scalar)
}
pub fn ne_scalar(tensor: &Tensor, scalar: f32) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor.ne_scalar(scalar)
}
pub fn lt_scalar(tensor: &Tensor, scalar: f32) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor.lt_scalar(scalar)
}
pub fn le_scalar(tensor: &Tensor, scalar: f32) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor.le_scalar(scalar)
}
pub fn gt_scalar(tensor: &Tensor, scalar: f32) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor.gt_scalar(scalar)
}
pub fn ge_scalar(tensor: &Tensor, scalar: f32) -> TorshResult<torsh_tensor::Tensor<bool>> {
tensor.ge_scalar(scalar)
}
pub fn logical_and(tensor1: &Tensor, tensor2: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
let bool1 = tensor1.ne_scalar(0.0)?;
let bool2 = tensor2.ne_scalar(0.0)?;
bool1.logical_and(&bool2)
}
pub fn logical_or(tensor1: &Tensor, tensor2: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
let bool1 = tensor1.ne_scalar(0.0)?;
let bool2 = tensor2.ne_scalar(0.0)?;
bool1.logical_or(&bool2)
}
pub fn logical_xor(tensor1: &Tensor, tensor2: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
let bool1 = tensor1.ne_scalar(0.0)?;
let bool2 = tensor2.ne_scalar(0.0)?;
bool1.logical_xor(&bool2)
}
pub fn logical_not(tensor: &Tensor) -> TorshResult<torsh_tensor::Tensor<bool>> {
let bool_tensor = tensor.eq_scalar(0.0)?;
Ok(bool_tensor)
}
pub fn minimum(tensor1: &Tensor, tensor2: &Tensor) -> TorshResult<Tensor> {
tensor1.minimum(tensor2)
}
pub fn maximum(tensor1: &Tensor, tensor2: &Tensor) -> TorshResult<Tensor> {
tensor1.maximum(tensor2)
}
pub fn minimum_scalar(tensor: &Tensor, scalar: f32) -> TorshResult<Tensor> {
let scalar_tensor = tensor.ones_like()?.mul_scalar(scalar)?;
tensor.minimum(&scalar_tensor)
}
pub fn maximum_scalar(tensor: &Tensor, scalar: f32) -> TorshResult<Tensor> {
let scalar_tensor = tensor.ones_like()?.mul_scalar(scalar)?;
tensor.maximum(&scalar_tensor)
}
pub fn clamp(tensor: &Tensor, min: f32, max: f32) -> TorshResult<Tensor> {
let min_clamped = maximum_scalar(tensor, min)?;
minimum_scalar(&min_clamped, max)
}
pub fn clamp_min(tensor: &Tensor, min: f32) -> TorshResult<Tensor> {
maximum_scalar(tensor, min)
}
pub fn clamp_max(tensor: &Tensor, max: f32) -> TorshResult<Tensor> {
minimum_scalar(tensor, max)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::random_ops::randn;
use torsh_tensor::creation::from_vec;
#[test]
fn test_cdist() {
let x1 = from_vec(
vec![0.0f32, 0.0, 1.0, 0.0, 0.0, 1.0],
&[3, 2],
torsh_core::device::DeviceType::Cpu,
)
.unwrap();
let x2 = from_vec(
vec![1.0f32, 1.0, 2.0, 2.0],
&[2, 2],
torsh_core::device::DeviceType::Cpu,
)
.unwrap();
let distances = cdist(&x1, &x2, 2.0).unwrap();
assert_eq!(distances.shape().dims(), &[3, 2]);
}
#[test]
fn test_einsum_matmul() {
let a = randn(&[3, 4], None, None, None).unwrap();
let b = randn(&[4, 5], None, None, None).unwrap();
let result = einsum("ij,jk->ik", &[a.clone(), b.clone()]).unwrap();
let expected = a.matmul(&b).unwrap();
assert_eq!(result.shape(), expected.shape());
}
}