use std::ops::{Add, Mul};
use crate::tensor::Tensor;
use crate::errors::{EtensorError, EtensorResult};
use crate::autograd::tape::record;
use crate::autograd::nodes::{
AddBackward, MulBackward, MatMulBackward,
SumAllBackward, ReluBackward, SigmoidBackward
};
use crate::backends::traits::Backend;
use crate::backends::cpu::CpuBackend;
pub struct Dispatcher;
impl Dispatcher {
pub fn add(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
if a.shape.dims != b.shape.dims {
return Err(EtensorError::ShapeMismatch {
expected: a.shape.dims.clone(),
got: b.shape.dims.clone(),
});
}
if a.device != b.device {
return Err(EtensorError::DeviceMismatch {
expected: a.device.to_string(), got: b.device.to_string(),
});
}
if a.dtype != b.dtype {
return Err(EtensorError::DTypeMismatch {
expected: a.dtype.to_string(), got: b.dtype.to_string(),
});
}
let mut out_tensor = match a.device {
crate::device::Device::Cpu => CpuBackend::add(a, b)?,
#[cfg(feature = "cuda-native")]
crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
#[cfg(feature = "torch")]
crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
};
let requires_grad = a.requires_grad || b.requires_grad;
out_tensor.requires_grad = requires_grad;
if requires_grad {
record(Box::new(AddBackward {
output_id: out_tensor.id,
lhs_id: if a.requires_grad { Some(a.id) } else { None },
rhs_id: if b.requires_grad { Some(b.id) } else { None },
}));
}
Ok(out_tensor)
}
pub fn mul(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
if a.shape.dims != b.shape.dims {
return Err(EtensorError::ShapeMismatch {
expected: a.shape.dims.clone(), got: b.shape.dims.clone(),
});
}
if a.device != b.device {
return Err(EtensorError::DeviceMismatch {
expected: a.device.to_string(), got: b.device.to_string(),
});
}
if a.dtype != b.dtype {
return Err(EtensorError::DTypeMismatch {
expected: a.dtype.to_string(), got: b.dtype.to_string(),
});
}
let mut out_tensor = match a.device {
crate::device::Device::Cpu => CpuBackend::mul(a, b)?,
#[cfg(feature = "cuda-native")]
crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
#[cfg(feature = "torch")]
crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
};
let requires_grad = a.requires_grad || b.requires_grad;
out_tensor.requires_grad = requires_grad;
if requires_grad {
record(Box::new(MulBackward {
output_id: out_tensor.id,
lhs_id: if a.requires_grad { Some(a.id) } else { None },
rhs_id: if b.requires_grad { Some(b.id) } else { None },
lhs_data: a.data.clone(),
rhs_data: b.data.clone(),
}));
}
Ok(out_tensor)
}
pub fn matmul(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
if a.device != b.device {
return Err(EtensorError::DeviceMismatch {
expected: a.device.to_string(), got: b.device.to_string(),
});
}
if a.dtype != b.dtype {
return Err(EtensorError::DTypeMismatch {
expected: a.dtype.to_string(), got: b.dtype.to_string(),
});
}
let mut out_tensor = match a.device {
crate::device::Device::Cpu => CpuBackend::matmul(a, b)?,
#[cfg(feature = "cuda-native")]
crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
#[cfg(feature = "torch")]
crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
};
let requires_grad = a.requires_grad || b.requires_grad;
out_tensor.requires_grad = requires_grad;
if requires_grad {
record(Box::new(MatMulBackward {
output_id: out_tensor.id,
lhs_id: if a.requires_grad { Some(a.id) } else { None },
rhs_id: if b.requires_grad { Some(b.id) } else { None },
lhs_data: a.data.clone(),
rhs_data: b.data.clone(),
lhs_shape: a.shape.clone(),
rhs_shape: b.shape.clone(),
}));
}
Ok(out_tensor)
}
pub fn sum_all(a: &Tensor) -> EtensorResult<Tensor> {
let mut out_tensor = match a.device {
crate::device::Device::Cpu => CpuBackend::sum_all(a)?,
#[cfg(feature = "cuda-native")]
crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
#[cfg(feature = "torch")]
crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
};
out_tensor.requires_grad = a.requires_grad;
if a.requires_grad {
record(Box::new(SumAllBackward {
output_id: out_tensor.id,
input_id: a.id,
input_shape: a.shape.clone(),
}));
}
Ok(out_tensor)
}
pub fn relu(a: &Tensor) -> EtensorResult<Tensor> {
let mut out_tensor = match a.device {
crate::device::Device::Cpu => CpuBackend::relu(a)?,
#[cfg(feature = "cuda-native")]
crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
#[cfg(feature = "torch")]
crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
};
out_tensor.requires_grad = a.requires_grad;
if a.requires_grad {
record(Box::new(ReluBackward {
output_id: out_tensor.id,
input_id: a.id,
input_data: a.data.clone(), }));
}
Ok(out_tensor)
}
pub fn sigmoid(a: &Tensor) -> EtensorResult<Tensor> {
let mut out_tensor = match a.device {
crate::device::Device::Cpu => CpuBackend::sigmoid(a)?,
#[cfg(feature = "cuda-native")]
crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
#[cfg(feature = "torch")]
crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
};
out_tensor.requires_grad = a.requires_grad;
if a.requires_grad {
record(Box::new(SigmoidBackward {
output_id: out_tensor.id,
input_id: a.id,
output_data: out_tensor.data.clone(), }));
}
Ok(out_tensor)
}
}
impl Add for &Tensor {
type Output = Tensor;
fn add(self, rhs: Self) -> Self::Output {
Dispatcher::add(self, rhs).expect("Tensor addition failed!")
}
}
impl Mul for &Tensor {
type Output = Tensor;
fn mul(self, rhs: Self) -> Self::Output {
Dispatcher::mul(self, rhs).expect("Tensor multiplication failed!")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::Device;
use crate::dtypes::DType;
use crate::shape::Shape;
use crate::buffer::Buffer;
fn make_tensor(data: Vec<f32>, requires_grad: bool) -> Tensor {
let len = data.len();
Tensor::new(Buffer::from_f32_vec(data), Shape::new(vec![len]), Device::Cpu, DType::F32, requires_grad)
}
#[test]
fn test_dispatcher_add_forward_logic() {
let a = make_tensor(vec![1.0, 2.0, 3.0], false);
let b = make_tensor(vec![4.0, 5.0, 6.0], false);
let c = Dispatcher::add(&a, &b).unwrap();
assert_eq!(c.data.as_f32_slice().unwrap(), &[5.0, 7.0, 9.0]);
assert!(!c.requires_grad);
}
#[test]
fn test_operator_overloading() {
let a = make_tensor(vec![2.0, 4.0], false);
let b = make_tensor(vec![3.0, 5.0], false);
let c = &a + &b;
let d = &a * &b;
assert_eq!(c.data.as_f32_slice().unwrap(), &[5.0, 9.0]);
assert_eq!(d.data.as_f32_slice().unwrap(), &[6.0, 20.0]);
}
}