use crate::traits::Scalar;
use pictorus_block_data::{BlockData as OldBlockData, FromPass};
use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
#[derive(Debug)]
pub struct ExponentBlock<T: Pass + Default> {
pub data: OldBlockData,
output: Option<T>,
}
impl<T: Pass + Default> Default for ExponentBlock<T>
where
OldBlockData: FromPass<T>,
{
fn default() -> Self {
Self {
data: <OldBlockData as FromPass<T>>::from_pass(T::default().as_by()),
output: None,
}
}
}
impl<S: Scalar + num_traits::Float + num_traits::Zero> ProcessBlock for ExponentBlock<S>
where
OldBlockData: FromPass<S>,
{
type Inputs = S;
type Output = S;
type Parameters = Parameters<S>;
fn process<'b>(
&'b mut self,
parameters: &Self::Parameters,
_context: &dyn pictorus_traits::Context,
inputs: PassBy<'_, Self::Inputs>,
) -> PassBy<'b, Self::Output> {
let mut inputs_local = inputs;
if (inputs < S::zero()) && (parameters.coefficient < S::one()) {
if !parameters.preserve_sign {
panic!("Negative input to Exponent with coefficient < 1.0!");
} else {
inputs_local = inputs_local.abs();
}
}
let output = self
.output
.insert(inputs_local.powf(parameters.coefficient));
if parameters.preserve_sign {
let should_flip_sign = (*output < S::zero()) != (inputs < S::zero());
if should_flip_sign {
*output = output.neg();
};
}
self.data = OldBlockData::from_pass(*output);
*output
}
}
impl<S: Scalar + num_traits::Float + num_traits::Zero, const NROWS: usize, const NCOLS: usize>
ProcessBlock for ExponentBlock<Matrix<NROWS, NCOLS, S>>
where
OldBlockData: FromPass<Matrix<NROWS, NCOLS, S>>,
{
type Inputs = Matrix<NROWS, NCOLS, S>;
type Output = Matrix<NROWS, NCOLS, S>;
type Parameters = Parameters<S>;
fn process<'b>(
&'b mut self,
parameters: &Self::Parameters,
_context: &dyn pictorus_traits::Context,
inputs: PassBy<'_, Self::Inputs>,
) -> PassBy<'b, Self::Output> {
let output = self.output.insert(*inputs);
output.data.as_flattened_mut().iter_mut().for_each(|x| {
let mut x_local = *x;
if (x_local < S::zero()) && (parameters.coefficient < S::one()) {
if !parameters.preserve_sign {
panic!("Negative input to Exponent with coefficient < 1.0!");
} else {
x_local = x_local.abs();
}
}
x_local = x_local.powf(parameters.coefficient);
if parameters.preserve_sign {
let should_flip_sign = (x_local < S::zero()) != (*x < S::zero());
if should_flip_sign {
x_local = x_local.neg();
};
}
*x = x_local;
});
self.data = OldBlockData::from_pass(output);
output
}
}
#[derive(Debug, Clone, Copy)]
pub struct Parameters<T: Scalar + num_traits::Float> {
coefficient: T,
preserve_sign: bool,
}
impl<T: Scalar + num_traits::Float> Parameters<T> {
pub fn new<S: Scalar>(coefficient: T, preserve_sign: S) -> Self {
Self {
coefficient,
preserve_sign: preserve_sign.is_truthy(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::StubContext;
#[test]
fn test_exponent_block_scalar() {
let context = StubContext::default();
let mut block = ExponentBlock::<f64>::default();
let parameters = Parameters::new(2.0, false);
let input = 2.0;
let output = block.process(¶meters, &context, input.as_by());
assert_eq!(output, 4.0);
let input = -2.0;
let output = block.process(¶meters, &context, input.as_by());
assert_eq!(output, 4.0);
let parameters = Parameters::new(4.0, true);
let input = 11.0;
let output = block.process(¶meters, &context, input.as_by());
assert_eq!(output, 14641.0);
let input = -11.0;
let output = block.process(¶meters, &context, input.as_by());
assert_eq!(output, -14641.0);
let parameters = Parameters::new(0.5, false);
let input = 4.0;
let output = block.process(¶meters, &context, input.as_by());
assert_eq!(output, 2.0);
let parameters = Parameters::new(0.5, true);
let output = block.process(¶meters, &context, input.as_by());
assert_eq!(output, 2.0);
let input = -4.0;
let output = block.process(¶meters, &context, input.as_by());
assert_eq!(output, -2.0);
}
#[test]
#[should_panic]
fn test_root_negative_input_no_preserve_sign_panic() {
let context = StubContext::default();
let mut block = ExponentBlock::<f64>::default();
let parameters = Parameters::new(0.5, false);
let input = -4.0;
block.process(¶meters, &context, input.as_by());
}
#[test]
fn test_exponent_block_matrix() {
let context = StubContext::default();
let mut block = ExponentBlock::<Matrix<2, 2, f32>>::default();
let parameters = Parameters::new(2.0, false);
let input = Matrix {
data: [[1.0, -2.0], [3.0, -4.0]],
};
let output = block.process(¶meters, &context, &input);
assert_eq!(output.data, [[1.0, 4.0], [9.0, 16.0]]);
let parameters = Parameters::new(4.0, true);
let output = block.process(¶meters, &context, &input);
assert_eq!(output.data, [[1.0, -16.0], [81.0, -256.0]]);
let parameters = Parameters::new(0.5, false);
let input = Matrix {
data: [[1.0, 4.0], [9.0, 16.0]],
};
let output = block.process(¶meters, &context, &input);
assert_eq!(output.data, [[1.0, 2.0], [3.0, 4.0]]);
let parameters = Parameters::new(0.5, true);
let output = block.process(¶meters, &context, &input);
assert_eq!(output.data, [[1.0, 2.0], [3.0, 4.0]]);
let input = Matrix {
data: [[1.0, -4.0], [9.0, -16.0]],
};
let output = block.process(¶meters, &context, &input);
assert_eq!(output.data, [[1.0, -2.0], [3.0, -4.0]]);
}
#[test]
#[should_panic]
fn test_root_matrix_negative_input_no_preserve_sign_panic() {
let context = StubContext::default();
let mut block = ExponentBlock::<Matrix<2, 2, f32>>::default();
let parameters = Parameters::new(0.5, false);
let input = Matrix {
data: [[1.0, -4.0], [9.0, -16.0]],
};
block.process(¶meters, &context, &input);
}
}