auto_diff/op/
linalg.rs

1use tensor_rs::tensor::Tensor;
2use super::{OpTrait, OpHandle};
3
4#[cfg(feature = "use-serde")]
5use serde::{Serialize, Deserialize};
6#[cfg(feature = "use-serde")]
7use std::any::Any;
8
9#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
10pub struct NormalizeUnit {
11    #[cfg_attr(feature = "use-serde", serde(skip))]
12    handle: OpHandle,
13}
14impl NormalizeUnit {
15    pub fn new() -> NormalizeUnit {
16        NormalizeUnit {
17            handle: OpHandle::new(),
18        }
19    }
20    fn get_handle(&self) -> &OpHandle {
21        &self.handle
22    }
23    fn get_handle_mut(&mut self) -> &mut OpHandle {
24        &mut self.handle
25    }
26}
27impl OpTrait for NormalizeUnit {
28     
29    fn get_name(&self) -> &'static str {
30        "NormalizeUnit"
31    }
32    fn get_input_size(&self) -> usize {
33        1
34    }
35    fn get_output_size(&self) -> usize {
36        1
37    }
38    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
39        output[0].swap(&input[0].normalize_unit());
40    }
41    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
42        unimplemented!();
43    }
44    fn get_values(&self) -> Vec<Tensor> {
45        Vec::new()
46    }
47    fn get_grads(&self) -> Vec<Tensor> {
48        Vec::new()
49    }
50    fn set_values(&self, _v: &[Tensor]) {
51    }
52    #[cfg(feature = "use-serde")]
53    fn as_any(&self) -> &dyn Any {
54	self
55    }
56}
57impl Default for NormalizeUnit {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63
64#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
65pub struct Det {
66    #[cfg_attr(feature = "use-serde", serde(skip))]
67    handle: OpHandle,
68}
69impl Det {
70    pub fn new() -> Det {
71        Det {
72            handle: OpHandle::new(),
73        }
74    }
75    fn get_handle(&self) -> &OpHandle {
76        &self.handle
77    }
78    fn get_handle_mut(&mut self) -> &mut OpHandle {
79        &mut self.handle
80    }
81}
82impl OpTrait for Det {
83     
84    fn get_name(&self) -> &'static str {
85        "Det"
86    }
87    fn get_input_size(&self) -> usize {
88        1
89    }
90    fn get_output_size(&self) -> usize {
91        1
92    }
93    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
94        output[0].swap(&input[0].det().expect("det() does not get a result."));
95    }
96    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
97        unimplemented!();
98    }
99    fn get_values(&self) -> Vec<Tensor> {
100        Vec::new()
101    }
102    fn get_grads(&self) -> Vec<Tensor> {
103        Vec::new()
104    }
105    fn set_values(&self, _v: &[Tensor]) {
106    }
107    #[cfg(feature = "use-serde")]
108    fn as_any(&self) -> &dyn Any {
109	self
110    }
111}
112impl Default for Det {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
119pub struct Inv {
120    #[cfg_attr(feature = "use-serde", serde(skip))]
121    handle: OpHandle,
122}
123impl Inv {
124    pub fn new() -> Inv {
125        Inv {
126            handle: OpHandle::new(),
127        }
128    }
129    fn get_handle(&self) -> &OpHandle {
130        &self.handle
131    }
132    fn get_handle_mut(&mut self) -> &mut OpHandle {
133        &mut self.handle
134    }
135}
136impl OpTrait for Inv {
137     
138    fn get_name(&self) -> &'static str {
139        "Inv"
140    }
141    fn get_input_size(&self) -> usize {
142        1
143    }
144    fn get_output_size(&self) -> usize {
145        1
146    }
147    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
148        output[0].swap(&input[0].inv().expect("inv() does not get a result."));
149    }
150    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
151        unimplemented!();
152    }
153    fn get_values(&self) -> Vec<Tensor> {
154        Vec::new()
155    }
156    fn get_grads(&self) -> Vec<Tensor> {
157        Vec::new()
158    }
159    fn set_values(&self, _v: &[Tensor]) {
160    }
161    #[cfg(feature = "use-serde")]
162    fn as_any(&self) -> &dyn Any {
163	self
164    }
165}
166impl Default for Inv {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
173pub struct Tr {
174    #[cfg_attr(feature = "use-serde", serde(skip))]
175    handle: OpHandle,
176}
177impl Tr {
178    pub fn new() -> Tr {
179        Tr {
180            handle: OpHandle::new(),
181        }
182    }
183    fn get_handle(&self) -> &OpHandle {
184        &self.handle
185    }
186    fn get_handle_mut(&mut self) -> &mut OpHandle {
187        &mut self.handle
188    }
189}
190impl OpTrait for Tr {
191     
192    fn get_name(&self) -> &'static str {
193        "Tr"
194    }
195    fn get_input_size(&self) -> usize {
196        1
197    }
198    fn get_output_size(&self) -> usize {
199        1
200    }
201    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
202        output[0].swap(&input[0].tr());
203    }
204    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
205        unimplemented!();
206    }
207    fn get_values(&self) -> Vec<Tensor> {
208        Vec::new()
209    }
210    fn get_grads(&self) -> Vec<Tensor> {
211        Vec::new()
212    }
213    fn set_values(&self, _v: &[Tensor]) {
214    }
215    #[cfg(feature = "use-serde")]
216    fn as_any(&self) -> &dyn Any {
217	self
218    }
219}
220impl Default for Tr {
221    fn default() -> Self {
222        Self::new()
223    }
224}