use crate::autograd::DualTensor;
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::{ActivationOps, BinaryOps, CompareOps, ScalarOps, TensorOps};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::Tensor;
pub fn dual_relu<R, C>(a: &DualTensor<R>, client: &C) -> Result<DualTensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + ActivationOps<R> + CompareOps<R> + BinaryOps<R> + TensorOps<R>,
{
let primal = client.relu(a.primal())?;
let tangent = match a.tangent() {
Some(at) => {
let zero = Tensor::zeros(a.primal().shape(), a.primal().dtype(), a.primal().device());
let mask = client.gt(a.primal(), &zero)?;
let mask_typed = client.cast(&mask, at.dtype())?;
Some(client.mul(&mask_typed, at)?)
}
None => None,
};
Ok(DualTensor::new(primal, tangent))
}
pub fn dual_sigmoid<R, C>(a: &DualTensor<R>, client: &C) -> Result<DualTensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + ActivationOps<R> + BinaryOps<R> + ScalarOps<R>,
{
let primal = client.sigmoid(a.primal())?;
let tangent = match a.tangent() {
Some(at) => {
let one = Tensor::ones(primal.shape(), primal.dtype(), primal.device());
let one_minus_sig = client.sub(&one, &primal)?;
let grad = client.mul(&primal, &one_minus_sig)?;
Some(client.mul(&grad, at)?)
}
None => None,
};
Ok(DualTensor::new(primal, tangent))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
(device, client)
}
#[test]
fn test_dual_relu_positive() {
let (device, client) = setup();
let x_primal = Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device);
let x_tangent = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let x = DualTensor::with_tangent(x_primal, x_tangent);
let y = dual_relu(&x, &client).unwrap();
assert_eq!(y.primal().to_vec::<f32>(), [2.0]);
assert_eq!(y.tangent().unwrap().to_vec::<f32>(), [1.0]);
}
#[test]
fn test_dual_relu_negative() {
let (device, client) = setup();
let x_primal = Tensor::<CpuRuntime>::from_slice(&[-1.0f32], &[1], &device);
let x_tangent = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let x = DualTensor::with_tangent(x_primal, x_tangent);
let y = dual_relu(&x, &client).unwrap();
assert_eq!(y.primal().to_vec::<f32>(), [0.0]);
assert_eq!(y.tangent().unwrap().to_vec::<f32>(), [0.0]);
}
#[test]
fn test_dual_relu_mixed() {
let (device, client) = setup();
let x_primal = Tensor::<CpuRuntime>::from_slice(&[-1.0f32, 0.0, 2.0], &[3], &device);
let x_tangent = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device);
let x = DualTensor::with_tangent(x_primal, x_tangent);
let y = dual_relu(&x, &client).unwrap();
assert_eq!(y.primal().to_vec::<f32>(), [0.0, 0.0, 2.0]);
assert_eq!(y.tangent().unwrap().to_vec::<f32>(), [0.0, 0.0, 1.0]);
}
#[test]
fn test_dual_sigmoid_at_zero() {
let (device, client) = setup();
let x_primal = Tensor::<CpuRuntime>::from_slice(&[0.0f32], &[1], &device);
let x_tangent = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let x = DualTensor::with_tangent(x_primal, x_tangent);
let y = dual_sigmoid(&x, &client).unwrap();
assert!((y.primal().to_vec::<f32>()[0] - 0.5).abs() < 1e-6);
assert!((y.tangent().unwrap().to_vec::<f32>()[0] - 0.25).abs() < 1e-6);
}
#[test]
fn test_dual_sigmoid_large_positive() {
let (device, client) = setup();
let x_primal = Tensor::<CpuRuntime>::from_slice(&[10.0f32], &[1], &device);
let x_tangent = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let x = DualTensor::with_tangent(x_primal, x_tangent);
let y = dual_sigmoid(&x, &client).unwrap();
assert!(y.primal().to_vec::<f32>()[0] > 0.999);
assert!(y.tangent().unwrap().to_vec::<f32>()[0] < 0.001);
}
#[test]
fn test_dual_sigmoid_large_negative() {
let (device, client) = setup();
let x_primal = Tensor::<CpuRuntime>::from_slice(&[-10.0f32], &[1], &device);
let x_tangent = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let x = DualTensor::with_tangent(x_primal, x_tangent);
let y = dual_sigmoid(&x, &client).unwrap();
assert!(y.primal().to_vec::<f32>()[0] < 0.001);
assert!(y.tangent().unwrap().to_vec::<f32>()[0] < 0.001);
}
}