1use tensor_rs::tensor::Tensor;
2use super::{OpTrait, OpCall, Op, OpHandle};
3
4use std::cell::{RefCell};
5use std::rc::Rc;
6
7use crate::var::{Var};
8use crate::err::AutoDiffError;
9
10#[cfg(feature = "use-serde")]
11use serde::{Serialize, Deserialize};
12#[cfg(feature = "use-serde")]
13use std::any::Any;
14
15#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
16pub struct Linear {
17 in_fea: Option<usize>,
18 out_fea: Option<usize>,
19 bias_option: bool,
20 weight: Tensor,
21 bias: Tensor,
22 weight_grad: Tensor,
23 bias_grad: Tensor,
24 #[cfg_attr(feature = "use-serde", serde(skip))]
25 handle: OpHandle,
26}
27impl Linear {
28 pub fn new(in_features: Option<usize>,
29 out_features: Option<usize>,
30 bias: bool) -> Linear {
31 let weight: Tensor;
32 let bias_tensor: Tensor;
33 match (in_features, out_features) {
34 (Some(d1), Some(d2)) => {
35 weight = Tensor::zeros(&[d1, d2]);
36 bias_tensor = Tensor::zeros(&[d2,]);
37 Linear {
38 in_fea: in_features,
39 out_fea: out_features,
40 bias_option: bias,
41 weight,
42 bias: bias_tensor,
43 weight_grad: Tensor::new(),
44 bias_grad: Tensor::new(),
45 handle: OpHandle::new(),
46 }
47 },
48 _ => {
49 Linear {
50 in_fea: in_features,
51 out_fea: out_features,
52 bias_option: bias,
53 weight: Tensor::new(),
54 bias: Tensor::new(),
55 weight_grad: Tensor::new(),
56 bias_grad: Tensor::new(),
57 handle: OpHandle::new(),
58 }
59 },
60 }
61
62 }
63
64 pub fn weight(&self) -> &Tensor {
65 &self.weight
66 }
67
68 pub fn set_weight(&self, var: Var) {
69 self.weight.swap(&var.val());
70 }
71
72 pub fn bias(&self) -> &Tensor {
73 &self.bias
74 }
75
76 pub fn set_bias(&self, var: Var) {
77 self.bias.swap(&var.val());
78 }
79
80 handle_method!();
81}
82
83impl OpCall for Linear {
84 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
85 let new_one = Linear {
86 in_fea: self.in_fea,
87 out_fea: self.out_fea,
88 bias_option: self.bias_option,
89 weight: self.weight.ref_copy(),
90 bias: self.bias.ref_copy(),
91 weight_grad: self.weight_grad.ref_copy(),
92 bias_grad: self.bias_grad.ref_copy(),
93 handle: OpHandle::new(), };
95
96 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
97
98 inputs[0].called_with(op, &inputs[1..inputs.len()])
99 }
100
101
102}
103
104impl OpTrait for Linear {
105
106
107
108 fn get_name(&self) -> &'static str {
109 "Linear"
110 }
111
112 fn get_input_size(&self) -> usize {
113 1
114 }
115
116 fn get_output_size(&self) -> usize {
117 1
118 }
119
120 fn apply(&self, inputs: &[Tensor],
121 outputs: &[Tensor]) {
122 if inputs.len() != 1 {
125 panic!("linear expect one input.");
126 }
127 if inputs[0].size()[inputs[0].size().len()-1] != self.weight.size()[0] {
128 panic!("dismatched size");
129 }
130 let ret = inputs[0].matmul(&self.weight);
131 outputs[0].swap(&ret);
132 if self.bias_option {
134 let ret = outputs[0].add(&self.bias);
135 outputs[0].swap(&ret);
136 }
137 }
138
139 fn grad(&self, inputs: &[Tensor],
140 output_grad: &[Tensor],
141 input_grad: &[Tensor]) {
142 if inputs.is_empty() {
143 panic!("Expect one input tensor");
144 }
145 if inputs[0].size()[1] != self.weight.size()[0] {
146 panic!("Expect input dimension matches weight dimension {:?}, {:?}",
147 inputs[0].size(), self.weight.size());
148 }
149 if inputs[0].size()[0] != output_grad[0].size()[0] {
150 panic!("Expect input population matches output gradient population {:?}, {:?}",
151 inputs[0].size(), output_grad[0].size());
152 }
153 if output_grad[0].size()[1] != self.weight.size()[1] {
154 panic!("Expect output gradient dimension matches weight dimension {:?}, {:?}",
155 output_grad[0].size(), self.weight.size());
156 }
157
158 input_grad[0].swap(&output_grad[0].matmul(&self.weight.permute(&[1,0])));
159 self.weight_grad.swap(&inputs[0].outer(&output_grad[0], Some(true)));
160 if self.bias_option {
161 self.bias_grad.swap(&output_grad[0].mean(Some(&[0]), false));
162 }
163 }
164
165 fn get_values(&self) -> Vec<Tensor> {
166 let mut ret = vec![self.weight.clone()];
167 if self.bias_option {
168 ret.push(self.bias.clone());
169 }
170 ret
171 }
172 fn set_values(&self, v: &[Tensor]) {
173 self.weight.swap(&v[0].clone());
174 if self.bias_option {
175 self.bias.swap(&v[1].clone());
176 }
177 }
178 fn get_grads(&self) -> Vec<Tensor> {
180 let mut ret = vec![self.weight_grad.clone()];
181 if self.bias_option {
182 ret.push(self.bias_grad.clone());
183 }
184 ret
185 }
186
187 #[cfg(feature = "use-serde")]
188 fn as_any(&self) -> &dyn Any {
189 self
190 }
191
192}
193
194#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
196pub struct BiLinear {
197 in1_fea: Option<usize>,
198 in2_fea: Option<usize>,
199 out_fea: Option<usize>,
200 bias_option: bool,
201 weight: Tensor,
202 bias: Tensor,
203 weight_grad: Tensor,
204 bias_grad: Tensor,
205 #[cfg_attr(feature = "use-serde", serde(skip))]
206 handle: OpHandle,
207}
208impl BiLinear {
209 pub fn new(in1_features: Option<usize>,
210 in2_features: Option<usize>,
211 out_features: Option<usize>,
212 bias: bool) -> BiLinear {
213 BiLinear {
214 in1_fea: in1_features,
215 in2_fea: in2_features,
216 out_fea: out_features,
217 bias_option: bias,
218 weight: Tensor::new(),
219 bias: Tensor::new(),
220 weight_grad: Tensor::new(),
221 bias_grad: Tensor::new(),
222 handle: OpHandle::new(),
223 }
224 }
225
226 pub fn weight(&self) -> &Tensor {
227 &self.weight
228 }
229
230 pub fn set_weight(&self, var: Var) {
231 self.weight.swap(&var.val());
232 }
233
234 pub fn bias(&self) -> &Tensor {
235 &self.bias
236 }
237
238 pub fn set_bias(&self, var: Var) {
239 self.bias.swap(&var.val());
240 }
241
242 handle_method!();
243}
244
245impl OpCall for BiLinear {
246 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
247 let new_one = BiLinear {
248 in1_fea: self.in1_fea,
249 in2_fea: self.in2_fea,
250 out_fea: self.out_fea,
251 bias_option: self.bias_option,
252 weight: self.weight.ref_copy(),
253 bias: self.bias.ref_copy(),
254 weight_grad: self.weight_grad.ref_copy(),
255 bias_grad: self.bias_grad.ref_copy(),
256 handle: OpHandle::new(), };
258
259 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
260
261 inputs[0].called_with(op, &inputs[1..inputs.len()])
262 }
263
264
265}
266
267impl OpTrait for BiLinear {
268 fn get_name(&self) -> &'static str {
269 "BiLinear"
270 }
271
272 fn get_input_size(&self) -> usize {
273 2
274 }
275
276 fn get_output_size(&self) -> usize {
277 1
278 }
279
280 fn apply(&self, inputs: &[Tensor],
281 outputs: &[Tensor]) {
282
283 unimplemented!();
284
285 }
286
287 fn grad(&self, inputs: &[Tensor],
288 output_grad: &[Tensor],
289 input_grad: &[Tensor]) {
290 if inputs.is_empty() {
291 panic!("Expect one input tensor");
292 }
293 if inputs[0].size()[1] != self.weight.size()[0] {
294 panic!("Expect input1 dimension matches weight dimension {:?}, {:?}",
295 inputs[0].size(), self.weight.size());
296 }
297 if self.weight.size()[1] != inputs[1].size()[0] {
298 panic!("Expect weight dimension matches input2 dimension {:?}, {:?}",
299 self.weight.size(), inputs[1].size());
300 }
301
302 unimplemented!();
303 }
304
305 fn get_values(&self) -> Vec<Tensor> {
306 let mut ret = vec![self.weight.clone()];
307 if self.bias_option {
308 ret.push(self.bias.clone());
309 }
310 ret
311 }
312 fn set_values(&self, v: &[Tensor]) {
313 self.weight.swap(&v[0].clone());
314 if self.bias_option {
315 self.bias.swap(&v[1].clone());
316 }
317 }
318 fn get_grads(&self) -> Vec<Tensor> {
320 let mut ret = vec![self.weight_grad.clone()];
321 if self.bias_option {
322 ret.push(self.bias_grad.clone());
323 }
324 ret
325 }
326
327 #[cfg(feature = "use-serde")]
328 fn as_any(&self) -> &dyn Any {
329 self
330 }
331
332}