use crate::tensor::Tensor;
use anyhow::{Result, anyhow};
pub trait ArithmeticOps {
fn add(&self, other: &Tensor) -> Result<Tensor>;
fn sub(&self, other: &Tensor) -> Result<Tensor>;
fn mul(&self, other: &Tensor) -> Result<Tensor>;
fn div(&self, other: &Tensor) -> Result<Tensor>;
fn add_scalar(&self, scalar: f32) -> Result<Tensor>;
fn sub_scalar(&self, scalar: f32) -> Result<Tensor>;
fn mul_scalar(&self, scalar: f32) -> Result<Tensor>;
fn div_scalar(&self, scalar: f32) -> Result<Tensor>;
fn neg(&self) -> Result<Tensor>;
fn abs(&self) -> Result<Tensor>;
fn pow(&self, exponent: f32) -> Result<Tensor>;
fn sqrt(&self) -> Result<Tensor>;
fn exp(&self) -> Result<Tensor>;
fn log(&self) -> Result<Tensor>;
}
impl ArithmeticOps for Tensor {
fn add(&self, other: &Tensor) -> Result<Tensor> {
if !self.is_broadcastable_with(other) {
return Err(anyhow!(
"Cannot broadcast tensors with shapes {:?} and {:?}",
self.shape(),
other.shape()
));
}
let result_candle = self.candle_tensor().broadcast_add(other.candle_tensor())?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn sub(&self, other: &Tensor) -> Result<Tensor> {
if !self.is_broadcastable_with(other) {
return Err(anyhow!(
"Cannot broadcast tensors with shapes {:?} and {:?}",
self.shape(),
other.shape()
));
}
let result_candle = self.candle_tensor().broadcast_sub(other.candle_tensor())?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn mul(&self, other: &Tensor) -> Result<Tensor> {
if !self.is_broadcastable_with(other) {
return Err(anyhow!(
"Cannot broadcast tensors with shapes {:?} and {:?}",
self.shape(),
other.shape()
));
}
let result_candle = self.candle_tensor().broadcast_mul(other.candle_tensor())?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn div(&self, other: &Tensor) -> Result<Tensor> {
if !self.is_broadcastable_with(other) {
return Err(anyhow!(
"Cannot broadcast tensors with shapes {:?} and {:?}",
self.shape(),
other.shape()
));
}
let result_candle = self.candle_tensor().broadcast_div(other.candle_tensor())?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn add_scalar(&self, scalar: f32) -> Result<Tensor> {
let result_candle = (self.candle_tensor() + scalar as f64)?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn sub_scalar(&self, scalar: f32) -> Result<Tensor> {
let result_candle = (self.candle_tensor() - scalar as f64)?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn mul_scalar(&self, scalar: f32) -> Result<Tensor> {
let result_candle = (self.candle_tensor() * scalar as f64)?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn div_scalar(&self, scalar: f32) -> Result<Tensor> {
if scalar == 0.0 {
return Err(anyhow!("Division by zero"));
}
let result_candle = (self.candle_tensor() / scalar as f64)?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn neg(&self) -> Result<Tensor> {
let result_candle = self.candle_tensor().neg()?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn abs(&self) -> Result<Tensor> {
let result_candle = self.candle_tensor().abs()?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn pow(&self, exponent: f32) -> Result<Tensor> {
let result_candle = self.candle_tensor().powf(exponent as f64)?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn sqrt(&self) -> Result<Tensor> {
let result_candle = self.candle_tensor().sqrt()?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn exp(&self) -> Result<Tensor> {
let result_candle = self.candle_tensor().exp()?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
fn log(&self) -> Result<Tensor> {
let result_candle = self.candle_tensor().log()?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
}
impl Tensor {
pub fn clamp(&self, min: f32, max: f32) -> Result<Tensor> {
if min > max {
return Err(anyhow!(
"Min value {} is greater than max value {}",
min,
max
));
}
let result_candle = self.candle_tensor().clamp(min as f64, max as f64)?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
pub fn relu(&self) -> Result<Tensor> {
self.clamp(0.0, f32::INFINITY)
}
pub fn sigmoid(&self) -> Result<Tensor> {
let neg_x = self.neg()?;
let exp_neg_x = neg_x.exp()?;
let one = Tensor::ones(vec![1], self.dtype(), self.layout())?;
let one_plus_exp = one.add(&exp_neg_x)?;
one.div(&one_plus_exp)
}
pub fn tanh(&self) -> Result<Tensor> {
let result_candle = self.candle_tensor().tanh()?;
Ok(Tensor::from_candle(
result_candle,
self.dtype(),
self.layout(),
))
}
pub fn gelu(&self) -> Result<Tensor> {
let x = self;
let x_cubed = x.pow(3.0)?;
let term1 = x_cubed.mul_scalar(0.044715)?;
let term2 = x.add(&term1)?;
let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
let term3 = term2.mul_scalar(sqrt_2_over_pi)?;
let tanh_term = term3.tanh()?;
let one = Tensor::ones(vec![1], self.dtype(), self.layout())?;
let one_plus_tanh = one.add(&tanh_term)?;
let half = Tensor::from_data(vec![0.5], vec![1], self.dtype(), self.layout())?;
let result = x.mul(&half)?.mul(&one_plus_tanh)?;
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{DataType, TensorLayout};
#[test]
fn test_arithmetic_operations() -> Result<()> {
let a = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
DataType::F32,
TensorLayout::RowMajor,
)?;
let b = Tensor::from_data(
vec![2.0, 1.0, 1.0, 2.0],
vec![2, 2],
DataType::F32,
TensorLayout::RowMajor,
)?;
let sum = a.add(&b)?;
let sum_data = sum.to_vec()?;
assert_eq!(sum_data, vec![3.0, 3.0, 4.0, 6.0]);
let diff = a.sub(&b)?;
let diff_data = diff.to_vec()?;
assert_eq!(diff_data, vec![-1.0, 1.0, 2.0, 2.0]);
let product = a.mul(&b)?;
let product_data = product.to_vec()?;
assert_eq!(product_data, vec![2.0, 2.0, 3.0, 8.0]);
let quotient = a.div(&b)?;
let quotient_data = quotient.to_vec()?;
assert_eq!(quotient_data, vec![0.5, 2.0, 3.0, 2.0]);
Ok(())
}
#[test]
fn test_scalar_operations() -> Result<()> {
let a = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
DataType::F32,
TensorLayout::RowMajor,
)?;
let sum = a.add_scalar(5.0)?;
let sum_data = sum.to_vec()?;
assert_eq!(sum_data, vec![6.0, 7.0, 8.0, 9.0]);
let product = a.mul_scalar(2.0)?;
let product_data = product.to_vec()?;
assert_eq!(product_data, vec![2.0, 4.0, 6.0, 8.0]);
Ok(())
}
#[test]
fn test_broadcasting() -> Result<()> {
let a = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
DataType::F32,
TensorLayout::RowMajor,
)?;
let b = Tensor::from_data(
vec![10.0, 20.0, 30.0],
vec![3],
DataType::F32,
TensorLayout::RowMajor,
)?;
let sum = a.add(&b)?;
let sum_data = sum.to_vec()?;
assert_eq!(sum_data, vec![11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
Ok(())
}
#[test]
fn test_activation_functions() -> Result<()> {
let a = Tensor::from_data(
vec![-1.0, 0.0, 1.0, 2.0],
vec![4],
DataType::F32,
TensorLayout::RowMajor,
)?;
let relu_result = a.relu()?;
let relu_data = relu_result.to_vec()?;
assert_eq!(relu_data, vec![0.0, 0.0, 1.0, 2.0]);
let abs_result = a.abs()?;
let abs_data = abs_result.to_vec()?;
assert_eq!(abs_data, vec![1.0, 0.0, 1.0, 2.0]);
let neg_result = a.neg()?;
let neg_data = neg_result.to_vec()?;
assert_eq!(neg_data, vec![1.0, 0.0, -1.0, -2.0]);
Ok(())
}
#[test]
fn test_sigmoid() -> Result<()> {
let x = Tensor::from_data(vec![0.0], vec![1], DataType::F32, TensorLayout::RowMajor)?;
let sigmoid_result = x.sigmoid()?;
let sigmoid_data = sigmoid_result.to_vec()?;
assert!((sigmoid_data[0] - 0.5).abs() < 1e-6);
Ok(())
}
#[test]
fn test_error_handling() {
let a = Tensor::from_data(
vec![1.0, 2.0],
vec![2],
DataType::F32,
TensorLayout::RowMajor,
)
.unwrap();
let b = Tensor::from_data(
vec![1.0, 2.0, 3.0],
vec![3],
DataType::F32,
TensorLayout::RowMajor,
)
.unwrap();
assert!(a.add(&b).is_err());
assert!(a.div_scalar(0.0).is_err());
assert!(a.clamp(5.0, 1.0).is_err());
}
}