pictorus_blocks/core_blocks/
exponent_block.rs1use crate::traits::Scalar;
2use pictorus_block_data::{BlockData as OldBlockData, FromPass};
3use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
4
5#[derive(Debug)]
18pub struct ExponentBlock<T: Pass + Default> {
19 pub data: OldBlockData,
20 output: Option<T>,
21}
22
23impl<T: Pass + Default> Default for ExponentBlock<T>
24where
25 OldBlockData: FromPass<T>,
26{
27 fn default() -> Self {
28 Self {
29 data: <OldBlockData as FromPass<T>>::from_pass(T::default().as_by()),
30 output: None,
31 }
32 }
33}
34
35impl<S: Scalar + num_traits::Float + num_traits::Zero> ProcessBlock for ExponentBlock<S>
36where
37 OldBlockData: FromPass<S>,
38{
39 type Inputs = S;
40 type Output = S;
41 type Parameters = Parameters<S>;
42
43 fn process<'b>(
44 &'b mut self,
45 parameters: &Self::Parameters,
46 _context: &dyn pictorus_traits::Context,
47 inputs: PassBy<'_, Self::Inputs>,
48 ) -> PassBy<'b, Self::Output> {
49 let mut inputs_local = inputs;
50 if (inputs < S::zero()) && (parameters.coefficient < S::one()) {
51 if !parameters.preserve_sign {
52 panic!("Negative input to Exponent with coefficient < 1.0!");
53 } else {
54 inputs_local = inputs_local.abs();
55 }
56 }
57 let output = self
58 .output
59 .insert(inputs_local.powf(parameters.coefficient));
60 if parameters.preserve_sign {
61 let should_flip_sign = (*output < S::zero()) != (inputs < S::zero());
62 if should_flip_sign {
63 *output = output.neg();
64 };
65 }
66 self.data = OldBlockData::from_pass(*output);
67 *output
68 }
69}
70
71impl<S: Scalar + num_traits::Float + num_traits::Zero, const NROWS: usize, const NCOLS: usize>
72 ProcessBlock for ExponentBlock<Matrix<NROWS, NCOLS, S>>
73where
74 OldBlockData: FromPass<Matrix<NROWS, NCOLS, S>>,
75{
76 type Inputs = Matrix<NROWS, NCOLS, S>;
77 type Output = Matrix<NROWS, NCOLS, S>;
78 type Parameters = Parameters<S>;
79
80 fn process<'b>(
81 &'b mut self,
82 parameters: &Self::Parameters,
83 _context: &dyn pictorus_traits::Context,
84 inputs: PassBy<'_, Self::Inputs>,
85 ) -> PassBy<'b, Self::Output> {
86 let output = self.output.insert(*inputs);
87 output.data.as_flattened_mut().iter_mut().for_each(|x| {
88 let mut x_local = *x;
89 if (x_local < S::zero()) && (parameters.coefficient < S::one()) {
90 if !parameters.preserve_sign {
91 panic!("Negative input to Exponent with coefficient < 1.0!");
92 } else {
93 x_local = x_local.abs();
94 }
95 }
96 x_local = x_local.powf(parameters.coefficient);
97 if parameters.preserve_sign {
98 let should_flip_sign = (x_local < S::zero()) != (*x < S::zero());
99 if should_flip_sign {
100 x_local = x_local.neg();
101 };
102 }
103 *x = x_local;
104 });
105 self.data = OldBlockData::from_pass(output);
106 output
107 }
108}
109
110#[derive(Debug, Clone, Copy)]
112pub struct Parameters<T: Scalar + num_traits::Float> {
113 coefficient: T,
116 preserve_sign: bool,
121}
122
123impl<T: Scalar + num_traits::Float> Parameters<T> {
124 pub fn new<S: Scalar>(coefficient: T, preserve_sign: S) -> Self {
125 Self {
126 coefficient,
127 preserve_sign: preserve_sign.is_truthy(),
128 }
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::testing::StubContext;
136
137 #[test]
138 fn test_exponent_block_scalar() {
139 let context = StubContext::default();
140 let mut block = ExponentBlock::<f64>::default();
141
142 let parameters = Parameters::new(2.0, false);
144 let input = 2.0;
145 let output = block.process(¶meters, &context, input.as_by());
146 assert_eq!(output, 4.0);
147 let input = -2.0;
148 let output = block.process(¶meters, &context, input.as_by());
149 assert_eq!(output, 4.0);
150
151 let parameters = Parameters::new(4.0, true);
153 let input = 11.0;
154 let output = block.process(¶meters, &context, input.as_by());
155 assert_eq!(output, 14641.0);
156 let input = -11.0;
157 let output = block.process(¶meters, &context, input.as_by());
158 assert_eq!(output, -14641.0);
159
160 let parameters = Parameters::new(0.5, false);
162 let input = 4.0;
163 let output = block.process(¶meters, &context, input.as_by());
164 assert_eq!(output, 2.0);
165
166 let parameters = Parameters::new(0.5, true);
168 let output = block.process(¶meters, &context, input.as_by());
169 assert_eq!(output, 2.0);
170 let input = -4.0;
171 let output = block.process(¶meters, &context, input.as_by());
172 assert_eq!(output, -2.0);
173 }
174
175 #[test]
176 #[should_panic]
177 fn test_root_negative_input_no_preserve_sign_panic() {
178 let context = StubContext::default();
179 let mut block = ExponentBlock::<f64>::default();
180 let parameters = Parameters::new(0.5, false);
181 let input = -4.0;
182 block.process(¶meters, &context, input.as_by());
183 }
184
185 #[test]
186 fn test_exponent_block_matrix() {
187 let context = StubContext::default();
188 let mut block = ExponentBlock::<Matrix<2, 2, f32>>::default();
189
190 let parameters = Parameters::new(2.0, false);
192 let input = Matrix {
193 data: [[1.0, -2.0], [3.0, -4.0]],
194 };
195 let output = block.process(¶meters, &context, &input);
196 assert_eq!(output.data, [[1.0, 4.0], [9.0, 16.0]]);
197
198 let parameters = Parameters::new(4.0, true);
200 let output = block.process(¶meters, &context, &input);
201 assert_eq!(output.data, [[1.0, -16.0], [81.0, -256.0]]);
202
203 let parameters = Parameters::new(0.5, false);
205 let input = Matrix {
206 data: [[1.0, 4.0], [9.0, 16.0]],
207 };
208 let output = block.process(¶meters, &context, &input);
209 assert_eq!(output.data, [[1.0, 2.0], [3.0, 4.0]]);
210
211 let parameters = Parameters::new(0.5, true);
213 let output = block.process(¶meters, &context, &input);
214 assert_eq!(output.data, [[1.0, 2.0], [3.0, 4.0]]);
215
216 let input = Matrix {
217 data: [[1.0, -4.0], [9.0, -16.0]],
218 };
219 let output = block.process(¶meters, &context, &input);
220 assert_eq!(output.data, [[1.0, -2.0], [3.0, -4.0]]);
221 }
222
223 #[test]
224 #[should_panic]
225 fn test_root_matrix_negative_input_no_preserve_sign_panic() {
226 let context = StubContext::default();
227 let mut block = ExponentBlock::<Matrix<2, 2, f32>>::default();
228 let parameters = Parameters::new(0.5, false);
229 let input = Matrix {
230 data: [[1.0, -4.0], [9.0, -16.0]],
231 };
232 block.process(¶meters, &context, &input);
233 }
234}