1use tensor_rs::tensor::Tensor;
2use super::{OpTrait, OpHandle};
3
4#[cfg(feature = "use-serde")]
5use serde::{Serialize, Deserialize};
6#[cfg(feature = "use-serde")]
7use std::any::Any;
8
9#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
10pub struct NormalizeUnit {
11 #[cfg_attr(feature = "use-serde", serde(skip))]
12 handle: OpHandle,
13}
14impl NormalizeUnit {
15 pub fn new() -> NormalizeUnit {
16 NormalizeUnit {
17 handle: OpHandle::new(),
18 }
19 }
20 fn get_handle(&self) -> &OpHandle {
21 &self.handle
22 }
23 fn get_handle_mut(&mut self) -> &mut OpHandle {
24 &mut self.handle
25 }
26}
27impl OpTrait for NormalizeUnit {
28
29 fn get_name(&self) -> &'static str {
30 "NormalizeUnit"
31 }
32 fn get_input_size(&self) -> usize {
33 1
34 }
35 fn get_output_size(&self) -> usize {
36 1
37 }
38 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
39 output[0].swap(&input[0].normalize_unit());
40 }
41 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
42 unimplemented!();
43 }
44 fn get_values(&self) -> Vec<Tensor> {
45 Vec::new()
46 }
47 fn get_grads(&self) -> Vec<Tensor> {
48 Vec::new()
49 }
50 fn set_values(&self, _v: &[Tensor]) {
51 }
52 #[cfg(feature = "use-serde")]
53 fn as_any(&self) -> &dyn Any {
54 self
55 }
56}
57impl Default for NormalizeUnit {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63
64#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
65pub struct Det {
66 #[cfg_attr(feature = "use-serde", serde(skip))]
67 handle: OpHandle,
68}
69impl Det {
70 pub fn new() -> Det {
71 Det {
72 handle: OpHandle::new(),
73 }
74 }
75 fn get_handle(&self) -> &OpHandle {
76 &self.handle
77 }
78 fn get_handle_mut(&mut self) -> &mut OpHandle {
79 &mut self.handle
80 }
81}
82impl OpTrait for Det {
83
84 fn get_name(&self) -> &'static str {
85 "Det"
86 }
87 fn get_input_size(&self) -> usize {
88 1
89 }
90 fn get_output_size(&self) -> usize {
91 1
92 }
93 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
94 output[0].swap(&input[0].det().expect("det() does not get a result."));
95 }
96 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
97 unimplemented!();
98 }
99 fn get_values(&self) -> Vec<Tensor> {
100 Vec::new()
101 }
102 fn get_grads(&self) -> Vec<Tensor> {
103 Vec::new()
104 }
105 fn set_values(&self, _v: &[Tensor]) {
106 }
107 #[cfg(feature = "use-serde")]
108 fn as_any(&self) -> &dyn Any {
109 self
110 }
111}
112impl Default for Det {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
119pub struct Inv {
120 #[cfg_attr(feature = "use-serde", serde(skip))]
121 handle: OpHandle,
122}
123impl Inv {
124 pub fn new() -> Inv {
125 Inv {
126 handle: OpHandle::new(),
127 }
128 }
129 fn get_handle(&self) -> &OpHandle {
130 &self.handle
131 }
132 fn get_handle_mut(&mut self) -> &mut OpHandle {
133 &mut self.handle
134 }
135}
136impl OpTrait for Inv {
137
138 fn get_name(&self) -> &'static str {
139 "Inv"
140 }
141 fn get_input_size(&self) -> usize {
142 1
143 }
144 fn get_output_size(&self) -> usize {
145 1
146 }
147 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
148 output[0].swap(&input[0].inv().expect("inv() does not get a result."));
149 }
150 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
151 unimplemented!();
152 }
153 fn get_values(&self) -> Vec<Tensor> {
154 Vec::new()
155 }
156 fn get_grads(&self) -> Vec<Tensor> {
157 Vec::new()
158 }
159 fn set_values(&self, _v: &[Tensor]) {
160 }
161 #[cfg(feature = "use-serde")]
162 fn as_any(&self) -> &dyn Any {
163 self
164 }
165}
166impl Default for Inv {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
173pub struct Tr {
174 #[cfg_attr(feature = "use-serde", serde(skip))]
175 handle: OpHandle,
176}
177impl Tr {
178 pub fn new() -> Tr {
179 Tr {
180 handle: OpHandle::new(),
181 }
182 }
183 fn get_handle(&self) -> &OpHandle {
184 &self.handle
185 }
186 fn get_handle_mut(&mut self) -> &mut OpHandle {
187 &mut self.handle
188 }
189}
190impl OpTrait for Tr {
191
192 fn get_name(&self) -> &'static str {
193 "Tr"
194 }
195 fn get_input_size(&self) -> usize {
196 1
197 }
198 fn get_output_size(&self) -> usize {
199 1
200 }
201 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
202 output[0].swap(&input[0].tr());
203 }
204 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
205 unimplemented!();
206 }
207 fn get_values(&self) -> Vec<Tensor> {
208 Vec::new()
209 }
210 fn get_grads(&self) -> Vec<Tensor> {
211 Vec::new()
212 }
213 fn set_values(&self, _v: &[Tensor]) {
214 }
215 #[cfg(feature = "use-serde")]
216 fn as_any(&self) -> &dyn Any {
217 self
218 }
219}
220impl Default for Tr {
221 fn default() -> Self {
222 Self::new()
223 }
224}