1use std::cell::{RefCell};
3use std::rc::Rc;
4
5use tensor_rs::tensor::Tensor;
6use crate::var::Var;
7use crate::err::AutoDiffError;
8use crate::collection::generational_index::{GenKey};
9use crate::compute_graph::Net;
10
11
12#[cfg(feature = "use-serde")]
13use serde::{Serializer, de, de::MapAccess, de::SeqAccess,};
14#[cfg(feature = "use-serde")]
15use serde::{Serialize, Deserialize};
16#[cfg(feature = "use-serde")]
17use std::any::Any;
18
19pub trait OpTrait {
23 fn get_name(&self) -> &'static str;
25
26 fn get_input_size(&self) -> usize;
28
29 fn get_output_size(&self) -> usize;
31
32 fn apply(&self, input: &[Tensor], output: &[Tensor]);
34
35 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]);
39
40 fn get_values(&self) -> Vec<Tensor>;
42 fn set_values(&self, v: &[Tensor]);
43 fn get_grads(&self) -> Vec<Tensor>;
45
46 #[cfg(feature = "use-serde")]
47 fn as_any(&self) -> &dyn Any;
48}
49
50pub trait OpCall {
53 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError>;
54}
55
56pub struct OpHandle {
57 id: GenKey,
58 net: Rc<RefCell<Net>>,
59}
60impl OpHandle {
61 pub fn new() -> OpHandle {
62 OpHandle {
63 id: GenKey::new(0, 0),
64 net: Rc::new(RefCell::new(Net::new()))
65 }
66 }
67}
68impl Default for OpHandle {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74macro_rules! handle_method {
75 () => {
76 fn get_handle(&self) -> &OpHandle {
77 &self.handle
78 }
79
80 fn get_handle_mut(&mut self) -> &mut OpHandle {
81 &mut self.handle
82 }
83 }
84}
85
86
87#[derive(Clone)]
91pub struct Op {
92 inner_op: Rc<RefCell<Box<dyn OpTrait>>>,
93}
94impl Op {
95 pub fn new(op: Rc<RefCell<Box<dyn OpTrait>>>) -> Self {
96 Op {
97 inner_op: op.clone(),
98 }
99 }
100 pub fn inner(&self) -> &Rc<RefCell<Box<dyn OpTrait>>> {
101 &self.inner_op
102 }
103
104 pub fn ref_copy(&self) -> Self {
105 Op {
106 inner_op: self.inner_op.clone(),
107 }
108 }
109
110 pub fn get_name(&self) -> String {
111 self.inner_op.borrow().get_name().to_string()
112 }
113 pub fn get_input_size(&self) -> usize {
114 self.inner_op.borrow().get_input_size()
115 }
116 pub fn get_output_size(&self) -> usize {
117 self.inner_op.borrow().get_output_size()
118 }
119 pub fn apply(&self, input: &[Tensor],
122 output: &[Tensor]) {
123 self.inner_op.borrow().apply(input, output);
124 }
125 pub fn grad(&self, input: &[Tensor],
128 output_grad: &[Tensor],
129 input_grad: &[Tensor]) {
130
131 self.inner_op.borrow().grad(input, output_grad, input_grad);
132 }
133
134 pub fn get_values(&self) -> Vec<Tensor> {
136 self.inner_op.borrow().get_values()
137 }
138
139 pub fn set_values(&self, v: &[Tensor]) {
141 self.inner_op.borrow_mut().set_values(v);
142 }
143
144 pub fn get_grads(&self) -> Vec<Tensor> {
146 self.inner_op.borrow().get_grads()
147 }
148}
149pub fn _gradient_checker(op: &mut dyn OpTrait,
213 one_input: &[Tensor], input_mask: Option<&[bool]>,
214 step: Option<Tensor>, tolerance: Option<Tensor>)
215 -> bool {
216
217 let x_mask = if let Some(val) = input_mask {val.to_vec()} else {vec![true; one_input.len()]};
218 let delta = if let Some(val) = step {val.get_scale_f64()} else {0.01};
219 let tol = if let Some(val) = tolerance {val.get_scale_f64()} else {0.01};
220
221
222 let output = Tensor::new();
224 op.apply(one_input, &[output.ref_copy()]);
225
226 let output = output.get_scale_f64();
227
228 let input_grad = vec![Tensor::new(); op.get_input_size()];
230 let mut input_grad_ref = Vec::new();
231 for i in &input_grad {
232 input_grad_ref.push(i.ref_copy());
233 }
234 let output_grad = Tensor::from_vec_f64(&[1.], &[1]);
235 op.grad(one_input, &[output_grad], &input_grad_ref);
236
237 let mut numeric_gradient = Vec::new();
239 for v in one_input {
240 numeric_gradient.push(v.zeros_like())
241 }
242
243 let mut good_gradient = true;
244 for (index, v) in one_input.iter().enumerate() {
245 if !x_mask[index] {
246 continue;
247 }
248
249 for i in 0..v.numel() {
250 let dimpos = v.index2dimpos(i);
251
252 let base_value = v.get_f64(&dimpos);
253 let right_value = base_value + delta;
254 let mut right_tensor = (*v).clone();
255 right_tensor.set_f64(&dimpos, right_value);
256
257 let mut right_input = one_input.to_vec();
258 right_input[index] = right_tensor.ref_copy();
259 let right_output = Tensor::new();
260 op.apply(&right_input, &[right_output.ref_copy()]);
261 let right_output = right_output.get_scale_f64();
262
263 let scale_gradient = (right_output - output)/delta;
264 numeric_gradient[index].set_f64(&dimpos, scale_gradient);
265
266 let system_gradient = input_grad[index].get_f64(&dimpos);
267
268 if (scale_gradient - system_gradient)*(scale_gradient - system_gradient) > tol {
269 println!("input: {:?}, numeric: {:?}, imple: {:?}", one_input[0], scale_gradient, system_gradient);
270 good_gradient = false;
271 }
272 }
273 }
274 good_gradient
275}
276
277#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
281pub struct View {
282 shape: Vec<usize>,
283 #[cfg_attr(feature = "use-serde", serde(skip))]
284 handle: OpHandle,
285}
286impl View {
287 pub fn new(new_shape: &[usize]) -> View {
288 View {
289 shape: new_shape.to_vec(),
290 handle: OpHandle::new(),
291 }
292 }
293 handle_method!();
294}
295impl OpCall for View {
296 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
297 let new_one = View {
298 shape: self.shape.clone(),
299 handle: OpHandle::new(),
300 };
301
302 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
303
304 inputs[0].called_with(op, &inputs[1..inputs.len()])
305 }
306}
307impl OpTrait for View {
308 fn get_name(&self) -> &'static str {
309 "View"
310 }
311 fn get_input_size(&self) -> usize {
312 1
313 }
314 fn get_output_size(&self) -> usize {
315 1
316 }
317
318 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
319 if input.len() > 1 {
320 panic!("view only acceipt one input");
321 }
322
323 let total_numel: usize = self.shape.iter().product();
324 if input[0].numel() != total_numel {
325 panic!("view expect tensor has a total elem of {}, get {}", total_numel, input[0].numel());
326 }
327
328 output[0].swap(&input[0].reshape(&self.shape));
329 }
330
331 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
332
333 input_grad[0].swap(&output_grad[0].reshape(&input[0].size()));
334 }
335
336 fn get_values(&self) -> Vec<Tensor> {
337 Vec::new()
338 }
339 fn set_values(&self, _v: &[Tensor]) {
340 }
341 fn get_grads(&self) -> Vec<Tensor> {
343 Vec::new()
344 }
345
346 #[cfg(feature = "use-serde")]
347 fn as_any(&self) -> &dyn Any {
348 self
349 }
350}
351
352pub mod macros;
353
354pub mod local;
355pub use local::{Add, Sub, Mul, Div, Matmul, Outer};
356
357pub mod linear;
358pub use linear::Linear;
359
360pub mod nonlinear;
361pub use nonlinear::{ELU, ReLU, };
362
363pub mod convolution;
364pub use convolution::{ Conv2d};
365
366pub mod pooling;
367
368pub mod loss;
369pub use loss::{MSELoss, BCEWithLogitsLoss, CrossEntropyLoss};
370
371pub mod element;
372pub use element::{Abs, Acos, Asin, Atan, Ceil, Cos, Cosh, Exp, Expm1, Floor, Frac, Log, Log10, Log1p, Log1pexp, Log2, Neg, Reciprocal, Round, Rsqrt,Sigmoid, Sign, Sin, Sinh, Sqrt, Tan, Tanh, Trunc};
373
374pub mod comparison;
375pub use comparison::{MaxPair, MinPair, ArgSort, EqElem, Equal, Ge, Gt, Le, Lt, Ne};
376
377pub mod index_slicing;
378pub use index_slicing::{Cat, Chunk, ConditionalSelect, Gather, IndexSelect, IndexExclude, Reshape, Split, Squeeze, Stack, T, Take, Permute, Unsqueeze, Repeat};
379
380pub mod linalg;
381pub use linalg::{Det, Inv, NormalizeUnit, Tr};
382
383pub mod reduction;
384pub use reduction::{Argmax, Argmin, Logsumexp, Mean, Prod, Std, Sum, Variance, Max, Min};
385
386pub mod vision;
387pub use vision::{GetPatch, SetPatch};
388
389#[cfg(feature = "use-serde")]
390use auto_diff_macros::gen_serde_funcs;
391#[cfg(feature = "use-serde")]
392use serde::{ser};
393#[cfg(feature = "use-serde")]
394gen_serde_funcs!(View,
395 Add, Sub, Mul, Div, Matmul, Outer,
396 Linear,
397 ELU, ReLU,
398 Conv2d,
399 MSELoss, BCEWithLogitsLoss, CrossEntropyLoss,
400 Abs, Acos, Asin, Atan, Ceil, Cos, Cosh, Exp, Expm1, Floor, Frac, Log, Log10, Log1p, Log1pexp, Log2, Neg, Reciprocal, Round, Rsqrt,Sigmoid, Sign, Sin, Sinh, Sqrt, Tan, Tanh, Trunc,
401 MaxPair, MinPair, ArgSort, EqElem, Equal, Ge, Gt, Le, Lt, Ne,
402 Cat, Chunk, ConditionalSelect, Gather, IndexSelect, IndexExclude, Reshape, Split, Squeeze, Stack, T, Take, Permute, Unsqueeze, Repeat,
403 Det, Inv, NormalizeUnit, Tr,
404 Argmax, Argmin, Logsumexp, Mean, Prod, Std, Sum, Variance, Max, Min,
405 GetPatch, SetPatch);