auto_diff/op/
reduction.rs

1#![allow(clippy::redundant_closure_call)]
2use std::cell::{RefCell};
3use std::rc::Rc;
4
5use tensor_rs::tensor::Tensor;
6use super::{OpTrait, OpCall, Op, OpHandle};
7use crate::err::AutoDiffError;
8
9#[cfg(feature = "use-serde")]
10use serde::{Serialize, Deserialize};
11#[cfg(feature = "use-serde")]
12use std::any::Any;
13
14macro_rules! reduce_macro {
15    ($a:ident, $b:expr, $c:ident, $d: tt) => {
16	#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
17        pub struct $a {
18	    #[cfg_attr(feature = "use-serde", serde(skip))]
19            handle: OpHandle,
20            dim: Option<Vec<usize>>,
21            keepdim: bool
22        }
23        impl $a {
24            pub fn new(dim: Option<&[usize]>, keepdim: bool) -> $a{
25                $a{
26                    handle: OpHandle::new(),
27                    dim: dim.map(|v| v.to_vec()),
28                    keepdim,
29                }
30            }
31            fn get_handle(&self) -> &OpHandle {
32                &self.handle
33            }
34            fn get_handle_mut(&mut self) -> &mut OpHandle {
35                &mut self.handle
36            }
37        }
38        impl OpCall for $a {
39            fn call(&mut self, inputs: &[&crate::var::Var]) -> Result<Vec<crate::var::Var>, AutoDiffError> {
40                let new_one = $a {
41                    handle: OpHandle::new(),
42                    dim: self.dim.as_ref().map(|v| v.to_vec()),
43                    keepdim: self.keepdim,
44                };
45
46                let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
47
48                inputs[0].called_with(op, &inputs[1..inputs.len()])
49            }
50        }
51        impl OpTrait for $a {
52     
53            fn get_name(&self) -> &'static str {
54                ($b)
55            }
56            fn get_input_size(&self) -> usize {
57                1
58            }
59            fn get_output_size(&self) -> usize {
60                1
61            }
62            fn apply(&self, input: &[Tensor], output: &[Tensor]) {
63                match &self.dim {
64                    Some(v) => {
65                        let v1 = v.clone();
66                        output[0].swap(&input[0].$c(Some(&v1), self.keepdim));
67                    },
68                    None => {
69                        output[0].swap(&input[0].$c(None, self.keepdim));
70                    },
71        }
72            }
73            fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
74                $d(input, output_grad, input_grad)
75            }
76            fn get_values(&self) -> Vec<Tensor> {
77                Vec::new()
78            }
79            fn get_grads(&self) -> Vec<Tensor> {
80                Vec::new()
81            }
82            fn set_values(&self, _v: &[Tensor]) {
83            }
84	    #[cfg(feature = "use-serde")]
85	    fn as_any(&self) -> &dyn Any {
86		self
87	    }
88        }
89    }
90}
91
92
93
94
95reduce_macro!(Argmax, "Argmax", argmax,
96              (|input: &[Tensor],
97               output_grad: &[Tensor],
98               input_grad: &[Tensor]| {
99                   unimplemented!();
100               }));
101
102
103reduce_macro!(Argmin, "Argmin", argmin,
104              (|input: &[Tensor],
105               output_grad: &[Tensor],
106               input_grad: &[Tensor]| {
107                   unimplemented!();
108               }));
109
110
111reduce_macro!(Logsumexp, "Logsumexp", logsumexp,
112              (|input: &[Tensor],
113               output_grad: &[Tensor],
114               input_grad: &[Tensor]| {
115                   unimplemented!();
116               }));
117
118
119reduce_macro!(Mean, "Mean", mean,
120              (|input: &[Tensor],
121               output_grad: &[Tensor],
122               input_grad: &[Tensor]| {
123                   unimplemented!();
124               }));
125
126reduce_macro!(Prod, "Prod", prod,
127              (|input: &[Tensor],
128               output_grad: &[Tensor],
129               input_grad: &[Tensor]| {
130                   unimplemented!();
131               }));
132
133reduce_macro!(Std, "Std", std,
134              (|input: &[Tensor],
135               output_grad: &[Tensor],
136               input_grad: &[Tensor]| {
137                   unimplemented!();
138               }));
139
140reduce_macro!(Sum, "Sum", sum,
141              (|input: &[Tensor],
142               output_grad: &[Tensor],
143               input_grad: &[Tensor]| {
144                   unimplemented!();
145               }));
146
147reduce_macro!(Variance, "Var", var,
148              (|input: &[Tensor],
149               output_grad: &[Tensor],
150               input_grad: &[Tensor]| {
151                   unimplemented!();
152               }));
153
154reduce_macro!(Max, "Max", max,
155              (|input: &[Tensor],
156               output_grad: &[Tensor],
157               input_grad: &[Tensor]| {
158                   unimplemented!();
159               }));
160
161reduce_macro!(Min, "Min", min,
162              (|input: &[Tensor],
163               output_grad: &[Tensor],
164               input_grad: &[Tensor]| {
165                   unimplemented!();
166               }));