use crate::autograd::GradFn;
use crate::autograd::var::Var;
use crate::autograd::var_ops::var_div_scalar;
use crate::error::Result;
use crate::ops::ScalarOps;
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::{Tensor, TensorId};
use std::sync::Arc;
use super::common::ensure_contiguous;
pub struct SumBackward<R: Runtime> {
input_id: TensorId,
input_shape: Vec<usize>,
dims: Vec<usize>,
keepdim: bool,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> SumBackward<R> {
pub fn new(
input_id: TensorId,
input_shape: &[usize],
dims: &[usize],
keepdim: bool,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
input_shape: input_shape.to_vec(),
dims: dims.to_vec(),
keepdim,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for SumBackward<R> {
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let mut grad = grad_output.clone();
if !self.keepdim {
let mut sorted_dims = self.dims.clone();
sorted_dims.sort();
for &dim in &sorted_dims {
grad = grad.unsqueeze(dim as isize)?;
}
}
grad = ensure_contiguous(grad.broadcast_to(&self.input_shape)?);
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let mut grad_tensor = grad_output.tensor().clone();
if !self.keepdim {
let mut sorted_dims = self.dims.clone();
sorted_dims.sort();
for &dim in &sorted_dims {
grad_tensor = grad_tensor.unsqueeze(dim as isize)?;
}
}
grad_tensor = ensure_contiguous(grad_tensor.broadcast_to(&self.input_shape)?);
Ok(vec![Some(Var::new(grad_tensor, true))])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
pub struct MeanBackward<R: Runtime> {
input_id: TensorId,
input_shape: Vec<usize>,
dims: Vec<usize>,
keepdim: bool,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> MeanBackward<R> {
pub fn new(
input_id: TensorId,
input_shape: &[usize],
dims: &[usize],
keepdim: bool,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
input_shape: input_shape.to_vec(),
dims: dims.to_vec(),
keepdim,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for MeanBackward<R>
where
R::Client: ScalarOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let count: usize = self.dims.iter().map(|&d| self.input_shape[d]).product();
let count_f64 = count as f64;
let mut grad = grad_output.clone();
if !self.keepdim {
let mut sorted_dims = self.dims.clone();
sorted_dims.sort();
for &dim in &sorted_dims {
grad = grad.unsqueeze(dim as isize)?;
}
}
grad = ensure_contiguous(grad.broadcast_to(&self.input_shape)?);
let grad = client.div_scalar(&grad, count_f64)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R> + crate::ops::TensorOps<R> + ScalarOps<R>,
{
let client = R::default_client(grad_output.tensor().device());
let count: usize = self.dims.iter().map(|&d| self.input_shape[d]).product();
let count_f64 = count as f64;
let mut grad_tensor = grad_output.tensor().clone();
if !self.keepdim {
let mut sorted_dims = self.dims.clone();
sorted_dims.sort();
for &dim in &sorted_dims {
grad_tensor = grad_tensor.unsqueeze(dim as isize)?;
}
}
grad_tensor = ensure_contiguous(grad_tensor.broadcast_to(&self.input_shape)?);
let grad_var = Var::new(grad_tensor, grad_output.requires_grad());
let grad = var_div_scalar(&grad_var, count_f64, &client)?;
Ok(vec![Some(grad)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn name(&self) -> &'static str {
"MeanBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
#[test]
fn test_sum_backward_keepdim() {
let device = CpuDevice::new();
let grad_out = Tensor::<CpuRuntime>::ones(&[2, 1], DType::F32, &device);
let backward = SumBackward::<CpuRuntime>::new(TensorId::new(), &[2, 3], &[1], true, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a = grads[0].as_ref().unwrap();
assert_eq!(grad_a.shape(), &[2, 3]);
let grad_data: Vec<f32> = grad_a.to_vec();
assert_eq!(grad_data, vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn test_sum_backward_no_keepdim() {
let device = CpuDevice::new();
let grad_out = Tensor::<CpuRuntime>::ones(&[2], DType::F32, &device);
let backward = SumBackward::<CpuRuntime>::new(TensorId::new(), &[2, 3], &[1], false, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a = grads[0].as_ref().unwrap();
assert_eq!(grad_a.shape(), &[2, 3]);
let grad_data: Vec<f32> = grad_a.to_vec();
assert_eq!(grad_data, vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn test_mean_backward() {
let device = CpuDevice::new();
let grad_out = Tensor::<CpuRuntime>::ones(&[2, 1], DType::F32, &device);
let backward = MeanBackward::<CpuRuntime>::new(TensorId::new(), &[2, 3], &[1], true, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a = grads[0].as_ref().unwrap();
assert_eq!(grad_a.shape(), &[2, 3]);
let grad_data: Vec<f32> = grad_a.to_vec();
let expected = 1.0 / 3.0;
for val in grad_data {
assert!((val - expected).abs() < 1e-6);
}
}
}