use std::any::TypeId;
use std::sync::Arc;
use crate::autograd::no_grad::{is_grad_enabled, no_grad};
use crate::bool_tensor::BoolTensor;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::gpu_dispatch::gpu_backend;
use crate::grad_fns::arithmetic::reduce_grad_to_shape;
use crate::ops::elementwise::{binary_map, fast_cos, fast_sin, unary_map};
use crate::shape::broadcast_shapes;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[inline]
fn is_f32<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<f32>()
}
#[inline]
fn is_f64<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<f64>()
}
#[inline]
fn is_bf16<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<half::bf16>()
}
#[inline]
fn is_f16<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<half::f16>()
}
#[inline]
fn needs_grad_unary<T: Float>(a: &Tensor<T>) -> bool {
is_grad_enabled() && a.requires_grad()
}
#[derive(Debug)]
struct ExpBackward<T: Float> {
input: Tensor<T>,
output: Tensor<T>,
}
impl<T: Float> GradFn<T> for ExpBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
if grad_output.is_cuda() {
Some(no_grad(|| {
crate::grad_fns::arithmetic::mul(grad_output, &self.output)
})?)
} else {
let go_data = grad_output.data()?;
let out_data = self.output.data()?;
let grad_a: Vec<T> = go_data
.iter()
.zip(out_data.iter())
.map(|(&g, &o)| g * o)
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"ExpBackward"
}
}
pub fn exp<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
crate::profiler_hook::profile_op_scope("exp", "tensor_op", &[input.shape()], || {
exp_inner(input)
})
}
fn exp_inner<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() && (is_f32::<T>() || is_f64::<T>() || is_bf16::<T>() || is_f16::<T>()) {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let handle: crate::gpu_dispatch::GpuBufferHandle = crate::dispatch_floating_dtype!(
T,
"exp",
f32 => backend.exp_f32(input.gpu_handle()?),
f64 => backend.exp_f64(input.gpu_handle()?),
bf16 => backend.exp_bf16_bf16(input.gpu_handle()?),
f16 => backend.exp_f16(input.gpu_handle()?),
)?;
let storage = TensorStorage::gpu(handle);
let shape = input.shape().to_vec();
if needs_grad_unary(&input) {
let output = Tensor::from_storage(storage, shape.clone(), false)?;
let grad_fn = Arc::new(ExpBackward {
input: input.clone(),
output: output.clone(),
});
let (s, sh) = output.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
} else {
let output = crate::ops::elementwise::fast_exp(input)?;
if needs_grad_unary(input) {
let grad_fn = Arc::new(ExpBackward {
input: input.clone(),
output: output.clone(),
});
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(output)
}
}
}
#[derive(Debug)]
struct LogBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for LogBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
if grad_output.is_cuda() {
Some(no_grad(|| {
crate::grad_fns::arithmetic::div(grad_output, &self.input)
})?)
} else {
let go_data = grad_output.data()?;
let x_data = self.input.data()?;
let grad_a: Vec<T> = go_data
.iter()
.zip(x_data.iter())
.map(|(&g, &x)| g / x)
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"LogBackward"
}
}
pub fn log<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
crate::profiler_hook::profile_op_scope("log", "tensor_op", &[input.shape()], || {
log_inner(input)
})
}
fn log_inner<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda() && (is_f32::<T>() || is_f64::<T>() || is_bf16::<T>() || is_f16::<T>()) {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let handle: crate::gpu_dispatch::GpuBufferHandle = crate::dispatch_floating_dtype!(
T,
"log",
f32 => backend.log_f32(input.gpu_handle()?),
f64 => backend.log_f64(input.gpu_handle()?),
bf16 => backend.log_bf16_bf16(input.gpu_handle()?),
f16 => backend.log_f16(input.gpu_handle()?),
)?;
let storage = TensorStorage::gpu(handle);
let shape = input.shape().to_vec();
if needs_grad_unary(&input) {
Tensor::from_operation(
storage,
shape,
Arc::new(LogBackward {
input: input.clone(),
}),
)
} else {
Tensor::from_storage(storage, shape, false)
}
} else {
let output = crate::ops::elementwise::fast_log(input)?;
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(
storage,
shape,
Arc::new(LogBackward {
input: input.clone(),
}),
)
} else {
Ok(output)
}
}
}
#[derive(Debug)]
struct SinBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SinBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
if grad_output.is_cuda() {
let da = no_grad(|| {
let cos_x = cos(&self.input)?;
crate::grad_fns::arithmetic::mul(grad_output, &cos_x)
})?;
Some(da)
} else {
let go_data = grad_output.data()?;
let x_data = self.input.data()?;
let grad_a: Vec<T> = go_data
.iter()
.zip(x_data.iter())
.map(|(&g, &x)| g * x.cos())
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SinBackward"
}
}
pub fn sin<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
crate::profiler_hook::profile_op_scope("sin", "tensor_op", &[input.shape()], || {
sin_inner(input)
})
}
fn sin_inner<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = fast_sin(input)?;
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(
storage,
shape,
Arc::new(SinBackward {
input: input.clone(),
}),
)
} else {
Ok(output)
}
}
#[derive(Debug)]
struct CosBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for CosBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
if grad_output.is_cuda() {
let da = no_grad(|| {
let sin_x = sin(&self.input)?;
let neg_sin = crate::grad_fns::arithmetic::neg(&sin_x)?;
crate::grad_fns::arithmetic::mul(grad_output, &neg_sin)
})?;
Some(da)
} else {
let go_data = grad_output.data()?;
let x_data = self.input.data()?;
let grad_a: Vec<T> = go_data
.iter()
.zip(x_data.iter())
.map(|(&g, &x)| g * (-x.sin()))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CosBackward"
}
}
pub fn cos<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
crate::profiler_hook::profile_op_scope("cos", "tensor_op", &[input.shape()], || {
cos_inner(input)
})
}
fn cos_inner<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = fast_cos(input)?;
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(
storage,
shape,
Arc::new(CosBackward {
input: input.clone(),
}),
)
} else {
Ok(output)
}
}
#[derive(Debug)]
struct ClampBackward<T: Float> {
input: Tensor<T>,
min: T,
max: T,
}
impl<T: Float> GradFn<T> for ClampBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
if grad_output.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let result_h = if is_f32::<T>() {
let min_f = self.min.to_f64().unwrap_or(f64::NEG_INFINITY) as f32;
let max_f = self.max.to_f64().unwrap_or(f64::INFINITY) as f32;
backend.clamp_backward_f32(
grad_output.gpu_handle()?,
self.input.gpu_handle()?,
min_f,
max_f,
)?
} else {
let min_f = self.min.to_f64().unwrap_or(f64::NEG_INFINITY);
let max_f = self.max.to_f64().unwrap_or(f64::INFINITY);
backend.clamp_backward_f64(
grad_output.gpu_handle()?,
self.input.gpu_handle()?,
min_f,
max_f,
)?
};
Some(Tensor::from_storage(
TensorStorage::gpu(result_h),
self.input.shape().to_vec(),
false,
)?)
} else if grad_output.is_cuda() || self.input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "ClampBackward",
});
} else {
let go_data = grad_output.data()?;
let x_data = self.input.data()?;
let zero = <T as num_traits::Zero>::zero();
let grad_a: Vec<T> = go_data
.iter()
.zip(x_data.iter())
.map(|(&g, &x)| {
if x >= self.min && x <= self.max {
g
} else {
zero
}
})
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?)
}
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"ClampBackward"
}
}
pub fn clamp<T: Float>(input: &Tensor<T>, min: T, max: T) -> FerrotorchResult<Tensor<T>> {
if input.is_cuda()
&& (is_f32::<T>() || is_f64::<T>())
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let input = input.contiguous()?;
let handle = if is_f32::<T>() {
let min_f32 = min.to_f32().unwrap_or(f32::MIN);
let max_f32 = max.to_f32().unwrap_or(f32::MAX);
backend.clamp_f32(input.gpu_handle()?, min_f32, max_f32)?
} else {
let min_f64 = min.to_f64().unwrap_or(f64::MIN);
let max_f64 = max.to_f64().unwrap_or(f64::MAX);
backend.clamp_f64(input.gpu_handle()?, min_f64, max_f64)?
};
return if needs_grad_unary(&input) {
Tensor::from_operation(
TensorStorage::gpu(handle),
input.shape().to_vec(),
Arc::new(ClampBackward {
input: input.clone(),
min,
max,
}),
)
} else {
Tensor::from_storage(TensorStorage::gpu(handle), input.shape().to_vec(), false)
};
}
let output = unary_map(input, |x| {
if x < min {
min
} else if x > max {
max
} else {
x
}
})?;
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(
storage,
shape,
Arc::new(ClampBackward {
input: input.clone(),
min,
max,
}),
)
} else {
Ok(output)
}
}
fn zeros_like_tensor<T: Float>(like: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
let n: usize = like.shape().iter().product::<usize>().max(1);
Tensor::from_storage(
TensorStorage::cpu(vec![zero; n]),
like.shape().to_vec(),
false,
)
}
#[inline]
fn finish_unary<T: Float, G: GradFn<T> + 'static>(
output: Tensor<T>,
input: &Tensor<T>,
make_grad: impl FnOnce() -> G,
) -> FerrotorchResult<Tensor<T>> {
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, Arc::new(make_grad()))
} else {
Ok(output)
}
}
#[derive(Debug)]
struct TanBackward<T: Float> {
input: Tensor<T>,
output: Tensor<T>,
}
impl<T: Float> GradFn<T> for TanBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let o = self.output.data()?;
let one = <T as num_traits::One>::one();
let g: Vec<T> = go
.iter()
.zip(o.iter())
.map(|(&g, &t)| g * (one + t * t))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"TanBackward"
}
}
pub fn tan<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.tan())?;
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
let out_tensor = Tensor::from_storage(storage, shape, false)?;
let grad_fn = Arc::new(TanBackward {
input: input.clone(),
output: out_tensor.clone(),
});
let (s, sh) = out_tensor.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Ok(output)
}
}
#[derive(Debug)]
struct AsinBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for AsinBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let one = <T as num_traits::One>::one();
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| g / (one - x * x).sqrt())
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"AsinBackward"
}
}
pub fn asin<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.asin())?;
finish_unary(output, input, || AsinBackward {
input: input.clone(),
})
}
#[derive(Debug)]
struct AcosBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for AcosBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let one = <T as num_traits::One>::one();
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| -(g / (one - x * x).sqrt()))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"AcosBackward"
}
}
pub fn acos<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.acos())?;
finish_unary(output, input, || AcosBackward {
input: input.clone(),
})
}
#[derive(Debug)]
struct AtanBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for AtanBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let one = <T as num_traits::One>::one();
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| g / (one + x * x))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"AtanBackward"
}
}
pub fn atan<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.atan())?;
finish_unary(output, input, || AtanBackward {
input: input.clone(),
})
}
#[derive(Debug)]
struct SinhBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SinhBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| g * x.cosh())
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SinhBackward"
}
}
pub fn sinh<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.sinh())?;
finish_unary(output, input, || SinhBackward {
input: input.clone(),
})
}
#[derive(Debug)]
struct CoshBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for CoshBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| g * x.sinh())
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"CoshBackward"
}
}
pub fn cosh<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.cosh())?;
finish_unary(output, input, || CoshBackward {
input: input.clone(),
})
}
#[derive(Debug)]
struct AsinhBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for AsinhBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let one = <T as num_traits::One>::one();
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| g / (x * x + one).sqrt())
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"AsinhBackward"
}
}
pub fn asinh<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.asinh())?;
finish_unary(output, input, || AsinhBackward {
input: input.clone(),
})
}
#[derive(Debug)]
struct AcoshBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for AcoshBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let one = <T as num_traits::One>::one();
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| g / (x * x - one).sqrt())
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"AcoshBackward"
}
}
pub fn acosh<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.acosh())?;
finish_unary(output, input, || AcoshBackward {
input: input.clone(),
})
}
#[derive(Debug)]
struct AtanhBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for AtanhBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let one = <T as num_traits::One>::one();
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| g / (one - x * x))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"AtanhBackward"
}
}
pub fn atanh<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.atanh())?;
finish_unary(output, input, || AtanhBackward {
input: input.clone(),
})
}
#[derive(Debug)]
struct Exp2Backward<T: Float> {
input: Tensor<T>,
output: Tensor<T>,
}
impl<T: Float> GradFn<T> for Exp2Backward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let o = self.output.data()?;
let ln2 = T::from(std::f64::consts::LN_2).unwrap_or_else(<T as num_traits::Zero>::zero);
let g: Vec<T> = go
.iter()
.zip(o.iter())
.map(|(&g, &r)| g * r * ln2)
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"Exp2Backward"
}
}
pub fn exp2<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.exp2())?;
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
let out_tensor = Tensor::from_storage(storage, shape, false)?;
let grad_fn = Arc::new(Exp2Backward {
input: input.clone(),
output: out_tensor.clone(),
});
let (s, sh) = out_tensor.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Ok(output)
}
}
#[derive(Debug)]
struct Expm1Backward<T: Float> {
input: Tensor<T>,
output: Tensor<T>,
}
impl<T: Float> GradFn<T> for Expm1Backward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let o = self.output.data()?;
let one = <T as num_traits::One>::one();
let g: Vec<T> = go
.iter()
.zip(o.iter())
.map(|(&g, &r)| g * (r + one))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"Expm1Backward"
}
}
pub fn expm1<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.exp_m1())?;
if needs_grad_unary(input) {
let (storage, shape) = output.into_storage_and_shape()?;
let out_tensor = Tensor::from_storage(storage, shape, false)?;
let grad_fn = Arc::new(Expm1Backward {
input: input.clone(),
output: out_tensor.clone(),
});
let (s, sh) = out_tensor.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Ok(output)
}
}
#[derive(Debug)]
struct Log2Backward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for Log2Backward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let ln2 = T::from(std::f64::consts::LN_2).unwrap_or_else(<T as num_traits::Zero>::zero);
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| g / (x * ln2))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"Log2Backward"
}
}
pub fn log2<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.log2())?;
finish_unary(output, input, || Log2Backward {
input: input.clone(),
})
}
#[derive(Debug)]
struct Log10Backward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for Log10Backward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let ln10 =
T::from(std::f64::consts::LN_10).unwrap_or_else(<T as num_traits::Zero>::zero);
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| g / (x * ln10))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"Log10Backward"
}
}
pub fn log10<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.log10())?;
finish_unary(output, input, || Log10Backward {
input: input.clone(),
})
}
#[derive(Debug)]
struct Log1pBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for Log1pBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let one = <T as num_traits::One>::one();
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| g / (one + x))
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"Log1pBackward"
}
}
pub fn log1p<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.ln_1p())?;
finish_unary(output, input, || Log1pBackward {
input: input.clone(),
})
}
#[derive(Debug)]
struct ZerosLikeBackward<T: Float> {
input: Tensor<T>,
name: &'static str,
}
impl<T: Float> GradFn<T> for ZerosLikeBackward<T> {
fn backward(&self, _grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
Some(zeros_like_tensor(&self.input)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
self.name
}
}
pub fn ceil<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.ceil())?;
finish_unary(output, input, || ZerosLikeBackward {
input: input.clone(),
name: "CeilBackward",
})
}
pub fn floor<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.floor())?;
finish_unary(output, input, || ZerosLikeBackward {
input: input.clone(),
name: "FloorBackward",
})
}
pub fn round<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, round_half_to_even)?;
finish_unary(output, input, || ZerosLikeBackward {
input: input.clone(),
name: "RoundBackward",
})
}
#[inline]
fn round_half_to_even<T: Float>(x: T) -> T {
let two = T::from(2.0).unwrap_or_else(<T as num_traits::One>::one);
let half = T::from(0.5).unwrap_or_else(<T as num_traits::Zero>::zero);
let one = <T as num_traits::One>::one();
let f = x.floor();
let diff = x - f;
if diff < half {
f
} else if diff > half {
f + one
} else {
let half_f = (f / two).floor();
if f - half_f * two == <T as num_traits::Zero>::zero() {
f
} else {
f + one
}
}
}
pub fn trunc<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x.trunc())?;
finish_unary(output, input, || ZerosLikeBackward {
input: input.clone(),
name: "TruncBackward",
})
}
#[derive(Debug)]
struct FracBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for FracBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
Some(Tensor::from_storage(
TensorStorage::cpu(go.to_vec()),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"FracBackward"
}
}
pub fn frac<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let output = unary_map(input, |x| x - x.trunc())?;
finish_unary(output, input, || FracBackward {
input: input.clone(),
})
}
pub fn sign<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let zero = <T as num_traits::Zero>::zero();
let output = unary_map(input, |x| {
if x.is_nan() || x == zero {
zero
} else {
x.signum()
}
})?;
finish_unary(output, input, || ZerosLikeBackward {
input: input.clone(),
name: "SignBackward",
})
}
#[derive(Debug)]
struct SincBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SincBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go = grad_output.data()?;
let x = self.input.data()?;
let pi = T::from(std::f64::consts::PI).unwrap_or_else(<T as num_traits::Zero>::zero);
let zero = <T as num_traits::Zero>::zero();
let g: Vec<T> = go
.iter()
.zip(x.iter())
.map(|(&g, &x)| {
if x == zero {
zero
} else {
let px = pi * x;
g * (px.cos() / x - px.sin() / (px * x))
}
})
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(g),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SincBackward"
}
}
pub fn sinc<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let pi = T::from(std::f64::consts::PI).unwrap_or_else(<T as num_traits::Zero>::zero);
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let output = unary_map(input, |x| {
if x == zero {
one
} else {
let px = pi * x;
px.sin() / px
}
})?;
finish_unary(output, input, || SincBackward {
input: input.clone(),
})
}
#[inline]
fn needs_grad_binary<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> bool {
is_grad_enabled() && (a.requires_grad() || b.requires_grad())
}
#[derive(Debug)]
struct Atan2Backward<T: Float> {
y: Tensor<T>,
x: Tensor<T>,
}
impl<T: Float> GradFn<T> for Atan2Backward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if grad_output.is_cuda() || self.y.is_cuda() || self.x.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "atan2 backward",
});
}
let out_shape = broadcast_shapes(self.y.shape(), self.x.shape())?;
let zero = <T as num_traits::Zero>::zero();
let y_b = binary_map(&self.y, &self.x, |y, _x| y)?;
let x_b = binary_map(&self.y, &self.x, |_y, x| x)?;
let denom = binary_map(&y_b, &x_b, |y, x| y * y + x * x)?;
let denom_data = denom.data()?;
let y_data = y_b.data()?;
let x_data = x_b.data()?;
let go_data = grad_output.data()?;
let recip: Vec<T> = denom_data
.iter()
.map(|&d| {
if d == zero {
zero
} else {
<T as num_traits::One>::one() / d
}
})
.collect();
let grad_y_raw: Vec<T> = go_data
.iter()
.zip(x_data.iter())
.zip(recip.iter())
.map(|((&g, &x), &r)| g * x * r)
.collect();
let grad_x_raw: Vec<T> = go_data
.iter()
.zip(y_data.iter())
.zip(recip.iter())
.map(|((&g, &y), &r)| -(g * y * r))
.collect();
let grad_y_tensor =
Tensor::from_storage(TensorStorage::cpu(grad_y_raw), out_shape.clone(), false)?;
let grad_x_tensor = Tensor::from_storage(TensorStorage::cpu(grad_x_raw), out_shape, false)?;
let da = if self.y.requires_grad() {
Some(reduce_grad_to_shape(&grad_y_tensor, self.y.shape())?)
} else {
None
};
let db = if self.x.requires_grad() {
Some(reduce_grad_to_shape(&grad_x_tensor, self.x.shape())?)
} else {
None
};
Ok(vec![da, db])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.y, &self.x]
}
fn name(&self) -> &'static str {
"Atan2Backward"
}
}
pub fn atan2<T: Float>(y: &Tensor<T>, x: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if y.is_cuda() || x.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "atan2" });
}
let output = binary_map(y, x, |yy, xx| yy.atan2(xx))?;
if needs_grad_binary(y, x) {
let (storage, shape) = output.into_storage_and_shape()?;
Tensor::from_operation(
storage,
shape,
Arc::new(Atan2Backward {
y: y.clone(),
x: x.clone(),
}),
)
} else {
Ok(output)
}
}
pub fn signbit<T: Float>(input: &Tensor<T>) -> FerrotorchResult<BoolTensor> {
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "signbit" });
}
let data = input.data()?;
let bits: Vec<bool> = data
.iter()
.map(|&v| <T as num_traits::Float>::is_sign_negative(v))
.collect();
BoolTensor::from_vec(bits, input.shape().to_vec())
}
#[derive(Debug)]
struct CopysignBackward<T: Float> {
magnitude: Tensor<T>,
sign: Tensor<T>,
result: Tensor<T>,
}
impl<T: Float> GradFn<T> for CopysignBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if grad_output.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "copysign backward",
});
}
let zero = <T as num_traits::Zero>::zero();
let out_shape = broadcast_shapes(self.magnitude.shape(), self.sign.shape())?;
let mag_b = binary_map(&self.magnitude, &self.sign, |m, _s| m)?;
let res_data = self.result.data()?;
let mag_data = mag_b.data()?;
let go_data = grad_output.data()?;
let raw: Vec<T> = go_data
.iter()
.zip(res_data.iter())
.zip(mag_data.iter())
.map(|((&g, &r), &m)| if m == zero { zero } else { g * (r / m) })
.collect();
let grad_b = Tensor::from_storage(TensorStorage::cpu(raw), out_shape, false)?;
let da = if self.magnitude.requires_grad() {
Some(reduce_grad_to_shape(&grad_b, self.magnitude.shape())?)
} else {
None
};
let db = if self.sign.requires_grad() {
let n: usize = self.sign.shape().iter().product::<usize>().max(1);
Some(Tensor::from_storage(
TensorStorage::cpu(vec![zero; n]),
self.sign.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da, db])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.magnitude, &self.sign]
}
fn name(&self) -> &'static str {
"CopysignBackward"
}
}
pub fn copysign<T: Float>(magnitude: &Tensor<T>, sign: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if magnitude.is_cuda() || sign.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "copysign" });
}
let output = binary_map(magnitude, sign, |m, s| {
<T as num_traits::Float>::copysign(m, s)
})?;
if needs_grad_binary(magnitude, sign) {
let (storage, shape) = output.into_storage_and_shape()?;
let out_tensor = Tensor::from_storage(storage, shape, false)?;
let grad_fn = Arc::new(CopysignBackward {
magnitude: magnitude.clone(),
sign: sign.clone(),
result: out_tensor.clone(),
});
let (s, sh) = out_tensor.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Ok(output)
}
}
#[derive(Debug)]
struct HypotBackward<T: Float> {
a: Tensor<T>,
b: Tensor<T>,
result: Tensor<T>,
}
impl<T: Float> GradFn<T> for HypotBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if grad_output.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "hypot backward",
});
}
let out_shape = broadcast_shapes(self.a.shape(), self.b.shape())?;
let a_b = binary_map(&self.a, &self.b, |x, _y| x)?;
let b_b = binary_map(&self.a, &self.b, |_x, y| y)?;
let go_data = grad_output.data()?;
let res_data = self.result.data()?;
let a_data = a_b.data()?;
let b_data = b_b.data()?;
let zero = <T as num_traits::Zero>::zero();
let raw_a: Vec<T> = go_data
.iter()
.zip(a_data.iter())
.zip(res_data.iter())
.map(|((&g, &x), &r)| if r == zero { zero } else { g * x / r })
.collect();
let raw_b: Vec<T> = go_data
.iter()
.zip(b_data.iter())
.zip(res_data.iter())
.map(|((&g, &y), &r)| if r == zero { zero } else { g * y / r })
.collect();
let ga = Tensor::from_storage(TensorStorage::cpu(raw_a), out_shape.clone(), false)?;
let gb = Tensor::from_storage(TensorStorage::cpu(raw_b), out_shape, false)?;
let da = if self.a.requires_grad() {
Some(reduce_grad_to_shape(&ga, self.a.shape())?)
} else {
None
};
let db = if self.b.requires_grad() {
Some(reduce_grad_to_shape(&gb, self.b.shape())?)
} else {
None
};
Ok(vec![da, db])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"HypotBackward"
}
}
pub fn hypot<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.is_cuda() || b.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "hypot" });
}
let output = binary_map(a, b, |x, y| <T as num_traits::Float>::hypot(x, y))?;
if needs_grad_binary(a, b) {
let (storage, shape) = output.into_storage_and_shape()?;
let out_tensor = Tensor::from_storage(storage, shape, false)?;
let grad_fn = Arc::new(HypotBackward {
a: a.clone(),
b: b.clone(),
result: out_tensor.clone(),
});
let (s, sh) = out_tensor.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Ok(output)
}
}
#[inline]
#[allow(
clippy::float_cmp,
reason = "exact IEEE-754 zero test gates the cross-zero ULP branch; \
an epsilon tolerance would corrupt nextafter's bit-exact step."
)]
fn f64_one_ulp(x: f64, up: bool) -> f64 {
debug_assert!(!x.is_nan());
if x == 0.0 {
return if up {
f64::from_bits(1)
} else {
-f64::from_bits(1)
};
}
let bits = x.to_bits();
let step_up_in_bits = up == (x > 0.0);
let new_bits = if step_up_in_bits { bits + 1 } else { bits - 1 };
f64::from_bits(new_bits)
}
#[inline]
#[allow(
clippy::float_cmp,
reason = "exact IEEE-754 zero test gates the cross-zero ULP branch; \
an epsilon tolerance would corrupt nextafter's bit-exact step."
)]
fn f32_one_ulp(x: f32, up: bool) -> f32 {
debug_assert!(!x.is_nan());
if x == 0.0 {
return if up {
f32::from_bits(1)
} else {
-f32::from_bits(1)
};
}
let bits = x.to_bits();
let step_up_in_bits = up == (x > 0.0);
let new_bits = if step_up_in_bits { bits + 1 } else { bits - 1 };
f32::from_bits(new_bits)
}
#[inline]
fn u16_one_ulp(bits: u16, is_zero: bool, is_positive: bool, up: bool) -> u16 {
if is_zero {
return if up { 0x0001 } else { 0x8001 };
}
let step_up_in_bits = up == is_positive;
if step_up_in_bits { bits + 1 } else { bits - 1 }
}
#[inline]
#[allow(
clippy::float_cmp,
reason = "exact IEEE-754 equality is the std::nextafter tie semantics: \
on a == b (incl. signed-zero) the result is exactly b."
)]
fn nextafter_scalar<T: Float>(a: T, b: T) -> T {
if a.is_nan() || b.is_nan() {
return T::nan();
}
if a == b {
return b;
}
let up = b > a;
if is_f32::<T>() {
let af = <T as num_traits::ToPrimitive>::to_f32(&a).unwrap_or(f32::NAN);
let stepped = f32_one_ulp(af, up);
return T::from(stepped).unwrap_or(b);
}
if is_f64::<T>() {
let af = <T as num_traits::ToPrimitive>::to_f64(&a).unwrap_or(f64::NAN);
let stepped = f64_one_ulp(af, up);
return T::from(stepped).unwrap_or(b);
}
if is_f16::<T>() {
let ah: half::f16 = match <half::f16 as num_traits::NumCast>::from(a) {
Some(v) => v,
None => return b,
};
let bits = u16_one_ulp(
ah.to_bits(),
ah == half::f16::ZERO,
ah > half::f16::ZERO,
up,
);
return T::from(half::f16::from_bits(bits)).unwrap_or(b);
}
if is_bf16::<T>() {
let ah: half::bf16 = match <half::bf16 as num_traits::NumCast>::from(a) {
Some(v) => v,
None => return b,
};
let bits = u16_one_ulp(
ah.to_bits(),
ah == half::bf16::ZERO,
ah > half::bf16::ZERO,
up,
);
return T::from(half::bf16::from_bits(bits)).unwrap_or(b);
}
let af = <T as num_traits::ToPrimitive>::to_f64(&a).unwrap_or(f64::NAN);
T::from(f64_one_ulp(af, up)).unwrap_or(b)
}
#[derive(Debug)]
struct NextafterBackward<T: Float> {
a: Tensor<T>,
b: Tensor<T>,
}
impl<T: Float> GradFn<T> for NextafterBackward<T> {
#[allow(
clippy::float_cmp,
reason = "the `a != b` mask is the exact upstream gradient gate per \
derivatives.yaml:1323 `at::where(self != other, grad, 0)`; \
an epsilon tolerance would misroute the tie's zero gradient."
)]
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if grad_output.is_cuda() || self.a.is_cuda() || self.b.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "nextafter backward",
});
}
let zero = <T as num_traits::Zero>::zero();
let out_shape = broadcast_shapes(self.a.shape(), self.b.shape())?;
let a_b = binary_map(&self.a, &self.b, |a, _b| a)?;
let b_b = binary_map(&self.a, &self.b, |_a, b| b)?;
let go_data = grad_output.data()?;
let a_data = a_b.data()?;
let b_data = b_b.data()?;
let raw_a: Vec<T> = go_data
.iter()
.zip(a_data.iter())
.zip(b_data.iter())
.map(|((&g, &av), &bv)| if av == bv { zero } else { g })
.collect();
let grad_a = Tensor::from_storage(TensorStorage::cpu(raw_a), out_shape, false)?;
let da = if self.a.requires_grad() {
Some(reduce_grad_to_shape(&grad_a, self.a.shape())?)
} else {
None
};
let db = if self.b.requires_grad() {
let n: usize = self.b.shape().iter().product::<usize>().max(1);
Some(Tensor::from_storage(
TensorStorage::cpu(vec![zero; n]),
self.b.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da, db])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"NextafterBackward"
}
}
pub fn nextafter<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.is_cuda() || b.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "nextafter" });
}
let output = binary_map(a, b, nextafter_scalar)?;
if needs_grad_binary(a, b) {
let (storage, shape) = output.into_storage_and_shape()?;
let out_tensor = Tensor::from_storage(storage, shape, false)?;
let grad_fn = Arc::new(NextafterBackward {
a: a.clone(),
b: b.clone(),
});
let (s, sh) = out_tensor.into_storage_and_shape()?;
Tensor::from_operation(s, sh, grad_fn)
} else {
Ok(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn leaf_scalar(val: f32, requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], requires_grad).unwrap()
}
fn leaf_vec(data: &[f32], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
vec![data.len()],
requires_grad,
)
.unwrap()
}
fn assert_scalar_approx(t: &Tensor<f32>, expected: f32, tol: f32) {
let val = t.item().unwrap();
assert!(
(val - expected).abs() < tol,
"expected {expected}, got {val}"
);
}
#[test]
fn test_exp_forward() {
let a = leaf_vec(&[0.0, 1.0, 2.0], false);
let c = exp(&a).unwrap();
let d = c.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-5);
assert!((d[1] - std::f32::consts::E).abs() < 1e-5);
assert!((d[2] - std::f32::consts::E * std::f32::consts::E).abs() < 1e-4);
}
#[test]
fn test_log_forward() {
let a = leaf_vec(
&[
1.0,
std::f32::consts::E,
std::f32::consts::E * std::f32::consts::E,
],
false,
);
let c = log(&a).unwrap();
let d = c.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-5);
assert!((d[1] - 1.0).abs() < 1e-5);
assert!((d[2] - 2.0).abs() < 1e-4);
}
#[test]
fn test_sin_forward() {
let a = leaf_vec(
&[0.0, std::f32::consts::FRAC_PI_2, std::f32::consts::PI],
false,
);
let c = sin(&a).unwrap();
let d = c.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-6);
assert!((d[1] - 1.0).abs() < 1e-6);
assert!(d[2].abs() < 1e-6);
}
#[test]
fn test_cos_forward() {
let a = leaf_vec(
&[0.0, std::f32::consts::FRAC_PI_2, std::f32::consts::PI],
false,
);
let c = cos(&a).unwrap();
let d = c.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-6);
assert!(d[1].abs() < 1e-6);
assert!((d[2] - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_clamp_forward() {
let a = leaf_vec(&[-2.0, 0.5, 1.5, 3.0], false);
let c = clamp(&a, 0.0, 2.0).unwrap();
assert_eq!(c.data().unwrap(), &[0.0, 0.5, 1.5, 2.0]);
}
#[test]
fn test_exp_backward() {
let a = leaf_scalar(1.0, true);
let c = exp(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), std::f32::consts::E, 1e-5);
}
#[test]
fn test_log_backward() {
let a = leaf_scalar(2.0, true);
let c = log(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.5, 1e-6);
}
#[test]
fn test_sin_backward() {
let a = leaf_scalar(0.0, true);
let c = sin(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 1.0, 1e-6);
}
#[test]
fn test_sin_backward_pi_over_3() {
let a = leaf_scalar(std::f32::consts::FRAC_PI_3, true);
let c = sin(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.5, 1e-5);
}
#[test]
fn test_cos_backward() {
let a = leaf_scalar(0.0, true);
let c = cos(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.0, 1e-6);
}
#[test]
fn test_cos_backward_pi_over_2() {
let a = leaf_scalar(std::f32::consts::FRAC_PI_2, true);
let c = cos(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), -1.0, 1e-5);
}
#[test]
fn test_clamp_backward_interior() {
let a = leaf_scalar(1.5, true);
let c = clamp(&a, 0.0, 2.0).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 1.0, 1e-6);
}
#[test]
fn test_clamp_backward_clamped_low() {
let a = leaf_scalar(-1.0, true);
let c = clamp(&a, 0.0, 2.0).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.0, 1e-6);
}
#[test]
fn test_clamp_backward_clamped_high() {
let a = leaf_scalar(5.0, true);
let c = clamp(&a, 0.0, 2.0).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.0, 1e-6);
}
#[test]
fn test_chain_exp_log() {
let a = leaf_scalar(3.0, true);
let b = exp(&a).unwrap();
let c = log(&b).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 1.0, 1e-4);
}
#[test]
fn test_chain_sin_cos() {
let a = leaf_scalar(0.5, true);
let b = sin(&a).unwrap();
let c = cos(&b).unwrap();
c.backward().unwrap();
let expected = -(0.5_f32.sin().sin()) * 0.5_f32.cos();
assert_scalar_approx(&a.grad().unwrap().unwrap(), expected, 1e-4);
}
#[test]
fn test_exp_no_grad_fn_when_not_tracking() {
let a = leaf_scalar(1.0, false);
let c = exp(&a).unwrap();
assert!(c.grad_fn().is_none());
}
#[test]
fn test_log_no_grad_fn_when_not_tracking() {
let a = leaf_scalar(1.0, false);
let c = log(&a).unwrap();
assert!(c.grad_fn().is_none());
}
#[test]
fn test_clamp_no_grad_fn_when_not_tracking() {
let a = leaf_scalar(1.0, false);
let c = clamp(&a, 0.0, 2.0).unwrap();
assert!(c.grad_fn().is_none());
}
fn numerical_grad_check(f: impl Fn(f32) -> f32, x: f32, analytic_grad: f32, tol: f32) {
let h = 1e-4_f32;
let numerical = (f(x + h) - f(x - h)) / (2.0 * h);
assert!(
(analytic_grad - numerical).abs() < tol,
"analytic={analytic_grad}, numerical={numerical}",
);
}
#[test]
fn test_exp_numerical_grad() {
let x = 1.5_f32;
let a = leaf_scalar(x, true);
let c = exp(&a).unwrap();
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.exp(), x, g, 1e-3);
}
#[test]
fn test_log_numerical_grad() {
let x = 2.0_f32;
let a = leaf_scalar(x, true);
let c = log(&a).unwrap();
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.ln(), x, g, 1e-3);
}
#[test]
fn test_sin_numerical_grad() {
let x = 1.0_f32;
let a = leaf_scalar(x, true);
let c = sin(&a).unwrap();
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.sin(), x, g, 1e-3);
}
#[test]
fn test_cos_numerical_grad() {
let x = 1.0_f32;
let a = leaf_scalar(x, true);
let c = cos(&a).unwrap();
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.cos(), x, g, 1e-3);
}
#[test]
fn test_clamp_numerical_grad_interior() {
let x = 0.5_f32;
let a = leaf_scalar(x, true);
let c = clamp(&a, 0.0, 1.0).unwrap();
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.clamp(0.0, 1.0), x, g, 1e-3);
}
#[test]
fn test_tan_forward_and_backward() {
let a = leaf_scalar(0.5_f32, true);
let c = tan(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 0.5_f32.tan()).abs() < 1e-6, "tan(0.5) = {v}");
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.tan(), 0.5, g, 1e-3);
}
#[test]
fn test_asin_forward_and_backward() {
let a = leaf_scalar(0.5_f32, true);
let c = asin(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 0.5_f32.asin()).abs() < 1e-6);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.asin(), 0.5, g, 1e-3);
}
#[test]
fn test_acos_forward_and_backward() {
let a = leaf_scalar(0.5_f32, true);
let c = acos(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 0.5_f32.acos()).abs() < 1e-6);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.acos(), 0.5, g, 1e-3);
}
#[test]
fn test_atan_forward_and_backward() {
let a = leaf_scalar(1.0_f32, true);
let c = atan(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 1.0_f32.atan()).abs() < 1e-6);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.atan(), 1.0, g, 1e-3);
}
#[test]
fn test_sinh_forward_and_backward() {
let a = leaf_scalar(0.5_f32, true);
let c = sinh(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 0.5_f32.sinh()).abs() < 1e-6);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.sinh(), 0.5, g, 1e-3);
}
#[test]
fn test_cosh_forward_and_backward() {
let a = leaf_scalar(0.5_f32, true);
let c = cosh(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 0.5_f32.cosh()).abs() < 1e-6);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.cosh(), 0.5, g, 1e-3);
}
#[test]
fn test_asinh_forward_and_backward() {
let a = leaf_scalar(0.7_f32, true);
let c = asinh(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 0.7_f32.asinh()).abs() < 1e-6);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.asinh(), 0.7, g, 1e-3);
}
#[test]
fn test_acosh_forward_and_backward() {
let a = leaf_scalar(1.5_f32, true);
let c = acosh(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 1.5_f32.acosh()).abs() < 1e-6);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.acosh(), 1.5, g, 1e-3);
}
#[test]
fn test_atanh_forward_and_backward() {
let a = leaf_scalar(0.5_f32, true);
let c = atanh(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 0.5_f32.atanh()).abs() < 1e-6);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.atanh(), 0.5, g, 1e-3);
}
#[test]
fn test_exp2_forward_and_backward() {
let a = leaf_scalar(2.0_f32, true);
let c = exp2(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 4.0).abs() < 1e-5, "exp2(2) = {v}");
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.exp2(), 2.0, g, 1e-3);
}
#[test]
fn test_expm1_forward_and_backward() {
let a = leaf_scalar(0.5_f32, true);
let c = expm1(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 0.5_f32.exp_m1()).abs() < 1e-6);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.exp_m1(), 0.5, g, 1e-3);
}
#[test]
fn test_log2_forward_and_backward() {
let a = leaf_scalar(8.0_f32, true);
let c = log2(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 3.0).abs() < 1e-5);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.log2(), 8.0, g, 1e-3);
}
#[test]
fn test_log10_forward_and_backward() {
let a = leaf_scalar(100.0_f32, true);
let c = log10(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 2.0).abs() < 1e-5);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.log10(), 100.0, g, 1e-1);
}
#[test]
fn test_log1p_forward_and_backward() {
let a = leaf_scalar(0.5_f32, true);
let c = log1p(&a).unwrap();
let v = c.item().unwrap();
assert!((v - 0.5_f32.ln_1p()).abs() < 1e-6);
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
numerical_grad_check(|v| v.ln_1p(), 0.5, g, 1e-3);
}
#[test]
fn test_ceil_forward_and_zero_backward() {
let a = leaf_vec(&[-1.4, 0.5, 2.0], false);
let c = ceil(&a).unwrap();
assert_eq!(c.data().unwrap(), &[-1.0, 1.0, 2.0]);
let a2 = leaf_scalar(0.3, true);
let c2 = ceil(&a2).unwrap();
c2.backward().unwrap();
assert_scalar_approx(&a2.grad().unwrap().unwrap(), 0.0, 1e-6);
}
#[test]
fn test_floor_forward_and_zero_backward() {
let a = leaf_vec(&[-1.4, 0.5, 2.9], false);
let c = floor(&a).unwrap();
assert_eq!(c.data().unwrap(), &[-2.0, 0.0, 2.0]);
let a2 = leaf_scalar(0.3, true);
let c2 = floor(&a2).unwrap();
c2.backward().unwrap();
assert_scalar_approx(&a2.grad().unwrap().unwrap(), 0.0, 1e-6);
}
#[test]
fn test_round_banker_rounding() {
let a = leaf_vec(&[0.5, 1.5, 2.5, 3.5, -0.5, -1.5], false);
let c = round(&a).unwrap();
assert_eq!(c.data().unwrap(), &[0.0, 2.0, 2.0, 4.0, 0.0, -2.0]);
}
#[test]
fn test_trunc_forward_and_zero_backward() {
let a = leaf_vec(&[-1.7, 0.4, 2.9], false);
let c = trunc(&a).unwrap();
assert_eq!(c.data().unwrap(), &[-1.0, 0.0, 2.0]);
let a2 = leaf_scalar(0.3, true);
let c2 = trunc(&a2).unwrap();
c2.backward().unwrap();
assert_scalar_approx(&a2.grad().unwrap().unwrap(), 0.0, 1e-6);
}
#[test]
fn test_frac_forward_and_pass_through_backward() {
let a = leaf_vec(&[-1.5, 0.4, 2.75], false);
let c = frac(&a).unwrap();
let d = c.data().unwrap();
assert!((d[0] - (-0.5)).abs() < 1e-6);
assert!((d[1] - 0.4).abs() < 1e-6);
assert!((d[2] - 0.75).abs() < 1e-6);
let a2 = leaf_scalar(0.3, true);
let c2 = frac(&a2).unwrap();
c2.backward().unwrap();
assert_scalar_approx(&a2.grad().unwrap().unwrap(), 1.0, 1e-6);
}
#[test]
fn test_sign_forward_and_zero_backward() {
let a = leaf_vec(&[-3.0, 0.0, 5.0], false);
let c = sign(&a).unwrap();
assert_eq!(c.data().unwrap(), &[-1.0, 0.0, 1.0]);
let a2 = leaf_scalar(5.0, true);
let c2 = sign(&a2).unwrap();
c2.backward().unwrap();
assert_scalar_approx(&a2.grad().unwrap().unwrap(), 0.0, 1e-6);
}
#[test]
fn test_sinc_zero_and_nonzero() {
let a = leaf_vec(&[0.0, 0.5, 1.0], false);
let c = sinc(&a).unwrap();
let d = c.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-6);
let expected = (std::f32::consts::PI * 0.5).sin() / (std::f32::consts::PI * 0.5);
assert!(
(d[1] - expected).abs() < 1e-6,
"sinc(0.5) = {} vs expected {}",
d[1],
expected
);
assert!(d[2].abs() < 1e-6);
}
#[test]
fn test_sinc_numerical_grad_interior() {
let x = 0.5_f32;
let a = leaf_scalar(x, true);
let c = sinc(&a).unwrap();
c.backward().unwrap();
let g = a.grad().unwrap().unwrap().item().unwrap();
let sinc_fn = |v: f32| {
if v == 0.0 {
1.0
} else {
let p = std::f32::consts::PI * v;
p.sin() / p
}
};
numerical_grad_check(sinc_fn, x, g, 1e-3);
}
#[test]
fn test_sinc_zero_backward_is_zero() {
let a = leaf_scalar(0.0, true);
let c = sinc(&a).unwrap();
c.backward().unwrap();
assert_scalar_approx(&a.grad().unwrap().unwrap(), 0.0, 1e-6);
}
fn leaf_vec_f64(data: &[f64], requires_grad: bool) -> Tensor<f64> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
vec![data.len()],
requires_grad,
)
.unwrap()
}
#[test]
#[allow(
clippy::float_cmp,
reason = "nextafter returns the EXACT one-ULP IEEE neighbour; bit-exact equality against f64::next_up/next_down/from_bits is the contract being tested, not an approximation"
)]
fn test_nextafter_matches_std_nextafter_f64() {
let a = leaf_vec_f64(&[1.0, -1.0, 0.0, 0.0, 1e300], false);
let b = leaf_vec_f64(&[2.0, -2.0, 1.0, -1.0, f64::INFINITY], false);
let out = nextafter(&a, &b).unwrap();
let d = out.data().unwrap();
assert_eq!(d[0], 1.0_f64.next_up());
assert_eq!(d[1], (-1.0_f64).next_down());
assert_eq!(d[2], f64::from_bits(1));
assert_eq!(d[3], -f64::from_bits(1));
assert_eq!(d[4], 1e300_f64.next_up());
}
#[test]
#[allow(
clippy::float_cmp,
reason = "asserting the exact tie result (a==b -> b, carrying b's signed-zero) requires bit-exact equality, not an epsilon comparison"
)]
fn test_nextafter_equal_returns_b_and_nan_propagates() {
let a = leaf_vec_f64(&[5.0, 0.0], false);
let b = leaf_vec_f64(&[5.0, -0.0], false);
let out = nextafter(&a, &b).unwrap();
let d = out.data().unwrap();
assert_eq!(d[0], 5.0);
assert!(d[1].is_sign_negative() && d[1] == 0.0);
let an = leaf_vec_f64(&[f64::NAN, 1.0], false);
let bn = leaf_vec_f64(&[1.0, f64::NAN], false);
let outn = nextafter(&an, &bn).unwrap();
let dn = outn.data().unwrap();
assert!(dn[0].is_nan() && dn[1].is_nan());
}
#[test]
#[allow(
clippy::float_cmp,
reason = "the VJP tie mask is exactly 1.0 (pass-through) or 0.0 (masked); bit-exact equality is the contract per derivatives.yaml nextafter"
)]
fn test_nextafter_backward_passthrough_and_tie_mask() {
let a = leaf_vec_f64(&[1.0, 3.0], true);
let b = leaf_vec_f64(&[2.0, 3.0], true);
let out = nextafter(&a, &b).unwrap();
out.sum_all().unwrap().backward().unwrap();
let ga = a.grad().unwrap().unwrap();
let gad = ga.data().unwrap();
assert_eq!(gad[0], 1.0);
assert_eq!(gad[1], 0.0);
let gb = b.grad().unwrap().unwrap();
for &v in gb.data().unwrap() {
assert_eq!(v, 0.0);
}
}
}