auto_diff/op/
index_slicing.rs

1#![allow(clippy::redundant_closure_call)]
2use tensor_rs::tensor::Tensor;
3use super::{OpTrait, OpCall, Op, OpHandle};
4
5use std::cell::{RefCell};
6use std::rc::Rc;
7
8use crate::var::{Var};
9use crate::err::AutoDiffError;
10use super::macros::{many_to_1_op_with_paras,
11                    one_to_vec_op_with_paras,
12                    new_element_op,
13                    one_to_1_op_with_paras};
14
15#[cfg(feature = "use-serde")]
16use serde::{Serialize, Deserialize};
17#[cfg(feature = "use-serde")]
18use std::any::Any;
19
20#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
21pub struct Cat {
22    #[cfg_attr(feature = "use-serde", serde(skip))]
23    handle: OpHandle,
24    dim: usize
25}
26impl Cat {
27    pub fn new(dim: usize) -> Cat {
28        Cat {
29            handle: OpHandle::new(),
30            dim,
31        }
32    }
33    fn get_handle(&self) -> &OpHandle {
34        &self.handle
35    }
36    fn get_handle_mut(&mut self) -> &mut OpHandle {
37        &mut self.handle
38    }
39}
40impl OpCall for Cat {
41    fn call(&mut self, inputs: &[&Var])
42            -> Result<Vec<Var>, AutoDiffError> {
43        let new_one = Cat {
44            handle: OpHandle::new(),
45            dim: self.dim,
46        };
47
48        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
49
50        inputs[0].called_with(op, &inputs[1..inputs.len()])
51    }
52}
53impl OpTrait for Cat {
54
55    fn get_name(&self) -> &'static str {
56        "Cat"
57    }
58    fn get_input_size(&self) -> usize {
59        1
60    }
61    fn get_output_size(&self) -> usize {
62        1
63    }
64    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
65        let mut new_input = vec![];
66        for item in input.iter().skip(1) {
67            new_input.push(item.ref_copy());
68        }
69        output[0].swap(&input[0].cat(&new_input, self.dim));
70    }
71    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
72        let mut splits = Vec::new();
73        for i in input {
74            splits.push(i.size()[self.dim]);
75        }
76        let result = output_grad[0].split(&splits, self.dim);
77        for i in result {
78            input_grad[0].swap(&i);
79        }
80    }
81    fn get_values(&self) -> Vec<Tensor> {
82        Vec::new()
83    }
84    fn get_grads(&self) -> Vec<Tensor> {
85        Vec::new()
86    }
87    fn set_values(&self, _v: &[Tensor]) {
88    }
89    #[cfg(feature = "use-serde")]
90    fn as_any(&self) -> &dyn Any {
91	self
92    }
93}
94
95
96one_to_vec_op_with_paras!(Chunk,
97                          "Chunk",
98                          1,
99			  1, // TODO, this is dependent on the number of output.
100			  chunk,
101                          (|input: &[Tensor],
102                           output_grad: &[Tensor],
103                           input_grad: &[Tensor]| {
104                               unimplemented!();
105                           }),
106                          chunks: usize, dim: usize);
107                          
108// gather
109#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
110pub struct Gather {
111    #[cfg_attr(feature = "use-serde", serde(skip))]
112    handle: OpHandle,
113    dim: usize
114}
115impl Gather {
116    pub fn new(dim: usize) -> Gather {
117        Gather {
118            handle: OpHandle::new(),
119            dim,
120        }
121    }
122    fn get_handle(&self) -> &OpHandle {
123        &self.handle
124    }
125    fn get_handle_mut(&mut self) -> &mut OpHandle {
126        &mut self.handle
127    }
128}
129impl OpCall for Gather {
130    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
131        let new_one = Gather {
132            handle: OpHandle::new(),
133            dim: self.dim,
134        };
135
136        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
137
138        inputs[0].called_with(op, &inputs[1..inputs.len()])
139    }
140}
141impl OpTrait for Gather {
142
143    fn get_name(&self) -> &'static str {
144        "Gather"
145    }
146    fn get_input_size(&self) -> usize {
147        1
148    }
149    fn get_output_size(&self) -> usize {
150        1
151    }
152    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
153        output[0].swap(&input[0].gather(self.dim, &input[1]));
154    }
155    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
156        unimplemented!();
157    }
158    fn get_values(&self) -> Vec<Tensor> {
159        Vec::new()
160    }
161    fn get_grads(&self) -> Vec<Tensor> {
162        Vec::new()
163    }
164    fn set_values(&self, _v: &[Tensor]) {
165    }
166    #[cfg(feature = "use-serde")]
167    fn as_any(&self) -> &dyn Any {
168	self
169    }
170}
171
172// index_select
173#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
174pub struct IndexSelect {
175    #[cfg_attr(feature = "use-serde", serde(skip))]
176    handle: OpHandle,
177    dim: usize
178}
179impl IndexSelect {
180    pub fn new(dim: usize) -> IndexSelect {
181        IndexSelect {
182            handle: OpHandle::new(),
183            dim,
184        }
185    }
186    fn get_handle(&self) -> &OpHandle {
187        &self.handle
188    }
189    fn get_handle_mut(&mut self) -> &mut OpHandle {
190        &mut self.handle
191    }
192}
193impl OpCall for IndexSelect {
194    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
195        let new_one = IndexSelect {
196            handle: OpHandle::new(),
197            dim: self.dim,
198        };
199
200        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
201
202        inputs[0].called_with(op, &inputs[1..inputs.len()])
203    }
204}
205impl OpTrait for IndexSelect {
206
207    fn get_name(&self) -> &'static str {
208        "Index_select"
209    }
210    fn get_input_size(&self) -> usize {
211        1
212    }
213    fn get_output_size(&self) -> usize {
214        1
215    }
216    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
217        output[0].swap(&input[0].index_select(self.dim, &input[1]));
218    }
219    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
220        unimplemented!();
221    }
222    fn get_values(&self) -> Vec<Tensor> {
223        Vec::new()
224    }
225    fn get_grads(&self) -> Vec<Tensor> {
226        Vec::new()
227    }
228    fn set_values(&self, _v: &[Tensor]) {
229    }
230    #[cfg(feature = "use-serde")]
231    fn as_any(&self) -> &dyn Any {
232	self
233    }
234}
235
236// index_exclude
237#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
238pub struct IndexExclude {
239    #[cfg_attr(feature = "use-serde", serde(skip))]
240    handle: OpHandle,
241    dim: usize
242}
243impl IndexExclude {
244    pub fn new(dim: usize) -> IndexExclude {
245        IndexExclude {
246            handle: OpHandle::new(),
247            dim,
248        }
249    }
250    fn get_handle(&self) -> &OpHandle {
251        &self.handle
252    }
253    fn get_handle_mut(&mut self) -> &mut OpHandle {
254        &mut self.handle
255    }
256}
257impl OpCall for IndexExclude {
258    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
259        let new_one = IndexExclude {
260            handle: OpHandle::new(),
261            dim: self.dim,
262        };
263
264        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
265
266        inputs[0].called_with(op, &inputs[1..inputs.len()])
267    }
268}
269impl OpTrait for IndexExclude {
270
271    fn get_name(&self) -> &'static str {
272        "Index_exclude"
273    }
274    fn get_input_size(&self) -> usize {
275        1
276    }
277    fn get_output_size(&self) -> usize {
278        1
279    }
280    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
281        output[0].swap(&input[0].index_exclude(self.dim, &input[1]));
282    }
283    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
284        unimplemented!();
285    }
286    fn get_values(&self) -> Vec<Tensor> {
287        Vec::new()
288    }
289    fn get_grads(&self) -> Vec<Tensor> {
290        Vec::new()
291    }
292    fn set_values(&self, _v: &[Tensor]) {
293    }
294    #[cfg(feature = "use-serde")]
295    fn as_any(&self) -> &dyn Any {
296	self
297    }
298}
299
300// reshape
301#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
302pub struct Reshape {
303    #[cfg_attr(feature = "use-serde", serde(skip))]
304    handle: OpHandle,
305    new_shape: Vec<usize>,
306}
307impl Reshape {
308    pub fn new(new_shape: &[usize]) -> Reshape {
309        Reshape {
310            handle: OpHandle::new(),
311            new_shape: new_shape.to_vec(),
312        }
313    }
314    fn get_handle(&self) -> &OpHandle {
315        &self.handle
316    }
317    fn get_handle_mut(&mut self) -> &mut OpHandle {
318        &mut self.handle
319    }
320}
321impl OpCall for Reshape {
322    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
323        let new_one = Reshape {
324            handle: OpHandle::new(),
325            new_shape: self.new_shape.clone(),
326        };
327
328        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
329
330        inputs[0].called_with(op, &inputs[1..inputs.len()])
331    }
332}
333impl OpTrait for Reshape {
334
335    fn get_name(&self) -> &'static str {
336        "Reshape"
337    }
338    fn get_input_size(&self) -> usize {
339        1
340    }
341    fn get_output_size(&self) -> usize {
342        1
343    }
344    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
345        output[0].swap(&input[0].reshape(&self.new_shape));
346    }
347    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
348        unimplemented!();
349    }
350    fn get_values(&self) -> Vec<Tensor> {
351        Vec::new()
352    }
353    fn get_grads(&self) -> Vec<Tensor> {
354        Vec::new()
355    }
356    fn set_values(&self, _v: &[Tensor]) {
357    }
358    #[cfg(feature = "use-serde")]
359    fn as_any(&self) -> &dyn Any {
360	self
361    }
362}
363
364
365// split
366#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
367pub struct Split {
368    #[cfg_attr(feature = "use-serde", serde(skip))]
369    handle: OpHandle,
370    sections: Vec<usize>,
371    dim: usize,
372}
373impl Split {
374    pub fn new(sections: &[usize], dim: usize) -> Split {
375        Split {
376            handle: OpHandle::new(),
377            sections: sections.to_vec(),
378            dim,
379        }
380    }
381    fn get_handle(&self) -> &OpHandle {
382        &self.handle
383    }
384    fn get_handle_mut(&mut self) -> &mut OpHandle {
385        &mut self.handle
386    }
387}
388impl OpCall for Split {
389    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
390        let new_one = Split {
391            handle: OpHandle::new(),
392            sections: self.sections.clone(),
393            dim: self.dim,
394        };
395
396        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
397
398        inputs[0].called_with(op, &inputs[1..inputs.len()])
399    }
400}
401impl OpTrait for Split {
402
403    fn get_name(&self) -> &'static str {
404        "Split"
405    }
406    fn get_input_size(&self) -> usize {
407        1
408    }
409    fn get_output_size(&self) -> usize {
410        self.sections.len()
411    }
412    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
413        let mut result = input[0].split(&self.sections, self.dim);
414        for (index, i) in result.drain(..).enumerate() {
415            output[index].swap(&i);
416        }
417    }
418    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
419        unimplemented!();
420    }
421    fn get_values(&self) -> Vec<Tensor> {
422        Vec::new()
423    }
424    fn get_grads(&self) -> Vec<Tensor> {
425        Vec::new()
426    }
427    fn set_values(&self, _v: &[Tensor]) {
428    }
429    #[cfg(feature = "use-serde")]
430    fn as_any(&self) -> &dyn Any {
431	self
432    }
433}
434
435// squeeze
436one_to_1_op_with_paras!(Squeeze,
437                        "Squeeze",
438                        1, 1,
439                        squeeze,
440                        (|input: &[Tensor],
441                         output_grad: &[Tensor],
442                         input_grad: &[Tensor]| {
443                             unimplemented!();
444                         }),
445                        dim: Option<usize>);
446
447
448// stack
449many_to_1_op_with_paras!(Stack,
450                          "Stack",
451                          2, // TODO, this is dependent on the number of input.
452                          1,
453                          stack,
454                          (|input: &[Tensor],
455                           output_grad: &[Tensor],
456                           input_grad: &[Tensor]| {
457                               unimplemented!();
458                           }),
459                          dim: usize);
460// t
461new_element_op!(T,
462                "T",
463                t,
464                (|input: &[Tensor],
465                 output_grad: &[Tensor],
466                 input_grad: &[Tensor]| {
467                     unimplemented!();
468                 }));
469
470// take
471#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
472pub struct Take {
473    #[cfg_attr(feature = "use-serde", serde(skip))]
474    handle: OpHandle,
475    sizes: Vec<usize>,
476}
477impl Take {
478    pub fn new(sizes: &[usize]) -> Take {
479        Take {
480            handle: OpHandle::new(),
481            sizes: sizes.to_vec(),
482        }
483    }
484    fn get_handle(&self) -> &OpHandle {
485        &self.handle
486    }
487    fn get_handle_mut(&mut self) -> &mut OpHandle {
488        &mut self.handle
489    }
490}
491impl OpCall for Take {
492    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
493        let new_one = Take {
494            handle: OpHandle::new(),
495            sizes: self.sizes.clone(),
496        };
497
498        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
499
500        inputs[0].called_with(op, &inputs[1..inputs.len()])
501    }
502}
503impl OpTrait for Take {
504
505    fn get_name(&self) -> &'static str {
506        "Take"
507    }
508    fn get_input_size(&self) -> usize {
509        1
510    }
511    fn get_output_size(&self) -> usize {
512        1
513    }
514    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
515        output[0].swap(&input[0].take(&self.sizes))
516    }
517    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
518        unimplemented!();
519    }
520    fn get_values(&self) -> Vec<Tensor> {
521        Vec::new()
522    }
523    fn get_grads(&self) -> Vec<Tensor> {
524        Vec::new()
525    }
526    fn set_values(&self, _v: &[Tensor]) {
527    }
528    #[cfg(feature = "use-serde")]
529    fn as_any(&self) -> &dyn Any {
530	self
531    }
532}
533
534// permute
535#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
536pub struct Permute {
537    #[cfg_attr(feature = "use-serde", serde(skip))]
538    handle: OpHandle,
539    sizes: Vec<usize>,
540}
541impl Permute {
542    pub fn new(sizes: &[usize]) -> Permute {
543        Permute {
544            handle: OpHandle::new(),
545            sizes: sizes.to_vec(),
546        }
547    }
548    fn get_handle(&self) -> &OpHandle {
549        &self.handle
550    }
551    fn get_handle_mut(&mut self) -> &mut OpHandle {
552        &mut self.handle
553    }
554}
555impl OpCall for Permute {
556    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
557        let new_one = Permute {
558            handle: OpHandle::new(),
559            sizes: self.sizes.clone(),
560        };
561
562        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
563
564        inputs[0].called_with(op, &inputs[1..inputs.len()])
565    }
566}
567impl OpTrait for Permute {
568
569    fn get_name(&self) -> &'static str {
570        "Permute"
571    }
572    fn get_input_size(&self) -> usize {
573        1
574    }
575    fn get_output_size(&self) -> usize {
576        1
577    }
578    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
579        output[0].swap(&input[0].permute(&self.sizes))
580    }
581    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
582        unimplemented!();
583    }
584    fn get_values(&self) -> Vec<Tensor> {
585        Vec::new()
586    }
587    fn get_grads(&self) -> Vec<Tensor> {
588        Vec::new()
589    }
590    fn set_values(&self, _v: &[Tensor]) {
591    }
592    #[cfg(feature = "use-serde")]
593    fn as_any(&self) -> &dyn Any {
594	self
595    }
596}
597
598
599// unsqueeze
600one_to_1_op_with_paras!(Unsqueeze,
601                        "Unsqueeze",
602                        1, 1,
603                        unsqueeze,
604                        (|input: &[Tensor],
605                         output_grad: &[Tensor],
606                         input_grad: &[Tensor]| {
607                             unimplemented!();
608                         }),
609                        dim: usize);
610
611// conditional_select
612#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
613pub struct ConditionalSelect {
614    #[cfg_attr(feature = "use-serde", serde(skip))]
615    handle: OpHandle,
616}
617impl ConditionalSelect {
618    pub fn new() -> ConditionalSelect {
619        ConditionalSelect {
620            handle: OpHandle::new(),
621        }
622    }
623    fn get_handle(&self) -> &OpHandle {
624        &self.handle
625    }
626    fn get_handle_mut(&mut self) -> &mut OpHandle {
627        &mut self.handle
628    }
629}
630impl OpCall for ConditionalSelect {
631    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
632        let new_one = ConditionalSelect {
633            handle: OpHandle::new(),
634        };
635
636        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
637
638        inputs[0].called_with(op, &inputs[1..inputs.len()])
639    }
640}
641impl OpTrait for ConditionalSelect {
642
643    fn get_name(&self) -> &'static str {
644        "ConditionalSelect"
645    }
646    fn get_input_size(&self) -> usize {
647        3
648    }
649    fn get_output_size(&self) -> usize {
650        1
651    }
652    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
653        output[0].swap(&input[0].conditional_select(&input[0], &input[1]));
654    }
655    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
656        unimplemented!();
657    }
658    fn get_values(&self) -> Vec<Tensor> {
659        Vec::new()
660    }
661    fn get_grads(&self) -> Vec<Tensor> {
662        Vec::new()
663    }
664    fn set_values(&self, _v: &[Tensor]) {
665    }
666    #[cfg(feature = "use-serde")]
667    fn as_any(&self) -> &dyn Any {
668	self
669    }
670}
671impl Default for ConditionalSelect {
672    fn default() -> Self {
673        Self::new()
674    }
675}
676
677
678// repeat
679#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
680pub struct Repeat {
681    #[cfg_attr(feature = "use-serde", serde(skip))]
682    handle: OpHandle,
683    sizes: Vec<usize>,
684}
685impl Repeat {
686    pub fn new(sizes: &[usize]) -> Repeat {
687        Repeat {
688            handle: OpHandle::new(),
689            sizes: sizes.to_vec(),
690        }
691    }
692    fn get_handle(&self) -> &OpHandle {
693        &self.handle
694    }
695    fn get_handle_mut(&mut self) -> &mut OpHandle {
696        &mut self.handle
697    }
698}
699impl OpCall for Repeat {
700    fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
701        let new_one = Repeat {
702            handle: OpHandle::new(),
703            sizes: self.sizes.clone(),
704        };
705
706        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
707
708        inputs[0].called_with(op, &inputs[1..inputs.len()])
709    }
710}
711impl OpTrait for Repeat {
712
713    fn get_name(&self) -> &'static str {
714        "Repeat"
715    }
716    fn get_input_size(&self) -> usize {
717        1
718    }
719    fn get_output_size(&self) -> usize {
720        1
721    }
722    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
723        output[0].swap(&input[0].repeat(&self.sizes))
724    }
725    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
726        unimplemented!();
727    }
728    fn get_values(&self) -> Vec<Tensor> {
729        Vec::new()
730    }
731    fn get_grads(&self) -> Vec<Tensor> {
732        Vec::new()
733    }
734    fn set_values(&self, _v: &[Tensor]) {
735    }
736    #[cfg(feature = "use-serde")]
737    fn as_any(&self) -> &dyn Any {
738	self
739    }
740}