auto_diff/op/
linear.rs

1use tensor_rs::tensor::Tensor;
2use super::{OpTrait, OpCall, Op, OpHandle};
3
4use std::cell::{RefCell};
5use std::rc::Rc;
6
7use crate::var::{Var};
8use crate::err::AutoDiffError;
9
10#[cfg(feature = "use-serde")]
11use serde::{Serialize, Deserialize};
12#[cfg(feature = "use-serde")]
13use std::any::Any;
14
15#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
16pub struct Linear {
17    in_fea: Option<usize>,
18    out_fea: Option<usize>,
19    bias_option: bool,
20    weight: Tensor,
21    bias: Tensor,
22    weight_grad: Tensor,
23    bias_grad: Tensor,
24    #[cfg_attr(feature = "use-serde", serde(skip))]
25    handle: OpHandle,
26}
27impl Linear {
28    pub fn new(in_features: Option<usize>,
29               out_features: Option<usize>,
30               bias: bool) -> Linear {
31        let weight: Tensor;
32        let bias_tensor: Tensor;
33        match (in_features, out_features) {
34            (Some(d1), Some(d2)) => {
35                weight = Tensor::zeros(&[d1, d2]);
36                bias_tensor = Tensor::zeros(&[d2,]);
37                Linear {
38                    in_fea: in_features,
39                    out_fea: out_features,
40                    bias_option: bias,
41                    weight,
42                    bias: bias_tensor,
43                    weight_grad: Tensor::new(),
44                    bias_grad: Tensor::new(),
45                    handle: OpHandle::new(),
46                }
47            },
48            _ => {
49                Linear {
50                    in_fea: in_features,
51                    out_fea: out_features,
52                    bias_option: bias,
53                    weight: Tensor::new(),
54                    bias: Tensor::new(),
55                    weight_grad: Tensor::new(),
56                    bias_grad: Tensor::new(),
57                    handle: OpHandle::new(),
58                }
59            },
60        }
61        
62    }
63
64    pub fn weight(&self) -> &Tensor {
65        &self.weight
66    }
67
68    pub fn set_weight(&self, var: Var) {
69        self.weight.swap(&var.val());
70    }
71    
72    pub fn bias(&self) -> &Tensor {
73        &self.bias
74    }
75    
76    pub fn set_bias(&self, var: Var) {
77        self.bias.swap(&var.val());
78    }
79
80    handle_method!();
81}
82
83impl OpCall for Linear {
84    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
85        let new_one = Linear {
86            in_fea: self.in_fea,
87            out_fea: self.out_fea,
88            bias_option: self.bias_option,
89            weight: self.weight.ref_copy(),
90            bias: self.bias.ref_copy(),
91            weight_grad: self.weight_grad.ref_copy(),
92            bias_grad: self.bias_grad.ref_copy(),
93            handle: OpHandle::new(), // TODO; change this to None, this shold never be used.
94        };
95        
96        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
97        
98        inputs[0].called_with(op, &inputs[1..inputs.len()])
99    }
100
101
102}
103
104impl OpTrait for Linear {
105
106    
107
108    fn get_name(&self) -> &'static str {
109        "Linear"
110    }
111
112    fn get_input_size(&self) -> usize {
113        1
114    }
115
116    fn get_output_size(&self) -> usize {
117        1
118    }
119
120    fn apply(&self, inputs: &[Tensor],
121             outputs: &[Tensor]) {
122        // TODO go through condition where dimension is missing somewhere.
123        //println!("left sie: {:?}, right size: {:?}", inputs[0], self.weight);
124        if inputs.len() != 1 {
125            panic!("linear expect one input.");
126        }
127        if inputs[0].size()[inputs[0].size().len()-1] != self.weight.size()[0] {
128            panic!("dismatched size");
129        }
130        let ret = inputs[0].matmul(&self.weight);
131        outputs[0].swap(&ret);
132        //println!("matmut done");
133        if self.bias_option {
134            let ret = outputs[0].add(&self.bias);
135            outputs[0].swap(&ret);
136        }
137    }
138
139    fn grad(&self, inputs: &[Tensor],
140            output_grad: &[Tensor],
141            input_grad: &[Tensor]) {
142        if inputs.is_empty() {
143            panic!("Expect one input tensor");
144        }
145        if inputs[0].size()[1] != self.weight.size()[0] {
146            panic!("Expect input dimension matches weight dimension {:?}, {:?}",
147                   inputs[0].size(), self.weight.size());
148        }
149        if inputs[0].size()[0] != output_grad[0].size()[0] {
150            panic!("Expect input population matches output gradient population {:?}, {:?}",
151                   inputs[0].size(), output_grad[0].size());
152        }
153        if output_grad[0].size()[1] != self.weight.size()[1] {
154            panic!("Expect output gradient dimension matches weight dimension {:?}, {:?}",
155                   output_grad[0].size(), self.weight.size());
156        }
157
158        input_grad[0].swap(&output_grad[0].matmul(&self.weight.permute(&[1,0])));
159        self.weight_grad.swap(&inputs[0].outer(&output_grad[0], Some(true)));
160        if self.bias_option {
161            self.bias_grad.swap(&output_grad[0].mean(Some(&[0]), false));
162        }
163    }
164
165    fn get_values(&self) -> Vec<Tensor> {
166        let mut ret = vec![self.weight.clone()];
167        if self.bias_option {
168            ret.push(self.bias.clone());
169        }
170        ret
171    }
172    fn set_values(&self, v: &[Tensor]) {
173        self.weight.swap(&v[0].clone());
174        if self.bias_option {
175            self.bias.swap(&v[1].clone());
176        }
177    }
178    /// access gradient values
179    fn get_grads(&self) -> Vec<Tensor> {
180        let mut ret = vec![self.weight_grad.clone()];
181        if self.bias_option {
182            ret.push(self.bias_grad.clone());
183        }
184        ret
185    }
186
187    #[cfg(feature = "use-serde")]
188    fn as_any(&self) -> &dyn Any {
189	self
190    }
191    
192}
193
194// Bilinear
195#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
196pub struct BiLinear {
197    in1_fea: Option<usize>,
198    in2_fea: Option<usize>,
199    out_fea: Option<usize>,
200    bias_option: bool,
201    weight: Tensor,
202    bias: Tensor,
203    weight_grad: Tensor,
204    bias_grad: Tensor,
205    #[cfg_attr(feature = "use-serde", serde(skip))]
206    handle: OpHandle,
207}
208impl BiLinear {
209    pub fn new(in1_features: Option<usize>,
210               in2_features: Option<usize>,
211               out_features: Option<usize>,
212               bias: bool) -> BiLinear {
213        BiLinear {
214            in1_fea: in1_features,
215            in2_fea: in2_features,
216            out_fea: out_features,
217            bias_option: bias,
218            weight: Tensor::new(),
219            bias: Tensor::new(),
220            weight_grad: Tensor::new(),
221            bias_grad: Tensor::new(),
222            handle: OpHandle::new(),
223        }
224    }
225
226    pub fn weight(&self) -> &Tensor {
227        &self.weight
228    }
229
230    pub fn set_weight(&self, var: Var) {
231        self.weight.swap(&var.val());
232    }
233    
234    pub fn bias(&self) -> &Tensor {
235        &self.bias
236    }
237    
238    pub fn set_bias(&self, var: Var) {
239        self.bias.swap(&var.val());
240    }
241
242    handle_method!();
243}
244
245impl OpCall for BiLinear {
246    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
247        let new_one = BiLinear {
248            in1_fea: self.in1_fea,
249            in2_fea: self.in2_fea,
250            out_fea: self.out_fea,
251            bias_option: self.bias_option,
252            weight: self.weight.ref_copy(),
253            bias: self.bias.ref_copy(),
254            weight_grad: self.weight_grad.ref_copy(),
255            bias_grad: self.bias_grad.ref_copy(),
256            handle: OpHandle::new(), // TODO; change this to None, this shold never be used.
257        };
258        
259        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
260        
261        inputs[0].called_with(op, &inputs[1..inputs.len()])
262    }
263
264
265}
266
267impl OpTrait for BiLinear {
268    fn get_name(&self) -> &'static str {
269        "BiLinear"
270    }
271
272    fn get_input_size(&self) -> usize {
273        2
274    }
275
276    fn get_output_size(&self) -> usize {
277        1
278    }
279
280    fn apply(&self, inputs: &[Tensor],
281             outputs: &[Tensor]) {
282        
283        unimplemented!();
284        
285    }
286
287    fn grad(&self, inputs: &[Tensor],
288            output_grad: &[Tensor],
289            input_grad: &[Tensor]) {
290        if inputs.is_empty() {
291            panic!("Expect one input tensor");
292        }
293        if inputs[0].size()[1] != self.weight.size()[0] {
294            panic!("Expect input1 dimension matches weight dimension {:?}, {:?}",
295                   inputs[0].size(), self.weight.size());
296        }
297        if self.weight.size()[1] != inputs[1].size()[0] {
298            panic!("Expect weight dimension matches input2 dimension {:?}, {:?}",
299                   self.weight.size(), inputs[1].size());
300        }
301
302        unimplemented!();
303    }
304
305    fn get_values(&self) -> Vec<Tensor> {
306        let mut ret = vec![self.weight.clone()];
307        if self.bias_option {
308            ret.push(self.bias.clone());
309        }
310        ret
311    }
312    fn set_values(&self, v: &[Tensor]) {
313        self.weight.swap(&v[0].clone());
314        if self.bias_option {
315            self.bias.swap(&v[1].clone());
316        }
317    }
318    /// access gradient values
319    fn get_grads(&self) -> Vec<Tensor> {
320        let mut ret = vec![self.weight_grad.clone()];
321        if self.bias_option {
322            ret.push(self.bias_grad.clone());
323        }
324        ret
325    }
326
327    #[cfg(feature = "use-serde")]
328    fn as_any(&self) -> &dyn Any {
329	self
330    }
331    
332}