pictorus_blocks/core_blocks/
trigonometry_block.rs1use crate::traits::MatrixOps;
2use num_traits::Float;
3use pictorus_block_data::{BlockData as OldBlockData, FromPass};
4use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
5
6#[derive(strum::EnumString, PartialEq)]
7pub enum TrigonometryFunction {
8 Sine,
9 Cosine,
10 Tangent,
11 ArcSine,
12 ArcCosine,
13 ArcTangent,
14 SineHyperbolic,
15 CosineHyperbolic,
16 TangentHyperbolic,
17 ArcSineHyperbolic,
18 ArcCosineHyperbolic,
19 ArcTangentHyperbolic,
20}
21
22pub struct Parameters {
23 pub function: TrigonometryFunction,
24}
25
26impl Parameters {
27 pub fn new(function: &str) -> Self {
28 Self {
29 function: function
30 .parse()
31 .expect("Failed to parse TrigonometryFunction"),
32 }
33 }
34}
35
36pub struct TrigonometryBlock<T> {
37 pub data: OldBlockData,
38 buffer: T,
39}
40
41impl<T> Default for TrigonometryBlock<T>
42where
43 T: Default + Pass,
44 OldBlockData: FromPass<T>,
45{
46 fn default() -> Self {
47 Self {
48 data: <OldBlockData as FromPass<T>>::from_pass(T::default().as_by()),
49 buffer: T::default(),
50 }
51 }
52}
53
54macro_rules! impl_trig_block {
55 ($type:ty) => {
56 impl ProcessBlock for TrigonometryBlock<$type> {
57 type Inputs = $type;
58 type Output = $type;
59 type Parameters = Parameters;
60
61 fn process(
62 &mut self,
63 parameters: &Self::Parameters,
64 _context: &dyn pictorus_traits::Context,
65 inputs: PassBy<'_, Self::Inputs>,
66 ) -> PassBy<Self::Output> {
67 let output = match parameters.function {
68 TrigonometryFunction::Sine => Float::sin(inputs),
69 TrigonometryFunction::Cosine => Float::cos(inputs),
70 TrigonometryFunction::Tangent => Float::tan(inputs),
71 TrigonometryFunction::ArcSine => Float::asin(inputs),
72 TrigonometryFunction::ArcCosine => Float::acos(inputs),
73 TrigonometryFunction::ArcTangent => Float::atan(inputs),
74 TrigonometryFunction::SineHyperbolic => Float::sinh(inputs),
75 TrigonometryFunction::CosineHyperbolic => Float::cosh(inputs),
76 TrigonometryFunction::TangentHyperbolic => Float::tanh(inputs),
77 TrigonometryFunction::ArcSineHyperbolic => Float::asinh(inputs),
78 TrigonometryFunction::ArcCosineHyperbolic => Float::acosh(inputs),
79 TrigonometryFunction::ArcTangentHyperbolic => Float::atanh(inputs),
80 };
81 self.buffer = output;
82 self.data = OldBlockData::from_scalar(output.into());
83 output
84 }
85 }
86
87 impl<const ROWS: usize, const COLS: usize> ProcessBlock
88 for TrigonometryBlock<Matrix<ROWS, COLS, $type>>
89 where
90 OldBlockData: FromPass<Matrix<ROWS, COLS, $type>>,
91 {
92 type Inputs = Matrix<ROWS, COLS, $type>;
93 type Output = Matrix<ROWS, COLS, $type>;
94 type Parameters = Parameters;
95
96 fn process(
97 &mut self,
98 parameters: &Self::Parameters,
99 _context: &dyn pictorus_traits::Context,
100 inputs: PassBy<'_, Self::Inputs>,
101 ) -> PassBy<Self::Output> {
102 inputs.for_each(|input, c, r| {
103 let output = match parameters.function {
104 TrigonometryFunction::Sine => Float::sin(input),
105 TrigonometryFunction::Cosine => Float::cos(input),
106 TrigonometryFunction::Tangent => Float::tan(input),
107 TrigonometryFunction::ArcSine => Float::asin(input),
108 TrigonometryFunction::ArcCosine => Float::acos(input),
109 TrigonometryFunction::ArcTangent => Float::atan(input),
110 TrigonometryFunction::SineHyperbolic => Float::sinh(input),
111 TrigonometryFunction::CosineHyperbolic => Float::cosh(input),
112 TrigonometryFunction::TangentHyperbolic => Float::tanh(input),
113 TrigonometryFunction::ArcSineHyperbolic => Float::asinh(input),
114 TrigonometryFunction::ArcCosineHyperbolic => Float::acosh(input),
115 TrigonometryFunction::ArcTangentHyperbolic => Float::atanh(input),
116 };
117 self.buffer.data[c][r] = output;
118 });
119 self.data = OldBlockData::from_pass(&self.buffer);
120 &self.buffer
121 }
122 }
123 };
124}
125
126impl_trig_block!(f64);
127impl_trig_block!(f32);
128
129#[cfg(test)]
130mod tests {
131 extern crate std;
132 use super::*;
133 use crate::testing::StubContext;
134 use approx::assert_relative_eq;
135 use core::f64::consts::PI;
136 use rstest::rstest;
137
138 #[rstest]
139 #[case::sin_0("Sine", 0.0, 0.0)]
140 #[case::sin_pi_2("Sine", PI / 2.0, 1.0)]
141 #[case::cos_0("Cosine", 0.0, 1.0)]
142 #[case::cos_pi_2("Cosine", PI / 2.0, 0.0)]
143 #[case::tan_0("Tangent", 0.0, 0.0)]
144 #[case::tan_pi_4("Tangent", PI / 4.0, 1.0)]
145 #[case::asin_0("ArcSine", 0.0, 0.0)]
146 #[case::asin_1("ArcSine", 1.0, PI / 2.0)]
147 #[case::acos_1("ArcCosine", 1.0, 0.0)]
148 #[case::acos_0("ArcCosine", 0.0, PI / 2.0)]
149 #[case::atan_0("ArcTangent", 0.0, 0.0)]
150 #[case::atan_1("ArcTangent", 1.0, PI / 4.0)]
151 #[case::sinh_0("SineHyperbolic", 0.0, 0.0)]
152 #[case::sinh_1("SineHyperbolic", 1.0, 1.17520)]
153 #[case::cosh_0("CosineHyperbolic", 0.0, 1.0)]
154 #[case::cosh_1("CosineHyperbolic", 1.0, 1.54308)]
155 #[case::tanh_0("TangentHyperbolic", 0.0, 0.0)]
156 #[case::tanh_1("TangentHyperbolic", 1.0, 0.76159)]
157 #[case::asinh_0("ArcSineHyperbolic", 0.0, 0.0)]
158 #[case::asinh_1_17520("ArcSineHyperbolic", 1.17520, 1.0)]
159 #[case::acosh_1("ArcCosineHyperbolic", 1.0, 0.0)]
160 #[case::acosh_1_54308("ArcCosineHyperbolic", 1.54308, 1.0)]
161 #[case::atanh_0("ArcTangentHyperbolic", 0.0, 0.0)]
162 #[case::atanh_0_76159("ArcTangentHyperbolic", 0.76159, 1.0)]
163 fn test_trig_functions(
164 #[case] function: &'static str,
165 #[case] input: f64,
166 #[case] expected: f64,
167 ) {
168 let c = StubContext::default();
169 let mut block = TrigonometryBlock::<f64>::default();
170 let p = Parameters::new(function);
171
172 let output = block.process(&p, &c, input);
173 assert_relative_eq!(output, expected, max_relative = 0.00001);
174 assert_relative_eq!(block.data.scalar(), expected, max_relative = 0.00001);
175 }
176
177 #[test]
178 fn test_trigonometry_block_vectorized() {
179 let c = StubContext::default();
180 let mut sine_block = TrigonometryBlock::<Matrix<1, 2, f64>>::default();
181 let p = Parameters::new("Sine");
182 let inputs = Matrix {
183 data: [[0.0], [PI / 2.0]],
184 };
185
186 let output = sine_block.process(&p, &c, &inputs);
187 assert_relative_eq!(
188 output.data.as_flattened(),
189 [[0.0], [1.0]].as_flattened(),
190 max_relative = 0.00001
191 );
192 }
193}