pictorus-blocks 0.0.0

Implementations of Pictorus blocks.
Documentation
use crate::traits::MatrixOps;
use num_traits::Float;
use pictorus_block_data::{BlockData as OldBlockData, FromPass};
use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};

#[derive(strum::EnumString, PartialEq)]
pub enum TrigonometryFunction {
    Sine,
    Cosine,
    Tangent,
    ArcSine,
    ArcCosine,
    ArcTangent,
    SineHyperbolic,
    CosineHyperbolic,
    TangentHyperbolic,
    ArcSineHyperbolic,
    ArcCosineHyperbolic,
    ArcTangentHyperbolic,
}

pub struct Parameters {
    pub function: TrigonometryFunction,
}

impl Parameters {
    pub fn new(function: &str) -> Self {
        Self {
            function: function
                .parse()
                .expect("Failed to parse TrigonometryFunction"),
        }
    }
}

pub struct TrigonometryBlock<T> {
    pub data: OldBlockData,
    buffer: T,
}

impl<T> Default for TrigonometryBlock<T>
where
    T: Default + Pass,
    OldBlockData: FromPass<T>,
{
    fn default() -> Self {
        Self {
            data: <OldBlockData as FromPass<T>>::from_pass(T::default().as_by()),
            buffer: T::default(),
        }
    }
}

macro_rules! impl_trig_block {
    ($type:ty) => {
        impl ProcessBlock for TrigonometryBlock<$type> {
            type Inputs = $type;
            type Output = $type;
            type Parameters = Parameters;

            fn process(
                &mut self,
                parameters: &Self::Parameters,
                _context: &dyn pictorus_traits::Context,
                inputs: PassBy<'_, Self::Inputs>,
            ) -> PassBy<Self::Output> {
                let output = match parameters.function {
                    TrigonometryFunction::Sine => Float::sin(inputs),
                    TrigonometryFunction::Cosine => Float::cos(inputs),
                    TrigonometryFunction::Tangent => Float::tan(inputs),
                    TrigonometryFunction::ArcSine => Float::asin(inputs),
                    TrigonometryFunction::ArcCosine => Float::acos(inputs),
                    TrigonometryFunction::ArcTangent => Float::atan(inputs),
                    TrigonometryFunction::SineHyperbolic => Float::sinh(inputs),
                    TrigonometryFunction::CosineHyperbolic => Float::cosh(inputs),
                    TrigonometryFunction::TangentHyperbolic => Float::tanh(inputs),
                    TrigonometryFunction::ArcSineHyperbolic => Float::asinh(inputs),
                    TrigonometryFunction::ArcCosineHyperbolic => Float::acosh(inputs),
                    TrigonometryFunction::ArcTangentHyperbolic => Float::atanh(inputs),
                };
                self.buffer = output;
                self.data = OldBlockData::from_scalar(output.into());
                output
            }
        }

        impl<const ROWS: usize, const COLS: usize> ProcessBlock
            for TrigonometryBlock<Matrix<ROWS, COLS, $type>>
        where
            OldBlockData: FromPass<Matrix<ROWS, COLS, $type>>,
        {
            type Inputs = Matrix<ROWS, COLS, $type>;
            type Output = Matrix<ROWS, COLS, $type>;
            type Parameters = Parameters;

            fn process(
                &mut self,
                parameters: &Self::Parameters,
                _context: &dyn pictorus_traits::Context,
                inputs: PassBy<'_, Self::Inputs>,
            ) -> PassBy<Self::Output> {
                inputs.for_each(|input, c, r| {
                    let output = match parameters.function {
                        TrigonometryFunction::Sine => Float::sin(input),
                        TrigonometryFunction::Cosine => Float::cos(input),
                        TrigonometryFunction::Tangent => Float::tan(input),
                        TrigonometryFunction::ArcSine => Float::asin(input),
                        TrigonometryFunction::ArcCosine => Float::acos(input),
                        TrigonometryFunction::ArcTangent => Float::atan(input),
                        TrigonometryFunction::SineHyperbolic => Float::sinh(input),
                        TrigonometryFunction::CosineHyperbolic => Float::cosh(input),
                        TrigonometryFunction::TangentHyperbolic => Float::tanh(input),
                        TrigonometryFunction::ArcSineHyperbolic => Float::asinh(input),
                        TrigonometryFunction::ArcCosineHyperbolic => Float::acosh(input),
                        TrigonometryFunction::ArcTangentHyperbolic => Float::atanh(input),
                    };
                    self.buffer.data[c][r] = output;
                });
                self.data = OldBlockData::from_pass(&self.buffer);
                &self.buffer
            }
        }
    };
}

