Skip to main content

pictorus_blocks/core_blocks/
exponent_block.rs

1use crate::traits::Scalar;
2use pictorus_block_data::{BlockData as OldBlockData, FromPass};
3use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
4
5/// Raises the input to a specified power (coefficient),
6/// and optionally preserves the sign of the input.
7///
8/// It can accept a scalar or a matrix input. If the input is a matrix,
9/// the exponentiation is performed element-wise.
10///
11/// The power to raise the input to as well as a flag to optionally preserve the sign
12/// of the input when performing the exponentiation can be set in the parameters.
13///
14/// # Panics
15/// If the input is negative and the coefficient is < 1.0 and preserve_sign is false,
16/// a panic will occur.
17#[derive(Debug)]
18pub struct ExponentBlock<T: Pass + Default> {
19    pub data: OldBlockData,
20    output: Option<T>,
21}
22
23impl<T: Pass + Default> Default for ExponentBlock<T>
24where
25    OldBlockData: FromPass<T>,
26{
27    fn default() -> Self {
28        Self {
29            data: <OldBlockData as FromPass<T>>::from_pass(T::default().as_by()),
30            output: None,
31        }
32    }
33}
34
35impl<S: Scalar + num_traits::Float + num_traits::Zero> ProcessBlock for ExponentBlock<S>
36where
37    OldBlockData: FromPass<S>,
38{
39    type Inputs = S;
40    type Output = S;
41    type Parameters = Parameters<S>;
42
43    fn process<'b>(
44        &'b mut self,
45        parameters: &Self::Parameters,
46        _context: &dyn pictorus_traits::Context,
47        inputs: PassBy<'_, Self::Inputs>,
48    ) -> PassBy<'b, Self::Output> {
49        let mut inputs_local = inputs;
50        if (inputs < S::zero()) && (parameters.coefficient < S::one()) {
51            if !parameters.preserve_sign {
52                panic!("Negative input to Exponent with coefficient < 1.0!");
53            } else {
54                inputs_local = inputs_local.abs();
55            }
56        }
57        let output = self
58            .output
59            .insert(inputs_local.powf(parameters.coefficient));
60        if parameters.preserve_sign {
61            let should_flip_sign = (*output < S::zero()) != (inputs < S::zero());
62            if should_flip_sign {
63                *output = output.neg();
64            };
65        }
66        self.data = OldBlockData::from_pass(*output);
67        *output
68    }
69}
70
71impl<S: Scalar + num_traits::Float + num_traits::Zero, const NROWS: usize, const NCOLS: usize>
72    ProcessBlock for ExponentBlock<Matrix<NROWS, NCOLS, S>>
73where
74    OldBlockData: FromPass<Matrix<NROWS, NCOLS, S>>,
75{
76    type Inputs = Matrix<NROWS, NCOLS, S>;
77    type Output = Matrix<NROWS, NCOLS, S>;
78    type Parameters = Parameters<S>;
79
80    fn process<'b>(
81        &'b mut self,
82        parameters: &Self::Parameters,
83        _context: &dyn pictorus_traits::Context,
84        inputs: PassBy<'_, Self::Inputs>,
85    ) -> PassBy<'b, Self::Output> {
86        let output = self.output.insert(*inputs);
87        output.data.as_flattened_mut().iter_mut().for_each(|x| {
88            let mut x_local = *x;
89            if (x_local < S::zero()) && (parameters.coefficient < S::one()) {
90                if !parameters.preserve_sign {
91                    panic!("Negative input to Exponent with coefficient < 1.0!");
92                } else {
93                    x_local = x_local.abs();
94                }
95            }
96            x_local = x_local.powf(parameters.coefficient);
97            if parameters.preserve_sign {
98                let should_flip_sign = (x_local < S::zero()) != (*x < S::zero());
99                if should_flip_sign {
100                    x_local = x_local.neg();
101                };
102            }
103            *x = x_local;
104        });
105        self.data = OldBlockData::from_pass(output);
106        output
107    }
108}
109
110/// Parameters for the ExponentBlock
111#[derive(Debug, Clone, Copy)]
112pub struct Parameters<T: Scalar + num_traits::Float> {
113    /// The coefficient to raise the input to
114    /// has the effect of being a root if < 1.0
115    coefficient: T,
116    /// Whether to preserve the sign of the input
117    /// when performing the exponentiation.
118    /// If the [`coefficient`] is < 1.0 and the input is negative,
119    /// this will cause a panic if set to false.
120    preserve_sign: bool,
121}
122
123impl<T: Scalar + num_traits::Float> Parameters<T> {
124    pub fn new<S: Scalar>(coefficient: T, preserve_sign: S) -> Self {
125        Self {
126            coefficient,
127            preserve_sign: preserve_sign.is_truthy(),
128        }
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::testing::StubContext;
136
137    #[test]
138    fn test_exponent_block_scalar() {
139        let context = StubContext::default();
140        let mut block = ExponentBlock::<f64>::default();
141
142        // Preserve sign is false
143        let parameters = Parameters::new(2.0, false);
144        let input = 2.0;
145        let output = block.process(&parameters, &context, input.as_by());
146        assert_eq!(output, 4.0);
147        let input = -2.0;
148        let output = block.process(&parameters, &context, input.as_by());
149        assert_eq!(output, 4.0);
150
151        // Preserve sign is true
152        let parameters = Parameters::new(4.0, true);
153        let input = 11.0;
154        let output = block.process(&parameters, &context, input.as_by());
155        assert_eq!(output, 14641.0);
156        let input = -11.0;
157        let output = block.process(&parameters, &context, input.as_by());
158        assert_eq!(output, -14641.0);
159
160        // Now try a Root
161        let parameters = Parameters::new(0.5, false);
162        let input = 4.0;
163        let output = block.process(&parameters, &context, input.as_by());
164        assert_eq!(output, 2.0);
165
166        // Now try a Root with preserve sign
167        let parameters = Parameters::new(0.5, true);
168        let output = block.process(&parameters, &context, input.as_by());
169        assert_eq!(output, 2.0);
170        let input = -4.0;
171        let output = block.process(&parameters, &context, input.as_by());
172        assert_eq!(output, -2.0);
173    }
174
175    #[test]
176    #[should_panic]
177    fn test_root_negative_input_no_preserve_sign_panic() {
178        let context = StubContext::default();
179        let mut block = ExponentBlock::<f64>::default();
180        let parameters = Parameters::new(0.5, false);
181        let input = -4.0;
182        block.process(&parameters, &context, input.as_by());
183    }
184
185    #[test]
186    fn test_exponent_block_matrix() {
187        let context = StubContext::default();
188        let mut block = ExponentBlock::<Matrix<2, 2, f32>>::default();
189
190        // Preserve sign is false
191        let parameters = Parameters::new(2.0, false);
192        let input = Matrix {
193            data: [[1.0, -2.0], [3.0, -4.0]],
194        };
195        let output = block.process(&parameters, &context, &input);
196        assert_eq!(output.data, [[1.0, 4.0], [9.0, 16.0]]);
197
198        // Preserve sign is true
199        let parameters = Parameters::new(4.0, true);
200        let output = block.process(&parameters, &context, &input);
201        assert_eq!(output.data, [[1.0, -16.0], [81.0, -256.0]]);
202
203        // Now try a Root
204        let parameters = Parameters::new(0.5, false);
205        let input = Matrix {
206            data: [[1.0, 4.0], [9.0, 16.0]],
207        };
208        let output = block.process(&parameters, &context, &input);
209        assert_eq!(output.data, [[1.0, 2.0], [3.0, 4.0]]);
210
211        // Now try a Root with preserve sign
212        let parameters = Parameters::new(0.5, true);
213        let output = block.process(&parameters, &context, &input);
214        assert_eq!(output.data, [[1.0, 2.0], [3.0, 4.0]]);
215
216        let input = Matrix {
217            data: [[1.0, -4.0], [9.0, -16.0]],
218        };
219        let output = block.process(&parameters, &context, &input);
220        assert_eq!(output.data, [[1.0, -2.0], [3.0, -4.0]]);
221    }
222
223    #[test]
224    #[should_panic]
225    fn test_root_matrix_negative_input_no_preserve_sign_panic() {
226        let context = StubContext::default();
227        let mut block = ExponentBlock::<Matrix<2, 2, f32>>::default();
228        let parameters = Parameters::new(0.5, false);
229        let input = Matrix {
230            data: [[1.0, -4.0], [9.0, -16.0]],
231        };
232        block.process(&parameters, &context, &input);
233    }
234}