1#![allow(clippy::redundant_closure_call)]
2use tensor_rs::tensor::Tensor;
3use super::{OpTrait, OpCall, Op, OpHandle};
4
5use std::cell::{RefCell};
6use std::rc::Rc;
7
8use crate::var::{Var};
9use crate::err::AutoDiffError;
10use super::macros::{many_to_1_op_with_paras,
11 one_to_vec_op_with_paras,
12 new_element_op,
13 one_to_1_op_with_paras};
14
15#[cfg(feature = "use-serde")]
16use serde::{Serialize, Deserialize};
17#[cfg(feature = "use-serde")]
18use std::any::Any;
19
20#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
21pub struct Cat {
22 #[cfg_attr(feature = "use-serde", serde(skip))]
23 handle: OpHandle,
24 dim: usize
25}
26impl Cat {
27 pub fn new(dim: usize) -> Cat {
28 Cat {
29 handle: OpHandle::new(),
30 dim,
31 }
32 }
33 fn get_handle(&self) -> &OpHandle {
34 &self.handle
35 }
36 fn get_handle_mut(&mut self) -> &mut OpHandle {
37 &mut self.handle
38 }
39}
40impl OpCall for Cat {
41 fn call(&mut self, inputs: &[&Var])
42 -> Result<Vec<Var>, AutoDiffError> {
43 let new_one = Cat {
44 handle: OpHandle::new(),
45 dim: self.dim,
46 };
47
48 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
49
50 inputs[0].called_with(op, &inputs[1..inputs.len()])
51 }
52}
53impl OpTrait for Cat {
54
55 fn get_name(&self) -> &'static str {
56 "Cat"
57 }
58 fn get_input_size(&self) -> usize {
59 1
60 }
61 fn get_output_size(&self) -> usize {
62 1
63 }
64 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
65 let mut new_input = vec![];
66 for item in input.iter().skip(1) {
67 new_input.push(item.ref_copy());
68 }
69 output[0].swap(&input[0].cat(&new_input, self.dim));
70 }
71 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
72 let mut splits = Vec::new();
73 for i in input {
74 splits.push(i.size()[self.dim]);
75 }
76 let result = output_grad[0].split(&splits, self.dim);
77 for i in result {
78 input_grad[0].swap(&i);
79 }
80 }
81 fn get_values(&self) -> Vec<Tensor> {
82 Vec::new()
83 }
84 fn get_grads(&self) -> Vec<Tensor> {
85 Vec::new()
86 }
87 fn set_values(&self, _v: &[Tensor]) {
88 }
89 #[cfg(feature = "use-serde")]
90 fn as_any(&self) -> &dyn Any {
91 self
92 }
93}
94
95
96one_to_vec_op_with_paras!(Chunk,
97 "Chunk",
98 1,
99 1, chunk,
101 (|input: &[Tensor],
102 output_grad: &[Tensor],
103 input_grad: &[Tensor]| {
104 unimplemented!();
105 }),
106 chunks: usize, dim: usize);
107
108#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
110pub struct Gather {
111 #[cfg_attr(feature = "use-serde", serde(skip))]
112 handle: OpHandle,
113 dim: usize
114}
115impl Gather {
116 pub fn new(dim: usize) -> Gather {
117 Gather {
118 handle: OpHandle::new(),
119 dim,
120 }
121 }
122 fn get_handle(&self) -> &OpHandle {
123 &self.handle
124 }
125 fn get_handle_mut(&mut self) -> &mut OpHandle {
126 &mut self.handle
127 }
128}
129impl OpCall for Gather {
130 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
131 let new_one = Gather {
132 handle: OpHandle::new(),
133 dim: self.dim,
134 };
135
136 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
137
138 inputs[0].called_with(op, &inputs[1..inputs.len()])
139 }
140}
141impl OpTrait for Gather {
142
143 fn get_name(&self) -> &'static str {
144 "Gather"
145 }
146 fn get_input_size(&self) -> usize {
147 1
148 }
149 fn get_output_size(&self) -> usize {
150 1
151 }
152 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
153 output[0].swap(&input[0].gather(self.dim, &input[1]));
154 }
155 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
156 unimplemented!();
157 }
158 fn get_values(&self) -> Vec<Tensor> {
159 Vec::new()
160 }
161 fn get_grads(&self) -> Vec<Tensor> {
162 Vec::new()
163 }
164 fn set_values(&self, _v: &[Tensor]) {
165 }
166 #[cfg(feature = "use-serde")]
167 fn as_any(&self) -> &dyn Any {
168 self
169 }
170}
171
172#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
174pub struct IndexSelect {
175 #[cfg_attr(feature = "use-serde", serde(skip))]
176 handle: OpHandle,
177 dim: usize
178}
179impl IndexSelect {
180 pub fn new(dim: usize) -> IndexSelect {
181 IndexSelect {
182 handle: OpHandle::new(),
183 dim,
184 }
185 }
186 fn get_handle(&self) -> &OpHandle {
187 &self.handle
188 }
189 fn get_handle_mut(&mut self) -> &mut OpHandle {
190 &mut self.handle
191 }
192}
193impl OpCall for IndexSelect {
194 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
195 let new_one = IndexSelect {
196 handle: OpHandle::new(),
197 dim: self.dim,
198 };
199
200 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
201
202 inputs[0].called_with(op, &inputs[1..inputs.len()])
203 }
204}
205impl OpTrait for IndexSelect {
206
207 fn get_name(&self) -> &'static str {
208 "Index_select"
209 }
210 fn get_input_size(&self) -> usize {
211 1
212 }
213 fn get_output_size(&self) -> usize {
214 1
215 }
216 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
217 output[0].swap(&input[0].index_select(self.dim, &input[1]));
218 }
219 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
220 unimplemented!();
221 }
222 fn get_values(&self) -> Vec<Tensor> {
223 Vec::new()
224 }
225 fn get_grads(&self) -> Vec<Tensor> {
226 Vec::new()
227 }
228 fn set_values(&self, _v: &[Tensor]) {
229 }
230 #[cfg(feature = "use-serde")]
231 fn as_any(&self) -> &dyn Any {
232 self
233 }
234}
235
236#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
238pub struct IndexExclude {
239 #[cfg_attr(feature = "use-serde", serde(skip))]
240 handle: OpHandle,
241 dim: usize
242}
243impl IndexExclude {
244 pub fn new(dim: usize) -> IndexExclude {
245 IndexExclude {
246 handle: OpHandle::new(),
247 dim,
248 }
249 }
250 fn get_handle(&self) -> &OpHandle {
251 &self.handle
252 }
253 fn get_handle_mut(&mut self) -> &mut OpHandle {
254 &mut self.handle
255 }
256}
257impl OpCall for IndexExclude {
258 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
259 let new_one = IndexExclude {
260 handle: OpHandle::new(),
261 dim: self.dim,
262 };
263
264 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
265
266 inputs[0].called_with(op, &inputs[1..inputs.len()])
267 }
268}
269impl OpTrait for IndexExclude {
270
271 fn get_name(&self) -> &'static str {
272 "Index_exclude"
273 }
274 fn get_input_size(&self) -> usize {
275 1
276 }
277 fn get_output_size(&self) -> usize {
278 1
279 }
280 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
281 output[0].swap(&input[0].index_exclude(self.dim, &input[1]));
282 }
283 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
284 unimplemented!();
285 }
286 fn get_values(&self) -> Vec<Tensor> {
287 Vec::new()
288 }
289 fn get_grads(&self) -> Vec<Tensor> {
290 Vec::new()
291 }
292 fn set_values(&self, _v: &[Tensor]) {
293 }
294 #[cfg(feature = "use-serde")]
295 fn as_any(&self) -> &dyn Any {
296 self
297 }
298}
299
300#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
302pub struct Reshape {
303 #[cfg_attr(feature = "use-serde", serde(skip))]
304 handle: OpHandle,
305 new_shape: Vec<usize>,
306}
307impl Reshape {
308 pub fn new(new_shape: &[usize]) -> Reshape {
309 Reshape {
310 handle: OpHandle::new(),
311 new_shape: new_shape.to_vec(),
312 }
313 }
314 fn get_handle(&self) -> &OpHandle {
315 &self.handle
316 }
317 fn get_handle_mut(&mut self) -> &mut OpHandle {
318 &mut self.handle
319 }
320}
321impl OpCall for Reshape {
322 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
323 let new_one = Reshape {
324 handle: OpHandle::new(),
325 new_shape: self.new_shape.clone(),
326 };
327
328 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
329
330 inputs[0].called_with(op, &inputs[1..inputs.len()])
331 }
332}
333impl OpTrait for Reshape {
334
335 fn get_name(&self) -> &'static str {
336 "Reshape"
337 }
338 fn get_input_size(&self) -> usize {
339 1
340 }
341 fn get_output_size(&self) -> usize {
342 1
343 }
344 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
345 output[0].swap(&input[0].reshape(&self.new_shape));
346 }
347 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
348 unimplemented!();
349 }
350 fn get_values(&self) -> Vec<Tensor> {
351 Vec::new()
352 }
353 fn get_grads(&self) -> Vec<Tensor> {
354 Vec::new()
355 }
356 fn set_values(&self, _v: &[Tensor]) {
357 }
358 #[cfg(feature = "use-serde")]
359 fn as_any(&self) -> &dyn Any {
360 self
361 }
362}
363
364
365#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
367pub struct Split {
368 #[cfg_attr(feature = "use-serde", serde(skip))]
369 handle: OpHandle,
370 sections: Vec<usize>,
371 dim: usize,
372}
373impl Split {
374 pub fn new(sections: &[usize], dim: usize) -> Split {
375 Split {
376 handle: OpHandle::new(),
377 sections: sections.to_vec(),
378 dim,
379 }
380 }
381 fn get_handle(&self) -> &OpHandle {
382 &self.handle
383 }
384 fn get_handle_mut(&mut self) -> &mut OpHandle {
385 &mut self.handle
386 }
387}
388impl OpCall for Split {
389 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
390 let new_one = Split {
391 handle: OpHandle::new(),
392 sections: self.sections.clone(),
393 dim: self.dim,
394 };
395
396 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
397
398 inputs[0].called_with(op, &inputs[1..inputs.len()])
399 }
400}
401impl OpTrait for Split {
402
403 fn get_name(&self) -> &'static str {
404 "Split"
405 }
406 fn get_input_size(&self) -> usize {
407 1
408 }
409 fn get_output_size(&self) -> usize {
410 self.sections.len()
411 }
412 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
413 let mut result = input[0].split(&self.sections, self.dim);
414 for (index, i) in result.drain(..).enumerate() {
415 output[index].swap(&i);
416 }
417 }
418 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
419 unimplemented!();
420 }
421 fn get_values(&self) -> Vec<Tensor> {
422 Vec::new()
423 }
424 fn get_grads(&self) -> Vec<Tensor> {
425 Vec::new()
426 }
427 fn set_values(&self, _v: &[Tensor]) {
428 }
429 #[cfg(feature = "use-serde")]
430 fn as_any(&self) -> &dyn Any {
431 self
432 }
433}
434
435one_to_1_op_with_paras!(Squeeze,
437 "Squeeze",
438 1, 1,
439 squeeze,
440 (|input: &[Tensor],
441 output_grad: &[Tensor],
442 input_grad: &[Tensor]| {
443 unimplemented!();
444 }),
445 dim: Option<usize>);
446
447
448many_to_1_op_with_paras!(Stack,
450 "Stack",
451 2, 1,
453 stack,
454 (|input: &[Tensor],
455 output_grad: &[Tensor],
456 input_grad: &[Tensor]| {
457 unimplemented!();
458 }),
459 dim: usize);
460new_element_op!(T,
462 "T",
463 t,
464 (|input: &[Tensor],
465 output_grad: &[Tensor],
466 input_grad: &[Tensor]| {
467 unimplemented!();
468 }));
469
470#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
472pub struct Take {
473 #[cfg_attr(feature = "use-serde", serde(skip))]
474 handle: OpHandle,
475 sizes: Vec<usize>,
476}
477impl Take {
478 pub fn new(sizes: &[usize]) -> Take {
479 Take {
480 handle: OpHandle::new(),
481 sizes: sizes.to_vec(),
482 }
483 }
484 fn get_handle(&self) -> &OpHandle {
485 &self.handle
486 }
487 fn get_handle_mut(&mut self) -> &mut OpHandle {
488 &mut self.handle
489 }
490}
491impl OpCall for Take {
492 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
493 let new_one = Take {
494 handle: OpHandle::new(),
495 sizes: self.sizes.clone(),
496 };
497
498 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
499
500 inputs[0].called_with(op, &inputs[1..inputs.len()])
501 }
502}
503impl OpTrait for Take {
504
505 fn get_name(&self) -> &'static str {
506 "Take"
507 }
508 fn get_input_size(&self) -> usize {
509 1
510 }
511 fn get_output_size(&self) -> usize {
512 1
513 }
514 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
515 output[0].swap(&input[0].take(&self.sizes))
516 }
517 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
518 unimplemented!();
519 }
520 fn get_values(&self) -> Vec<Tensor> {
521 Vec::new()
522 }
523 fn get_grads(&self) -> Vec<Tensor> {
524 Vec::new()
525 }
526 fn set_values(&self, _v: &[Tensor]) {
527 }
528 #[cfg(feature = "use-serde")]
529 fn as_any(&self) -> &dyn Any {
530 self
531 }
532}
533
534#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
536pub struct Permute {
537 #[cfg_attr(feature = "use-serde", serde(skip))]
538 handle: OpHandle,
539 sizes: Vec<usize>,
540}
541impl Permute {
542 pub fn new(sizes: &[usize]) -> Permute {
543 Permute {
544 handle: OpHandle::new(),
545 sizes: sizes.to_vec(),
546 }
547 }
548 fn get_handle(&self) -> &OpHandle {
549 &self.handle
550 }
551 fn get_handle_mut(&mut self) -> &mut OpHandle {
552 &mut self.handle
553 }
554}
555impl OpCall for Permute {
556 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
557 let new_one = Permute {
558 handle: OpHandle::new(),
559 sizes: self.sizes.clone(),
560 };
561
562 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
563
564 inputs[0].called_with(op, &inputs[1..inputs.len()])
565 }
566}
567impl OpTrait for Permute {
568
569 fn get_name(&self) -> &'static str {
570 "Permute"
571 }
572 fn get_input_size(&self) -> usize {
573 1
574 }
575 fn get_output_size(&self) -> usize {
576 1
577 }
578 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
579 output[0].swap(&input[0].permute(&self.sizes))
580 }
581 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
582 unimplemented!();
583 }
584 fn get_values(&self) -> Vec<Tensor> {
585 Vec::new()
586 }
587 fn get_grads(&self) -> Vec<Tensor> {
588 Vec::new()
589 }
590 fn set_values(&self, _v: &[Tensor]) {
591 }
592 #[cfg(feature = "use-serde")]
593 fn as_any(&self) -> &dyn Any {
594 self
595 }
596}
597
598
599one_to_1_op_with_paras!(Unsqueeze,
601 "Unsqueeze",
602 1, 1,
603 unsqueeze,
604 (|input: &[Tensor],
605 output_grad: &[Tensor],
606 input_grad: &[Tensor]| {
607 unimplemented!();
608 }),
609 dim: usize);
610
611#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
613pub struct ConditionalSelect {
614 #[cfg_attr(feature = "use-serde", serde(skip))]
615 handle: OpHandle,
616}
617impl ConditionalSelect {
618 pub fn new() -> ConditionalSelect {
619 ConditionalSelect {
620 handle: OpHandle::new(),
621 }
622 }
623 fn get_handle(&self) -> &OpHandle {
624 &self.handle
625 }
626 fn get_handle_mut(&mut self) -> &mut OpHandle {
627 &mut self.handle
628 }
629}
630impl OpCall for ConditionalSelect {
631 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
632 let new_one = ConditionalSelect {
633 handle: OpHandle::new(),
634 };
635
636 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
637
638 inputs[0].called_with(op, &inputs[1..inputs.len()])
639 }
640}
641impl OpTrait for ConditionalSelect {
642
643 fn get_name(&self) -> &'static str {
644 "ConditionalSelect"
645 }
646 fn get_input_size(&self) -> usize {
647 3
648 }
649 fn get_output_size(&self) -> usize {
650 1
651 }
652 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
653 output[0].swap(&input[0].conditional_select(&input[0], &input[1]));
654 }
655 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
656 unimplemented!();
657 }
658 fn get_values(&self) -> Vec<Tensor> {
659 Vec::new()
660 }
661 fn get_grads(&self) -> Vec<Tensor> {
662 Vec::new()
663 }
664 fn set_values(&self, _v: &[Tensor]) {
665 }
666 #[cfg(feature = "use-serde")]
667 fn as_any(&self) -> &dyn Any {
668 self
669 }
670}
671impl Default for ConditionalSelect {
672 fn default() -> Self {
673 Self::new()
674 }
675}
676
677
678#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
680pub struct Repeat {
681 #[cfg_attr(feature = "use-serde", serde(skip))]
682 handle: OpHandle,
683 sizes: Vec<usize>,
684}
685impl Repeat {
686 pub fn new(sizes: &[usize]) -> Repeat {
687 Repeat {
688 handle: OpHandle::new(),
689 sizes: sizes.to_vec(),
690 }
691 }
692 fn get_handle(&self) -> &OpHandle {
693 &self.handle
694 }
695 fn get_handle_mut(&mut self) -> &mut OpHandle {
696 &mut self.handle
697 }
698}
699impl OpCall for Repeat {
700 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
701 let new_one = Repeat {
702 handle: OpHandle::new(),
703 sizes: self.sizes.clone(),
704 };
705
706 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
707
708 inputs[0].called_with(op, &inputs[1..inputs.len()])
709 }
710}
711impl OpTrait for Repeat {
712
713 fn get_name(&self) -> &'static str {
714 "Repeat"
715 }
716 fn get_input_size(&self) -> usize {
717 1
718 }
719 fn get_output_size(&self) -> usize {
720 1
721 }
722 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
723 output[0].swap(&input[0].repeat(&self.sizes))
724 }
725 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
726 unimplemented!();
727 }
728 fn get_values(&self) -> Vec<Tensor> {
729 Vec::new()
730 }
731 fn get_grads(&self) -> Vec<Tensor> {
732 Vec::new()
733 }
734 fn set_values(&self, _v: &[Tensor]) {
735 }
736 #[cfg(feature = "use-serde")]
737 fn as_any(&self) -> &dyn Any {
738 self
739 }
740}