auto_diff/op/
convolution.rs

1#![allow(clippy::too_many_arguments)]
2use tensor_rs::tensor::{Tensor, PaddingMode};
3use std::cell::{RefCell};
4use std::rc::Rc;
5use super::{OpTrait, OpCall, Op, OpHandle};
6use crate::var::Var;
7use crate::err::AutoDiffError;
8
9#[cfg(feature = "use-serde")]
10use serde::{Serialize, Deserialize};
11#[cfg(feature = "use-serde")]
12use std::any::Any;
13
14#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
15pub struct Conv1d {
16    alpha: f32,
17    #[cfg_attr(feature = "use-serde", serde(skip))]
18    handle: OpHandle,
19}
20impl Conv1d {
21    pub fn new(alpha: f32) -> Conv1d {
22        Conv1d {
23            alpha,
24            handle: OpHandle::new(),
25        }
26    }
27    handle_method!();
28}
29impl OpTrait for Conv1d {
30    fn get_name(&self) -> &'static str {
31        "Conv1d"
32    }
33    fn get_input_size(&self) -> usize {
34        2
35    }
36    fn get_output_size(&self) -> usize {
37        1
38    }
39    /// Forward pass
40    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
41        unimplemented!();
42    }
43    
44    /// Given the forward input value and backward output_grad,
45    /// Update weight gradient.
46    /// return backward input gradeint.
47    fn grad(&self, input: &[Tensor],
48            output_grad: &[Tensor],
49            input_grad: &[Tensor]) {
50        unimplemented!();
51    }
52
53    /// access weight values
54    fn get_values(&self) -> Vec<Tensor> {
55        Vec::new()
56    }
57    fn set_values(&self, v: &[Tensor]) {
58    }
59    /// access gradient values
60    fn get_grads(&self) -> Vec<Tensor> {
61        Vec::new()
62    }
63    #[cfg(feature = "use-serde")]
64    fn as_any(&self) -> &dyn Any {
65	self
66    }
67}
68
69// Conv2d
70
71#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
72pub struct Conv2d {
73    in_channels: usize,
74    out_channels: usize,
75    kernel_size: (usize, usize),
76    stride: (usize, usize),
77    padding: (usize, usize),
78    dilation: (usize, usize),
79    groups: usize,
80    bias_option: bool,
81    padding_mode: PaddingMode,
82    
83    weight: Tensor,
84    bias: Tensor,
85    weight_grad: Tensor,
86    bias_grad: Tensor,
87    #[cfg_attr(feature = "use-serde", serde(skip))]
88    handle: OpHandle,
89}
90impl Conv2d {
91    pub fn new(in_channels: usize, out_channels: usize,
92               kernel_size: (usize, usize),
93               stride: (usize, usize),
94               padding: (usize, usize),
95               dilation: (usize, usize),
96               bias: bool,
97               padding_mode: PaddingMode
98    ) -> Conv2d {
99        Conv2d {
100            in_channels,
101            out_channels,
102            kernel_size,
103            stride,
104            padding,
105            dilation,
106            groups: 1,
107            bias_option: bias,
108            padding_mode,
109            
110            weight: Tensor::empty(&[out_channels, in_channels, kernel_size.0, kernel_size.1]),
111            bias: Tensor::empty(&[out_channels, ]),
112            weight_grad: Tensor::empty(&[out_channels, in_channels, kernel_size.0, kernel_size.1]),
113            bias_grad: Tensor::empty(&[out_channels, ]),
114
115            handle: OpHandle::new(),
116        }
117    }
118
119    pub fn weight(&self) -> &Tensor {
120        &self.weight
121    }
122
123    pub fn set_weight(&self, var: Var) {
124        self.weight.swap(&var.val());
125    }
126    
127    pub fn bias(&self) -> &Tensor {
128        &self.bias
129    }
130    
131    pub fn set_bias(&self, var: Var) {
132        self.bias.swap(&var.val());
133    }
134
135    handle_method!();    
136}
137
138impl OpCall for Conv2d {
139    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
140        let new_one = Conv2d {
141            in_channels: self.in_channels,
142            out_channels: self.out_channels,
143            kernel_size: self.kernel_size,
144            stride: self.stride,
145            padding: self.padding,
146            dilation: self.dilation,
147            groups: self.groups,
148            bias_option: self.bias_option,
149            padding_mode: self.padding_mode,
150            
151            weight: self.weight.ref_copy(),
152            bias: self.bias.ref_copy(),
153            weight_grad: self.weight_grad.ref_copy(),
154            bias_grad: self.bias_grad.ref_copy(),
155
156            handle: OpHandle::new(),
157        };
158        
159        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
160        
161        inputs[0].called_with(op, &inputs[1..inputs.len()])
162    }
163}
164
165impl OpTrait for Conv2d {
166    fn get_name(&self) -> &'static str {
167        "Conv2d"
168    }
169    fn get_input_size(&self) -> usize {
170        1
171    }
172    fn get_output_size(&self) -> usize {
173        1
174    }
175    /// Forward pass
176    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
177        if self.groups > 1 {
178            unimplemented!();
179        }
180        if self.weight.size()[2] != self.kernel_size.0 || self.weight.size()[3] != self.kernel_size.1 {
181            panic!("this is conv2d");
182        }
183        let input_size = input[0].size();
184        if input_size[1] != self.in_channels {
185            panic!("conv2d expect the same input channel: input: {:?}, config: {:?}", input_size[1], self.in_channels);
186        }
187        let conv_output = input[0].conv2d(&self.weight, self.stride, self.padding, self.dilation, self.padding_mode);
188        //println!("conv_output: {:?}, {:?}, {:?}, {:?}, {:?}, {:?}", self.weight.size(), self.stride, self.padding, self.dilation, conv_output.size(), input[0].size());
189        if conv_output.size()[1] != self.out_channels {
190            panic!("conv2d expect the same input channel {:?}, {:?}", input_size[1], self.in_channels);
191        }
192
193        if self.bias_option {
194            //println!("{:?}, {:?}", self.weight.size(), self.bias.size());
195            let expanded_bias = self.bias
196                .unsqueeze(1)
197                .unsqueeze(2)
198                .repeat(&[1, conv_output.size()[2], conv_output.size()[3]]);
199            //println!("conv_output: {:?}, expanded_bias.size() {:?}", conv_output.size(), expanded_bias.size());
200            let ret = conv_output.add(&expanded_bias);
201            output[0].swap(&ret);
202        } else {
203            output[0].swap(&conv_output);            
204        }
205    }
206    
207    /// Given the forward input value and backward output_grad,
208    /// Update weight gradient.
209    /// return backward input gradeint.
210    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
211        let (w_grad, d_grad) = input[0].conv2d_grad(&self.weight,
212                                                    self.stride,
213                                                    self.padding,
214                                                    self.dilation,
215                                                    self.padding_mode,
216                                                    &output_grad[0]);
217        self.weight_grad.swap(&w_grad);
218        input_grad[0].swap(&d_grad);
219
220        if self.bias_option {
221            self.bias_grad.swap(&output_grad[0].mean(Some(&[0, 2, 3]), false));
222        }
223    }
224
225    /// access weight values
226    fn get_values(&self) -> Vec<Tensor> {
227        vec![self.weight.ref_copy(), self.bias.ref_copy()]
228    }
229    fn set_values(&self, v: &[Tensor]) {
230        self.weight.data_copy(&v[0]);
231        self.bias.data_copy(&v[1]);
232    }
233    /// access gradient values
234    fn get_grads(&self) -> Vec<Tensor> {
235        vec![self.weight_grad.ref_copy(), self.bias_grad.ref_copy()]
236    }
237    #[cfg(feature = "use-serde")]
238    fn as_any(&self) -> &dyn Any {
239	self
240    }
241}
242// Conv3d
243// ConvTranspose1d
244// ConvTranspose2d
245// ConvTranspose3d
246// Unfold
247// Fold