Skip to main content

pictorus_blocks/core_blocks/
trigonometry_block.rs

1use crate::traits::MatrixOps;
2use num_traits::Float;
3use pictorus_block_data::{BlockData as OldBlockData, FromPass};
4use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
5
6#[derive(strum::EnumString, PartialEq)]
7pub enum TrigonometryFunction {
8    Sine,
9    Cosine,
10    Tangent,
11    ArcSine,
12    ArcCosine,
13    ArcTangent,
14    SineHyperbolic,
15    CosineHyperbolic,
16    TangentHyperbolic,
17    ArcSineHyperbolic,
18    ArcCosineHyperbolic,
19    ArcTangentHyperbolic,
20}
21
22pub struct Parameters {
23    pub function: TrigonometryFunction,
24}
25
26impl Parameters {
27    pub fn new(function: &str) -> Self {
28        Self {
29            function: function
30                .parse()
31                .expect("Failed to parse TrigonometryFunction"),
32        }
33    }
34}
35
36pub struct TrigonometryBlock<T> {
37    pub data: OldBlockData,
38    buffer: T,
39}
40
41impl<T> Default for TrigonometryBlock<T>
42where
43    T: Default + Pass,
44    OldBlockData: FromPass<T>,
45{
46    fn default() -> Self {
47        Self {
48            data: <OldBlockData as FromPass<T>>::from_pass(T::default().as_by()),
49            buffer: T::default(),
50        }
51    }
52}
53
54macro_rules! impl_trig_block {
55    ($type:ty) => {
56        impl ProcessBlock for TrigonometryBlock<$type> {
57            type Inputs = $type;
58            type Output = $type;
59            type Parameters = Parameters;
60
61            fn process(
62                &mut self,
63                parameters: &Self::Parameters,
64                _context: &dyn pictorus_traits::Context,
65                inputs: PassBy<'_, Self::Inputs>,
66            ) -> PassBy<Self::Output> {
67                let output = match parameters.function {
68                    TrigonometryFunction::Sine => Float::sin(inputs),
69                    TrigonometryFunction::Cosine => Float::cos(inputs),
70                    TrigonometryFunction::Tangent => Float::tan(inputs),
71                    TrigonometryFunction::ArcSine => Float::asin(inputs),
72                    TrigonometryFunction::ArcCosine => Float::acos(inputs),
73                    TrigonometryFunction::ArcTangent => Float::atan(inputs),
74                    TrigonometryFunction::SineHyperbolic => Float::sinh(inputs),
75                    TrigonometryFunction::CosineHyperbolic => Float::cosh(inputs),
76                    TrigonometryFunction::TangentHyperbolic => Float::tanh(inputs),
77                    TrigonometryFunction::ArcSineHyperbolic => Float::asinh(inputs),
78                    TrigonometryFunction::ArcCosineHyperbolic => Float::acosh(inputs),
79                    TrigonometryFunction::ArcTangentHyperbolic => Float::atanh(inputs),
80                };
81                self.buffer = output;
82                self.data = OldBlockData::from_scalar(output.into());
83                output
84            }
85        }
86
87        impl<const ROWS: usize, const COLS: usize> ProcessBlock
88            for TrigonometryBlock<Matrix<ROWS, COLS, $type>>
89        where
90            OldBlockData: FromPass<Matrix<ROWS, COLS, $type>>,
91        {
92            type Inputs = Matrix<ROWS, COLS, $type>;
93            type Output = Matrix<ROWS, COLS, $type>;
94            type Parameters = Parameters;
95
96            fn process(
97                &mut self,
98                parameters: &Self::Parameters,
99                _context: &dyn pictorus_traits::Context,
100                inputs: PassBy<'_, Self::Inputs>,
101            ) -> PassBy<Self::Output> {
102                inputs.for_each(|input, c, r| {
103                    let output = match parameters.function {
104                        TrigonometryFunction::Sine => Float::sin(input),
105                        TrigonometryFunction::Cosine => Float::cos(input),
106                        TrigonometryFunction::Tangent => Float::tan(input),
107                        TrigonometryFunction::ArcSine => Float::asin(input),
108                        TrigonometryFunction::ArcCosine => Float::acos(input),
109                        TrigonometryFunction::ArcTangent => Float::atan(input),
110                        TrigonometryFunction::SineHyperbolic => Float::sinh(input),
111                        TrigonometryFunction::CosineHyperbolic => Float::cosh(input),
112                        TrigonometryFunction::TangentHyperbolic => Float::tanh(input),
113                        TrigonometryFunction::ArcSineHyperbolic => Float::asinh(input),
114                        TrigonometryFunction::ArcCosineHyperbolic => Float::acosh(input),
115                        TrigonometryFunction::ArcTangentHyperbolic => Float::atanh(input),
116                    };
117                    self.buffer.data[c][r] = output;
118                });
119                self.data = OldBlockData::from_pass(&self.buffer);
120                &self.buffer
121            }
122        }
123    };
124}
125
126impl_trig_block!(f64);
127impl_trig_block!(f32);
128
129#[cfg(test)]
130mod tests {
131    extern crate std;
132    use super::*;
133    use crate::testing::StubContext;
134    use approx::assert_relative_eq;
135    use core::f64::consts::PI;
136    use rstest::rstest;
137
138    #[rstest]
139    #[case::sin_0("Sine", 0.0, 0.0)]
140    #[case::sin_pi_2("Sine", PI / 2.0, 1.0)]
141    #[case::cos_0("Cosine", 0.0, 1.0)]
142    #[case::cos_pi_2("Cosine", PI / 2.0, 0.0)]
143    #[case::tan_0("Tangent", 0.0, 0.0)]
144    #[case::tan_pi_4("Tangent", PI / 4.0, 1.0)]
145    #[case::asin_0("ArcSine", 0.0, 0.0)]
146    #[case::asin_1("ArcSine", 1.0, PI / 2.0)]
147    #[case::acos_1("ArcCosine", 1.0, 0.0)]
148    #[case::acos_0("ArcCosine", 0.0, PI / 2.0)]
149    #[case::atan_0("ArcTangent", 0.0, 0.0)]
150    #[case::atan_1("ArcTangent", 1.0, PI / 4.0)]
151    #[case::sinh_0("SineHyperbolic", 0.0, 0.0)]
152    #[case::sinh_1("SineHyperbolic", 1.0, 1.17520)]
153    #[case::cosh_0("CosineHyperbolic", 0.0, 1.0)]
154    #[case::cosh_1("CosineHyperbolic", 1.0, 1.54308)]
155    #[case::tanh_0("TangentHyperbolic", 0.0, 0.0)]
156    #[case::tanh_1("TangentHyperbolic", 1.0, 0.76159)]
157    #[case::asinh_0("ArcSineHyperbolic", 0.0, 0.0)]
158    #[case::asinh_1_17520("ArcSineHyperbolic", 1.17520, 1.0)]
159    #[case::acosh_1("ArcCosineHyperbolic", 1.0, 0.0)]
160    #[case::acosh_1_54308("ArcCosineHyperbolic", 1.54308, 1.0)]
161    #[case::atanh_0("ArcTangentHyperbolic", 0.0, 0.0)]
162    #[case::atanh_0_76159("ArcTangentHyperbolic", 0.76159, 1.0)]
163    fn test_trig_functions(
164        #[case] function: &'static str,
165        #[case] input: f64,
166        #[case] expected: f64,
167    ) {
168        let c = StubContext::default();
169        let mut block = TrigonometryBlock::<f64>::default();
170        let p = Parameters::new(function);
171
172        let output = block.process(&p, &c, input);
173        assert_relative_eq!(output, expected, max_relative = 0.00001);
174        assert_relative_eq!(block.data.scalar(), expected, max_relative = 0.00001);
175    }
176
177    #[test]
178    fn test_trigonometry_block_vectorized() {
179        let c = StubContext::default();
180        let mut sine_block = TrigonometryBlock::<Matrix<1, 2, f64>>::default();
181        let p = Parameters::new("Sine");
182        let inputs = Matrix {
183            data: [[0.0], [PI / 2.0]],
184        };
185
186        let output = sine_block.process(&p, &c, &inputs);
187        assert_relative_eq!(
188            output.data.as_flattened(),
189            [[0.0], [1.0]].as_flattened(),
190            max_relative = 0.00001
191        );
192    }
193}