1use std::cell::RefCell;
2use std::rc::Rc;
3use std::fmt;
4use ::rand::prelude::StdRng;
5use std::ops::{Add, Div, Neg, Sub, Mul};
6use std::convert::TryFrom;
7
8use tensor_rs::tensor::{Tensor};
9use crate::op::{Op};
10use crate::err::AutoDiffError;
11use crate::optim::Optimizer;
12use crate::var_inner::VarInner;
13use crate::compute_graph::{Net};
14
15
16
17macro_rules! var_1_to_1 {
18 ($(#[$attr:meta])*
19 $a:ident) => {
20 $(#[$attr])*
21 pub fn $a(&self) -> Result<Var, AutoDiffError> {
22 Ok(Var {
23 var: Rc::new(RefCell::new(self.var.borrow().$a()?))})
24 }
25 }
26}
27
28macro_rules! var_2_to_1 {
29 ($(#[$attr:meta])*
30 $a:ident) => {
31 $(#[$attr])*
32 pub fn $a(&self, other: &Var) -> Result<Var, AutoDiffError> {
33 Ok(Var {
34 var: Rc::new(RefCell::new(self.var.borrow().$a(&other.var.clone())?))})
35 }
36 }
37}
38
39macro_rules! var_more_to_1_with_para {
40 ($(#[$attr:meta])*
41 $a:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
42 $(#[$attr])*
43 pub fn $a(&self, other: &[Var], $( $arg_name : $ArgTy ),*) -> Result<Var, AutoDiffError> {
44 let mut other_input = Vec::new();
45 for i in other {
46 other_input.push(i.var.clone());
47 }
48 Ok(Var {
49 var: Rc::new(RefCell::new(self.var.borrow().$a(&other_input, $( $arg_name ),*)?))})
50 }
51 }
52}
53
54macro_rules! var_1_to_1_with_para {
55 ($(#[$attr:meta])*
56 $a:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
57 $(#[$attr])*
58 pub fn $a(&self, $( $arg_name : $ArgTy ),*) -> Result<Var, AutoDiffError> {
59 Ok(Var {
60 var: Rc::new(RefCell::new(self.var.borrow().$a($( $arg_name ),*)?))})
61 }
62 }
63}
64
65macro_rules! var_2_to_1_with_para {
66 ($(#[$attr:meta])*
67 $a:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
68 $(#[$attr])*
69 pub fn $a(&self, other: &Var, $( $arg_name : $ArgTy ),*) -> Result<Var, AutoDiffError> {
70 Ok(Var {
71 var: Rc::new(RefCell::new(self.var.borrow().$a(&other.var.clone(), $( $arg_name ),*)?))})
72 }
73 }
74}
75
76macro_rules! delegate_new_op {
77 ($(#[$attr:meta])*
78 $a:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
79 $(#[$attr])*
80 pub fn $a($( $arg_name : $ArgTy ),*) -> Var {
81 Var {
82 var: Rc::new(RefCell::new(VarInner::$a($( $arg_name ),*)))
83 }
84 }
85 }
86}
87
88pub struct Var {
92 var: Rc<RefCell<VarInner>>
93}
94impl Var {
95 #[cfg(feature = "use-f64")]
96 pub fn new(input: &[f64], dim: &[usize]) -> Var {
97 Var::new_f64(input, dim)
98 }
99 #[cfg(feature = "use-f32")]
100 pub fn new(input: &[f32], dim: &[usize]) -> Var {
101 Var::new_f32(input, dim)
102 }
103 pub fn new_f64(input: &[f64], dim: &[usize]) -> Var {
104 Var {
105 var: Rc::new(RefCell::new(VarInner::new_f64(input, dim)))
106 }
107 }
108 pub fn new_f32(input: &[f32], dim: &[usize]) -> Var {
109 Var {
110 var: Rc::new(RefCell::new(VarInner::new_f32(input, dim)))
111 }
112 }
113
114 pub fn ref_copy(self: &Var) -> Var {
118 Var {
119 var: self.var.clone(),
120 }
121 }
122 pub fn set(&self, o: &Var) {
124 self.var.borrow_mut().set(&o.var.borrow());
125 }
126
127 pub fn size(&self) -> Vec<usize> {
128 self.var.borrow().size()
129 }
130 pub fn numel(&self) -> usize {
131 self.var.borrow().numel()
132 }
133 pub fn get_f32(&self, o: &[usize]) -> Result<f32, AutoDiffError> {
134 self.var.borrow().get_f32(o)
135 }
136 pub fn set_f32(&self, o: &[usize], v: f32) -> Result<(), AutoDiffError> {
137 self.var.borrow_mut().set_f32(o, v)
138 }
139 pub fn get_f64(&self, o: &[usize]) -> Result<f64, AutoDiffError> {
140 self.var.borrow().get_f64(o)
141 }
142 pub fn set_f64(&self, o: &[usize], v: f64) -> Result<(), AutoDiffError> {
143 self.var.borrow_mut().set_f64(o, v)
144 }
145
146 pub fn fill(size: &[usize], fill_value: &Var) -> Var {
147 Var {
148 var: Rc::new(RefCell::new(
149 VarInner::fill(size, fill_value.var.clone())))
150 }
151 }
152 delegate_new_op!(fill_f32, size: &[usize], fill_value: f32);
153 delegate_new_op!(fill_f64, size: &[usize], fill_value: f64);
154 delegate_new_op!(zeros, dim: &[usize]);
155 delegate_new_op!(ones, dim: &[usize]);
156 delegate_new_op!(twos, dim: &[usize]);
157 delegate_new_op!(
162 eye, n: usize, m: usize);
164 delegate_new_op!(empty, dim: &[usize]);
165
166 pub fn from_record_f32(&self, row: usize, record: &[f32]) {
168 self.var.borrow().from_record_f32(row, record)
169 }
170 pub fn from_record_f64(&self, row: usize, record: &[f64]) {
171 self.var.borrow().from_record_f64(row, record)
172 }
173
174 delegate_new_op!(rand_usize,
176 rng: &mut StdRng,
177 dim: &[usize],
178 left: usize, right: usize);
179
180 delegate_new_op!(normal_f64,
181 rng: &mut StdRng,
182 dim: &[usize],
183 mean: f64, std: f64);
184 delegate_new_op!(normal_f32,
185 rng: &mut StdRng,
186 dim: &[usize],
187 mean: f32, std: f32);
188 #[cfg(feature = "use-f32")]
189 pub fn normal(rng: &mut StdRng,
190 dim: &[usize],
191 mean: f32, std: f32) -> Var {
192 Self::normal_f32(rng, dim, mean, std)
193 }
194 #[cfg(feature = "use-f64")]
195 pub fn normal(rng: &mut StdRng,
196 dim: &[usize],
197 mean: f64, std: f64) -> Var {
198 Self::normal_f64(rng, dim, mean, std)
199 }
200
201 delegate_new_op!(uniform_f64,
202 rng: &mut StdRng,
203 dim: &[usize],
204 from: f64, to: f64);
205 delegate_new_op!(uniform_f32,
206 rng: &mut StdRng,
207 dim: &[usize],
208 from: f32, to: f32);
209 #[cfg(feature = "use-f32")]
210 pub fn uniform(rng: &mut StdRng,
211 dim: &[usize],
212 from: f32, to: f32) -> Var {
213 Self::uniform_f32(rng, dim, from, to)
214 }
215 #[cfg(feature = "use-f64")]
216 pub fn uniform(rng: &mut StdRng,
217 dim: &[usize],
218 from: f64, to: f64) -> Var {
219 Self::uniform_f64(rng, dim, from, to)
220 }
221
222 pub fn _add(&self, other: &Var) -> Var {
223 Var {
224 var: Rc::new(RefCell::new(self.var.borrow().add(&other.var).expect("never fail.")))
225 }
226 }
227 pub fn _sub(&self, other: &Var) -> Var {
228 Var {
229 var: Rc::new(RefCell::new(self.var.borrow().sub(&other.var).expect("never fail.")))
230 }
231 }
232 pub fn _mul(&self, other: &Var) -> Var {
233 Var {
234 var: Rc::new(RefCell::new(self.var.borrow().mul(&other.var).expect("never fail.")))
235 }
236 }
237 pub fn _div(&self, other: &Var) -> Var {
238 Var {
239 var: Rc::new(RefCell::new(self.var.borrow().div(&other.var).expect("never fail.")))
240 }
241 }
242
243 var_2_to_1!(
244 matmul);
263 var_2_to_1!(
264 outer);
280
281 pub fn elu(&self, alpha: Var) -> Result<Var, AutoDiffError> {
283 Ok(Var {
284 var: Rc::new(RefCell::new(
285 self.var.borrow().elu(
286 VarInner::new_tensor(alpha.val()))?))
287 })
288 }
289 var_1_to_1!(relu);
290 var_1_to_1!(sigmoid);
291
292 var_2_to_1!(mse_loss);
294 var_2_to_1!(bce_with_logits_loss);
295 var_2_to_1!(cross_entropy_loss);
296
297 var_1_to_1!(abs);
299 var_1_to_1!(acos);
300 var_1_to_1!(asin);
301 var_1_to_1!(atan);
302 var_1_to_1!(ceil);
303 var_1_to_1!(cos);
304 var_1_to_1!(cosh);
305 var_1_to_1!(exp);
306 var_1_to_1!(expm1);
307 var_1_to_1!(floor);
308 var_1_to_1!(frac);
309 var_1_to_1!(log);
310 var_1_to_1!(log10);
311 var_1_to_1!(log1p);
312 var_1_to_1!(log1pexp);
313 var_1_to_1!(log2);
314 var_1_to_1!(neg);
315 pub fn _neg(&self) -> Var {
316 Var {
317 var: Rc::new(RefCell::new(self.var.borrow().neg().expect("never fail.")))
318 }
319 }
320 var_1_to_1!(reciprocal);
321 var_1_to_1!(round);
322 var_1_to_1!(rsqrt);
323 var_1_to_1!(sign);
324 var_1_to_1!(sin);
325 var_1_to_1!(sinh);
326 var_1_to_1!(sqrt);
327 var_1_to_1!(tan);
328 var_1_to_1!(tanh);
329 var_1_to_1!(trunc);
330
331 var_2_to_1!(max_pair);
333 var_2_to_1!(min_pair);
334 var_1_to_1_with_para!(arg_sort,
335 dim: usize, descending: bool);
336 var_2_to_1!(eq_elem);
337 var_2_to_1!(equal);
338 var_2_to_1!(ge);
339 var_2_to_1!(gt);
340 var_2_to_1!(le);
341 var_2_to_1!(lt);
342 var_2_to_1!(ne);
343
344 var_more_to_1_with_para!(
346 cat, dim: usize);
370 pub fn chunk(&self, chunks: usize, dim: usize)
371 -> Result<Vec<Var>, AutoDiffError> {
372 let mut result = self.var.borrow().chunk(chunks, dim)?;
373 let mut ret = Vec::new();
374 for i in result.drain(..) {
375 ret.push(Var {
376 var: Rc::new(RefCell::new(i)),
377 });
378 }
379 Ok(ret)
380 }
381 pub fn conditional_select(&self, x: &Var, y: &Var)
382 -> Result<Var, AutoDiffError> {
383 let result = self.var.borrow().conditional_select(x.var.clone(),
384 y.var.clone())?;
385 Ok(Var {
386 var: Rc::new(RefCell::new(result)),
387 })
388 }
389 pub fn gather(&self, dim: usize, index: Var)
390 -> Result<Var, AutoDiffError> {
391 let result = self.var.borrow().gather(dim, index.var)?;
392 Ok(Var {
393 var: Rc::new(RefCell::new(result)),
394 })
395 }
396 pub fn index_select(&self, dim: usize,
397 index: Var)
398 -> Result<Var, AutoDiffError> {
399 let result = self.var.borrow().index_select(
400 dim, index.var)?;
401 Ok(Var {
402 var: Rc::new(RefCell::new(result)),
403 })
404 }
405 pub fn index_exclude(&self, dim: usize,
406 index: Var)
407 -> Result<Var, AutoDiffError> {
408 let result = self.var.borrow().index_exclude(
409 dim, index.var)?;
410 Ok(Var {
411 var: Rc::new(RefCell::new(result)),
412 })
413 }
414 pub fn permute(&self, dim: &[usize])
415 -> Result<Var, AutoDiffError> {
416 let result = self.var.borrow().permute(dim)?;
417 Ok(Var {
418 var: Rc::new(RefCell::new(result)),
419 })
420 }
421 pub fn repeat(&self, dim: &[usize])
422 -> Result<Var, AutoDiffError> {
423 let result = self.var.borrow().repeat(dim)?;
424 Ok(Var {
425 var: Rc::new(RefCell::new(result)),
426 })
427 }
428 pub fn reshape(&self, new_shape: &[usize])
429 -> Result<Var, AutoDiffError> {
430 let result = self.var.borrow().reshape(new_shape)?;
431 Ok(Var {
432 var: Rc::new(RefCell::new(result)),
433 })
434 }
435 pub fn split(&self, sections: &[usize], dim: usize)
436 -> Result<Vec<Var>, AutoDiffError> {
437 let mut result = self.var.borrow().split(sections, dim)?;
438 let mut ret = Vec::new();
439 for i in result.drain(..) {
440 ret.push(Var {
441 var: Rc::new(RefCell::new(i)),
442 });
443 }
444 Ok(ret)
445 }
446 pub fn squeeze(&self, dim: Option<usize>)
447 -> Result<Var, AutoDiffError> {
448 let result = self.var.borrow().squeeze(dim)?;
449 Ok(Var {
450 var: Rc::new(RefCell::new(result)),
451 })
452 }
453 var_1_to_1!(t);
454 pub fn take(&self, index: &[usize])
455 -> Result<Var, AutoDiffError> {
456 let result = self.var.borrow().take(index)?;
457 Ok(Var {
458 var: Rc::new(RefCell::new(result)),
459 })
460 }
461 pub fn unsqueeze(&self, dim: usize)
462 -> Result<Var, AutoDiffError> {
463 let result = self.var.borrow().unsqueeze(dim)?;
464 Ok(Var {
465 var: Rc::new(RefCell::new(result)),
466 })
467 }
468 var_more_to_1_with_para!(
469 stack, dim: usize);
488
489 var_1_to_1!(det);
491 var_1_to_1!(inv);
492 var_1_to_1!(normalize_unit);
493 var_1_to_1!(tr);
494
495 var_1_to_1_with_para!(argmax, dim: Option<&[usize]>, keepdim: bool);
497 var_1_to_1_with_para!(argmin, dim: Option<&[usize]>, keepdim: bool);
498 var_1_to_1_with_para!(logsumexp, dim: Option<&[usize]>, keepdim: bool);
499 var_1_to_1_with_para!(mean, dim: Option<&[usize]>, keepdim: bool);
500 var_1_to_1_with_para!(prod, dim: Option<&[usize]>, keepdim: bool);
501 var_1_to_1_with_para!(std, dim: Option<&[usize]>, keepdim: bool);
502 var_1_to_1_with_para!(sum, dim: Option<&[usize]>, keepdim: bool);
503 var_1_to_1_with_para!(var, dim: Option<&[usize]>, keepdim: bool);
504 var_1_to_1_with_para!(max, dim: Option<&[usize]>, keepdim: bool);
505 var_1_to_1_with_para!(min, dim: Option<&[usize]>, keepdim: bool);
506
507 var_1_to_1_with_para!(
509 get_patch, range: &[(usize, usize)], step: Option<&[usize]>);
525 var_2_to_1_with_para!(
526 set_patch, range: &[(usize, usize)], step: Option<&[usize]>);
545 var_1_to_1_with_para!(view, new_shape: &[usize]);
546
547
548 pub fn val(&self) -> Tensor {
550 self.var.borrow().val()
551 }
552
553 pub fn set_grad(&self, use_gradient: bool) {
555 self.var.borrow_mut().set_grad(use_gradient);
556 }
557
558 pub fn reset_net(&self) {
560 self.var.borrow_mut().reset_net();
561 }
562
563 pub fn grad(&self) -> Result<Var, AutoDiffError> {
565 Ok(Var {
566 var: Rc::new(RefCell::new(self.var.borrow().grad()?))
567 })
568 }
569
570 pub fn bp(&self) -> Result<(), AutoDiffError> {
572 self.var.borrow().bp()?;
573
574 Ok(())
575 }
576
577 pub fn step(&self, opt: &mut dyn Optimizer) -> Result<(), AutoDiffError> {
578 self.var.borrow().step(opt)
579 }
580
581 pub fn rerun(&self) -> Result<(), AutoDiffError> {
583 self.var.borrow().rerun()
584 }
585
586 pub fn get_io_var(&self) -> Result<(Vec<Var>, Vec<Var>), AutoDiffError> {
588 let (mut inputs, mut outputs) = self.var.borrow().get_io_var()?;
589 Ok((inputs.drain(..).map(|x| Var {var: Rc::new(RefCell::new(x))}).collect(),
590 outputs.drain(..).map(|x| Var {var: Rc::new(RefCell::new(x))}).collect()))
591 }
592
593 pub fn get_var_by_label(&self, label: &str) -> Result<Var, AutoDiffError> {
595 let inner = self.var.borrow().get_var_by_label(label)?;
596 Ok(Var {
597 var: Rc::new(RefCell::new(inner)),
598 })
599 }
600
601 pub fn set_label(&self, label: &str) -> Result<(), AutoDiffError> {
602 self.var.borrow().set_label(label)
603 }
604
605 pub fn set_predict(&self) -> Result<(), AutoDiffError> {
606 self.set_label("__predict__")
607 }
608 pub fn predict(&self) -> Result<Var, AutoDiffError> {
609 self.get_var_by_label("__predict__")
610 }
611
612 pub(crate) fn called_with(&self, op: Op,
613 others: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
614 let refs: Vec<Rc<RefCell<VarInner>>> = others.iter().map(|x| x.var.clone()).collect();
615 let mut var_inners = self.var.borrow().called_with(op, &refs)?;
616 let ret: Vec<Var> = var_inners.drain(..).map(|x| Var {
617 var: Rc::new(RefCell::new(x))
618 }).collect();
619 Ok(ret)
620 }
621
622 pub fn dump_net(&self) -> Rc<RefCell<Net>> {
624 self.var.borrow().dump_net()
625 }
626 pub(crate) fn inner(&self) -> Rc<RefCell<VarInner>> {
627 self.var.clone()
628 }
629 pub(crate) fn set_inner(var: VarInner) -> Var {
630 Var {
631 var: Rc::new(RefCell::new(var))
632 }
633 }
634}
635
636impl PartialEq for Var {
638 fn eq(&self, other: &Self) -> bool {
639 self.var.borrow().val().eq(&other.var.borrow().val())
640 }
641}
642
643impl Eq for Var {}
644
645impl fmt::Display for Var {
646 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
647write!(f, "tensor: {}", self.var.borrow().val())
649 }
650}
651
652impl fmt::Debug for Var {
653 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
654write!(f, "tensor: {}", self.var.borrow().val())
656 }
657}
658
659impl Clone for Var {
660 fn clone(&self) -> Self {
661 Var {
662 var: Rc::new(RefCell::new(self.var.borrow().clone()))
663 }
664 }
665}
666
667impl Add for Var {
669 type Output = Self;
670
671 fn add(self, other: Self) -> Self {
672 self._add(&other)
673 }
674}
675
676impl Sub for Var {
677 type Output = Self;
678
679 fn sub(self, other: Self) -> Self {
680 self._sub(&other)
681 }
682}
683
684impl Mul for Var {
685 type Output = Self;
686
687 fn mul(self, other: Self) -> Self {
688 self._mul(&other)
689 }
690}
691
692impl Div for Var {
693 type Output = Self;
694
695 fn div(self, other: Self) -> Self {
696 self._div(&other)
697 }
698}
699
700impl Neg for Var {
701 type Output = Self;
702
703 fn neg(self) -> Self {
704 self._neg()
705 }
706}
707
708impl TryFrom<Var> for f32 {
709 type Error = AutoDiffError;
710
711 fn try_from(value: Var) -> Result<Self, Self::Error> {
712 if value.numel() > 1 {
713 return Err(AutoDiffError::new("TryFrom<Var> for f32 only works for 1 element var."))
714 }
715 let index = vec![0; value.size().len()];
716 value.get_f32(&index)
717 }
718}
719
720impl TryFrom<Var> for f64 {
721 type Error = AutoDiffError;
722
723 fn try_from(value: Var) -> Result<Self, Self::Error> {
724 if value.numel() > 1 {
725 return Err(AutoDiffError::new("TryFrom<Var> for f64 only works for 1 element var."));
726 }
727 let index = vec![0; value.size().len()];
728 value.get_f64(&index)
729 }
730}
731
732impl TryFrom<Var> for Vec<usize> {
733 type Error = AutoDiffError;
734
735 fn try_from(value: Var) -> Result<Self, Self::Error> {
736 let t = value.val();
737 if t.size().len() > 2 {
738 return Err(AutoDiffError::new("expect size [n,1], or [n]."));
739 }
740 let value = t.reshape(&[value.numel()]);
741 let ret = value.get_raw_f64().iter().map(|x| *x as usize).collect();
742 Ok(ret)
743 }
744}
745
746#[macro_export]
747macro_rules! var_f64 {
748 ($a:expr) => {{
749 trait Expand {
750 fn expand(&self) -> Var;
751 }
752 impl Expand for [f64] {
753 fn expand(&self) -> Var {
754 Var::new_f64(&self, &[self.len(), 1])
755 }
756 }
757 impl Expand for [[f64; 1]] {
758 fn expand(&self) -> Var {
759 let mut data = vec![];
760 let m = self.len();
761 let mut n = 0;
762 for i in self {
763 n = i.len();
764 data.append(&mut i.to_vec());
765 }
766 Var::new_f64(&data, &[m, n])
767 }
768 }
769 impl Expand for [[f64; 2]] {
770 fn expand(&self) -> Var {
771 let mut data = vec![];
772 let m = self.len();
773 let mut n = 0;
774 for i in self {
775 n = i.len();
776 data.append(&mut i.to_vec());
777 }
778 Var::new_f64(&data, &[m, n])
779 }
780 }
781 impl Expand for [[f64; 3]] {
782 fn expand(&self) -> Var {
783 let mut data = vec![];
784 let m = self.len();
785 let mut n = 0;
786 for i in self {
787 n = i.len();
788 data.append(&mut i.to_vec());
789 }
790 Var::new_f64(&data, &[m, n])
791 }
792 }
793 impl Expand for [[f64; 4]] {
794 fn expand(&self) -> Var {
795 let mut data = vec![];
796 let m = self.len();
797 let mut n = 0;
798 for i in self {
799 n = i.len();
800 data.append(&mut i.to_vec());
801 }
802 Var::new_f64(&data, &[m, n])
803 }
804 }
805 $a.expand()
806 }}
807}
808
809
810#[cfg(test)]
811mod tests {
812 use super::*;
813 use crate::op::OpCall;
814 extern crate openblas_src;
815
816 #[test]
817 fn mul() {
818 let a = Var::new(&[2., 3., 4., 5.], &[2, 2]);
819 let b = Var::new(&[1., 2., 3., 4.], &[2, 2]);
820 let c = a.ref_copy() * b.ref_copy();
821 assert_eq!(c, Var::new(&[2., 6., 12., 20.], &[2, 2]));
822 c.bp().unwrap();
823 assert_eq!(a.grad().unwrap(), Var::new(&[1., 2., 3., 4.], &[2, 2]));
824 assert_eq!(b.grad().unwrap(), Var::new(&[2., 3., 4., 5.], &[2, 2]));
825 }
826
827 #[test]
828 fn test_mul_repeat_vars() {
829 let a = Var::new(&[2., 3., 4., 5.], &[2, 2]);
830 let b = Var::new(&[1., 2., 3., 4.], &[2, 2]);
831 let c = a * b.ref_copy();
832 let d = c * b; assert_eq!(d, Var::new(&[2., 12., 36., 80.], &[2, 2]));
834 }
835
836 #[test]
837 fn test_add_in_fn() {
838 let a = Var::new(&[2., 3., 4., 5.], &[2, 2]);
839 let b = Var::new(&[1., 2., 3., 4.], &[2, 2]);
840
841 fn my_mul(a: &Var, b: &Var) -> Var {
842 a.ref_copy() * b.ref_copy()
843 }
844 let c = my_mul(&a, &b);
845 assert_eq!(c, Var::new(&[2., 6., 12., 20.], &[2, 2]));
846 }
847
848 #[test]
849 fn test_op_mse() {
850 let a = Var::new(&[1., 2., 3., 4., 5., 6.,], &[3, 2]);
851 let b = Var::new(&[2., 3., 4., 5., 6., 7.,], &[3, 2]);
852 let c = a.mse_loss(&b).unwrap();
853 assert_eq!(c , Var::new(&[1., ], &vec![1]));
854 }
855
856 #[test]
857 fn test_linear() {
858 use crate::op::Linear;
859
860 let mut op1 = Linear::new(Some(2), Some(5), true);
861 op1.set_weight(Var::new(&[1.,2.,3.,4.,5.,6.,7.,8.,9.,10.], &[2, 5]));
862 op1.set_bias(Var::new(&[1.,2.,3.,4.,5.], &[5]));
863 let input = Var::ones(&[3,2]);
864 let output = op1.call(&[&input]).unwrap().pop().unwrap();
865 assert_eq!(output, Var::new(&[8.0, 11.0, 14.0, 17.0, 20.0, 8.0, 11.0, 14.0, 17.0, 20.0, 8.0, 11.0, 14.0, 17.0, 20.0],
866 &vec![3, 5]));
867 }
868
869 #[test]
870 fn test_macro_tensor() {
871 let a = var_f64!([1., 2., 3.,]);
872 assert_eq!(a.size(), [3, 1]);
873 let a1 = a.squeeze(None).unwrap();
874 println!("{:?}", a1.size());
875 let a = var_f64!([[1., 2.,],
876 [4., 5.,],
877 [4., 5.,]]);
878 assert_eq!(a.size(), [3, 2]);
879 let a = var_f64!([[1., 2., 3.,],
880 [4., 5., 6.,]]);
881 assert_eq!(a.size(), [2, 3]);
882 let a = var_f64!([[1., 2., 3., 7.],
883 [4., 5., 6., 8.]]);
884 assert_eq!(a.size(), [2, 4]);
885 }
886}