use crate::autograd::GradFn;
use crate::autograd::var::Var;
use crate::autograd::var_ops::var_mul;
use crate::error::Result;
use crate::ops::{BinaryOps, CompareOps, ReduceOps, ScalarOps, TensorOps};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::{Tensor, TensorId};
use std::sync::Arc;
use super::common::ensure_contiguous;
pub struct MaxBackward<R: Runtime> {
input_id: TensorId,
saved_input: Tensor<R>,
dims: Vec<usize>,
keepdim: bool,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> MaxBackward<R> {
pub fn new(
input_id: TensorId,
input: Tensor<R>,
dims: &[usize],
keepdim: bool,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
saved_input: input,
dims: dims.to_vec(),
keepdim,
input_grad_fn,
}
}
}
fn extremum_backward<R: Runtime>(
saved_input: &Tensor<R>,
grad_output: &Tensor<R>,
dims: &[usize],
keepdim: bool,
is_max: bool,
) -> Result<Tensor<R>>
where
R::Client: TensorOps<R> + ScalarOps<R> + CompareOps<R> + ReduceOps<R>,
{
let client = R::default_client(grad_output.device());
let extremum_vals = if is_max {
client.max(saved_input, dims, true)?
} else {
client.min(saved_input, dims, true)?
};
let extremum_broadcast = ensure_contiguous(extremum_vals.broadcast_to(saved_input.shape())?);
let mask = client.eq(saved_input, &extremum_broadcast)?;
let mask_sum = client.sum(&mask, dims, true)?;
let mask_sum_broadcast = ensure_contiguous(mask_sum.broadcast_to(saved_input.shape())?);
let normalized_mask = client.div(&mask, &mask_sum_broadcast)?;
let mut grad = grad_output.clone();
if !keepdim {
let mut sorted_dims = dims.to_vec();
sorted_dims.sort();
for &dim in &sorted_dims {
grad = grad.unsqueeze(dim as isize)?;
}
}
let grad_broadcast = ensure_contiguous(grad.broadcast_to(saved_input.shape())?);
client.mul(&grad_broadcast, &normalized_mask)
}
fn extremum_backward_var<R: Runtime>(
saved_input: &Tensor<R>,
grad_output: &Var<R>,
dims: &[usize],
keepdim: bool,
is_max: bool,
) -> Result<Var<R>>
where
R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + CompareOps<R> + ReduceOps<R>,
{
let client = R::default_client(grad_output.tensor().device());
let extremum_vals = if is_max {
client.max(saved_input, dims, true)?
} else {
client.min(saved_input, dims, true)?
};
let extremum_broadcast = ensure_contiguous(extremum_vals.broadcast_to(saved_input.shape())?);
let mask = client.eq(saved_input, &extremum_broadcast)?;
let mask_sum = client.sum(&mask, dims, true)?;
let mask_sum_broadcast = ensure_contiguous(mask_sum.broadcast_to(saved_input.shape())?);
let normalized_mask = client.div(&mask, &mask_sum_broadcast)?;
let mut grad_tensor = grad_output.tensor().clone();
if !keepdim {
let mut sorted_dims = dims.to_vec();
sorted_dims.sort();
for &dim in &sorted_dims {
grad_tensor = grad_tensor.unsqueeze(dim as isize)?;
}
}
let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(saved_input.shape())?);
let grad_var = Var::new(grad_broadcast, grad_output.requires_grad());
let mask_var = Var::new(normalized_mask, false);
var_mul(&grad_var, &mask_var, &client)
}
impl<R: Runtime> GradFn<R> for MaxBackward<R>
where
R::Client: TensorOps<R> + ScalarOps<R> + CompareOps<R> + ReduceOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let grad_input = extremum_backward(
&self.saved_input,
grad_output,
&self.dims,
self.keepdim,
true,
)?;
Ok(vec![Some(grad_input)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + CompareOps<R> + ReduceOps<R>,
{
let grad_input = extremum_backward_var(
&self.saved_input,
grad_output,
&self.dims,
self.keepdim,
true,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_input)
}
fn name(&self) -> &'static str {
"MaxBackward"
}
}
pub struct MinBackward<R: Runtime> {
input_id: TensorId,
saved_input: Tensor<R>,
dims: Vec<usize>,
keepdim: bool,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> MinBackward<R> {
pub fn new(
input_id: TensorId,
input: Tensor<R>,
dims: &[usize],
keepdim: bool,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
saved_input: input,
dims: dims.to_vec(),
keepdim,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for MinBackward<R>
where
R::Client: TensorOps<R> + ScalarOps<R> + CompareOps<R> + ReduceOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let grad_input = extremum_backward(
&self.saved_input,
grad_output,
&self.dims,
self.keepdim,
false,
)?;
Ok(vec![Some(grad_input)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + CompareOps<R> + ReduceOps<R>,
{
let grad_input = extremum_backward_var(
&self.saved_input,
grad_output,
&self.dims,
self.keepdim,
false,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn saved_tensors(&self) -> &[Tensor<R>] {
std::slice::from_ref(&self.saved_input)
}
fn name(&self) -> &'static str {
"MinBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
#[test]
fn test_max_backward() {
let device = CpuDevice::new();
let _client = CpuRuntime::default_client(&device);
let a =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 3.0, 2.0, 4.0, 2.0, 5.0], &[2, 3], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[2, 1], DType::F32, &device);
let backward = MaxBackward::<CpuRuntime>::new(a.id(), a.clone(), &[1], true, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a = grads[0].as_ref().unwrap();
assert_eq!(grad_a.shape(), &[2, 3]);
let grad_data: Vec<f32> = grad_a.to_vec();
assert_eq!(grad_data, vec![0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
}
#[test]
fn test_min_backward() {
let device = CpuDevice::new();
let _client = CpuRuntime::default_client(&device);
let a =
Tensor::<CpuRuntime>::from_slice(&[3.0f32, 1.0, 2.0, 4.0, 2.0, 5.0], &[2, 3], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[2, 1], DType::F32, &device);
let backward = MinBackward::<CpuRuntime>::new(a.id(), a.clone(), &[1], true, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a = grads[0].as_ref().unwrap();
assert_eq!(grad_a.shape(), &[2, 3]);
let grad_data: Vec<f32> = grad_a.to_vec();
assert_eq!(grad_data, vec![0.0, 1.0, 0.0, 0.0, 1.0, 0.0]);
}
#[test]
fn test_max_backward_with_ties() {
let device = CpuDevice::new();
let _client = CpuRuntime::default_client(&device);
let a = Tensor::<CpuRuntime>::from_slice(&[3.0f32, 3.0, 1.0], &[1, 3], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[1, 1], DType::F32, &device);
let backward = MaxBackward::<CpuRuntime>::new(a.id(), a.clone(), &[1], true, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a = grads[0].as_ref().unwrap();
assert_eq!(grad_a.shape(), &[1, 3]);
let grad_data: Vec<f32> = grad_a.to_vec();
assert!((grad_data[0] - 0.5).abs() < 1e-6);
assert!((grad_data[1] - 0.5).abs() < 1e-6);
assert!((grad_data[2] - 0.0).abs() < 1e-6);
}
}