auto_diff/op/
mod.rs

1/// Only NCWH format is supported.
2use std::cell::{RefCell};
3use std::rc::Rc;
4
5use tensor_rs::tensor::Tensor;
6use crate::var::Var;
7use crate::err::AutoDiffError;
8use crate::collection::generational_index::{GenKey};
9use crate::compute_graph::Net;
10
11
12#[cfg(feature = "use-serde")]
13use serde::{Serializer, de, de::MapAccess, de::SeqAccess,};
14#[cfg(feature = "use-serde")]
15use serde::{Serialize, Deserialize};
16#[cfg(feature = "use-serde")]
17use std::any::Any;
18
19/// Implement operator by this trait
20/// to allow the operator be able to stored
21/// in the computation graph.
22pub trait OpTrait {
23    /// A conventional name for the op
24    fn get_name(&self) -> &'static str;
25
26    /// The number of input needs by this op.
27    fn get_input_size(&self) -> usize;
28
29    /// The number of output produced by this op.
30    fn get_output_size(&self) -> usize;
31
32    /// Forward pass
33    fn apply(&self, input: &[Tensor], output: &[Tensor]);
34
35    /// Given the forward input value and backward output_grad,
36    /// Update weight gradient.
37    /// return backward input gradeint.
38    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]);
39
40    /// access weight values
41    fn get_values(&self) -> Vec<Tensor>;
42    fn set_values(&self, v: &[Tensor]);
43    /// access gradient values
44    fn get_grads(&self) -> Vec<Tensor>;
45
46    #[cfg(feature = "use-serde")]
47    fn as_any(&self) -> &dyn Any;
48}
49
50/// Ops that first created,
51/// then called needs to follow this behavior.
52pub trait OpCall {
53    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError>;
54}
55
56pub struct OpHandle {
57    id: GenKey,    
58    net: Rc<RefCell<Net>>,
59}
60impl OpHandle {
61    pub fn new() -> OpHandle {
62        OpHandle {
63            id: GenKey::new(0, 0),
64            net: Rc::new(RefCell::new(Net::new()))
65        }
66    }
67}
68impl Default for OpHandle {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74macro_rules! handle_method {
75    () => {
76        fn get_handle(&self) -> &OpHandle {
77            &self.handle
78        }
79    
80        fn get_handle_mut(&mut self) -> &mut OpHandle {
81            &mut self.handle
82        }
83    }
84}
85
86
87///
88/// Op is the Rc wrapper of typed op trait
89///
90#[derive(Clone)]
91pub struct Op {
92    inner_op: Rc<RefCell<Box<dyn OpTrait>>>,
93}
94impl Op {
95    pub fn new(op: Rc<RefCell<Box<dyn OpTrait>>>) -> Self {
96        Op {
97            inner_op: op.clone(),
98        }
99    }
100    pub fn inner(&self) -> &Rc<RefCell<Box<dyn OpTrait>>> {
101	&self.inner_op
102    }
103
104    pub fn ref_copy(&self) -> Self {
105        Op {
106            inner_op: self.inner_op.clone(),
107        }
108    }
109
110    pub fn get_name(&self) -> String {
111        self.inner_op.borrow().get_name().to_string()
112    }
113    pub fn get_input_size(&self) -> usize {
114        self.inner_op.borrow().get_input_size()
115    }
116    pub fn get_output_size(&self) -> usize {
117        self.inner_op.borrow().get_output_size()
118    }
119    /// Read the input, do the calculation and write result to output.
120    /// Called by compute_grapyh.
121    pub fn apply(&self, input: &[Tensor],
122                 output: &[Tensor]) {
123        self.inner_op.borrow().apply(input, output);
124    }
125    /// Given input and output_grad, return input_grad (forward view)
126    /// Called by compute_grapyh.
127    pub fn grad(&self, input: &[Tensor],
128                output_grad: &[Tensor],
129                input_grad: &[Tensor]) {
130
131        self.inner_op.borrow().grad(input, output_grad, input_grad);
132    }
133
134    /// access weight/paramenters
135    pub fn get_values(&self) -> Vec<Tensor> {
136        self.inner_op.borrow().get_values()
137    }
138
139    /// set parameters
140    pub fn set_values(&self, v: &[Tensor]) {
141        self.inner_op.borrow_mut().set_values(v);
142    }
143
144    /// return gradient for weight/parameters.
145    pub fn get_grads(&self) -> Vec<Tensor> {
146        self.inner_op.borrow().get_grads()
147    }
148}
149//impl Clone for Op {
150//    fn clone(&self) -> Self {
151//        Op {
152//            update_counter: self.update_counter.clone(),
153//            para_grad: self.para_grad.iter().map(|(a, b)| (a.clone(), b.clone())).collect(),
154//            func_apply: self.func_apply.clone(),
155//            func_gradient: self.func_gradient.clone(),
156//            name: self.name.clone(),
157//            input_size: self.input_size,
158//            output_size: self.output_size,
159//        }
160//    }
161//}
162
163
164
165//pub struct Nop {
166//}
167//impl OpTrait for Nop {
168//    fn get_name(&self) -> String {
169//        "Nop".to_string()
170//    }
171//    fn get_input_size(&self) -> usize {
172//        0
173//    }
174//    fn get_output_size(&self) -> usize {
175//        0
176//    }
177//
178//    /// Forward pass
179//    fn apply(&mut self, _input: &[&Tensor], _output: &[&Tensor]) {
180//        
181//    }
182//    fn grad(&self, _input: &[&Tensor], _output_grad: &[&Tensor], _input_grad: &[&Tensor]) {
183//        
184//    }
185//
186//    /// access weight values
187//    fn get_values(&self) -> Vec<&Tensor> {
188//        Vec::new()
189//    }
190//    fn set_values(&self, _v: &[Tensor]) {
191//        
192//    }
193//    /// access gradient values
194//    fn get_grads(&self) -> Vec<&Tensor> {
195//        Vec::new()
196//    }
197//}
198
199
200
201///
202/// Verify the gradient implementation is right.
203///
204/// op: the tested operator.
205/// one_input: test data points.
206/// input_mask: may skip some data point if the element is false.
207/// step: delta that is used for numeric difference.
208/// tolerance: numeric tolerance for equality.
209///
210/// one_input and input_mask should have the same size.
211/// step and tolerance are both scalar.
212pub fn _gradient_checker(op: &mut dyn OpTrait,
213                         one_input: &[Tensor], input_mask: Option<&[bool]>,
214                         step: Option<Tensor>, tolerance: Option<Tensor>)
215			 -> bool {
216
217    let x_mask = if let Some(val) = input_mask {val.to_vec()} else {vec![true; one_input.len()]};
218    let delta = if let Some(val) = step {val.get_scale_f64()} else {0.01};
219    let tol = if let Some(val) = tolerance {val.get_scale_f64()} else {0.01};
220
221
222    // system output
223    let output = Tensor::new();
224    op.apply(one_input, &[output.ref_copy()]);
225
226    let output = output.get_scale_f64();
227
228    // get the system gradient
229    let input_grad = vec![Tensor::new(); op.get_input_size()];
230    let mut input_grad_ref = Vec::new();
231    for i in &input_grad {
232        input_grad_ref.push(i.ref_copy());
233    }
234    let output_grad = Tensor::from_vec_f64(&[1.], &[1]);
235    op.grad(one_input, &[output_grad], &input_grad_ref);
236
237    // get the numeric gradient
238    let mut numeric_gradient = Vec::new();
239    for v in one_input {
240        numeric_gradient.push(v.zeros_like())
241    }
242
243    let mut good_gradient = true;
244    for (index, v) in one_input.iter().enumerate() {
245        if !x_mask[index] {
246            continue;
247        }
248        
249        for i in 0..v.numel() {
250            let dimpos = v.index2dimpos(i);
251                
252            let base_value = v.get_f64(&dimpos);
253            let right_value = base_value + delta;
254            let mut right_tensor = (*v).clone();
255            right_tensor.set_f64(&dimpos, right_value);
256
257            let mut right_input = one_input.to_vec();
258            right_input[index] = right_tensor.ref_copy();
259            let right_output = Tensor::new();
260            op.apply(&right_input, &[right_output.ref_copy()]);
261            let right_output = right_output.get_scale_f64();
262
263            let scale_gradient = (right_output - output)/delta;
264            numeric_gradient[index].set_f64(&dimpos, scale_gradient);
265
266            let system_gradient = input_grad[index].get_f64(&dimpos);
267
268            if (scale_gradient - system_gradient)*(scale_gradient - system_gradient) > tol {
269                println!("input: {:?}, numeric: {:?}, imple: {:?}", one_input[0], scale_gradient, system_gradient);
270                good_gradient = false;
271            }
272        }
273    }
274    good_gradient
275}
276
277///
278/// View op
279///
280#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
281pub struct View {
282    shape: Vec<usize>,
283    #[cfg_attr(feature = "use-serde", serde(skip))]
284    handle: OpHandle,
285}
286impl View {
287    pub fn new(new_shape: &[usize]) -> View {
288        View {
289            shape: new_shape.to_vec(),
290            handle: OpHandle::new(),
291        }
292    }
293    handle_method!();
294}
295impl OpCall for View {
296    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
297        let new_one = View {
298            shape: self.shape.clone(),
299            handle: OpHandle::new(),
300        };
301
302        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
303
304        inputs[0].called_with(op, &inputs[1..inputs.len()])
305    }
306}
307impl OpTrait for View {
308    fn get_name(&self) -> &'static str {
309        "View"
310    }
311    fn get_input_size(&self) -> usize {
312        1
313    }
314    fn get_output_size(&self) -> usize {
315        1
316    }
317
318    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
319        if input.len() > 1 {
320            panic!("view only acceipt one input");
321        }
322
323        let total_numel: usize = self.shape.iter().product();
324        if input[0].numel() != total_numel {
325            panic!("view expect tensor has a total elem of {}, get {}", total_numel, input[0].numel());
326        }
327
328        output[0].swap(&input[0].reshape(&self.shape));
329    }
330
331    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
332        
333        input_grad[0].swap(&output_grad[0].reshape(&input[0].size()));
334    }
335
336    fn get_values(&self) -> Vec<Tensor> {
337        Vec::new()
338    }
339    fn set_values(&self, _v: &[Tensor]) {
340    }
341    /// access gradient values
342    fn get_grads(&self) -> Vec<Tensor> {
343        Vec::new()
344    }
345
346    #[cfg(feature = "use-serde")]
347    fn as_any(&self) -> &dyn Any {
348	self
349    }
350}
351
352pub mod macros;
353
354pub mod local;
355pub use local::{Add, Sub, Mul, Div, Matmul, Outer};
356
357pub mod linear;
358pub use linear::Linear;
359
360pub mod nonlinear;
361pub use nonlinear::{ELU, ReLU, };
362
363pub mod convolution;
364pub use convolution::{ Conv2d};
365
366pub mod pooling;
367
368pub mod loss;
369pub use loss::{MSELoss, BCEWithLogitsLoss, CrossEntropyLoss};
370
371pub mod element;
372pub use element::{Abs, Acos, Asin, Atan, Ceil, Cos, Cosh, Exp, Expm1, Floor, Frac, Log, Log10, Log1p, Log1pexp, Log2, Neg, Reciprocal, Round, Rsqrt,Sigmoid, Sign, Sin, Sinh, Sqrt, Tan, Tanh, Trunc};
373
374pub mod comparison;
375pub use comparison::{MaxPair, MinPair, ArgSort, EqElem, Equal, Ge, Gt, Le, Lt, Ne};
376
377pub mod index_slicing;
378pub use index_slicing::{Cat, Chunk, ConditionalSelect, Gather, IndexSelect, IndexExclude, Reshape, Split, Squeeze, Stack, T, Take, Permute, Unsqueeze, Repeat};
379
380pub mod linalg;
381pub use linalg::{Det, Inv, NormalizeUnit, Tr};
382
383pub mod reduction;
384pub use reduction::{Argmax, Argmin, Logsumexp, Mean, Prod, Std, Sum, Variance, Max, Min};
385
386pub mod vision;
387pub use vision::{GetPatch, SetPatch};
388
389#[cfg(feature = "use-serde")]
390use auto_diff_macros::gen_serde_funcs;
391#[cfg(feature = "use-serde")]
392use serde::{ser};
393#[cfg(feature = "use-serde")]
394gen_serde_funcs!(View,
395                 Add, Sub, Mul, Div, Matmul, Outer,
396                 Linear,
397                 ELU, ReLU,
398                 Conv2d,
399                 MSELoss, BCEWithLogitsLoss, CrossEntropyLoss,
400                 Abs, Acos, Asin, Atan, Ceil, Cos, Cosh, Exp, Expm1, Floor, Frac, Log, Log10, Log1p, Log1pexp, Log2, Neg, Reciprocal, Round, Rsqrt,Sigmoid, Sign, Sin, Sinh, Sqrt, Tan, Tanh, Trunc,
401                 MaxPair, MinPair, ArgSort, EqElem, Equal, Ge, Gt, Le, Lt, Ne,
402                 Cat, Chunk, ConditionalSelect, Gather, IndexSelect, IndexExclude, Reshape, Split, Squeeze, Stack, T, Take, Permute, Unsqueeze, Repeat,
403                 Det, Inv, NormalizeUnit, Tr,
404                 Argmax, Argmin, Logsumexp, Mean, Prod, Std, Sum, Variance, Max, Min,
405                 GetPatch, SetPatch);