use crate::autograd::GradFn;
use crate::autograd::var::Var;
use crate::autograd::var_ops::{var_div_scalar, var_mul, var_mul_scalar};
use crate::error::Result;
use crate::ops::{BinaryOps, ScalarOps, TensorOps};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::{Tensor, TensorId};
pub struct AddScalarBackward<R: Runtime> {
input_id: TensorId,
input_grad_fn: Option<std::sync::Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> AddScalarBackward<R> {
pub fn new(input_id: TensorId, input_grad_fn: Option<std::sync::Arc<dyn GradFn<R>>>) -> Self {
Self {
input_id,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for AddScalarBackward<R> {
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
Ok(vec![Some(grad_output.clone())])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
Ok(vec![Some(Var::new(
grad_output.tensor().clone(),
grad_output.requires_grad(),
))])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<std::sync::Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn name(&self) -> &'static str {
"AddScalarBackward"
}
}
pub struct SubScalarBackward<R: Runtime> {
input_id: TensorId,
input_grad_fn: Option<std::sync::Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> SubScalarBackward<R> {
pub fn new(input_id: TensorId, input_grad_fn: Option<std::sync::Arc<dyn GradFn<R>>>) -> Self {
Self {
input_id,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for SubScalarBackward<R> {
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
Ok(vec![Some(grad_output.clone())])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
Ok(vec![Some(Var::new(
grad_output.tensor().clone(),
grad_output.requires_grad(),
))])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<std::sync::Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn name(&self) -> &'static str {
"SubScalarBackward"
}
}
pub struct MulScalarBackward<R: Runtime> {
input_id: TensorId,
scalar: f64,
input_grad_fn: Option<std::sync::Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> MulScalarBackward<R> {
pub fn new(
input_id: TensorId,
scalar: f64,
input_grad_fn: Option<std::sync::Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
scalar,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for MulScalarBackward<R>
where
R::Client: ScalarOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let grad = client.mul_scalar(grad_output, self.scalar)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R>,
{
let client = R::default_client(grad_output.tensor().device());
let grad = var_mul_scalar(grad_output, self.scalar, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<std::sync::Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn name(&self) -> &'static str {
"MulScalarBackward"
}
}
pub struct DivScalarBackward<R: Runtime> {
input_id: TensorId,
scalar: f64,
input_grad_fn: Option<std::sync::Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> DivScalarBackward<R> {
pub fn new(
input_id: TensorId,
scalar: f64,
input_grad_fn: Option<std::sync::Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
scalar,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for DivScalarBackward<R>
where
R::Client: ScalarOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let grad = client.div_scalar(grad_output, self.scalar)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R>,
{
let client = R::default_client(grad_output.tensor().device());
let grad = var_div_scalar(grad_output, self.scalar, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<std::sync::Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn name(&self) -> &'static str {
"DivScalarBackward"
}
}
pub struct PowScalarBackward<R: Runtime> {
input_id: TensorId,
saved_input: Tensor<R>,
scalar: f64,
input_grad_fn: Option<std::sync::Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> PowScalarBackward<R> {
pub fn new(
input_id: TensorId,
input: Tensor<R>,
scalar: f64,
input_grad_fn: Option<std::sync::Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
saved_input: input,
scalar,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for PowScalarBackward<R>
where
R::Client: TensorOps<R> + ScalarOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let a_pow_n_minus_1 = client.pow_scalar(&self.saved_input, self.scalar - 1.0)?;
let scaled = client.mul_scalar(&a_pow_n_minus_1, self.scalar)?;
let grad = client.mul(grad_output, &scaled)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R>,
{
use crate::autograd::var_ops::var_pow_scalar;
let client = R::default_client(grad_output.tensor().device());
let a_var = Var::with_id(self.saved_input.clone(), self.input_id, true);
let a_pow_n_minus_1 = var_pow_scalar(&a_var, self.scalar - 1.0, &client)?;
let scaled = var_mul_scalar(&a_pow_n_minus_1, self.scalar, &client)?;
let grad = var_mul(grad_output, &scaled, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<std::sync::Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_input)
}
fn name(&self) -> &'static str {
"PowScalarBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
#[test]
fn test_add_scalar_backward() {
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[3], DType::F32, &device);
let backward = AddScalarBackward::<CpuRuntime>::new(a.id(), None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
assert_eq!(grad_a, vec![1.0, 1.0, 1.0]);
}
#[test]
fn test_mul_scalar_backward() {
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[3], DType::F32, &device);
let backward = MulScalarBackward::<CpuRuntime>::new(a.id(), 3.0, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
assert_eq!(grad_a, vec![3.0, 3.0, 3.0]);
}
#[test]
fn test_div_scalar_backward() {
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[4.0f32, 6.0, 8.0], &[3], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[3], DType::F32, &device);
let backward = DivScalarBackward::<CpuRuntime>::new(a.id(), 2.0, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
assert_eq!(grad_a, vec![0.5, 0.5, 0.5]);
}
#[test]
fn test_pow_scalar_backward() {
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[2.0f32, 3.0, 4.0], &[3], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[3], DType::F32, &device);
let backward = PowScalarBackward::<CpuRuntime>::new(a.id(), a.clone(), 2.0, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f32> = grads[0].as_ref().unwrap().to_vec();
assert_eq!(grad_a, vec![4.0, 6.0, 8.0]);
}
}