use std::any::TypeId;
use std::sync::Arc;
use crate::autograd::no_grad::{is_grad_enabled, no_grad};
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::gpu_dispatch::gpu_backend;
use crate::ops::elementwise::{fast_cos, fast_sin, unary_map};
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[inline]
fn is_f32<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<f32>()
}
#[inline]
fn needs_grad_unary<T: Float>(a: &Tensor<T>) -> bool {
is_grad_enabled() && a.requires_grad()
}
#[derive(Debug)]
struct ExpBackward<T: Float> {
input: Tensor<T>,
output: Tensor<T>,
}
impl<T: Float> GradFn<T> for ExpBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
if grad_output.is_cuda() {
Some(no_grad(|| {
crate::grad_fns::arithmetic::mul(grad_output, &self.output)
})?)
} else {
let go_data = grad_output.data()?;
let out_data = self.output.data()?;
let grad_a: Vec<T> = go_data
.iter()
.zip(out_data.iter())
.map(|(&g, &o)| g * o)
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"ExpBackward"
}
}
pub fn exp<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() && is_f32::<T>() {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let handle = backend.exp_f32(input.gpu_handle()?)?;
let storage = TensorStorage::gpu(handle);
let shape = input.shape().to_vec();
if needs_grad_unary(input) {
let output = Tensor::from_storage(storage, shape.clone(), false)?;
let grad_fn = Arc::new(ExpBackward {
input: input.clone(),
output: output.clone(),
});
let (s, sh) = output.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
} else {
let output = crate::ops::elementwise::fast_exp(input)?;
if needs_grad_unary(input) {
let grad_fn = Arc::new(ExpBackward {
input: input.clone(),
output: output.clone(),
});
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(output)
}
}
}
#[derive(Debug)]
struct LogBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for LogBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
if grad_output.is_cuda() {
Some(no_grad(|| {
crate::grad_fns::arithmetic::div(grad_output, &self.input)
})?)
} else {
let go_data = grad_output.data()?;
let x_data = self.input.data()?;
let grad_a: Vec<T> = go_data
.iter()
.zip(x_data.iter())
.map(|(&g, &x)| g / x)
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"LogBackward"
}
}
pub fn log<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() && is_f32::<T>() {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let handle = backend.log_f32(input.gpu_handle()?)?;
let storage = TensorStorage::gpu(handle);
let shape = input.shape().to_vec();
if needs_grad_unary(input) {
Tensor::from_operation(
storage,
shape,
Arc::new(LogBackward {
input: input.clone(),
}),
)
} else {
Tensor::from_storage(storage, shape, false)
}
} else {
let output = unary_map(input, |x| x.ln())?;
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(
storage,
shape,
Arc::new(LogBackward {
input: input.clone(),
}),
)
} else {
Ok(output)
}
}
}
#[derive(Debug)]
struct SinBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SinBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
if grad_output.is_cuda() {
let da = no_grad(|| {
let cos_x = cos(&self.input)?;
crate::grad_fns::arithmetic::mul(grad_output, &cos_x)
})?;
Some(da)
} else {
let go_data = grad_output.data()?;
let x_data = self.input.data()?;
let grad_a: Vec<T> = go_data
.iter()
.zip(x_data.iter())
.map(|(&g, &x)| g * x.cos())
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SinBackward"
}
}
pub fn sin<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = fast_sin(input)?;
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(
storage,
shape,
Arc::new(SinBackward {
input: input.clone(),
}),
)
} else {
Ok(output)
}
}
#[derive(Debug)]
struct CosBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for CosBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
if grad_output.is_cuda() {
let da = no_grad(|| {
let sin_x = sin(&self.input)?;
let neg_sin = crate::grad_fns::arithmetic::neg(&sin_x)?;
crate::grad_fns::arithmetic::mul(grad_output, &neg_sin)
})?;
Some(da)
} else {
let go_data = grad_output.data()?;
let x_data = self.input.data()?;
let grad_a: Vec<T> = go_data
.iter()
.zip(x_data.iter())
.map(|(&g, &x)| g * (-x.sin()))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CosBackward"
}
}
pub fn cos<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = fast_cos(input)?;
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(
storage,
shape,
Arc::new(CosBackward {
input: input.clone(),
}),
)
} else {
Ok(output)
}
}
#[derive(Debug)]
struct ClampBackward<T: Float> {
input: Tensor<T>,
min: T,
max: T,
}
impl<T: Float> GradFn<T> for ClampBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
if grad_output.is_cuda() {
let input_cpu = self.input.cpu()?;
let x_data = input_cpu.data()?;
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let mask_data: Vec<T> = x_data
.iter()
.map(|&x| {
if x >= self.min && x <= self.max {
one
} else {
zero
}
})
.collect();
let mask_cpu = Tensor::from_storage(
TensorStorage::cpu(mask_data),
self.input.shape().to_vec(),
false,
)?;
let mask_gpu = mask_cpu.to(grad_output.device())?;
Some(no_grad(|| {
crate::grad_fns::arithmetic::mul(grad_output, &mask_gpu)
})?)
} else {
let go_data = grad_output.data()?;
let x_data = self.input.data()?;
let zero = <T as num_traits::Zero>::zero();
let grad_a: Vec<T> = go_data
.iter()
.zip(x_data.iter())
.map(|(&g, &x)| {
if x >= self.min && x <= self.max {
g
} else {
zero
}
})
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"ClampBackward"
}
}
pub fn clamp<T: Float>(input: &Tensor<T>, min: T, max: T) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| {
if x < min {
min
} else if x > max {
max
} else {
x
}
})?;
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(
storage,
shape,
Arc::new(ClampBackward {
input: input.clone(),
min,
max,
}),
)
} else {
Ok(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn leaf_scalar(val: f32, requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], requires_grad).unwrap()
}
fn leaf_vec(data: &[f32], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
vec![data.len()],
requires_grad,
)
.unwrap()
}
fn assert_scalar_approx(t: &Tensor<f32>, expected: f32, tol: f32) {
let val = t.item().unwrap();
assert!(
(val - expected).abs() < tol,
"expected {expected}, got {val}"
);
}
#[test]
fn test_exp_forward() {
let a = leaf_vec(&[0.0, 1.0, 2.0], false);
let c = exp(&a).unwrap();
let d = c.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-5);
assert!((d[1] - std::f32::consts::E).abs() < 1e-5);
assert!((d[2] - std::f32::consts::E * std::f32::consts::E).abs() < 1e-4);
}
#[test]
fn test_log_forward() {
let a = leaf_vec(
&[
1.0,
std::f32::consts::E,
std::f32::consts::E * std::f32::consts::E,
],
false,
);
let c = log(&a).unwrap();
let d = c.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-5);
assert!((d[1] - 1.0).abs() < 1e-5);
assert!((d[2] - 2.0).abs() < 1e-4);
}
#[test]
fn test_sin_forward() {
let a = leaf_vec(
&[0.0, std::f32::consts::FRAC_PI_2, std::f32::consts::PI],
false,
);
let c = sin(&a).unwrap();
let d = c.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-6);
assert!((d[1] - 1.0).abs() < 1e-6);
assert!(d[2].abs() < 1e-6);
}
#[test]
fn test_cos_forward() {
let a = leaf_vec(
&[0.0, std::f32::consts::FRAC_PI_2, std::f32::consts::PI],
false,
);
let c = cos(&a).unwrap();
let d = c.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-6);
assert!(d[1].abs() < 1e-6);
assert!((d[2] - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_clamp_forward() {
let a = leaf_vec(&[-2.0, 0.5, 1.5, 3.0], false);
let c = clamp(&a, 0.0, 2.0).unwrap();
assert_eq!(c.data().unwrap(), &[0.0, 0.5, 1.5, 2.0]);
}
#[test]
fn test_exp_backward() {
let a = leaf_scalar(1.0, true);
let c = exp(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), std::f32::consts::E, 1e-5);
}
#[test]
fn test_log_backward() {
let a = leaf_scalar(2.0, true);
let c = log(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.5, 1e-6);
}
#[test]
fn test_sin_backward() {
let a = leaf_scalar(0.0, true);
let c = sin(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 1.0, 1e-6);
}
#[test]
fn test_sin_backward_pi_over_3() {
let a = leaf_scalar(std::f32::consts::FRAC_PI_3, true);
let c = sin(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.5, 1e-5);
}
#[test]
fn test_cos_backward() {
let a = leaf_scalar(0.0, true);
let c = cos(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.0, 1e-6);
}
#[test]
fn test_cos_backward_pi_over_2() {
let a = leaf_scalar(std::f32::consts::FRAC_PI_2, true);
let c = cos(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), -1.0, 1e-5);
}
#[test]
fn test_clamp_backward_interior() {
let a = leaf_scalar(1.5, true);
let c = clamp(&a, 0.0, 2.0).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 1.0, 1e-6);
}
#[test]
fn test_clamp_backward_clamped_low() {
let a = leaf_scalar(-1.0, true);
let c = clamp(&a, 0.0, 2.0).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.0, 1e-6);
}
#[test]
fn test_clamp_backward_clamped_high() {
let a = leaf_scalar(5.0, true);
let c = clamp(&a, 0.0, 2.0).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.0, 1e-6);
}
#[test]
fn test_chain_exp_log() {
let a = leaf_scalar(3.0, true);
let b = exp(&a).unwrap();
let c = log(&b).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 1.0, 1e-4);
}
#[test]
fn test_chain_sin_cos() {
let a = leaf_scalar(0.5, true);
let b = sin(&a).unwrap();
let c = cos(&b).unwrap();
c.backward().unwrap();
let expected = -(0.5_f32.sin().sin()) * 0.5_f32.cos();
assert_scalar_approx(&a.grad().unwrap().unwrap(), expected, 1e-4);
}
#[test]
fn test_exp_no_grad_fn_when_not_tracking() {
let a = leaf_scalar(1.0, false);
let c = exp(&a).unwrap();
assert!(c.grad_fn().is_none());
}
#[test]
fn test_log_no_grad_fn_when_not_tracking() {
let a = leaf_scalar(1.0, false);
let c = log(&a).unwrap();
assert!(c.grad_fn().is_none());
}
#[test]
fn test_clamp_no_grad_fn_when_not_tracking() {
let a = leaf_scalar(1.0, false);
let c = clamp(&a, 0.0, 2.0).unwrap();
assert!(c.grad_fn().is_none());
}
fn numerical_grad_check(f: impl Fn(f32) -> f32, x: f32, analytic_grad: f32, tol: f32) {
let h = 1e-4_f32;
let numerical = (f(x + h) - f(x - h)) / (2.0 * h);
assert!(
(analytic_grad - numerical).abs() < tol,
"analytic={analytic_grad}, numerical={numerical}",
);
}
#[test]
fn test_exp_numerical_grad() {
let x = 1.5_f32;
let a = leaf_scalar(x, true);
let c = exp(&a).unwrap();
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.exp(), x, g, 1e-3);
}
#[test]
fn test_log_numerical_grad() {
let x = 2.0_f32;
let a = leaf_scalar(x, true);
let c = log(&a).unwrap();
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.ln(), x, g, 1e-3);
}
#[test]
fn test_sin_numerical_grad() {
let x = 1.0_f32;
let a = leaf_scalar(x, true);
let c = sin(&a).unwrap();
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.sin(), x, g, 1e-3);
}
#[test]
fn test_cos_numerical_grad() {
let x = 1.0_f32;
let a = leaf_scalar(x, true);
let c = cos(&a).unwrap();
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.cos(), x, g, 1e-3);
}
#[test]
fn test_clamp_numerical_grad_interior() {
let x = 0.5_f32;
let a = leaf_scalar(x, true);
let c = clamp(&a, 0.0, 1.0).unwrap();
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.clamp(0.0, 1.0), x, g, 1e-3);
}
}