impl_trig_block!(f64);
impl_trig_block!(f32);

#[cfg(test)]
mod tests {
    extern crate std;
    use super::*;
    use crate::testing::StubContext;
    use approx::assert_relative_eq;
    use core::f64::consts::PI;
    use rstest::rstest;

    #[rstest]
    #[case::sin_0("Sine", 0.0, 0.0)]
    #[case::sin_pi_2("Sine", PI / 2.0, 1.0)]
    #[case::cos_0("Cosine", 0.0, 1.0)]
    #[case::cos_pi_2("Cosine", PI / 2.0, 0.0)]
    #[case::tan_0("Tangent", 0.0, 0.0)]
    #[case::tan_pi_4("Tangent", PI / 4.0, 1.0)]
    #[case::asin_0("ArcSine", 0.0, 0.0)]
    #[case::asin_1("ArcSine", 1.0, PI / 2.0)]
    #[case::acos_1("ArcCosine", 1.0, 0.0)]
    #[case::acos_0("ArcCosine", 0.0, PI / 2.0)]
    #[case::atan_0("ArcTangent", 0.0, 0.0)]
    #[case::atan_1("ArcTangent", 1.0, PI / 4.0)]
    #[case::sinh_0("SineHyperbolic", 0.0, 0.0)]
    #[case::sinh_1("SineHyperbolic", 1.0, 1.17520)]
    #[case::cosh_0("CosineHyperbolic", 0.0, 1.0)]
    #[case::cosh_1("CosineHyperbolic", 1.0, 1.54308)]
    #[case::tanh_0("TangentHyperbolic", 0.0, 0.0)]
    #[case::tanh_1("TangentHyperbolic", 1.0, 0.76159)]
    #[case::asinh_0("ArcSineHyperbolic", 0.0, 0.0)]
    #[case::asinh_1_17520("ArcSineHyperbolic", 1.17520, 1.0)]
    #[case::acosh_1("ArcCosineHyperbolic", 1.0, 0.0)]
    #[case::acosh_1_54308("ArcCosineHyperbolic", 1.54308, 1.0)]
    #[case::atanh_0("ArcTangentHyperbolic", 0.0, 0.0)]
    #[case::atanh_0_76159("ArcTangentHyperbolic", 0.76159, 1.0)]
    fn test_trig_functions(
        #[case] function: &'static str,
        #[case] input: f64,
        #[case] expected: f64,
    ) {
        let c = StubContext::default();
        let mut block = TrigonometryBlock::<f64>::default();
        let p = Parameters::new(function);

        let output = block.process(&p, &c, input);
        assert_relative_eq!(output, expected, max_relative = 0.00001);
        assert_relative_eq!(block.data.scalar(), expected, max_relative = 0.00001);
    }

    #[test]
    fn test_trigonometry_block_vectorized() {
        let c = StubContext::default();
        let mut sine_block = TrigonometryBlock::<Matrix<1, 2, f64>>::default();
        let p = Parameters::new("Sine");
        let inputs = Matrix {
            data: [[0.0], [PI / 2.0]],
        };

        let output = sine_block.process(&p, &c, &inputs);
        assert_relative_eq!(
            output.data.as_flattened(),
            [[0.0], [1.0]].as_flattened(),
            max_relative = 0.00001
        );
    }
}