1use std::cell::RefCell;
2use std::rc::Rc;
3use std::fmt;
4use std::collections::BTreeMap;
5use ::rand::prelude::StdRng;
6
7use tensor_rs::tensor::{Tensor};
8use crate::compute_graph::{Net};
9use crate::collection::generational_index::{GenKey};
10use crate::op::{Op,
11 View,
12 Add, Sub, Mul, Div, Matmul, Outer,
13 ELU, ReLU, Sigmoid,
14 MSELoss, BCEWithLogitsLoss, CrossEntropyLoss,
15 Abs, Acos, Asin, Atan, Ceil, Cos, Cosh, Exp, Expm1, Floor, Frac, Log, Log10, Log1p, Log1pexp, Log2, Neg, Reciprocal, Round, Rsqrt, Sign, Sin, Sinh, Sqrt, Tan, Tanh, Trunc,
16 MaxPair, MinPair, ArgSort, EqElem, Equal, Ge, Gt, Le, Lt, Ne,
17 Cat, Chunk, Gather, IndexSelect, IndexExclude, Reshape, Split, Squeeze, Stack, T, Take, Permute, Unsqueeze, ConditionalSelect, Repeat,
18 Det, Inv, NormalizeUnit, Tr,
19 Argmax, Argmin, Logsumexp, Mean, Prod, Std, Sum, Variance, Max, Min,
20 GetPatch, SetPatch,
21};
22use crate::err::AutoDiffError;
23use crate::optim::Optimizer;
24
25
26macro_rules! var_inner_1_to_1 {
29 ($a:ident, $b:ident) => {
30 pub fn $a(&self) -> Result<VarInner, AutoDiffError> {
31 let new_one = $b::new();
32 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
33 let mut result = self.called_with(op, &[])?;
34 Ok(result.remove(0))
35 }
36 }
37}
38
39
40macro_rules! var_inner_2_to_1 {
41 ($a:ident, $b:ident) => {
42 pub fn $a(&self, other: &Rc<RefCell<VarInner>>) -> Result<VarInner, AutoDiffError> {
43 let new_one = $b::new();
44 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
45 let o_input = vec![other.clone()];
46 let mut result = self.called_with(op, &o_input)?;
47 Ok(result.remove(0))
48 }
49 }
50}
51
52macro_rules! var_inner_more_to_1_with_para {
54 ($a:ident, $b:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
55 pub fn $a(&self, inputs: &[Rc<RefCell<VarInner>>],
56 $( $arg_name : $ArgTy ),*) -> Result<VarInner, AutoDiffError> {
57 let new_one = $b::new($( $arg_name ),*);
58 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
59 let mut result = self.called_with(op, inputs)?;
60 Ok(result.remove(0))
61 }
62 }
63}
64
65macro_rules! var_inner_1_to_1_with_para {
66 ($a:ident, $b:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
67 pub fn $a(&self, $( $arg_name : $ArgTy ),*) -> Result<VarInner, AutoDiffError> {
68 let new_one = $b::new($( $arg_name ),*);
69 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
70 let mut result = self.called_with(op, &[])?;
71 Ok(result.remove(0))
72 }
73 }
74}
75
76macro_rules! var_inner_2_to_1_with_para {
77 ($a:ident, $b:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
78 pub fn $a(&self, other: &Rc<RefCell<VarInner>>,
79 $( $arg_name : $ArgTy ),*)
80 -> Result<VarInner, AutoDiffError> {
81 let new_one = $b::new($( $arg_name ),*);
82 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
83 let mut result = self.called_with(op, &[other.clone()])?;
84 Ok(result.remove(0))
85 }
86 }
87}
88
89
90
91macro_rules! delegate_new_inner_op {
94 ($a:ident, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
95 pub fn $a($( $arg_name : $ArgTy ),*) -> VarInner {
96 let mut net = Net::new();
97 let tensor = Tensor::$a($( $arg_name ),*);
98 let id = net.add_tensor(tensor);
99 VarInner {
100 id,
101 need_grad: true,
102 net: Rc::new(RefCell::new(net)),
103 }
104 }
105 }
106}
107
108pub(crate) struct VarInner {
109 id: GenKey,
110 need_grad: bool,
111 net: Rc<RefCell<Net>>,
112}
113
114impl VarInner {
115
116 #[cfg(feature = "use-f64")]
118 pub fn new(input: &[f64], dim: &[usize]) -> VarInner {
119 let mut net = Net::new();
120
121 let tensor = Tensor::from_vec_f64(input, dim);
122
123 let id = net.add_tensor(tensor);
124 VarInner {
125 id,
126 need_grad: true,
127 net: Rc::new(RefCell::new(net)),
128 }
129 }
130 #[cfg(feature = "use-f32")]
131 pub fn new(input: &[f32], dim: &[usize]) -> VarInner {
132 let mut net = Net::new();
133
134 let tensor = Tensor::from_vec_f32(input, dim);
135
136 let id = net.add_tensor(tensor);
137 VarInner {
138 id,
139 need_grad: true,
140 net: Rc::new(RefCell::new(net)),
141 }
142 }
143 pub fn new_f64(input: &[f64], dim: &[usize]) -> VarInner {
144 let mut net = Net::new();
145
146 let tensor = Tensor::from_vec_f64(input, dim);
147
148 let id = net.add_tensor(tensor);
149 VarInner {
150 id,
151 need_grad: true,
152 net: Rc::new(RefCell::new(net)),
153 }
154 }
155 pub fn new_f32(input: &[f32], dim: &[usize]) -> VarInner {
156 let mut net = Net::new();
157
158 let tensor = Tensor::from_vec_f32(input, dim);
159
160 let id = net.add_tensor(tensor);
161 VarInner {
162 id,
163 need_grad: true,
164 net: Rc::new(RefCell::new(net)),
165 }
166 }
167
168 pub(crate) fn new_net_tensor(net: Rc<RefCell<Net>>,
170 need_grad: bool,
171 tensor: Tensor) -> VarInner {
172 let id = net.borrow_mut().add_tensor(tensor);
173 VarInner {
174 id,
175 need_grad,
176 net
177 }
178 }
179
180 pub(crate) fn new_tensor(tensor: Tensor) -> VarInner {
181 let mut net = Net::new();
182 let id = net.add_tensor(tensor);
183 VarInner {
184 id,
185 need_grad: true,
186 net: Rc::new(RefCell::new(net)),
187 }
188 }
189
190 pub fn get_id(&self) -> GenKey {
191 self.id
192 }
193 pub fn get_need_grad(&self) -> bool {
194 self.need_grad
195 }
196 pub fn get_net(&self) -> Rc<RefCell<Net>> {
197 self.net.clone()
198 }
199
200 pub fn size(&self) -> Vec<usize> {
201 self.net.borrow().get_tensor(self.id).expect("").size()
202 }
203 pub fn numel(&self) -> usize {
204 self.net.borrow().get_tensor(self.id).expect("").numel()
205 }
206 fn check_index(v: &VarInner, o: &[usize]) -> Result<(), AutoDiffError> {
207 if v.size().len() != o.len() {
208 return Err(AutoDiffError::new(
209 &format!("Index for get() should have the same len. t: {:?}, index: {:?}",
210 v.size(), o.len())));
211 } else {
212 Ok(())
213 }
214 }
215 pub fn get_f32(&self, o: &[usize]) -> Result<f32, AutoDiffError> {
216 Self::check_index(self, o)?;
217 Ok(self.net.borrow().get_tensor(self.id)?.get_f32(o))
218 }
219 pub fn set_f32(&mut self, o: &[usize], v: f32) -> Result<(), AutoDiffError> {
220 Self::check_index(self, o)?;
221 self.net.borrow().get_tensor(self.id)?.set_f32(o, v);
222 Ok(())
223 }
224 pub fn get_f64(&self, o: &[usize]) -> Result<f64, AutoDiffError> {
225 Self::check_index(self, o)?;
226 Ok(self.net.borrow().get_tensor(self.id)?.get_f64(o))
227 }
228 pub fn set_f64(&mut self, o: &[usize], v: f64) -> Result<(), AutoDiffError>{
229 Self::check_index(self, o)?;
230 self.net.borrow().get_tensor(self.id)?.set_f64(o, v);
231 Ok(())
232 }
233
234 pub fn fill(size: &[usize], fill_value: Rc<RefCell<VarInner>>) -> VarInner {
235 let mut net = Net::new();
236 let tensor = Tensor::fill(size, &fill_value.borrow().val());
237 let id = net.add_tensor(tensor);
238 VarInner {
239 id,
240 need_grad: true,
241 net: Rc::new(RefCell::new(net)),
242 }
243 }
244 pub fn fill_f32(size: &[usize], fill_value: f32) -> VarInner {
245 let mut net = Net::new();
246 let tensor = Tensor::fill_f32(size, fill_value);
247 let id = net.add_tensor(tensor);
248 VarInner {
249 id,
250 need_grad: true,
251 net: Rc::new(RefCell::new(net)),
252 }
253 }
254 pub fn fill_f64(size: &[usize], fill_value: f64) -> VarInner {
255 let mut net = Net::new();
256 let tensor = Tensor::fill_f64(size, fill_value);
257 let id = net.add_tensor(tensor);
258 VarInner {
259 id,
260 need_grad: true,
261 net: Rc::new(RefCell::new(net)),
262 }
263 }
264 delegate_new_inner_op!(zeros, dim: &[usize]);
265 delegate_new_inner_op!(ones, dim: &[usize]);
266 delegate_new_inner_op!(twos, dim: &[usize]);
267 delegate_new_inner_op!(eye, n: usize, m: usize);
272 delegate_new_inner_op!(empty, dim: &[usize]);
273
274 pub fn from_record_f32(&self, row: usize, record: &[f32]) {
275 self.val().from_record_f32(row, record).expect("");
276 }
277 pub fn from_record_f64(&self, row: usize, record: &[f64]) {
278 self.val().from_record_f64(row, record).expect("");
279 }
280
281
282 delegate_new_inner_op!(rand_usize,
284 rng: &mut StdRng,
285 dim: &[usize],
286 left: usize, right: usize);
287 delegate_new_inner_op!(normal_f64,
288 rng: &mut StdRng,
289 dim: &[usize],
290 mean: f64, std: f64);
291 delegate_new_inner_op!(normal_f32,
292 rng: &mut StdRng,
293 dim: &[usize],
294 mean: f32, std: f32);
295 delegate_new_inner_op!(uniform_f64,
296 rng: &mut StdRng,
297 dim: &[usize],
298 from: f64, to: f64);
299 delegate_new_inner_op!(uniform_f32,
300 rng: &mut StdRng,
301 dim: &[usize],
302 from: f32, to: f32);
303
304
305 pub(crate) fn val(&self) -> Tensor {
308 self.net.borrow().get_tensor(self.id).unwrap()
309 }
310 pub(crate) fn set_val(&mut self, val: Tensor) {
311 self.net.borrow_mut().set_tensor(self.id, val).expect("");
312 }
313 pub fn set(&mut self, o: &VarInner) {
314 self.set_val(o.val())
315 }
316
317 pub fn grad(&self) -> Result<VarInner, AutoDiffError> {
318 Ok(VarInner::new_tensor(self.net.borrow().get_grad(self.id)?))
319 }
320
321 pub fn bp(&self) -> Result<(), AutoDiffError> {
323 let mut job = BTreeMap::new();
324 job.insert(self.id, Tensor::ones_like(&self.val()));
325 self.net.borrow_mut().bptt(&job);
326
327 Ok(())
328 }
329
330 pub fn step(&self, opt: &mut dyn Optimizer) -> Result<(), AutoDiffError> {
332 opt.step(self.net.clone());
333 Ok(())
334 }
335
336 pub fn rerun(&self) -> Result<(), AutoDiffError> {
337 let mut all_input = Vec::new();
338 for i in &self.net.borrow().get_input_edge_data() {
339 all_input.push(*i);
340 }
341 self.net.borrow_mut().eval(&all_input).expect("");
342 Ok(())
343 }
344
345 pub fn get_io_var(&self) -> Result<(Vec<VarInner>, Vec<VarInner>), AutoDiffError> {
346 let input_id = self.net.borrow().get_input_edge_data();
347 let output_id = self.net.borrow().get_output_edge_data();
348 Ok((input_id.iter().map(|x| VarInner {id: *x, need_grad: true, net: self.net.clone()}).collect(),
349 output_id.iter().map(|x| VarInner {id: *x, need_grad: true, net: self.net.clone()}).collect(),))
350 }
351
352 pub fn get_var_by_label(&self, label: &str) -> Result<VarInner, AutoDiffError> {
353 let id = self.net.borrow().get_id_by_label(label)?;
354 Ok(VarInner {
356 id,
357 need_grad: true,
358 net: self.net.clone(),
359 })
360 }
361
362 pub(crate) fn set_label(&self, label: &str) -> Result<(), AutoDiffError> {
363 self.net.borrow_mut().set_label(label, &self.id)
364 }
365
366 pub(crate) fn set_grad(&mut self, use_gradient: bool) {
367 self.need_grad = use_gradient;
368 }
369
370 pub(crate) fn reset_net(&mut self) {
371 let value = self.val();
372 let mut net = Net::new();
373 let id = net.add_tensor(value);
374 self.id = id;
375 self.net = Rc::new(RefCell::new(net));
376 }
377
378 pub(crate) fn called_with(&self, op: Op,
380 others: &[Rc<RefCell<VarInner>>])
381 -> Result<Vec<VarInner>, AutoDiffError> {
382 if self.need_grad {
383 let mut other_var_by_networks: Vec<Vec<Rc<RefCell<VarInner>>>> = vec![];
384 for item in others.iter().cloned() {
385 if !Rc::ptr_eq(&self.net, &item.borrow().net) {
386 let mut existing_net = false;
387 for set in &mut other_var_by_networks {
388 if Rc::ptr_eq(&item.borrow().net, &set[0].borrow().net) {
389 set.push(item.clone());
390 existing_net = true;
391 break;
392 }
393 }
394 if ! existing_net {
395 other_var_by_networks.push(vec![item.clone()]);
396 }
397 }
398 }
399 for set in other_var_by_networks {
400 let mut old_ids = vec![];
401 for item in &set {
402 old_ids.push(item.borrow().id);
403 }
404 let other_key = self.net.borrow_mut().append(
405 &set[0].borrow().net.borrow(), &old_ids)?;
406 for (index, item) in set.iter().enumerate() {
407 item.borrow_mut().net = self.net.clone();
408 item.borrow_mut().id = other_key[index];
409 }
410
411 }
412
413 let mut input_id = vec![self.id];
414 let mut inputs = vec![self.net.borrow().get_tensor(self.id)?];
415 for i in others {
416 input_id.push(i.borrow().id);
417 inputs.push(self.net.borrow().get_tensor(i.borrow().id)?);
418 }
419
420 let mut output_id = vec![];
421 let mut outputs = Vec::new();
422 let mut ret = Vec::new();
423 for _ in 0..op.get_output_size() {
424 let new_output = VarInner::new_net_tensor(self.net.clone(),
425 self.need_grad,
426 Tensor::new());
427 output_id.push(new_output.id);
428 outputs.push(self.net.borrow().get_tensor(new_output.id)?);
429 ret.push(new_output);
430 }
431
432 op.apply(&inputs, &outputs);
433 let opid = self.net.borrow_mut().add_op(op);
434
435 self.net.borrow_mut().connect(&input_id,
436 opid,
437 &output_id);
438
439 Ok(ret)
440 } else {
441 let mut inputs = vec![self.net.borrow().get_tensor(self.id)?];
442 for i in others {
443 inputs.push(i.borrow().net.borrow().get_tensor(i.borrow().id)?);
444 }
445
446 let mut ret = Vec::new();
447 let mut outputs = Vec::new();
448 for _ in 0..op.get_output_size() {
449 let new_output = VarInner::new_net_tensor(Rc::new(RefCell::new(Net::new())),
450 self.need_grad,
451 Tensor::new());
452 outputs.push(new_output.net.borrow().get_tensor(new_output.id)?);
453 ret.push(new_output);
454 }
455
456 op.apply(&inputs, &outputs);
457
458 Ok(ret)
459 }
460 }
461
462 var_inner_2_to_1!(add, Add);
464 var_inner_2_to_1!(sub, Sub);
465 var_inner_2_to_1!(mul, Mul);
466 var_inner_2_to_1!(div, Div);
467 var_inner_2_to_1!(matmul, Matmul);
468 var_inner_2_to_1!(outer, Outer);
469
470 pub fn elu(&self, alpha: VarInner) -> Result<VarInner, AutoDiffError> {
472 let new_one = ELU::new(alpha.val());
473 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
474 let mut result = self.called_with(op, &[])?;
475 Ok(result.remove(0))
476 }
477 var_inner_1_to_1!(relu, ReLU);
478 var_inner_1_to_1!(sigmoid, Sigmoid);
479
480 var_inner_2_to_1!(mse_loss, MSELoss);
482 var_inner_2_to_1!(bce_with_logits_loss, BCEWithLogitsLoss);
483 var_inner_2_to_1!(cross_entropy_loss, CrossEntropyLoss);
484
485 var_inner_1_to_1!(abs, Abs);
487 var_inner_1_to_1!(acos, Acos);
488 var_inner_1_to_1!(asin, Asin);
489 var_inner_1_to_1!(atan, Atan);
490 var_inner_1_to_1!(ceil, Ceil);
491 var_inner_1_to_1!(cos, Cos);
492 var_inner_1_to_1!(cosh, Cosh);
493 var_inner_1_to_1!(exp, Exp);
494 var_inner_1_to_1!(expm1, Expm1);
495 var_inner_1_to_1!(floor, Floor);
496 var_inner_1_to_1!(frac, Frac);
497 var_inner_1_to_1!(log, Log);
498 var_inner_1_to_1!(log10, Log10);
499 var_inner_1_to_1!(log1p, Log1p);
500 var_inner_1_to_1!(log1pexp, Log1pexp);
501 var_inner_1_to_1!(log2, Log2);
502 var_inner_1_to_1!(neg, Neg);
503 var_inner_1_to_1!(reciprocal, Reciprocal);
504 var_inner_1_to_1!(round, Round);
505 var_inner_1_to_1!(rsqrt, Rsqrt);
506 var_inner_1_to_1!(sign, Sign);
507 var_inner_1_to_1!(sin, Sin);
508 var_inner_1_to_1!(sinh, Sinh);
509 var_inner_1_to_1!(sqrt, Sqrt);
510 var_inner_1_to_1!(tan, Tan);
511 var_inner_1_to_1!(tanh, Tanh);
512 var_inner_1_to_1!(trunc, Trunc);
513
514 var_inner_2_to_1!(max_pair, MaxPair);
516 var_inner_2_to_1!(min_pair, MinPair);
517 var_inner_1_to_1_with_para!(arg_sort, ArgSort,
518 dim: usize, descending: bool);
519 var_inner_2_to_1!(eq_elem, EqElem);
520 var_inner_2_to_1!(equal, Equal);
521 var_inner_2_to_1!(ge, Ge);
522 var_inner_2_to_1!(gt, Gt);
523 var_inner_2_to_1!(le, Le);
524 var_inner_2_to_1!(lt, Lt);
525 var_inner_2_to_1!(ne, Ne);
526
527 var_inner_more_to_1_with_para!(cat, Cat, dim: usize);
529 pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<VarInner>, AutoDiffError> {
530 let new_one = Chunk::new(chunks, dim);
531 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
532 let result = self.called_with(op, &Vec::new())?;
533 Ok(result)
534 }
535 pub fn conditional_select(&self, x: Rc<RefCell<VarInner>>, y: Rc<RefCell<VarInner>>) -> Result<VarInner, AutoDiffError> {
536 let new_one = ConditionalSelect::new();
537 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
538 let inputs = vec![x, y];
539 let mut result = self.called_with(op, &inputs)?;
540 Ok(result.remove(0))
541 }
542 pub fn gather(&self, dim: usize, index: Rc<RefCell<VarInner>>) -> Result<VarInner, AutoDiffError> {
543 let new_one = Gather::new(dim);
544 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
545 let inputs = vec![index];
546 let mut result = self.called_with(op, &inputs)?;
547 Ok(result.remove(0))
548 }
549 pub fn index_select(&self, dim: usize, index: Rc<RefCell<VarInner>>) -> Result<VarInner, AutoDiffError> {
550 let new_one = IndexSelect::new(dim);
551 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
552 let inputs = vec![index];
553 let mut result = self.called_with(op, &inputs)?;
554 Ok(result.remove(0))
555 }
556 pub fn index_exclude(&self, dim: usize,
557 index: Rc<RefCell<VarInner>>)
558 -> Result<VarInner, AutoDiffError> {
559 let new_one = IndexExclude::new(dim);
560 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
561 let inputs = vec![index];
562 let mut result = self.called_with(op, &inputs)?;
563 Ok(result.remove(0))
564 }
565 pub fn permute(&self, dim: &[usize]) -> Result<VarInner, AutoDiffError> {
566 let new_one = Permute::new(dim);
567 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
568 let mut result = self.called_with(op, &[])?;
569 Ok(result.remove(0))
570 }
571 pub fn repeat(&self, dim: &[usize]) -> Result<VarInner, AutoDiffError> {
572 let new_one = Repeat::new(dim);
573 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
574 let mut result = self.called_with(op, &[])?;
575 Ok(result.remove(0))
576 }
577 pub fn reshape(&self, new_shape: &[usize]) -> Result<VarInner, AutoDiffError> {
578 let new_one = Reshape::new(new_shape);
579 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
580 let mut result = self.called_with(op, &[])?;
581 Ok(result.remove(0))
582 }
583 pub fn split(&self, sections: &[usize], dim: usize) -> Result<Vec<VarInner>, AutoDiffError> {
584 let new_one = Split::new(sections, dim);
585 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
586 let result = self.called_with(op, &Vec::new())?;
587 Ok(result)
588 }
589 pub fn squeeze(&self, dim: Option<usize>) -> Result<VarInner, AutoDiffError> {
590 let new_one = Squeeze::new(dim);
591 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
592 let mut result = self.called_with(op, &[])?;
593 Ok(result.remove(0))
594 }
595 var_inner_1_to_1!(t, T);
596 pub fn take(&self, index: &[usize]) -> Result<VarInner, AutoDiffError> {
597 let new_one = Take::new(index);
598 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
599 let mut result = self.called_with(op, &[])?;
600 Ok(result.remove(0))
601 }
602 pub fn unsqueeze(&self, dim: usize) -> Result<VarInner, AutoDiffError> {
603 let new_one = Unsqueeze::new(dim);
604 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
605 let mut result = self.called_with(op, &[])?;
606 Ok(result.remove(0))
607 }
608 var_inner_more_to_1_with_para!(stack, Stack, dim: usize);
609
610 var_inner_1_to_1!(det, Det);
612 var_inner_1_to_1!(inv, Inv);
613 var_inner_1_to_1!(normalize_unit, NormalizeUnit);
614 var_inner_1_to_1!(tr, Tr);
615
616 var_inner_1_to_1_with_para!(argmax, Argmax, dim: Option<&[usize]>, keepdim: bool);
618 var_inner_1_to_1_with_para!(argmin, Argmin, dim: Option<&[usize]>, keepdim: bool);
619 var_inner_1_to_1_with_para!(logsumexp, Logsumexp, dim: Option<&[usize]>, keepdim: bool);
620 var_inner_1_to_1_with_para!(mean, Mean, dim: Option<&[usize]>, keepdim: bool);
621 var_inner_1_to_1_with_para!(prod, Prod, dim: Option<&[usize]>, keepdim: bool);
622 var_inner_1_to_1_with_para!(std, Std, dim: Option<&[usize]>, keepdim: bool);
623 var_inner_1_to_1_with_para!(sum, Sum, dim: Option<&[usize]>, keepdim: bool);
624 var_inner_1_to_1_with_para!(var, Variance, dim: Option<&[usize]>, keepdim: bool);
625 var_inner_1_to_1_with_para!(max, Max, dim: Option<&[usize]>, keepdim: bool);
626 var_inner_1_to_1_with_para!(min, Min, dim: Option<&[usize]>, keepdim: bool);
627
628 var_inner_1_to_1_with_para!(get_patch, GetPatch, range: &[(usize, usize)], step: Option<&[usize]>);
630 var_inner_2_to_1_with_para!(set_patch, SetPatch, range: &[(usize, usize)], step: Option<&[usize]>);
631 var_inner_1_to_1_with_para!(view, View, new_shape: &[usize]);
632
633 pub fn dump_net(&self) -> Rc<RefCell<Net>> {
634 self.net.clone()
635 }
636
637 pub(crate) fn set_inner(id: GenKey, need_grad: bool, net: Net) -> VarInner {
638 VarInner {
639 id,
640 need_grad,
641 net: Rc::new(RefCell::new(net))
642 }
643 }
644}
645
646impl PartialEq for VarInner {
647 fn eq(&self, other: &Self) -> bool {
648 self.val().eq(&other.val())
649 }
650}
651
652impl Eq for VarInner {}
653
654impl fmt::Display for VarInner {
655 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
656 write!(f, "id: {}", self.id)?;
657 write!(f, "tensor: {}", self.val())
658 }
659}
660
661impl fmt::Debug for VarInner {
662 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
663 write!(f, "id: {}", self.id)?;
664 write!(f, "tensor: {}", self.val())
665 }
666}
667
668impl Clone for VarInner {
669 fn clone(&self) -> Self {
670 let val = self.val();
671 let mut ret = VarInner::new(&[], &[]);
672 ret.set_val(val);
673 ret.need_grad = self.need_grad;
674 ret
675 }
676}
677
678