auto_diff/
var_inner.rs

1use std::cell::RefCell;
2use std::rc::Rc;
3use std::fmt;
4use std::collections::BTreeMap;
5use ::rand::prelude::StdRng;
6
7use tensor_rs::tensor::{Tensor};
8use crate::compute_graph::{Net};
9use crate::collection::generational_index::{GenKey};
10use crate::op::{Op,
11                View,
12                Add, Sub, Mul, Div, Matmul, Outer,
13                ELU, ReLU, Sigmoid,
14                MSELoss, BCEWithLogitsLoss, CrossEntropyLoss,
15                Abs, Acos, Asin, Atan, Ceil, Cos, Cosh, Exp, Expm1, Floor, Frac, Log, Log10, Log1p, Log1pexp, Log2, Neg, Reciprocal, Round, Rsqrt, Sign, Sin, Sinh, Sqrt, Tan, Tanh, Trunc,
16                MaxPair, MinPair, ArgSort, EqElem, Equal, Ge, Gt, Le, Lt, Ne,
17                Cat, Chunk, Gather, IndexSelect, IndexExclude, Reshape, Split, Squeeze, Stack, T, Take, Permute, Unsqueeze, ConditionalSelect, Repeat,
18                Det, Inv, NormalizeUnit, Tr,
19                Argmax, Argmin, Logsumexp, Mean, Prod, Std, Sum, Variance, Max, Min,
20                GetPatch, SetPatch,
21};
22use crate::err::AutoDiffError;
23use crate::optim::Optimizer;
24
25
26/// For elementwise ops
27/// var_inner_1_to_1!(abs, Abs);
28macro_rules! var_inner_1_to_1 {
29    ($a:ident, $b:ident) => {
30        pub fn $a(&self) -> Result<VarInner, AutoDiffError> {
31            let new_one = $b::new();
32            let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
33            let mut result = self.called_with(op, &[])?;
34            Ok(result.remove(0))            
35        }
36    }
37}
38
39
40macro_rules! var_inner_2_to_1 {
41    ($a:ident, $b:ident) => {
42        pub fn $a(&self, other: &Rc<RefCell<VarInner>>) -> Result<VarInner, AutoDiffError> {
43            let new_one = $b::new();
44            let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
45            let o_input = vec![other.clone()];
46            let mut result = self.called_with(op, &o_input)?;
47            Ok(result.remove(0))            
48        }
49    }
50}
51
52/// Multiple tensor in, 1 out and with parameters
53macro_rules! var_inner_more_to_1_with_para {
54    ($a:ident, $b:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
55        pub fn $a(&self, inputs: &[Rc<RefCell<VarInner>>],
56        $( $arg_name : $ArgTy ),*) -> Result<VarInner, AutoDiffError> {
57            let new_one = $b::new($( $arg_name ),*);
58            let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
59            let mut result = self.called_with(op, inputs)?;
60            Ok(result.remove(0))            
61        }
62    }
63}
64
65macro_rules! var_inner_1_to_1_with_para {
66    ($a:ident, $b:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
67        pub fn $a(&self, $( $arg_name : $ArgTy ),*) -> Result<VarInner, AutoDiffError> {
68            let new_one = $b::new($( $arg_name ),*);
69            let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
70            let mut result = self.called_with(op, &[])?;
71            Ok(result.remove(0))            
72        }
73    }
74}
75
76macro_rules! var_inner_2_to_1_with_para {
77    ($a:ident, $b:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
78        pub fn $a(&self, other: &Rc<RefCell<VarInner>>,
79                  $( $arg_name : $ArgTy ),*)
80                  -> Result<VarInner, AutoDiffError> {
81            let new_one = $b::new($( $arg_name ),*);
82            let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
83            let mut result = self.called_with(op, &[other.clone()])?;
84            Ok(result.remove(0))            
85        }
86    }
87}
88
89
90
91// Macro for creation associated function.
92// Not for method.
93macro_rules! delegate_new_inner_op {
94    ($a:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
95        pub fn $a($( $arg_name : $ArgTy ),*) -> VarInner {
96            let mut net = Net::new();
97            let tensor = Tensor::$a($( $arg_name ),*);
98            let id = net.add_tensor(tensor);
99            VarInner {
100                id,
101                need_grad: true,
102                net: Rc::new(RefCell::new(net)),
103            }
104        }
105    }
106}
107
108pub(crate) struct VarInner {
109    id: GenKey,
110    need_grad: bool,
111    net: Rc<RefCell<Net>>,
112}
113
114impl VarInner {
115
116    // create functions.
117    #[cfg(feature = "use-f64")]
118    pub fn new(input: &[f64], dim: &[usize]) -> VarInner {
119        let mut net = Net::new();
120        
121        let tensor = Tensor::from_vec_f64(input, dim);
122        
123        let id = net.add_tensor(tensor);
124        VarInner {
125            id,
126            need_grad: true,
127            net: Rc::new(RefCell::new(net)),
128        }
129    }
130    #[cfg(feature = "use-f32")]
131    pub fn new(input: &[f32], dim: &[usize]) -> VarInner {
132        let mut net = Net::new();
133        
134        let tensor = Tensor::from_vec_f32(input, dim);
135        
136        let id = net.add_tensor(tensor);
137        VarInner {
138            id,
139            need_grad: true,
140            net: Rc::new(RefCell::new(net)),
141        }
142    }
143    pub fn new_f64(input: &[f64], dim: &[usize]) -> VarInner {
144        let mut net = Net::new();
145        
146        let tensor = Tensor::from_vec_f64(input, dim);
147        
148        let id = net.add_tensor(tensor);
149        VarInner {
150            id,
151            need_grad: true,
152            net: Rc::new(RefCell::new(net)),
153        }
154    }
155    pub fn new_f32(input: &[f32], dim: &[usize]) -> VarInner {
156        let mut net = Net::new();
157        
158        let tensor = Tensor::from_vec_f32(input, dim);
159        
160        let id = net.add_tensor(tensor);
161        VarInner {
162            id,
163            need_grad: true,
164            net: Rc::new(RefCell::new(net)),
165        }
166    }
167
168    /// Create a new var with an existing net and value.
169    pub(crate) fn new_net_tensor(net: Rc<RefCell<Net>>,
170                                 need_grad: bool,
171                                 tensor: Tensor) -> VarInner {
172        let id = net.borrow_mut().add_tensor(tensor);
173        VarInner {
174            id,
175            need_grad,
176            net
177        }
178    }
179
180    pub(crate) fn new_tensor(tensor: Tensor) -> VarInner {
181        let mut net = Net::new();
182        let id = net.add_tensor(tensor);
183        VarInner {
184            id,
185            need_grad: true,
186            net: Rc::new(RefCell::new(net)),
187        }
188    }
189
190    pub fn get_id(&self) -> GenKey {
191	self.id
192    }
193    pub fn get_need_grad(&self) -> bool {
194	self.need_grad
195    }
196    pub fn get_net(&self) -> Rc<RefCell<Net>> {
197	self.net.clone()
198    }
199
200    pub fn size(&self) -> Vec<usize> {
201        self.net.borrow().get_tensor(self.id).expect("").size()
202    }
203    pub fn numel(&self) -> usize {
204        self.net.borrow().get_tensor(self.id).expect("").numel()
205    }
206    fn check_index(v: &VarInner, o: &[usize]) -> Result<(), AutoDiffError> {
207	if v.size().len() != o.len() {
208	    return Err(AutoDiffError::new(
209		&format!("Index for get() should have the same len. t: {:?}, index: {:?}",
210			 v.size(), o.len())));
211	} else {
212	    Ok(())
213	}
214    }
215    pub fn get_f32(&self, o: &[usize]) -> Result<f32, AutoDiffError> {
216	Self::check_index(self, o)?;
217        Ok(self.net.borrow().get_tensor(self.id)?.get_f32(o))
218    }
219    pub fn set_f32(&mut self, o: &[usize], v: f32) -> Result<(), AutoDiffError> {
220	Self::check_index(self, o)?;
221	self.net.borrow().get_tensor(self.id)?.set_f32(o, v);
222        Ok(())
223    }
224    pub fn get_f64(&self, o: &[usize]) -> Result<f64, AutoDiffError> {
225	Self::check_index(self, o)?;
226        Ok(self.net.borrow().get_tensor(self.id)?.get_f64(o))
227    }
228    pub fn set_f64(&mut self, o: &[usize], v: f64) -> Result<(), AutoDiffError>{
229	Self::check_index(self, o)?;
230	self.net.borrow().get_tensor(self.id)?.set_f64(o, v);
231        Ok(())
232    }
233
234    pub fn fill(size: &[usize], fill_value: Rc<RefCell<VarInner>>) -> VarInner {
235        let mut net = Net::new();
236        let tensor = Tensor::fill(size, &fill_value.borrow().val());
237        let id = net.add_tensor(tensor);
238        VarInner {
239            id,
240            need_grad: true,
241            net: Rc::new(RefCell::new(net)),
242        }
243    }
244    pub fn fill_f32(size: &[usize], fill_value: f32) -> VarInner {
245        let mut net = Net::new();
246        let tensor = Tensor::fill_f32(size, fill_value);
247        let id = net.add_tensor(tensor);
248        VarInner {
249            id,
250            need_grad: true,
251            net: Rc::new(RefCell::new(net)),
252        }
253    }
254    pub fn fill_f64(size: &[usize], fill_value: f64) -> VarInner {
255        let mut net = Net::new();
256        let tensor = Tensor::fill_f64(size, fill_value);
257        let id = net.add_tensor(tensor);
258        VarInner {
259            id,
260            need_grad: true,
261            net: Rc::new(RefCell::new(net)),
262        }
263    }
264    delegate_new_inner_op!(zeros, dim: &[usize]);
265    delegate_new_inner_op!(ones, dim: &[usize]);
266    delegate_new_inner_op!(twos, dim: &[usize]);
267    //delegate_new_inner_op!(arange, end: usize);
268    //delegate_new_inner_op!(range, start: f32, end: f32, step: Option<f32>);
269    //delegate_new_inner_op!(linspace, start: f32, end: f32, steps: usize);
270    //delegate_new_inner_op!(logspace, start: f32, end: f32, steps: usize, base: f32);
271    delegate_new_inner_op!(eye, n: usize, m: usize);
272    delegate_new_inner_op!(empty, dim: &[usize]);
273
274    pub fn from_record_f32(&self, row: usize, record: &[f32]) {
275        self.val().from_record_f32(row, record).expect("");
276    }
277    pub fn from_record_f64(&self, row: usize, record: &[f64]) {
278        self.val().from_record_f64(row, record).expect("");
279    }
280    
281
282    // rand
283    delegate_new_inner_op!(rand_usize,
284                           rng: &mut StdRng,
285                           dim: &[usize],
286                           left: usize, right: usize);
287    delegate_new_inner_op!(normal_f64,
288                           rng: &mut StdRng,
289                           dim: &[usize],
290                           mean: f64, std: f64);
291    delegate_new_inner_op!(normal_f32,
292                           rng: &mut StdRng,
293                           dim: &[usize],
294                           mean: f32, std: f32);
295    delegate_new_inner_op!(uniform_f64,
296                           rng: &mut StdRng,
297                           dim: &[usize],
298                           from: f64, to: f64);
299    delegate_new_inner_op!(uniform_f32,
300                           rng: &mut StdRng,
301                           dim: &[usize],
302                           from: f32, to: f32);
303    
304
305    // get and set.
306    /// This is a ref. Clone it to cut the connection.
307    pub(crate) fn val(&self) -> Tensor {
308        self.net.borrow().get_tensor(self.id).unwrap()
309    }
310    pub(crate) fn set_val(&mut self, val: Tensor) {
311        self.net.borrow_mut().set_tensor(self.id, val).expect("");
312    }
313    pub fn set(&mut self, o: &VarInner) {
314        self.set_val(o.val())
315    }
316
317    pub fn grad(&self) -> Result<VarInner, AutoDiffError> {
318        Ok(VarInner::new_tensor(self.net.borrow().get_grad(self.id)?))
319    }
320
321    /// backward pass.
322    pub fn bp(&self) -> Result<(), AutoDiffError> {
323        let mut job = BTreeMap::new();
324        job.insert(self.id, Tensor::ones_like(&self.val()));
325        self.net.borrow_mut().bptt(&job);
326        
327        Ok(())
328    }
329
330    /// Update,
331    pub fn step(&self, opt: &mut dyn Optimizer) -> Result<(), AutoDiffError> {
332        opt.step(self.net.clone());
333        Ok(())
334    }
335
336    pub fn rerun(&self) -> Result<(), AutoDiffError> {
337        let mut all_input = Vec::new();
338        for i in &self.net.borrow().get_input_edge_data() {
339            all_input.push(*i);
340        }
341        self.net.borrow_mut().eval(&all_input).expect("");
342        Ok(())
343    }
344
345    pub fn get_io_var(&self) -> Result<(Vec<VarInner>, Vec<VarInner>), AutoDiffError> {
346	let input_id = self.net.borrow().get_input_edge_data();
347	let output_id = self.net.borrow().get_output_edge_data();
348	Ok((input_id.iter().map(|x| VarInner {id: *x, need_grad: true, net: self.net.clone()}).collect(),
349	    output_id.iter().map(|x| VarInner {id: *x, need_grad: true, net: self.net.clone()}).collect(),))
350    }
351
352    pub fn get_var_by_label(&self, label: &str) -> Result<VarInner, AutoDiffError> {
353	let id = self.net.borrow().get_id_by_label(label)?;
354	//self.net.borrow().
355	Ok(VarInner {
356	    id,
357	    need_grad: true,
358	    net: self.net.clone(),
359	})
360    }
361
362    pub(crate) fn set_label(&self, label: &str) -> Result<(), AutoDiffError> {
363	self.net.borrow_mut().set_label(label, &self.id)
364    }
365
366    pub(crate) fn set_grad(&mut self, use_gradient: bool) {
367        self.need_grad = use_gradient;
368    }
369
370    pub(crate) fn reset_net(&mut self) {
371        let value = self.val();
372        let mut net = Net::new();
373        let id = net.add_tensor(value);
374        self.id = id;
375        self.net = Rc::new(RefCell::new(net));
376    }
377
378    /// used in OpCall trait implementation.
379    pub(crate) fn called_with(&self, op: Op,
380                              others: &[Rc<RefCell<VarInner>>])
381                              -> Result<Vec<VarInner>, AutoDiffError> {
382        if self.need_grad {
383            let mut other_var_by_networks: Vec<Vec<Rc<RefCell<VarInner>>>> = vec![];
384            for item in others.iter().cloned() {
385                if !Rc::ptr_eq(&self.net, &item.borrow().net) {
386                    let mut existing_net = false;
387                    for set in &mut other_var_by_networks {
388                        if Rc::ptr_eq(&item.borrow().net, &set[0].borrow().net) {
389                            set.push(item.clone());
390                            existing_net = true;
391                            break;
392                        }
393                    }
394                    if ! existing_net {
395                        other_var_by_networks.push(vec![item.clone()]);
396                    }
397                }
398            }
399            for set in other_var_by_networks {
400                let mut old_ids = vec![];
401                for item in &set {
402                    old_ids.push(item.borrow().id);
403                }
404                let other_key = self.net.borrow_mut().append(
405                    &set[0].borrow().net.borrow(), &old_ids)?;
406                for (index, item) in set.iter().enumerate() {
407                    item.borrow_mut().net = self.net.clone();
408                    item.borrow_mut().id = other_key[index];
409                }
410
411            }
412            
413            let mut input_id = vec![self.id];
414            let mut inputs = vec![self.net.borrow().get_tensor(self.id)?];
415            for i in others {
416                input_id.push(i.borrow().id);
417                inputs.push(self.net.borrow().get_tensor(i.borrow().id)?);
418            }
419            
420            let mut output_id = vec![];
421            let mut outputs = Vec::new();
422            let mut ret = Vec::new();
423            for _ in 0..op.get_output_size() {
424                let new_output = VarInner::new_net_tensor(self.net.clone(),
425                                                          self.need_grad,
426                                                          Tensor::new());
427                output_id.push(new_output.id);
428                outputs.push(self.net.borrow().get_tensor(new_output.id)?);
429                ret.push(new_output);
430            }
431
432            op.apply(&inputs, &outputs);
433            let opid = self.net.borrow_mut().add_op(op);
434            
435            self.net.borrow_mut().connect(&input_id,
436                                          opid,
437                                          &output_id);
438            
439            Ok(ret)    
440        } else {
441            let mut inputs = vec![self.net.borrow().get_tensor(self.id)?];
442            for i in others {
443                inputs.push(i.borrow().net.borrow().get_tensor(i.borrow().id)?);
444            }
445            
446            let mut ret = Vec::new();
447            let mut outputs = Vec::new();
448            for _ in 0..op.get_output_size() {
449                let new_output = VarInner::new_net_tensor(Rc::new(RefCell::new(Net::new())),
450                                                          self.need_grad,
451                                                          Tensor::new());
452                outputs.push(new_output.net.borrow().get_tensor(new_output.id)?);
453                ret.push(new_output);
454            }
455            
456            op.apply(&inputs, &outputs);
457
458            Ok(ret)
459        }
460    }
461
462    // 2-in-1 ops
463    var_inner_2_to_1!(add, Add);
464    var_inner_2_to_1!(sub, Sub);
465    var_inner_2_to_1!(mul, Mul);
466    var_inner_2_to_1!(div, Div);
467    var_inner_2_to_1!(matmul, Matmul);
468    var_inner_2_to_1!(outer, Outer);
469
470    // nonlinear
471    pub fn elu(&self, alpha: VarInner) -> Result<VarInner, AutoDiffError> {
472        let new_one = ELU::new(alpha.val());
473        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
474        let mut result = self.called_with(op, &[])?;
475        Ok(result.remove(0))            
476    }
477    var_inner_1_to_1!(relu, ReLU);
478    var_inner_1_to_1!(sigmoid, Sigmoid);
479    
480    // loss
481    var_inner_2_to_1!(mse_loss, MSELoss);
482    var_inner_2_to_1!(bce_with_logits_loss, BCEWithLogitsLoss);
483    var_inner_2_to_1!(cross_entropy_loss, CrossEntropyLoss);
484
485    // element ops
486    var_inner_1_to_1!(abs, Abs);
487    var_inner_1_to_1!(acos, Acos);
488    var_inner_1_to_1!(asin, Asin);
489    var_inner_1_to_1!(atan, Atan);
490    var_inner_1_to_1!(ceil, Ceil);
491    var_inner_1_to_1!(cos, Cos);
492    var_inner_1_to_1!(cosh, Cosh);
493    var_inner_1_to_1!(exp, Exp);
494    var_inner_1_to_1!(expm1, Expm1);
495    var_inner_1_to_1!(floor, Floor);
496    var_inner_1_to_1!(frac, Frac);
497    var_inner_1_to_1!(log, Log);
498    var_inner_1_to_1!(log10, Log10);
499    var_inner_1_to_1!(log1p, Log1p);
500    var_inner_1_to_1!(log1pexp, Log1pexp);
501    var_inner_1_to_1!(log2, Log2);
502    var_inner_1_to_1!(neg, Neg);
503    var_inner_1_to_1!(reciprocal, Reciprocal);
504    var_inner_1_to_1!(round, Round);
505    var_inner_1_to_1!(rsqrt, Rsqrt);
506    var_inner_1_to_1!(sign, Sign);
507    var_inner_1_to_1!(sin, Sin);
508    var_inner_1_to_1!(sinh, Sinh);
509    var_inner_1_to_1!(sqrt, Sqrt);
510    var_inner_1_to_1!(tan, Tan);
511    var_inner_1_to_1!(tanh, Tanh);
512    var_inner_1_to_1!(trunc, Trunc);
513
514    // comparison
515    var_inner_2_to_1!(max_pair, MaxPair);
516    var_inner_2_to_1!(min_pair, MinPair);
517    var_inner_1_to_1_with_para!(arg_sort, ArgSort,
518                                dim: usize, descending: bool);
519    var_inner_2_to_1!(eq_elem, EqElem);
520    var_inner_2_to_1!(equal, Equal);
521    var_inner_2_to_1!(ge, Ge);
522    var_inner_2_to_1!(gt, Gt);
523    var_inner_2_to_1!(le, Le);
524    var_inner_2_to_1!(lt, Lt);
525    var_inner_2_to_1!(ne, Ne);
526
527    // index and slicing
528    var_inner_more_to_1_with_para!(cat, Cat, dim: usize);
529    pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<VarInner>, AutoDiffError> {
530        let new_one = Chunk::new(chunks, dim);
531        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
532        let result = self.called_with(op, &Vec::new())?;
533        Ok(result)
534    }
535    pub fn conditional_select(&self, x: Rc<RefCell<VarInner>>, y: Rc<RefCell<VarInner>>) -> Result<VarInner, AutoDiffError> {
536        let new_one = ConditionalSelect::new();
537        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
538        let inputs = vec![x, y];
539        let mut result = self.called_with(op, &inputs)?;
540        Ok(result.remove(0))
541    }
542    pub fn gather(&self, dim: usize, index: Rc<RefCell<VarInner>>) -> Result<VarInner, AutoDiffError> {
543        let new_one = Gather::new(dim);
544        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
545        let inputs = vec![index];
546        let mut result = self.called_with(op, &inputs)?;
547        Ok(result.remove(0))
548    }
549    pub fn index_select(&self, dim: usize, index: Rc<RefCell<VarInner>>) -> Result<VarInner, AutoDiffError> {
550        let new_one = IndexSelect::new(dim);
551        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
552        let inputs = vec![index];
553        let mut result = self.called_with(op, &inputs)?;
554        Ok(result.remove(0))
555    }
556    pub fn index_exclude(&self, dim: usize,
557                         index: Rc<RefCell<VarInner>>)
558                         -> Result<VarInner, AutoDiffError> {
559        let new_one = IndexExclude::new(dim);
560        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
561        let inputs = vec![index];
562        let mut result = self.called_with(op, &inputs)?;
563        Ok(result.remove(0))
564    }
565    pub fn permute(&self, dim: &[usize]) -> Result<VarInner, AutoDiffError> {
566        let new_one = Permute::new(dim);
567        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
568        let mut result = self.called_with(op, &[])?;
569        Ok(result.remove(0))
570    }
571    pub fn repeat(&self, dim: &[usize]) -> Result<VarInner, AutoDiffError> {
572        let new_one = Repeat::new(dim);
573        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
574        let mut result = self.called_with(op, &[])?;
575        Ok(result.remove(0))
576    }
577    pub fn reshape(&self, new_shape: &[usize]) -> Result<VarInner, AutoDiffError> {
578        let new_one = Reshape::new(new_shape);
579        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
580        let mut result = self.called_with(op, &[])?;
581        Ok(result.remove(0))
582    }
583    pub fn split(&self, sections: &[usize], dim: usize) -> Result<Vec<VarInner>, AutoDiffError> {
584        let new_one = Split::new(sections, dim);
585        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
586        let result = self.called_with(op, &Vec::new())?;
587        Ok(result)
588    }
589    pub fn squeeze(&self, dim: Option<usize>) -> Result<VarInner, AutoDiffError> {
590        let new_one = Squeeze::new(dim);
591        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
592        let mut result = self.called_with(op, &[])?;
593        Ok(result.remove(0))
594    }
595    var_inner_1_to_1!(t, T);
596    pub fn take(&self, index: &[usize]) -> Result<VarInner, AutoDiffError> {
597        let new_one = Take::new(index);
598        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
599        let mut result = self.called_with(op, &[])?;
600        Ok(result.remove(0))
601    }
602    pub fn unsqueeze(&self, dim: usize) -> Result<VarInner, AutoDiffError> {
603        let new_one = Unsqueeze::new(dim);
604        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
605        let mut result = self.called_with(op, &[])?;
606        Ok(result.remove(0))
607    }
608    var_inner_more_to_1_with_para!(stack, Stack, dim: usize);
609
610    // linalg
611    var_inner_1_to_1!(det, Det);
612    var_inner_1_to_1!(inv, Inv);
613    var_inner_1_to_1!(normalize_unit, NormalizeUnit);
614    var_inner_1_to_1!(tr, Tr);
615
616    // reduction
617    var_inner_1_to_1_with_para!(argmax, Argmax, dim: Option<&[usize]>, keepdim: bool);
618    var_inner_1_to_1_with_para!(argmin, Argmin, dim: Option<&[usize]>, keepdim: bool);
619    var_inner_1_to_1_with_para!(logsumexp, Logsumexp, dim: Option<&[usize]>, keepdim: bool);
620    var_inner_1_to_1_with_para!(mean, Mean, dim: Option<&[usize]>, keepdim: bool);
621    var_inner_1_to_1_with_para!(prod, Prod, dim: Option<&[usize]>, keepdim: bool);
622    var_inner_1_to_1_with_para!(std, Std, dim: Option<&[usize]>, keepdim: bool);
623    var_inner_1_to_1_with_para!(sum, Sum, dim: Option<&[usize]>, keepdim: bool);
624    var_inner_1_to_1_with_para!(var, Variance, dim: Option<&[usize]>, keepdim: bool);
625    var_inner_1_to_1_with_para!(max, Max, dim: Option<&[usize]>, keepdim: bool);
626    var_inner_1_to_1_with_para!(min, Min, dim: Option<&[usize]>, keepdim: bool);
627
628    // images
629    var_inner_1_to_1_with_para!(get_patch, GetPatch, range: &[(usize, usize)], step: Option<&[usize]>);
630    var_inner_2_to_1_with_para!(set_patch, SetPatch, range: &[(usize, usize)], step: Option<&[usize]>);
631    var_inner_1_to_1_with_para!(view, View, new_shape: &[usize]);
632
633    pub fn dump_net(&self) -> Rc<RefCell<Net>> {
634        self.net.clone()
635    }
636
637    pub(crate) fn set_inner(id: GenKey, need_grad: bool, net: Net) -> VarInner {
638	VarInner {
639	    id,
640	    need_grad,
641	    net: Rc::new(RefCell::new(net))
642	}
643    }
644}
645
646impl PartialEq for VarInner {
647    fn eq(&self, other: &Self) -> bool {
648        self.val().eq(&other.val())
649    }
650}
651
652impl Eq for VarInner {}
653
654impl fmt::Display for VarInner {
655    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
656        write!(f, "id: {}", self.id)?;
657        write!(f, "tensor: {}", self.val())
658    }
659}
660
661impl fmt::Debug for VarInner {
662    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
663        write!(f, "id: {}", self.id)?;
664        write!(f, "tensor: {}", self.val())
665    }
666}
667
668impl Clone for VarInner {
669    fn clone(&self) -> Self {
670        let val = self.val();
671        let mut ret = VarInner::new(&[], &[]);
672        ret.set_val(val);
673        ret.need_grad = self.need_grad;
674        ret
675    }
676}
677
678