use std::any::TypeId;
use std::sync::Arc;
use crate::autograd::no_grad::{is_grad_enabled, no_grad};
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::ops::elementwise::{binary_map, scalar_map, unary_map, fast_add, fast_mul};
use crate::shape::broadcast_shapes;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[inline]
fn is_f64<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<f64>()
}
#[inline]
fn needs_grad<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> bool {
is_grad_enabled() && (a.requires_grad() || b.requires_grad())
}
#[inline]
fn needs_grad_unary<T: Float>(a: &Tensor<T>) -> bool {
is_grad_enabled() && a.requires_grad()
}
fn reduce_grad_to_shape<T: Float>(
grad: &Tensor<T>,
target_shape: &[usize],
) -> FerrotorchResult<Tensor<T>> {
let grad_shape = grad.shape();
if grad_shape == target_shape {
return Ok(grad.clone());
}
if target_shape.is_empty() {
return crate::grad_fns::reduction::sum(grad);
}
let device = grad.device();
let cpu_grad = if grad.is_cuda() {
grad.cpu()?
} else {
grad.clone()
};
let grad_data = cpu_grad.data()?;
let grad_ndim = grad_shape.len();
let target_ndim = target_shape.len();
let padded_target: Vec<usize> = if target_ndim < grad_ndim {
let mut p = vec![1usize; grad_ndim - target_ndim];
p.extend_from_slice(target_shape);
p
} else {
target_shape.to_vec()
};
let out_numel: usize = target_shape.iter().product();
let mut result = vec![<T as num_traits::Zero>::zero(); out_numel.max(1)];
let mut target_strides = vec![1usize; target_ndim];
for td in (0..target_ndim.saturating_sub(1)).rev() {
target_strides[td] = target_strides[td + 1] * target_shape[td + 1];
}
let offset = grad_ndim - target_ndim;
for i in 0..grad_data.len() {
let mut coords = [0usize; 16]; let mut rem = i;
for d in (0..grad_ndim).rev() {
coords[d] = rem % grad_shape[d];
rem /= grad_shape[d];
}
let mut flat = 0usize;
for td in 0..target_ndim {
let gd = td + offset;
let coord = if padded_target[gd] == 1 { 0 } else { coords[gd] };
flat += coord * target_strides[td];
}
result[flat] = result[flat] + grad_data[i];
}
let reduced =
Tensor::from_storage(TensorStorage::cpu(result), target_shape.to_vec(), false)?;
if device.is_cuda() {
reduced.to(device)
} else {
Ok(reduced)
}
}
#[derive(Debug)]
struct AddBackward<T: Float> {
a: Tensor<T>,
b: Tensor<T>,
}
impl<T: Float> GradFn<T> for AddBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.a.requires_grad() {
Some(reduce_grad_to_shape(grad_output, self.a.shape())?)
} else {
None
};
let db = if self.b.requires_grad() {
Some(reduce_grad_to_shape(grad_output, self.b.shape())?)
} else {
None
};
Ok(vec![da, db])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"AddBackward"
}
}
pub fn add<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.device() != b.device() {
return Err(FerrotorchError::DeviceMismatch { expected: a.device(), got: b.device() });
}
if a.is_cuda() {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let needs_broadcast = a.shape() != b.shape();
let (handle, out_shape) = if needs_broadcast {
let out_shape = broadcast_shapes(a.shape(), b.shape())?;
let h = backend.broadcast_add_f32(
a.gpu_handle()?, b.gpu_handle()?,
a.shape(), b.shape(), &out_shape,
)?;
(h, out_shape)
} else if is_f64::<T>() {
(backend.add_f64(a.gpu_handle()?, b.gpu_handle()?)?, a.shape().to_vec())
} else {
(backend.add_f32(a.gpu_handle()?, b.gpu_handle()?)?, a.shape().to_vec())
};
let storage = TensorStorage::gpu(handle);
if needs_grad(a, b) {
Tensor::from_operation(
storage,
out_shape,
Arc::new(AddBackward { a: a.clone(), b: b.clone() }),
)
} else {
Tensor::from_storage(storage, out_shape, false)
}
} else {
let result = fast_add(a, b)?;
if needs_grad(a, b) {
let storage = TensorStorage::cpu(result.data()?.to_vec());
Tensor::from_operation(
storage,
result.shape().to_vec(),
Arc::new(AddBackward { a: a.clone(), b: b.clone() }),
)
} else {
Ok(result)
}
}
}
#[derive(Debug)]
struct SubBackward<T: Float> {
a: Tensor<T>,
b: Tensor<T>,
}
impl<T: Float> GradFn<T> for SubBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.a.requires_grad() {
Some(reduce_grad_to_shape(grad_output, self.a.shape())?)
} else {
None
};
let db = if self.b.requires_grad() {
let neg_grad = no_grad(|| neg(grad_output))?;
Some(reduce_grad_to_shape(&neg_grad, self.b.shape())?)
} else {
None
};
Ok(vec![da, db])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"SubBackward"
}
}
pub fn sub<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.device() != b.device() {
return Err(FerrotorchError::DeviceMismatch { expected: a.device(), got: b.device() });
}
if a.is_cuda() {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let needs_broadcast = a.shape() != b.shape();
let (handle, out_shape) = if needs_broadcast {
let out_shape = broadcast_shapes(a.shape(), b.shape())?;
let h = backend.broadcast_sub_f32(
a.gpu_handle()?, b.gpu_handle()?,
a.shape(), b.shape(), &out_shape,
)?;
(h, out_shape)
} else if is_f64::<T>() {
(backend.sub_f64(a.gpu_handle()?, b.gpu_handle()?)?, a.shape().to_vec())
} else {
(backend.sub_f32(a.gpu_handle()?, b.gpu_handle()?)?, a.shape().to_vec())
};
let storage = TensorStorage::gpu(handle);
if needs_grad(a, b) {
Tensor::from_operation(
storage,
out_shape,
Arc::new(SubBackward { a: a.clone(), b: b.clone() }),
)
} else {
Tensor::from_storage(storage, out_shape, false)
}
} else {
let result = binary_map(a, b, |x, y| x - y)?;
if needs_grad(a, b) {
let storage = TensorStorage::cpu(result.data()?.to_vec());
Tensor::from_operation(
storage,
result.shape().to_vec(),
Arc::new(SubBackward { a: a.clone(), b: b.clone() }),
)
} else {
Ok(result)
}
}
}
#[derive(Debug)]
struct MulBackward<T: Float> {
a: Tensor<T>,
b: Tensor<T>,
}
impl<T: Float> GradFn<T> for MulBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if grad_output.requires_grad() || grad_output.grad_fn().is_some() {
let da = if self.a.requires_grad() {
let raw = mul(grad_output, &self.b)?;
Some(reduce_grad_to_shape(&raw, self.a.shape())?)
} else {
None
};
let db = if self.b.requires_grad() {
let raw = mul(grad_output, &self.a)?;
Some(reduce_grad_to_shape(&raw, self.b.shape())?)
} else {
None
};
return Ok(vec![da, db]);
}
let da = if self.a.requires_grad() {
let raw = no_grad(|| mul(grad_output, &self.b))?;
Some(reduce_grad_to_shape(&raw, self.a.shape())?)
} else {
None
};
let db = if self.b.requires_grad() {
let raw = no_grad(|| mul(grad_output, &self.a))?;
Some(reduce_grad_to_shape(&raw, self.b.shape())?)
} else {
None
};
Ok(vec![da, db])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"MulBackward"
}
}
pub fn mul<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.device() != b.device() {
return Err(FerrotorchError::DeviceMismatch { expected: a.device(), got: b.device() });
}
if a.is_cuda() {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let needs_broadcast = a.shape() != b.shape();
let (handle, out_shape) = if needs_broadcast {
let out_shape = broadcast_shapes(a.shape(), b.shape())?;
let h = backend.broadcast_mul_f32(
a.gpu_handle()?, b.gpu_handle()?,
a.shape(), b.shape(), &out_shape,
)?;
(h, out_shape)
} else if is_f64::<T>() {
(backend.mul_f64(a.gpu_handle()?, b.gpu_handle()?)?, a.shape().to_vec())
} else {
(backend.mul_f32(a.gpu_handle()?, b.gpu_handle()?)?, a.shape().to_vec())
};
let storage = TensorStorage::gpu(handle);
if needs_grad(a, b) {
Tensor::from_operation(
storage,
out_shape,
Arc::new(MulBackward { a: a.clone(), b: b.clone() }),
)
} else {
Tensor::from_storage(storage, out_shape, false)
}
} else {
let result = fast_mul(a, b)?;
if needs_grad(a, b) {
let storage = TensorStorage::cpu(result.data()?.to_vec());
Tensor::from_operation(
storage,
result.shape().to_vec(),
Arc::new(MulBackward { a: a.clone(), b: b.clone() }),
)
} else {
Ok(result)
}
}
}
#[derive(Debug)]
struct DivBackward<T: Float> {
a: Tensor<T>,
b: Tensor<T>,
}
impl<T: Float> GradFn<T> for DivBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if grad_output.is_cuda() {
let da = if self.a.requires_grad() {
let raw = no_grad(|| div(grad_output, &self.b))?;
Some(reduce_grad_to_shape(&raw, self.a.shape())?)
} else {
None
};
let db = if self.b.requires_grad() {
let raw = no_grad(|| {
let neg_go = neg(grad_output)?;
let neg_go_a = mul(&neg_go, &self.a)?;
let b_sq = mul(&self.b, &self.b)?;
div(&neg_go_a, &b_sq)
})?;
Some(reduce_grad_to_shape(&raw, self.b.shape())?)
} else {
None
};
return Ok(vec![da, db]);
}
let go_data = grad_output.data()?;
let b_data = self.b.data()?;
let da = if self.a.requires_grad() {
let grad_a: Vec<T> = go_data
.iter()
.zip(b_data.iter())
.map(|(&g, &b)| g / b)
.collect();
let raw = Tensor::from_storage(
TensorStorage::cpu(grad_a),
grad_output.shape().to_vec(),
false,
)?;
Some(reduce_grad_to_shape(&raw, self.a.shape())?)
} else {
None
};
let db = if self.b.requires_grad() {
let a_data = self.a.data()?;
let grad_b: Vec<T> = go_data
.iter()
.zip(a_data.iter().zip(b_data.iter()))
.map(|(&g, (&a, &b))| -g * a / (b * b))
.collect();
let raw = Tensor::from_storage(
TensorStorage::cpu(grad_b),
grad_output.shape().to_vec(),
false,
)?;
Some(reduce_grad_to_shape(&raw, self.b.shape())?)
} else {
None
};
Ok(vec![da, db])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"DivBackward"
}
}
pub fn div<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let result = binary_map(a, b, |x, y| x / y)?;
if needs_grad(a, b) {
let storage = TensorStorage::cpu(result.data()?.to_vec());
Tensor::from_operation(
storage,
result.shape().to_vec(),
Arc::new(DivBackward {
a: a.clone(),
b: b.clone(),
}),
)
} else {
Ok(result)
}
}
#[derive(Debug)]
struct NegBackward<T: Float> {
a: Tensor<T>,
}
impl<T: Float> GradFn<T> for NegBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.a.requires_grad() {
Some(no_grad(|| neg(grad_output))?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a]
}
fn name(&self) -> &'static str {
"NegBackward"
}
}
pub fn neg<T: Float>(a: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.is_cuda() {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let handle = if is_f64::<T>() {
backend.neg_f64(a.gpu_handle()?)?
} else {
backend.neg_f32(a.gpu_handle()?)?
};
let storage = TensorStorage::gpu(handle);
let shape = a.shape().to_vec();
if needs_grad_unary(a) {
Tensor::from_operation(
storage,
shape,
Arc::new(NegBackward { a: a.clone() }),
)
} else {
Tensor::from_storage(storage, shape, false)
}
} else {
let result = unary_map(a, |x| -x)?;
if needs_grad_unary(a) {
let storage = TensorStorage::cpu(result.data()?.to_vec());
Tensor::from_operation(
storage,
result.shape().to_vec(),
Arc::new(NegBackward { a: a.clone() }),
)
} else {
Ok(result)
}
}
}
#[derive(Debug)]
struct PowBackward<T: Float> {
a: Tensor<T>,
exp: f64,
}
impl<T: Float> GradFn<T> for PowBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.a.requires_grad() {
if grad_output.requires_grad() || grad_output.grad_fn().is_some() {
let a_pow = pow(&self.a, self.exp - 1.0)?; let exp_t = T::from(self.exp).unwrap();
let exp_tensor = Tensor::from_storage(
TensorStorage::cpu(vec![exp_t; self.a.numel().max(1)]),
self.a.shape().to_vec(),
false,
)?;
let scaled = mul(&exp_tensor, &a_pow)?; Some(mul(grad_output, &scaled)?) } else if grad_output.is_cuda() {
let da = no_grad(|| {
let a_pow = pow(&self.a, self.exp - 1.0)?;
let exp_t = T::from(self.exp).unwrap();
let exp_tensor = Tensor::from_storage(
TensorStorage::cpu(vec![exp_t; self.a.numel().max(1)]),
self.a.shape().to_vec(),
false,
)?;
let exp_gpu = exp_tensor.to(self.a.device())?;
let scaled = mul(&exp_gpu, &a_pow)?;
mul(grad_output, &scaled)
})?;
Some(da)
} else {
let go_data = grad_output.data()?;
let a_data = self.a.data()?;
let exp_t = T::from(self.exp).unwrap();
let exp_m1 = T::from(self.exp - 1.0).unwrap();
let grad_a: Vec<T> = go_data
.iter()
.zip(a_data.iter())
.map(|(&g, &a)| g * exp_t * a.powf(exp_m1))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.a.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a]
}
fn name(&self) -> &'static str {
"PowBackward"
}
}
pub fn pow<T: Float>(a: &Tensor<T>, exp: f64) -> FerrotorchResult<Tensor<T>> {
let exp_t = T::from(exp).unwrap();
let result = scalar_map(a, exp_t, |x, e| x.powf(e))?;
if needs_grad_unary(a) {
let storage = TensorStorage::cpu(result.data()?.to_vec());
Tensor::from_operation(
storage,
result.shape().to_vec(),
Arc::new(PowBackward {
a: a.clone(),
exp,
}),
)
} else {
Ok(result)
}
}
#[derive(Debug)]
struct SqrtBackward<T: Float> {
a: Tensor<T>,
}
impl<T: Float> GradFn<T> for SqrtBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.a.requires_grad() {
if grad_output.is_cuda() {
let da = no_grad(|| {
let sqrt_a = sqrt(&self.a)?;
let two_t = T::from(2.0).unwrap();
let two_tensor = Tensor::from_storage(
TensorStorage::cpu(vec![two_t; self.a.numel().max(1)]),
self.a.shape().to_vec(),
false,
)?;
let two_gpu = two_tensor.to(self.a.device())?;
let denom = mul(&two_gpu, &sqrt_a)?;
div(grad_output, &denom)
})?;
Some(da)
} else {
let go_data = grad_output.data()?;
let a_data = self.a.data()?;
let two = T::from(2.0).unwrap();
let grad_a: Vec<T> = go_data
.iter()
.zip(a_data.iter())
.map(|(&g, &a)| g / (two * a.sqrt()))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.a.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a]
}
fn name(&self) -> &'static str {
"SqrtBackward"
}
}
pub fn sqrt<T: Float>(a: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let result = unary_map(a, |x| x.sqrt())?;
if needs_grad_unary(a) {
let storage = TensorStorage::cpu(result.data()?.to_vec());
Tensor::from_operation(
storage,
result.shape().to_vec(),
Arc::new(SqrtBackward { a: a.clone() }),
)
} else {
Ok(result)
}
}
#[derive(Debug)]
struct AbsBackward<T: Float> {
a: Tensor<T>,
}
impl<T: Float> GradFn<T> for AbsBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.a.requires_grad() {
if grad_output.is_cuda() {
let a_cpu = self.a.cpu()?;
let a_data = a_cpu.data()?;
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let sign_data: Vec<T> = a_data
.iter()
.map(|&a| {
if a > zero {
one
} else if a < zero {
-one
} else {
zero
}
})
.collect();
let sign_cpu = Tensor::from_storage(
TensorStorage::cpu(sign_data),
self.a.shape().to_vec(),
false,
)?;
let sign_gpu = sign_cpu.to(grad_output.device())?;
Some(no_grad(|| mul(grad_output, &sign_gpu))?)
} else {
let go_data = grad_output.data()?;
let a_data = self.a.data()?;
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let grad_a: Vec<T> = go_data
.iter()
.zip(a_data.iter())
.map(|(&g, &a)| {
let sign = if a > zero {
one
} else if a < zero {
-one
} else {
zero
};
g * sign
})
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.a.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a]
}
fn name(&self) -> &'static str {
"AbsBackward"
}
}
pub fn abs<T: Float>(a: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let result = unary_map(a, |x| x.abs())?;
if needs_grad_unary(a) {
let storage = TensorStorage::cpu(result.data()?.to_vec());
Tensor::from_operation(
storage,
result.shape().to_vec(),
Arc::new(AbsBackward { a: a.clone() }),
)
} else {
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn leaf_scalar(val: f32, requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], requires_grad).unwrap()
}
fn leaf_vec(data: &[f32], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], requires_grad)
.unwrap()
}
fn assert_scalar_approx(t: &Tensor<f32>, expected: f32, tol: f32) {
let val = t.item().unwrap();
assert!(
(val - expected).abs() < tol,
"expected {expected}, got {val}"
);
}
#[test]
fn test_add_forward() {
let a = leaf_vec(&[1.0, 2.0, 3.0], false);
let b = leaf_vec(&[4.0, 5.0, 6.0], false);
let c = add(&a, &b).unwrap();
assert_eq!(c.data().unwrap(), &[5.0, 7.0, 9.0]);
}
#[test]
fn test_sub_forward() {
let a = leaf_vec(&[10.0, 20.0, 30.0], false);
let b = leaf_vec(&[1.0, 2.0, 3.0], false);
let c = sub(&a, &b).unwrap();
assert_eq!(c.data().unwrap(), &[9.0, 18.0, 27.0]);
}
#[test]
fn test_mul_forward() {
let a = leaf_vec(&[2.0, 3.0, 4.0], false);
let b = leaf_vec(&[5.0, 6.0, 7.0], false);
let c = mul(&a, &b).unwrap();
assert_eq!(c.data().unwrap(), &[10.0, 18.0, 28.0]);
}
#[test]
fn test_div_forward() {
let a = leaf_vec(&[10.0, 20.0, 30.0], false);
let b = leaf_vec(&[2.0, 5.0, 10.0], false);
let c = div(&a, &b).unwrap();
assert_eq!(c.data().unwrap(), &[5.0, 4.0, 3.0]);
}
#[test]
fn test_neg_forward() {
let a = leaf_vec(&[1.0, -2.0, 3.0], false);
let c = neg(&a).unwrap();
assert_eq!(c.data().unwrap(), &[-1.0, 2.0, -3.0]);
}
#[test]
fn test_pow_forward() {
let a = leaf_vec(&[2.0, 3.0, 4.0], false);
let c = pow(&a, 2.0).unwrap();
let d = c.data().unwrap();
assert!((d[0] - 4.0).abs() < 1e-6);
assert!((d[1] - 9.0).abs() < 1e-6);
assert!((d[2] - 16.0).abs() < 1e-6);
}
#[test]
fn test_sqrt_forward() {
let a = leaf_vec(&[4.0, 9.0, 16.0], false);
let c = sqrt(&a).unwrap();
let d = c.data().unwrap();
assert!((d[0] - 2.0).abs() < 1e-6);
assert!((d[1] - 3.0).abs() < 1e-6);
assert!((d[2] - 4.0).abs() < 1e-6);
}
#[test]
fn test_abs_forward() {
let a = leaf_vec(&[-3.0, 0.0, 5.0], false);
let c = abs(&a).unwrap();
assert_eq!(c.data().unwrap(), &[3.0, 0.0, 5.0]);
}
#[test]
fn test_add_backward() {
let a = leaf_scalar(2.0, true);
let b = leaf_scalar(3.0, true);
let c = add(&a, &b).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 1.0, 1e-6);
assert_scalar_approx(&b.grad().unwrap().unwrap(), 1.0, 1e-6);
}
#[test]
fn test_sub_backward() {
let a = leaf_scalar(5.0, true);
let b = leaf_scalar(3.0, true);
let c = sub(&a, &b).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 1.0, 1e-6);
assert_scalar_approx(&b.grad().unwrap().unwrap(), -1.0, 1e-6);
}
#[test]
fn test_mul_backward() {
let a = leaf_scalar(2.0, true);
let b = leaf_scalar(3.0, true);
let c = mul(&a, &b).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 3.0, 1e-6);
assert_scalar_approx(&b.grad().unwrap().unwrap(), 2.0, 1e-6);
}
#[test]
fn test_div_backward() {
let a = leaf_scalar(6.0, true);
let b = leaf_scalar(4.0, true);
let c = div(&a, &b).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.25, 1e-6);
assert_scalar_approx(&b.grad().unwrap().unwrap(), -0.375, 1e-6);
}
#[test]
fn test_neg_backward() {
let a = leaf_scalar(7.0, true);
let c = neg(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), -1.0, 1e-6);
}
#[test]
fn test_pow_backward() {
let a = leaf_scalar(2.0, true);
let c = pow(&a, 3.0).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 12.0, 1e-5);
}
#[test]
fn test_sqrt_backward() {
let a = leaf_scalar(4.0, true);
let c = sqrt(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.25, 1e-6);
}
#[test]
fn test_abs_backward_positive() {
let a = leaf_scalar(3.0, true);
let c = abs(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 1.0, 1e-6);
}
#[test]
fn test_abs_backward_negative() {
let a = leaf_scalar(-3.0, true);
let c = abs(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), -1.0, 1e-6);
}
#[test]
fn test_add_no_grad_fn_when_inputs_detached() {
let a = leaf_scalar(2.0, false);
let b = leaf_scalar(3.0, false);
let c = add(&a, &b).unwrap();
assert!(c.grad_fn().is_none());
}
#[test]
fn test_mul_partial_requires_grad() {
let a = leaf_scalar(3.0, true);
let b = leaf_scalar(5.0, false);
let c = mul(&a, &b).unwrap();
assert!(c.grad_fn().is_some());
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 5.0, 1e-6);
assert!(b.grad().unwrap().is_none());
}
#[test]
fn test_no_grad_context_skips_backward() {
use crate::autograd::no_grad::no_grad;
let a = leaf_scalar(2.0, true);
let b = leaf_scalar(3.0, true);
let c = no_grad(|| add(&a, &b)).unwrap();
assert!(c.grad_fn().is_none());
}
#[test]
fn test_chain_mul_add() {
let a = leaf_scalar(2.0, true);
let b = leaf_scalar(3.0, true);
let c = mul(&a, &b).unwrap();
let d = add(&c, &b).unwrap();
d.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 3.0, 1e-6);
assert_scalar_approx(&b.grad().unwrap().unwrap(), 3.0, 1e-6);
}
#[test]
fn test_chain_div_sub() {
let a = leaf_scalar(3.0, true);
let b = leaf_scalar(2.0, true);
let d = div(&a, &b).unwrap();
let e = sub(&d, &a).unwrap();
e.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), -0.5, 1e-5);
assert_scalar_approx(&b.grad().unwrap().unwrap(), -0.75, 1e-5);
}
#[test]
fn test_chain_sqrt_pow() {
let a = leaf_scalar(9.0, true);
let s = sqrt(&a).unwrap();
let c = pow(&s, 2.0).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 1.0, 1e-5);
}
#[test]
fn test_neg_double() {
let a = leaf_scalar(5.0, true);
let b = neg(&a).unwrap();
let c = neg(&b).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 1.0, 1e-6);
}
#[test]
fn test_mul_vector_backward() {
let a = leaf_vec(&[1.0, 2.0, 3.0], true);
let b = leaf_vec(&[4.0, 5.0, 6.0], true);
let c = mul(&a, &b).unwrap();
let c_data = c.data().unwrap().to_vec();
let total: f32 = c_data.iter().sum();
let sum_backward = SumBackward { input: c.clone() };
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(sum_backward),
)
.unwrap();
loss.backward().unwrap();
let a_grad = a.grad().unwrap().unwrap();
let a_g = a_grad.data().unwrap();
assert!((a_g[0] - 4.0).abs() < 1e-6);
assert!((a_g[1] - 5.0).abs() < 1e-6);
assert!((a_g[2] - 6.0).abs() < 1e-6);
let b_grad = b.grad().unwrap().unwrap();
let b_g = b_grad.data().unwrap();
assert!((b_g[0] - 1.0).abs() < 1e-6);
assert!((b_g[1] - 2.0).abs() < 1e-6);
assert!((b_g[2] - 3.0).abs() < 1e-6);
}
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(&self, _grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let ones_data = vec![<T as num_traits::One>::one(); self.input.numel()];
let ones = Tensor::from_storage(
TensorStorage::cpu(ones_data),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(ones)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
}