auto_diff/op/
comparison.rs

1#![allow(clippy::redundant_closure_call)]
2use tensor_rs::tensor::Tensor;
3use super::{OpTrait, OpHandle, OpCall, Op};
4use super::macros::new_binary_op;
5
6use std::cell::{RefCell};
7use std::rc::Rc;
8
9use crate::var::{Var};
10use crate::err::AutoDiffError;
11
12#[cfg(feature = "use-serde")]
13use serde::{Serialize, Deserialize};
14#[cfg(feature = "use-serde")]
15use std::any::Any;
16
17// max_pair
18new_binary_op!(MaxPair, "Max_pair",
19               (|a:&[Tensor], b:&[Tensor]|
20                b[0].swap(&a[0].max_pair(&a[1]))
21               ),
22               (|input: &[Tensor], output_grad: &[Tensor],
23                input_grad: &[Tensor]| {
24                    unimplemented!();
25               }));
26// max, in reduction
27// min_pair
28new_binary_op!(MinPair, "Min_pair",
29               (|a:&[Tensor], b:&[Tensor]|
30                b[0].swap(&a[0].min_pair(&a[1]))
31               ),
32               (|input: &[Tensor], output_grad: &[Tensor],
33                input_grad: &[Tensor]| {
34                    unimplemented!();
35               }));
36// min, in reduction
37// arg_sort
38#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
39pub struct ArgSort {
40    #[cfg_attr(feature = "use-serde", serde(skip))]
41    handle: OpHandle,
42    dim: usize,
43    descending: bool,
44}
45impl ArgSort {
46    pub fn new(dim: usize, descending: bool) -> ArgSort {
47        ArgSort {
48            handle: OpHandle::new(),
49            dim,
50            descending,
51        }
52    }
53    fn get_handle(&self) -> &OpHandle {
54        &self.handle
55    }
56    fn get_handle_mut(&mut self) -> &mut OpHandle {
57        &mut self.handle
58    }
59}
60impl OpCall for ArgSort {
61    fn call(&mut self, inputs: &[&Var])
62            -> Result<Vec<Var>, AutoDiffError> {
63        let new_one = ArgSort {
64            handle: OpHandle::new(),
65            dim: self.dim,
66            descending: self.descending,
67        };
68
69        let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
70
71        inputs[0].called_with(op, &inputs[1..inputs.len()])
72    }
73}
74impl OpTrait for ArgSort {
75
76    fn get_name(&self) -> &'static str {
77        "Arg_sort"
78    }
79    fn get_input_size(&self) -> usize {
80        1
81    }
82    fn get_output_size(&self) -> usize {
83        1
84    }
85    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
86        output[0].swap(&input[0].arg_sort(self.dim, self.descending))
87    }
88    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
89        unimplemented!();
90    }
91    fn get_values(&self) -> Vec<Tensor> {
92        Vec::new()
93    }
94    fn get_grads(&self) -> Vec<Tensor> {
95        Vec::new()
96    }
97    fn set_values(&self, _v: &[Tensor]) {
98    }
99    #[cfg(feature = "use-serde")]
100    fn as_any(&self) -> &dyn Any {
101	self
102    }
103}
104// eq_t (use eq_elem)
105new_binary_op!(EqElem, "Eq_t",
106               (|a:&[Tensor], b:&[Tensor]|
107                b[0].swap(&a[0].eq_t(&a[1]))
108               ),
109               (|input: &[Tensor], output_grad: &[Tensor],
110                input_grad: &[Tensor]| {
111                    unimplemented!();
112               }));
113// equal, 0 is == 1 is !=
114new_binary_op!(Equal, "Equal",
115               (|a:&[Tensor], b:&[Tensor]|
116                if a[0].equal(&a[1]) {
117                    b[0].swap(&Tensor::zeros(&[1]))
118                } else {
119                    b[0].swap(&Tensor::ones(&[1]))
120                }),
121               (|input: &[Tensor], output_grad: &[Tensor],
122                input_grad: &[Tensor]| {
123                    unimplemented!();
124               }));
125// ge
126new_binary_op!(Ge, "Ge",
127               (|a:&[Tensor], b:&[Tensor]|
128                b[0].swap(&a[0].ge(&a[1]))
129               ),
130               (|input: &[Tensor], output_grad: &[Tensor],
131                input_grad: &[Tensor]| {
132                    unimplemented!();
133               }));
134// gt
135new_binary_op!(Gt, "Gt",
136               (|a:&[Tensor], b:&[Tensor]|
137                b[0].swap(&a[0].gt(&a[1]))
138               ),
139               (|input: &[Tensor], output_grad: &[Tensor],
140                input_grad: &[Tensor]| {
141                    unimplemented!();
142               }));
143// le
144new_binary_op!(Le, "Le",
145               (|a:&[Tensor], b:&[Tensor]|
146                b[0].swap(&a[0].le(&a[1]))
147               ),
148               (|input: &[Tensor], output_grad: &[Tensor],
149                input_grad: &[Tensor]| {
150                    unimplemented!();
151               }));
152// lt
153new_binary_op!(Lt, "Lt",
154               (|a:&[Tensor], b:&[Tensor]|
155                b[0].swap(&a[0].lt(&a[1]))
156               ),
157               (|input: &[Tensor], output_grad: &[Tensor],
158                input_grad: &[Tensor]| {
159                    unimplemented!();
160               }));
161// ne
162new_binary_op!(Ne, "Ne",
163               (|a:&[Tensor], b:&[Tensor]|
164                b[0].swap(&a[0].ne(&a[1]))
165               ),
166               (|input: &[Tensor], output_grad: &[Tensor],
167                input_grad: &[Tensor]| {
168                    unimplemented!();
169               }));