auto_diff/
var.rs

1use std::cell::RefCell;
2use std::rc::Rc;
3use std::fmt;
4use ::rand::prelude::StdRng;
5use std::ops::{Add, Div, Neg, Sub, Mul};
6use std::convert::TryFrom;
7
8use tensor_rs::tensor::{Tensor};
9use crate::op::{Op};
10use crate::err::AutoDiffError;
11use crate::optim::Optimizer;
12use crate::var_inner::VarInner;
13use crate::compute_graph::{Net};
14
15
16
17macro_rules! var_1_to_1 {
18    ($(#[$attr:meta])*
19     $a:ident) => {
20        $(#[$attr])*
21        pub fn $a(&self) -> Result<Var, AutoDiffError> {
22            Ok(Var {
23                var: Rc::new(RefCell::new(self.var.borrow().$a()?))})
24        }
25    }
26}
27
28macro_rules! var_2_to_1 {
29    ($(#[$attr:meta])*
30     $a:ident) => {
31        $(#[$attr])*
32        pub fn $a(&self, other: &Var) -> Result<Var, AutoDiffError> {
33            Ok(Var {
34                var: Rc::new(RefCell::new(self.var.borrow().$a(&other.var.clone())?))})
35        }
36    }
37}
38
39macro_rules! var_more_to_1_with_para {
40    ($(#[$attr:meta])*
41     $a:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
42        $(#[$attr])*
43        pub fn $a(&self, other: &[Var], $( $arg_name : $ArgTy ),*) -> Result<Var, AutoDiffError> {
44            let mut other_input = Vec::new();
45            for i in other {
46                other_input.push(i.var.clone());
47            }
48            Ok(Var {
49                var: Rc::new(RefCell::new(self.var.borrow().$a(&other_input, $( $arg_name ),*)?))})
50        }
51    }
52}
53
54macro_rules! var_1_to_1_with_para {
55    ($(#[$attr:meta])*
56     $a:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
57        $(#[$attr])*
58        pub fn $a(&self, $( $arg_name : $ArgTy ),*) -> Result<Var, AutoDiffError> {
59            Ok(Var {
60                var: Rc::new(RefCell::new(self.var.borrow().$a($( $arg_name ),*)?))})
61        }
62    }
63}
64
65macro_rules! var_2_to_1_with_para {
66    ($(#[$attr:meta])*
67     $a:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
68        $(#[$attr])*
69        pub fn $a(&self, other: &Var, $( $arg_name : $ArgTy ),*) -> Result<Var, AutoDiffError> {
70            Ok(Var {
71                var: Rc::new(RefCell::new(self.var.borrow().$a(&other.var.clone(), $( $arg_name ),*)?))})
72        }
73    }
74}
75
76macro_rules! delegate_new_op {
77    ($(#[$attr:meta])*
78     $a:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
79        $(#[$attr])*
80        pub fn $a($( $arg_name : $ArgTy ),*) -> Var {
81            Var {
82                var: Rc::new(RefCell::new(VarInner::$a($( $arg_name ),*)))
83            }
84        }
85    }
86}
87
88/// [Var] can be thought as the value it holds plus a link
89/// to the computation graph.
90/// Majority of operators are methods on [Var].
91pub struct Var {
92    var: Rc<RefCell<VarInner>>
93}
94impl Var {
95    #[cfg(feature = "use-f64")]
96    pub fn new(input: &[f64], dim: &[usize]) -> Var {
97        Var::new_f64(input, dim)
98    }
99    #[cfg(feature = "use-f32")]
100    pub fn new(input: &[f32], dim: &[usize]) -> Var {
101        Var::new_f32(input, dim)
102    }
103    pub fn new_f64(input: &[f64], dim: &[usize]) -> Var {
104        Var {
105            var: Rc::new(RefCell::new(VarInner::new_f64(input, dim)))
106        }
107    }
108    pub fn new_f32(input: &[f32], dim: &[usize]) -> Var {
109        Var {
110            var: Rc::new(RefCell::new(VarInner::new_f32(input, dim)))
111        }
112    }
113
114    /// Where it needs a assign operator,
115    /// we should use this ref_copy.
116    /// If a hard copy is necessary, then call clone().
117    pub fn ref_copy(self: &Var) -> Var {
118        Var {
119            var: self.var.clone(),
120        }
121    }
122    /// With a &Var, use set to copy a value.
123    pub fn set(&self, o: &Var) {
124        self.var.borrow_mut().set(&o.var.borrow());
125    }
126
127    pub fn size(&self) -> Vec<usize> {
128        self.var.borrow().size()
129    }
130    pub fn numel(&self) -> usize {
131        self.var.borrow().numel()
132    }
133    pub fn get_f32(&self, o: &[usize]) -> Result<f32, AutoDiffError> {
134        self.var.borrow().get_f32(o)
135    }
136    pub fn set_f32(&self, o: &[usize], v: f32) -> Result<(), AutoDiffError> {
137        self.var.borrow_mut().set_f32(o, v)
138    }
139    pub fn get_f64(&self, o: &[usize]) -> Result<f64, AutoDiffError> {
140        self.var.borrow().get_f64(o)
141    }
142    pub fn set_f64(&self, o: &[usize], v: f64) -> Result<(), AutoDiffError> {
143        self.var.borrow_mut().set_f64(o, v)
144    }
145
146    pub fn fill(size: &[usize], fill_value: &Var) -> Var {
147        Var {
148            var: Rc::new(RefCell::new(
149                VarInner::fill(size, fill_value.var.clone())))
150        }
151    }
152    delegate_new_op!(fill_f32, size: &[usize], fill_value: f32);
153    delegate_new_op!(fill_f64, size: &[usize], fill_value: f64);
154    delegate_new_op!(zeros, dim: &[usize]);
155    delegate_new_op!(ones, dim: &[usize]);
156    delegate_new_op!(twos, dim: &[usize]);
157    //delegate_new_inner_op!(arange, end: usize);
158    //delegate_new_inner_op!(range, start: f32, end: f32, step: Option<f32>);
159    //delegate_new_inner_op!(linspace, start: f32, end: f32, steps: usize);
160    //delegate_new_inner_op!(logspace, start: f32, end: f32, steps: usize, base: f32);
161    delegate_new_op!(
162        /// Identity matrix
163        eye, n: usize, m: usize);
164    delegate_new_op!(empty, dim: &[usize]);
165    
166    /// Fill row by row.
167    pub fn from_record_f32(&self, row: usize, record: &[f32]) {
168        self.var.borrow().from_record_f32(row, record)
169    }
170    pub fn from_record_f64(&self, row: usize, record: &[f64]) {
171        self.var.borrow().from_record_f64(row, record)
172    }
173    
174    // rand
175    delegate_new_op!(rand_usize,
176                     rng: &mut StdRng,
177                     dim: &[usize],
178                     left: usize, right: usize);
179    
180    delegate_new_op!(normal_f64,
181                     rng: &mut StdRng,
182                     dim: &[usize],
183                     mean: f64, std: f64);
184    delegate_new_op!(normal_f32,
185                     rng: &mut StdRng,
186                     dim: &[usize],
187                     mean: f32, std: f32);
188    #[cfg(feature = "use-f32")]
189    pub fn normal(rng: &mut StdRng,
190                  dim: &[usize],
191                  mean: f32, std: f32) -> Var {
192        Self::normal_f32(rng, dim, mean, std)
193    }
194    #[cfg(feature = "use-f64")]
195    pub fn normal(rng: &mut StdRng,
196                  dim: &[usize],
197                  mean: f64, std: f64) -> Var {
198        Self::normal_f64(rng, dim, mean, std)
199    }
200    
201    delegate_new_op!(uniform_f64,
202                     rng: &mut StdRng,
203                     dim: &[usize],
204                     from: f64, to: f64);
205    delegate_new_op!(uniform_f32,
206                     rng: &mut StdRng,
207                     dim: &[usize],
208                     from: f32, to: f32);
209    #[cfg(feature = "use-f32")]
210    pub fn uniform(rng: &mut StdRng,
211                   dim: &[usize],
212                   from: f32, to: f32) -> Var {
213        Self::uniform_f32(rng, dim, from, to)
214    }
215    #[cfg(feature = "use-f64")]
216    pub fn uniform(rng: &mut StdRng,
217                   dim: &[usize],
218                   from: f64, to: f64) -> Var {
219        Self::uniform_f64(rng, dim, from, to)
220    }
221
222    pub fn _add(&self, other: &Var) -> Var {
223        Var {
224            var: Rc::new(RefCell::new(self.var.borrow().add(&other.var).expect("never fail.")))
225        }
226    }
227    pub fn _sub(&self, other: &Var) -> Var {
228        Var {
229            var: Rc::new(RefCell::new(self.var.borrow().sub(&other.var).expect("never fail.")))
230        }
231    }
232    pub fn _mul(&self, other: &Var) -> Var {
233        Var {
234            var: Rc::new(RefCell::new(self.var.borrow().mul(&other.var).expect("never fail.")))
235        }
236    }
237    pub fn _div(&self, other: &Var) -> Var {
238        Var {
239            var: Rc::new(RefCell::new(self.var.borrow().div(&other.var).expect("never fail.")))
240        }
241    }
242    
243    var_2_to_1!(
244        /// Matrix/inner/dot product
245        /// 
246        /// # use auto_diff::{Var, var_f64, AutoDiffError};
247        /// # extern crate openblas_src;
248        /// # fn test_matmul() -> Result<(), AutoDiffError> {
249        /// let v1 = var_f64!([[1., 2., 3.],
250        ///                    [4., 5., 6.]]);
251        /// let v2 = var_f64!([[11., 12., 13.],
252        ///                    [14., 15., 16.],
253        ///                    [17., 18., 19.]]);
254        /// let v3 = v1.matmul(&v2)?;
255        /// let em = var_f64!([[90.0, 96.0, 102.0],
256        ///                    [216.0, 231.0, 246.0]]);
257        /// assert_eq!(v3, em);
258        /// #   Ok(())
259        /// # }
260        /// # test_matmul();
261        /// 
262        matmul);
263    var_2_to_1!(
264        /// Outer product
265        /// ```
266        /// # use auto_diff::{Var, var_f64, AutoDiffError};
267        /// # fn test_outer() -> Result<(), AutoDiffError> {
268        /// let v1 = Var::new_f64(&[1., 2., 3.], &[3]);
269        /// let v2 = Var::new_f64(&[4., 5., 6.], &[3]);
270        /// let v3 = v1.outer(&v2)?;
271        /// let em = var_f64!([[4.,   5.,  6.],
272        ///                    [8.,  10., 12.],
273        ///                    [12., 15., 18.]]);
274        /// assert_eq!(v3, em);
275        /// #   Ok(())
276        /// # }
277        /// # test_outer();
278        /// ```
279        outer);
280
281    // nonlinear
282    pub fn elu(&self, alpha: Var) -> Result<Var, AutoDiffError> {
283        Ok(Var {
284            var: Rc::new(RefCell::new(
285                self.var.borrow().elu(
286                    VarInner::new_tensor(alpha.val()))?))
287        })
288    }
289    var_1_to_1!(relu);
290    var_1_to_1!(sigmoid);
291
292    // loss
293    var_2_to_1!(mse_loss);
294    var_2_to_1!(bce_with_logits_loss);
295    var_2_to_1!(cross_entropy_loss);
296
297    //elementwise op
298    var_1_to_1!(abs);
299    var_1_to_1!(acos);
300    var_1_to_1!(asin);
301    var_1_to_1!(atan);
302    var_1_to_1!(ceil);
303    var_1_to_1!(cos);
304    var_1_to_1!(cosh);
305    var_1_to_1!(exp);
306    var_1_to_1!(expm1);
307    var_1_to_1!(floor);
308    var_1_to_1!(frac);
309    var_1_to_1!(log);
310    var_1_to_1!(log10);
311    var_1_to_1!(log1p);
312    var_1_to_1!(log1pexp);
313    var_1_to_1!(log2);
314    var_1_to_1!(neg);
315    pub fn _neg(&self) -> Var {
316        Var {
317            var: Rc::new(RefCell::new(self.var.borrow().neg().expect("never fail.")))
318        }
319    }
320    var_1_to_1!(reciprocal);
321    var_1_to_1!(round);
322    var_1_to_1!(rsqrt);
323    var_1_to_1!(sign);
324    var_1_to_1!(sin);
325    var_1_to_1!(sinh);
326    var_1_to_1!(sqrt);
327    var_1_to_1!(tan);
328    var_1_to_1!(tanh);
329    var_1_to_1!(trunc);
330
331    // comparison
332    var_2_to_1!(max_pair);
333    var_2_to_1!(min_pair);
334    var_1_to_1_with_para!(arg_sort, 
335                          dim: usize, descending: bool);
336    var_2_to_1!(eq_elem);
337    var_2_to_1!(equal);
338    var_2_to_1!(ge);
339    var_2_to_1!(gt);
340    var_2_to_1!(le);
341    var_2_to_1!(lt);
342    var_2_to_1!(ne);
343
344    // index and slicing
345    var_more_to_1_with_para!(
346        /// Concatenates the given sequence of seq tensors
347        /// in the given dimension.
348        /// The input tensor should all have the same size except
349        /// on the given dimension.
350        /// The output tensor will have all the same size as the input
351        /// except the given dimension, which will be the sum of
352        /// the inputs on the given dimension.
353        /// Apply cat on [tensor(5, 3, 2), tensor(5, 7, 2), ]
354        /// will get a tensor(5, 10, 2).
355        ///
356        /// 
357        /// # use auto_diff::{Var, var_f64, AutoDiffError};
358        /// # extern crate openblas_src;
359        /// # fn test_cat() -> Result<(), AutoDiffError> {
360        /// let m1 = Var::empty(&[3, 1]);
361        /// let m2 = Var::empty(&[3, 1]);
362        /// let m3 = Var::empty(&[3, 1]);
363        /// let m4 = m1.cat(&[m2, m3], 1)?;
364        /// assert_eq!(m4.size(), [3, 3]);
365        /// #   Ok(())
366        /// # }
367        /// # test_cat();
368        /// 
369        cat, dim: usize);
370    pub fn chunk(&self, chunks: usize, dim: usize)
371                 -> Result<Vec<Var>, AutoDiffError> {
372        let mut result = self.var.borrow().chunk(chunks, dim)?;
373        let mut ret = Vec::new();
374        for i in result.drain(..) {
375            ret.push(Var {
376                var: Rc::new(RefCell::new(i)),
377            });
378        }
379        Ok(ret)
380    }
381    pub fn conditional_select(&self, x: &Var, y: &Var)
382                              -> Result<Var, AutoDiffError> {
383        let result = self.var.borrow().conditional_select(x.var.clone(),
384                                                          y.var.clone())?;
385        Ok(Var {
386            var: Rc::new(RefCell::new(result)),
387        })
388    }
389    pub fn gather(&self, dim: usize, index: Var)
390                  -> Result<Var, AutoDiffError> {
391        let result = self.var.borrow().gather(dim, index.var)?;
392        Ok(Var {
393            var: Rc::new(RefCell::new(result)),
394        })
395    }
396    pub fn index_select(&self, dim: usize,
397                        index: Var)
398                        -> Result<Var, AutoDiffError> {
399        let result = self.var.borrow().index_select(
400            dim, index.var)?;
401        Ok(Var {
402            var: Rc::new(RefCell::new(result)),
403        })
404    }
405    pub fn index_exclude(&self, dim: usize,
406                        index: Var)
407                        -> Result<Var, AutoDiffError> {
408        let result = self.var.borrow().index_exclude(
409            dim, index.var)?;
410        Ok(Var {
411            var: Rc::new(RefCell::new(result)),
412        })
413    }
414    pub fn permute(&self, dim: &[usize])
415                   -> Result<Var, AutoDiffError> {
416        let result = self.var.borrow().permute(dim)?;
417        Ok(Var {
418            var: Rc::new(RefCell::new(result)),
419        })
420    }
421    pub fn repeat(&self, dim: &[usize])
422                  -> Result<Var, AutoDiffError> {
423        let result = self.var.borrow().repeat(dim)?;
424        Ok(Var {
425            var: Rc::new(RefCell::new(result)),
426        })
427    }
428    pub fn reshape(&self, new_shape: &[usize])
429                  -> Result<Var, AutoDiffError> {
430        let result = self.var.borrow().reshape(new_shape)?;
431        Ok(Var {
432            var: Rc::new(RefCell::new(result)),
433        })
434    }
435    pub fn split(&self, sections: &[usize], dim: usize)
436                  -> Result<Vec<Var>, AutoDiffError> {
437        let mut result = self.var.borrow().split(sections, dim)?;
438        let mut ret = Vec::new();
439        for i in result.drain(..) {
440            ret.push(Var {
441                var: Rc::new(RefCell::new(i)),
442            });
443        }
444        Ok(ret)
445    }
446    pub fn squeeze(&self, dim: Option<usize>)
447                   -> Result<Var, AutoDiffError> {
448        let result = self.var.borrow().squeeze(dim)?;
449        Ok(Var {
450            var: Rc::new(RefCell::new(result)),
451        })
452    }
453    var_1_to_1!(t);
454    pub fn take(&self, index: &[usize])
455                -> Result<Var, AutoDiffError> {
456        let result = self.var.borrow().take(index)?;
457        Ok(Var {
458            var: Rc::new(RefCell::new(result)),
459        })
460    }
461    pub fn unsqueeze(&self, dim: usize)
462                   -> Result<Var, AutoDiffError> {
463        let result = self.var.borrow().unsqueeze(dim)?;
464        Ok(Var {
465            var: Rc::new(RefCell::new(result)),
466        })
467    }
468    var_more_to_1_with_para!(
469        /// Stack tensor with the same size along a new dimension
470        /// specified by dim.
471        /// The difference from cat is that cat don't create new dimension.
472        ///
473        /// ```
474        /// # use auto_diff::{Var, var_f64, AutoDiffError};
475        /// # fn test_stack() -> Result<(), AutoDiffError> {
476        /// let m1 = var_f64!([[1., 2., ],
477        ///                [3., 4., ]]);
478        /// let m2 = var_f64!([[5., 6., ],
479        ///                [7., 8., ]]);
480        /// let m3 = m1.stack(&[m2], 1)?;
481        /// #   let em = Var::new_f64(&[1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0], &[2, 2, 2]);
482        /// #   assert_eq!(m3, em);
483        /// #   Ok(())
484        /// # }
485        /// # test_stack();
486        /// ```
487        stack, dim: usize);
488
489    // linalg
490    var_1_to_1!(det);
491    var_1_to_1!(inv);
492    var_1_to_1!(normalize_unit);
493    var_1_to_1!(tr);
494
495    // reduction
496    var_1_to_1_with_para!(argmax, dim: Option<&[usize]>, keepdim: bool);
497    var_1_to_1_with_para!(argmin, dim: Option<&[usize]>, keepdim: bool);
498    var_1_to_1_with_para!(logsumexp, dim: Option<&[usize]>, keepdim: bool);
499    var_1_to_1_with_para!(mean, dim: Option<&[usize]>, keepdim: bool);
500    var_1_to_1_with_para!(prod, dim: Option<&[usize]>, keepdim: bool);
501    var_1_to_1_with_para!(std, dim: Option<&[usize]>, keepdim: bool);
502    var_1_to_1_with_para!(sum, dim: Option<&[usize]>, keepdim: bool);
503    var_1_to_1_with_para!(var, dim: Option<&[usize]>, keepdim: bool);
504    var_1_to_1_with_para!(max, dim: Option<&[usize]>, keepdim: bool);
505    var_1_to_1_with_para!(min, dim: Option<&[usize]>, keepdim: bool);
506
507    // images
508    var_1_to_1_with_para!(
509        /// Get a portion of the tensor and return it.
510        ///
511        /// ```
512        /// # use auto_diff::{Var, var_f64, AutoDiffError};
513        /// # fn test_get_patch() -> Result<(), AutoDiffError> {
514        /// let m1 = var_f64!([[1., 2., 3.],
515        ///                    [4., 5., 6.],
516        ///                    [7., 8., 9.]]);
517        /// let m2 = var_f64!([[4., 5.],
518        ///                    [7., 8.]]);
519        /// assert_eq!(m1.get_patch(&[(1, 3), (0, 2)], None)?, m2);
520        /// #   Ok(())
521        /// # }
522        /// # test_get_patch();
523        /// ```
524        get_patch, range: &[(usize, usize)], step: Option<&[usize]>);
525    var_2_to_1_with_para!(
526        /// Set a portion of the tensor.
527        ///
528        /// ```
529        /// # use auto_diff::{Var, var_f64, AutoDiffError};
530        /// # fn test_set_patch() -> Result<(), AutoDiffError> {
531        /// let m1 = var_f64!([[1., 2., 3.],
532        ///                    [4., 5., 6.],
533        ///                    [7., 8., 9.]]);
534        /// let m2 = var_f64!([[10., 11.],
535        ///                    [12., 13.]]);
536        /// let m3 = var_f64!([[1.,   2., 3.],
537        ///                    [10., 11., 6.],
538        ///                    [12., 13., 9.]]);
539        /// assert_eq!(m1.set_patch(&m2, &[(1, 3), (0, 2)], None)?, m3);
540        /// #   Ok(())
541        /// # }
542        /// # test_set_patch();
543        /// ```
544        set_patch, range: &[(usize, usize)], step: Option<&[usize]>);
545    var_1_to_1_with_para!(view, new_shape: &[usize]);
546
547
548    // innternal use
549    pub fn val(&self) -> Tensor {
550        self.var.borrow().val()
551    }
552
553    /// Use gradient or not, default is to use.
554    pub fn set_grad(&self, use_gradient: bool) {
555        self.var.borrow_mut().set_grad(use_gradient);
556    }
557
558    /// Reset net in the background.
559    pub fn reset_net(&self) {
560        self.var.borrow_mut().reset_net();
561    }
562
563    /// The current gradient for the Var.
564    pub fn grad(&self) -> Result<Var, AutoDiffError> {
565        Ok(Var {
566            var: Rc::new(RefCell::new(self.var.borrow().grad()?))
567        })
568    }
569
570    /// Apply back propagation to get numerical gradient.
571    pub fn bp(&self) -> Result<(), AutoDiffError> {
572        self.var.borrow().bp()?;
573
574        Ok(())
575    }
576
577    pub fn step(&self, opt: &mut dyn Optimizer) -> Result<(), AutoDiffError> {
578        self.var.borrow().step(opt)
579    }
580
581    /// Run the computation graph again.
582    pub fn rerun(&self) -> Result<(), AutoDiffError> {
583        self.var.borrow().rerun()
584    }
585
586    /// Extract input and output from the hidden net.
587    pub fn get_io_var(&self) -> Result<(Vec<Var>, Vec<Var>), AutoDiffError> {
588	let (mut inputs, mut outputs) = self.var.borrow().get_io_var()?;
589	Ok((inputs.drain(..).map(|x| Var {var: Rc::new(RefCell::new(x))}).collect(),
590	    outputs.drain(..).map(|x| Var {var: Rc::new(RefCell::new(x))}).collect()))
591    }
592    
593    /// Get var by string label
594    pub fn get_var_by_label(&self, label: &str) -> Result<Var, AutoDiffError> {
595	let inner = self.var.borrow().get_var_by_label(label)?;
596	Ok(Var {
597	    var: Rc::new(RefCell::new(inner)),
598	})
599    }
600
601    pub fn set_label(&self, label: &str) -> Result<(), AutoDiffError> {
602	self.var.borrow().set_label(label)
603    }
604
605    pub fn set_predict(&self) -> Result<(), AutoDiffError> {
606	self.set_label("__predict__")
607    }
608    pub fn predict(&self) -> Result<Var, AutoDiffError> {
609	self.get_var_by_label("__predict__")
610    }
611
612    pub(crate) fn called_with(&self, op: Op,
613                              others: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
614        let refs: Vec<Rc<RefCell<VarInner>>> = others.iter().map(|x| x.var.clone()).collect();
615        let mut var_inners = self.var.borrow().called_with(op, &refs)?;
616        let ret: Vec<Var> = var_inners.drain(..).map(|x| Var {
617            var: Rc::new(RefCell::new(x))
618        }).collect();
619        Ok(ret)
620    }
621
622    /// For debug.
623    pub fn dump_net(&self) -> Rc<RefCell<Net>> {
624        self.var.borrow().dump_net()
625    }
626    pub(crate) fn inner(&self) -> Rc<RefCell<VarInner>> {
627	self.var.clone()
628    }
629    pub(crate) fn set_inner(var: VarInner) -> Var {
630	Var {
631	    var: Rc::new(RefCell::new(var))
632	}
633    }
634}
635
636// Test for equal
637impl PartialEq for Var {
638    fn eq(&self, other: &Self) -> bool {
639        self.var.borrow().val().eq(&other.var.borrow().val())
640    }
641}
642
643impl Eq for Var {}
644
645impl fmt::Display for Var {
646    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
647//        write!(f, "id: {}", self.id)?;
648        write!(f, "tensor: {}", self.var.borrow().val())
649    }
650}
651
652impl fmt::Debug for Var {
653    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
654//        write!(f, "id: {}", self.id)?;
655        write!(f, "tensor: {}", self.var.borrow().val())
656    }
657}
658
659impl Clone for Var {
660    fn clone(&self) -> Self {
661        Var {
662            var: Rc::new(RefCell::new(self.var.borrow().clone()))
663        }
664    }
665}
666
667// Operator overloading
668impl Add for Var {
669    type Output = Self;
670
671    fn add(self, other: Self) -> Self {
672        self._add(&other)
673    }
674}
675
676impl Sub for Var {
677    type Output = Self;
678
679    fn sub(self, other: Self) -> Self {
680        self._sub(&other)
681    }
682}
683
684impl Mul for Var {
685    type Output = Self;
686
687    fn mul(self, other: Self) -> Self {
688        self._mul(&other)
689    }
690}
691
692impl Div for Var {
693    type Output = Self;
694
695    fn div(self, other: Self) -> Self {
696        self._div(&other)
697    }
698}
699
700impl Neg for Var {
701    type Output = Self;
702
703    fn neg(self) -> Self {
704        self._neg()
705    }
706}
707
708impl TryFrom<Var> for f32 {
709    type Error = AutoDiffError;
710
711    fn try_from(value: Var) -> Result<Self, Self::Error> {
712        if value.numel() > 1 {
713            return Err(AutoDiffError::new("TryFrom<Var> for f32 only works for 1 element var."))
714        }
715	let index = vec![0; value.size().len()];
716	value.get_f32(&index)
717    }
718}
719
720impl TryFrom<Var> for f64 {
721    type Error = AutoDiffError;
722
723    fn try_from(value: Var) -> Result<Self, Self::Error> {
724        if value.numel() > 1 {
725            return Err(AutoDiffError::new("TryFrom<Var> for f64 only works for 1 element var."));
726        }
727	let index = vec![0; value.size().len()];
728	value.get_f64(&index)
729    }
730}
731
732impl TryFrom<Var> for Vec<usize> {
733    type Error = AutoDiffError;
734
735    fn try_from(value: Var) -> Result<Self, Self::Error> {
736        let t = value.val();
737        if t.size().len() > 2 {
738            return Err(AutoDiffError::new("expect size [n,1], or [n]."));
739        }
740        let value = t.reshape(&[value.numel()]);
741        let ret = value.get_raw_f64().iter().map(|x| *x as usize).collect();
742        Ok(ret)
743    }
744}
745
746#[macro_export]
747macro_rules! var_f64 {
748    ($a:expr) => {{
749        trait Expand {
750            fn expand(&self) -> Var;
751        }
752        impl Expand for [f64] {
753            fn expand(&self) -> Var {
754                Var::new_f64(&self, &[self.len(), 1])
755            }
756        }
757        impl Expand for [[f64; 1]] {
758            fn expand(&self) -> Var {
759                let mut data = vec![];
760                let m = self.len();
761                let mut n = 0;
762                for i in self {
763                    n = i.len();
764                    data.append(&mut i.to_vec());
765                }
766                Var::new_f64(&data, &[m, n])
767            }
768        }
769        impl Expand for [[f64; 2]] {
770            fn expand(&self) -> Var {
771                let mut data = vec![];
772                let m = self.len();
773                let mut n = 0;
774                for i in self {
775                    n = i.len();
776                    data.append(&mut i.to_vec());
777                }
778                Var::new_f64(&data, &[m, n])
779            }
780        }
781        impl Expand for [[f64; 3]] {
782            fn expand(&self) -> Var {
783                let mut data = vec![];
784                let m = self.len();
785                let mut n = 0;
786                for i in self {
787                    n = i.len();
788                    data.append(&mut i.to_vec());
789                }
790                Var::new_f64(&data, &[m, n])
791            }
792        }
793        impl Expand for [[f64; 4]] {
794            fn expand(&self) -> Var {
795                let mut data = vec![];
796                let m = self.len();
797                let mut n = 0;
798                for i in self {
799                    n = i.len();
800                    data.append(&mut i.to_vec());
801                }
802                Var::new_f64(&data, &[m, n])
803            }
804        }
805        $a.expand()
806    }}
807}
808
809
810#[cfg(test)]
811mod tests {
812    use super::*;
813    use crate::op::OpCall;
814    extern crate openblas_src;
815
816    #[test]
817    fn mul() {
818        let a = Var::new(&[2., 3., 4., 5.], &[2, 2]);
819        let b = Var::new(&[1., 2., 3., 4.], &[2, 2]);
820        let c = a.ref_copy() * b.ref_copy();
821        assert_eq!(c, Var::new(&[2., 6., 12., 20.], &[2, 2]));
822        c.bp().unwrap();
823        assert_eq!(a.grad().unwrap(), Var::new(&[1., 2., 3., 4.], &[2, 2]));
824        assert_eq!(b.grad().unwrap(), Var::new(&[2., 3., 4., 5.], &[2, 2]));
825    }
826
827    #[test]
828    fn test_mul_repeat_vars() {
829        let a = Var::new(&[2., 3., 4., 5.], &[2, 2]);
830        let b = Var::new(&[1., 2., 3., 4.], &[2, 2]);
831        let c = a * b.ref_copy();
832        let d = c * b; // repeat vars
833        assert_eq!(d, Var::new(&[2., 12., 36., 80.], &[2, 2]));
834    }
835
836    #[test]
837    fn test_add_in_fn() {
838        let a = Var::new(&[2., 3., 4., 5.], &[2, 2]);
839        let b = Var::new(&[1., 2., 3., 4.], &[2, 2]);
840    
841        fn my_mul(a: &Var, b: &Var) -> Var {
842            a.ref_copy() * b.ref_copy()
843        }
844        let c = my_mul(&a, &b);
845        assert_eq!(c, Var::new(&[2., 6., 12., 20.], &[2, 2]));
846    }
847
848    #[test]
849    fn test_op_mse() {
850        let a = Var::new(&[1., 2., 3., 4., 5., 6.,], &[3, 2]);
851        let b = Var::new(&[2., 3., 4., 5., 6., 7.,], &[3, 2]);
852        let c = a.mse_loss(&b).unwrap();
853        assert_eq!(c , Var::new(&[1., ], &vec![1]));
854    }
855
856    #[test]
857    fn test_linear() {
858        use crate::op::Linear;
859        
860        let mut op1 = Linear::new(Some(2), Some(5), true);
861        op1.set_weight(Var::new(&[1.,2.,3.,4.,5.,6.,7.,8.,9.,10.], &[2, 5]));
862        op1.set_bias(Var::new(&[1.,2.,3.,4.,5.], &[5]));
863        let input = Var::ones(&[3,2]);
864        let output = op1.call(&[&input]).unwrap().pop().unwrap();
865        assert_eq!(output, Var::new(&[8.0, 11.0, 14.0, 17.0, 20.0, 8.0, 11.0, 14.0, 17.0, 20.0, 8.0, 11.0, 14.0, 17.0, 20.0],
866                                    &vec![3, 5]));
867    }
868
869    #[test]
870    fn test_macro_tensor() {
871        let a = var_f64!([1., 2., 3.,]);
872        assert_eq!(a.size(), [3, 1]);
873        let a1 = a.squeeze(None).unwrap();
874        println!("{:?}", a1.size());
875        let a = var_f64!([[1., 2.,],
876                          [4., 5.,],
877                          [4., 5.,]]);
878        assert_eq!(a.size(), [3, 2]);
879        let a = var_f64!([[1., 2., 3.,],
880                          [4., 5., 6.,]]);
881        assert_eq!(a.size(), [2, 3]);
882        let a = var_f64!([[1., 2., 3., 7.],
883                          [4., 5., 6., 8.]]);
884        assert_eq!(a.size(), [2, 4]);
885    }
886}