auto_diff/op/
reduction.rs1#![allow(clippy::redundant_closure_call)]
2use std::cell::{RefCell};
3use std::rc::Rc;
4
5use tensor_rs::tensor::Tensor;
6use super::{OpTrait, OpCall, Op, OpHandle};
7use crate::err::AutoDiffError;
8
9#[cfg(feature = "use-serde")]
10use serde::{Serialize, Deserialize};
11#[cfg(feature = "use-serde")]
12use std::any::Any;
13
14macro_rules! reduce_macro {
15 ($a:ident, $b:expr, $c:ident, $d: tt) => {
16 #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
17 pub struct $a {
18 #[cfg_attr(feature = "use-serde", serde(skip))]
19 handle: OpHandle,
20 dim: Option<Vec<usize>>,
21 keepdim: bool
22 }
23 impl $a {
24 pub fn new(dim: Option<&[usize]>, keepdim: bool) -> $a{
25 $a{
26 handle: OpHandle::new(),
27 dim: dim.map(|v| v.to_vec()),
28 keepdim,
29 }
30 }
31 fn get_handle(&self) -> &OpHandle {
32 &self.handle
33 }
34 fn get_handle_mut(&mut self) -> &mut OpHandle {
35 &mut self.handle
36 }
37 }
38 impl OpCall for $a {
39 fn call(&mut self, inputs: &[&crate::var::Var]) -> Result<Vec<crate::var::Var>, AutoDiffError> {
40 let new_one = $a {
41 handle: OpHandle::new(),
42 dim: self.dim.as_ref().map(|v| v.to_vec()),
43 keepdim: self.keepdim,
44 };
45
46 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
47
48 inputs[0].called_with(op, &inputs[1..inputs.len()])
49 }
50 }
51 impl OpTrait for $a {
52
53 fn get_name(&self) -> &'static str {
54 ($b)
55 }
56 fn get_input_size(&self) -> usize {
57 1
58 }
59 fn get_output_size(&self) -> usize {
60 1
61 }
62 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
63 match &self.dim {
64 Some(v) => {
65 let v1 = v.clone();
66 output[0].swap(&input[0].$c(Some(&v1), self.keepdim));
67 },
68 None => {
69 output[0].swap(&input[0].$c(None, self.keepdim));
70 },
71 }
72 }
73 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
74 $d(input, output_grad, input_grad)
75 }
76 fn get_values(&self) -> Vec<Tensor> {
77 Vec::new()
78 }
79 fn get_grads(&self) -> Vec<Tensor> {
80 Vec::new()
81 }
82 fn set_values(&self, _v: &[Tensor]) {
83 }
84 #[cfg(feature = "use-serde")]
85 fn as_any(&self) -> &dyn Any {
86 self
87 }
88 }
89 }
90}
91
92
93
94
95reduce_macro!(Argmax, "Argmax", argmax,
96 (|input: &[Tensor],
97 output_grad: &[Tensor],
98 input_grad: &[Tensor]| {
99 unimplemented!();
100 }));
101
102
103reduce_macro!(Argmin, "Argmin", argmin,
104 (|input: &[Tensor],
105 output_grad: &[Tensor],
106 input_grad: &[Tensor]| {
107 unimplemented!();
108 }));
109
110
111reduce_macro!(Logsumexp, "Logsumexp", logsumexp,
112 (|input: &[Tensor],
113 output_grad: &[Tensor],
114 input_grad: &[Tensor]| {
115 unimplemented!();
116 }));
117
118
119reduce_macro!(Mean, "Mean", mean,
120 (|input: &[Tensor],
121 output_grad: &[Tensor],
122 input_grad: &[Tensor]| {
123 unimplemented!();
124 }));
125
126reduce_macro!(Prod, "Prod", prod,
127 (|input: &[Tensor],
128 output_grad: &[Tensor],
129 input_grad: &[Tensor]| {
130 unimplemented!();
131 }));
132
133reduce_macro!(Std, "Std", std,
134 (|input: &[Tensor],
135 output_grad: &[Tensor],
136 input_grad: &[Tensor]| {
137 unimplemented!();
138 }));
139
140reduce_macro!(Sum, "Sum", sum,
141 (|input: &[Tensor],
142 output_grad: &[Tensor],
143 input_grad: &[Tensor]| {
144 unimplemented!();
145 }));
146
147reduce_macro!(Variance, "Var", var,
148 (|input: &[Tensor],
149 output_grad: &[Tensor],
150 input_grad: &[Tensor]| {
151 unimplemented!();
152 }));
153
154reduce_macro!(Max, "Max", max,
155 (|input: &[Tensor],
156 output_grad: &[Tensor],
157 input_grad: &[Tensor]| {
158 unimplemented!();
159 }));
160
161reduce_macro!(Min, "Min", min,
162 (|input: &[Tensor],
163 output_grad: &[Tensor],
164 input_grad: &[Tensor]| {
165 unimplemented!();
166 }));