1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
use std::any::Any;

use crate::{Primitive, Shape, Tensor};

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ReduceSum {
    pub axes: Vec<usize>,
}
impl ReduceSum {
    pub fn new(axes: impl Into<Vec<usize>>) -> Self {
        Self { axes: axes.into() }
    }
}

impl Primitive for ReduceSum {
    fn clone_boxed(&self) -> Box<dyn Primitive> {
        Box::new(self.clone())
    }

    fn dot_label(&self) -> String {
        format!("ReduceSum({:?})", &self.axes)
    }

    fn as_any(&self) -> &dyn Any {
        self
    }

    fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
        let tangent_x = &tangents[0];
        tangent_x.reduce_sum((&self.axes, false))
    }

    fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
        let x = &primals[0];
        let mut shape = x.shape().to_vec();
        for axis in &self.axes {
            shape[*axis] = 1;
        }
        let cotangent_x = cotangent.reshape(shape).broadcast_to(x);
        vec![cotangent_x]
    }
}