auto_diff/op/
comparison.rs1#![allow(clippy::redundant_closure_call)]
2use tensor_rs::tensor::Tensor;
3use super::{OpTrait, OpHandle, OpCall, Op};
4use super::macros::new_binary_op;
5
6use std::cell::{RefCell};
7use std::rc::Rc;
8
9use crate::var::{Var};
10use crate::err::AutoDiffError;
11
12#[cfg(feature = "use-serde")]
13use serde::{Serialize, Deserialize};
14#[cfg(feature = "use-serde")]
15use std::any::Any;
16
17new_binary_op!(MaxPair, "Max_pair",
19 (|a:&[Tensor], b:&[Tensor]|
20 b[0].swap(&a[0].max_pair(&a[1]))
21 ),
22 (|input: &[Tensor], output_grad: &[Tensor],
23 input_grad: &[Tensor]| {
24 unimplemented!();
25 }));
26new_binary_op!(MinPair, "Min_pair",
29 (|a:&[Tensor], b:&[Tensor]|
30 b[0].swap(&a[0].min_pair(&a[1]))
31 ),
32 (|input: &[Tensor], output_grad: &[Tensor],
33 input_grad: &[Tensor]| {
34 unimplemented!();
35 }));
36#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
39pub struct ArgSort {
40 #[cfg_attr(feature = "use-serde", serde(skip))]
41 handle: OpHandle,
42 dim: usize,
43 descending: bool,
44}
45impl ArgSort {
46 pub fn new(dim: usize, descending: bool) -> ArgSort {
47 ArgSort {
48 handle: OpHandle::new(),
49 dim,
50 descending,
51 }
52 }
53 fn get_handle(&self) -> &OpHandle {
54 &self.handle
55 }
56 fn get_handle_mut(&mut self) -> &mut OpHandle {
57 &mut self.handle
58 }
59}
60impl OpCall for ArgSort {
61 fn call(&mut self, inputs: &[&Var])
62 -> Result<Vec<Var>, AutoDiffError> {
63 let new_one = ArgSort {
64 handle: OpHandle::new(),
65 dim: self.dim,
66 descending: self.descending,
67 };
68
69 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
70
71 inputs[0].called_with(op, &inputs[1..inputs.len()])
72 }
73}
74impl OpTrait for ArgSort {
75
76 fn get_name(&self) -> &'static str {
77 "Arg_sort"
78 }
79 fn get_input_size(&self) -> usize {
80 1
81 }
82 fn get_output_size(&self) -> usize {
83 1
84 }
85 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
86 output[0].swap(&input[0].arg_sort(self.dim, self.descending))
87 }
88 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
89 unimplemented!();
90 }
91 fn get_values(&self) -> Vec<Tensor> {
92 Vec::new()
93 }
94 fn get_grads(&self) -> Vec<Tensor> {
95 Vec::new()
96 }
97 fn set_values(&self, _v: &[Tensor]) {
98 }
99 #[cfg(feature = "use-serde")]
100 fn as_any(&self) -> &dyn Any {
101 self
102 }
103}
104new_binary_op!(EqElem, "Eq_t",
106 (|a:&[Tensor], b:&[Tensor]|
107 b[0].swap(&a[0].eq_t(&a[1]))
108 ),
109 (|input: &[Tensor], output_grad: &[Tensor],
110 input_grad: &[Tensor]| {
111 unimplemented!();
112 }));
113new_binary_op!(Equal, "Equal",
115 (|a:&[Tensor], b:&[Tensor]|
116 if a[0].equal(&a[1]) {
117 b[0].swap(&Tensor::zeros(&[1]))
118 } else {
119 b[0].swap(&Tensor::ones(&[1]))
120 }),
121 (|input: &[Tensor], output_grad: &[Tensor],
122 input_grad: &[Tensor]| {
123 unimplemented!();
124 }));
125new_binary_op!(Ge, "Ge",
127 (|a:&[Tensor], b:&[Tensor]|
128 b[0].swap(&a[0].ge(&a[1]))
129 ),
130 (|input: &[Tensor], output_grad: &[Tensor],
131 input_grad: &[Tensor]| {
132 unimplemented!();
133 }));
134new_binary_op!(Gt, "Gt",
136 (|a:&[Tensor], b:&[Tensor]|
137 b[0].swap(&a[0].gt(&a[1]))
138 ),
139 (|input: &[Tensor], output_grad: &[Tensor],
140 input_grad: &[Tensor]| {
141 unimplemented!();
142 }));
143new_binary_op!(Le, "Le",
145 (|a:&[Tensor], b:&[Tensor]|
146 b[0].swap(&a[0].le(&a[1]))
147 ),
148 (|input: &[Tensor], output_grad: &[Tensor],
149 input_grad: &[Tensor]| {
150 unimplemented!();
151 }));
152new_binary_op!(Lt, "Lt",
154 (|a:&[Tensor], b:&[Tensor]|
155 b[0].swap(&a[0].lt(&a[1]))
156 ),
157 (|input: &[Tensor], output_grad: &[Tensor],
158 input_grad: &[Tensor]| {
159 unimplemented!();
160 }));
161new_binary_op!(Ne, "Ne",
163 (|a:&[Tensor], b:&[Tensor]|
164 b[0].swap(&a[0].ne(&a[1]))
165 ),
166 (|input: &[Tensor], output_grad: &[Tensor],
167 input_grad: &[Tensor]| {
168 unimplemented!();
169 }));