1#![allow(clippy::too_many_arguments)]
2use tensor_rs::tensor::{Tensor, PaddingMode};
3use std::cell::{RefCell};
4use std::rc::Rc;
5use super::{OpTrait, OpCall, Op, OpHandle};
6use crate::var::Var;
7use crate::err::AutoDiffError;
8
9#[cfg(feature = "use-serde")]
10use serde::{Serialize, Deserialize};
11#[cfg(feature = "use-serde")]
12use std::any::Any;
13
14#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
15pub struct Conv1d {
16 alpha: f32,
17 #[cfg_attr(feature = "use-serde", serde(skip))]
18 handle: OpHandle,
19}
20impl Conv1d {
21 pub fn new(alpha: f32) -> Conv1d {
22 Conv1d {
23 alpha,
24 handle: OpHandle::new(),
25 }
26 }
27 handle_method!();
28}
29impl OpTrait for Conv1d {
30 fn get_name(&self) -> &'static str {
31 "Conv1d"
32 }
33 fn get_input_size(&self) -> usize {
34 2
35 }
36 fn get_output_size(&self) -> usize {
37 1
38 }
39 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
41 unimplemented!();
42 }
43
44 fn grad(&self, input: &[Tensor],
48 output_grad: &[Tensor],
49 input_grad: &[Tensor]) {
50 unimplemented!();
51 }
52
53 fn get_values(&self) -> Vec<Tensor> {
55 Vec::new()
56 }
57 fn set_values(&self, v: &[Tensor]) {
58 }
59 fn get_grads(&self) -> Vec<Tensor> {
61 Vec::new()
62 }
63 #[cfg(feature = "use-serde")]
64 fn as_any(&self) -> &dyn Any {
65 self
66 }
67}
68
69#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
72pub struct Conv2d {
73 in_channels: usize,
74 out_channels: usize,
75 kernel_size: (usize, usize),
76 stride: (usize, usize),
77 padding: (usize, usize),
78 dilation: (usize, usize),
79 groups: usize,
80 bias_option: bool,
81 padding_mode: PaddingMode,
82
83 weight: Tensor,
84 bias: Tensor,
85 weight_grad: Tensor,
86 bias_grad: Tensor,
87 #[cfg_attr(feature = "use-serde", serde(skip))]
88 handle: OpHandle,
89}
90impl Conv2d {
91 pub fn new(in_channels: usize, out_channels: usize,
92 kernel_size: (usize, usize),
93 stride: (usize, usize),
94 padding: (usize, usize),
95 dilation: (usize, usize),
96 bias: bool,
97 padding_mode: PaddingMode
98 ) -> Conv2d {
99 Conv2d {
100 in_channels,
101 out_channels,
102 kernel_size,
103 stride,
104 padding,
105 dilation,
106 groups: 1,
107 bias_option: bias,
108 padding_mode,
109
110 weight: Tensor::empty(&[out_channels, in_channels, kernel_size.0, kernel_size.1]),
111 bias: Tensor::empty(&[out_channels, ]),
112 weight_grad: Tensor::empty(&[out_channels, in_channels, kernel_size.0, kernel_size.1]),
113 bias_grad: Tensor::empty(&[out_channels, ]),
114
115 handle: OpHandle::new(),
116 }
117 }
118
119 pub fn weight(&self) -> &Tensor {
120 &self.weight
121 }
122
123 pub fn set_weight(&self, var: Var) {
124 self.weight.swap(&var.val());
125 }
126
127 pub fn bias(&self) -> &Tensor {
128 &self.bias
129 }
130
131 pub fn set_bias(&self, var: Var) {
132 self.bias.swap(&var.val());
133 }
134
135 handle_method!();
136}
137
138impl OpCall for Conv2d {
139 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
140 let new_one = Conv2d {
141 in_channels: self.in_channels,
142 out_channels: self.out_channels,
143 kernel_size: self.kernel_size,
144 stride: self.stride,
145 padding: self.padding,
146 dilation: self.dilation,
147 groups: self.groups,
148 bias_option: self.bias_option,
149 padding_mode: self.padding_mode,
150
151 weight: self.weight.ref_copy(),
152 bias: self.bias.ref_copy(),
153 weight_grad: self.weight_grad.ref_copy(),
154 bias_grad: self.bias_grad.ref_copy(),
155
156 handle: OpHandle::new(),
157 };
158
159 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
160
161 inputs[0].called_with(op, &inputs[1..inputs.len()])
162 }
163}
164
165impl OpTrait for Conv2d {
166 fn get_name(&self) -> &'static str {
167 "Conv2d"
168 }
169 fn get_input_size(&self) -> usize {
170 1
171 }
172 fn get_output_size(&self) -> usize {
173 1
174 }
175 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
177 if self.groups > 1 {
178 unimplemented!();
179 }
180 if self.weight.size()[2] != self.kernel_size.0 || self.weight.size()[3] != self.kernel_size.1 {
181 panic!("this is conv2d");
182 }
183 let input_size = input[0].size();
184 if input_size[1] != self.in_channels {
185 panic!("conv2d expect the same input channel: input: {:?}, config: {:?}", input_size[1], self.in_channels);
186 }
187 let conv_output = input[0].conv2d(&self.weight, self.stride, self.padding, self.dilation, self.padding_mode);
188 if conv_output.size()[1] != self.out_channels {
190 panic!("conv2d expect the same input channel {:?}, {:?}", input_size[1], self.in_channels);
191 }
192
193 if self.bias_option {
194 let expanded_bias = self.bias
196 .unsqueeze(1)
197 .unsqueeze(2)
198 .repeat(&[1, conv_output.size()[2], conv_output.size()[3]]);
199 let ret = conv_output.add(&expanded_bias);
201 output[0].swap(&ret);
202 } else {
203 output[0].swap(&conv_output);
204 }
205 }
206
207 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
211 let (w_grad, d_grad) = input[0].conv2d_grad(&self.weight,
212 self.stride,
213 self.padding,
214 self.dilation,
215 self.padding_mode,
216 &output_grad[0]);
217 self.weight_grad.swap(&w_grad);
218 input_grad[0].swap(&d_grad);
219
220 if self.bias_option {
221 self.bias_grad.swap(&output_grad[0].mean(Some(&[0, 2, 3]), false));
222 }
223 }
224
225 fn get_values(&self) -> Vec<Tensor> {
227 vec![self.weight.ref_copy(), self.bias.ref_copy()]
228 }
229 fn set_values(&self, v: &[Tensor]) {
230 self.weight.data_copy(&v[0]);
231 self.bias.data_copy(&v[1]);
232 }
233 fn get_grads(&self) -> Vec<Tensor> {
235 vec![self.weight_grad.ref_copy(), self.bias_grad.ref_copy()]
236 }
237 #[cfg(feature = "use-serde")]
238 fn as_any(&self) -> &dyn Any {
239 self
240 }
241}
242