use super::macros::*;
use super::ops::*;
use crate::autograd::Var;
use crate::error::Result;
use crate::ops::{ScalarOps, TensorOps};
use crate::runtime::{Runtime, RuntimeClient};
impl_var_unary_op_id!(
var_neg, neg, NegBackward
);
impl_var_unary_op_output!(
var_exp, exp, ExpBackward
);
impl_var_unary_op_input!(
var_log, log, LogBackward
);
impl_var_unary_op_output_scalar!(
var_sqrt, sqrt, SqrtBackward
);
impl_var_unary_op_input!(
var_sin, sin, SinBackward
);
impl_var_unary_op_input!(
var_cos, cos, CosBackward
);
impl_var_unary_op_output_scalar!(
var_tanh, tanh, TanhBackward
);
impl_var_unary_op_input_scalar!(
var_square, square, SquareBackward
);
impl_var_unary_op_output_scalar!(
var_recip, recip, RecipBackward
);
impl_var_unary_op_input!(
var_abs, abs, AbsBackward
);
impl_var_unary_op_input_scalar!(
var_tan, tan, TanBackward
);
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::backward;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
use crate::tensor::Tensor;
#[test]
fn test_var_exp_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32], &[1], &device),
true,
);
let z = var_exp(&x, &client).unwrap();
let grads = backward(&z, &client).unwrap();
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
assert!((grad_x[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_var_tan_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32], &[1], &device),
true,
);
let z = var_tan(&x, &client).unwrap();
let grads = backward(&z, &client).unwrap();
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
assert!((grad_x[0] - 1.0).abs() < 1e-6);
}
}