Skip to main content

pictorus_blocks/core_blocks/
abs_block.rs

1use crate::nalgebra_interop::MatrixExt;
2use num_traits::Float;
3use pictorus_block_data::{BlockData as OldBlockData, FromPass};
4use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock, Scalar};
5
6pub struct Parameter {}
7
8impl Default for Parameter {
9    fn default() -> Self {
10        Self::new()
11    }
12}
13
14impl Parameter {
15    pub fn new() -> Self {
16        Self {}
17    }
18}
19
20/// Computes the absolute value of a scalar, vector, or matrix.
21pub struct AbsBlock<T: Pass + Default> {
22    pub data: OldBlockData,
23    buffer: Option<T>,
24}
25
26impl<T> Default for AbsBlock<T>
27where
28    T: Pass + Default,
29    OldBlockData: FromPass<T>,
30{
31    fn default() -> Self {
32        Self {
33            data: <OldBlockData as FromPass<T>>::from_pass(T::default().as_by()),
34            buffer: None,
35        }
36    }
37}
38
39macro_rules! impl_abs_block {
40    ($type:ty) => {
41        impl ProcessBlock for AbsBlock<$type>
42        where
43            $type: Scalar,
44            OldBlockData: FromPass<$type>,
45        {
46            type Inputs = $type;
47            type Output = $type;
48            type Parameters = Parameter;
49
50            fn process<'b>(
51                &'b mut self,
52                _parameters: &Self::Parameters,
53                _context: &dyn pictorus_traits::Context,
54                inputs: pictorus_traits::PassBy<'_, Self::Inputs>,
55            ) -> pictorus_traits::PassBy<'b, Self::Output> {
56                let output = Float::abs(inputs);
57                self.data = OldBlockData::from_scalar(output.into());
58                output
59            }
60        }
61
62        impl<const ROWS: usize, const COLS: usize> ProcessBlock
63            for AbsBlock<Matrix<ROWS, COLS, $type>>
64        where
65            $type: Scalar,
66            OldBlockData: FromPass<Matrix<ROWS, COLS, $type>>,
67        {
68            type Inputs = Matrix<ROWS, COLS, $type>;
69            type Output = Matrix<ROWS, COLS, $type>;
70            type Parameters = Parameter;
71
72            fn process(
73                &mut self,
74                _parameters: &Self::Parameters,
75                _context: &dyn pictorus_traits::Context,
76                input: PassBy<Self::Inputs>,
77            ) -> PassBy<Self::Output> {
78                let abs = input.as_view().abs();
79                let o = Matrix::<ROWS, COLS, $type>::from_view(&abs.as_view());
80                let output = self.buffer.insert(o);
81                self.data = OldBlockData::from_pass(output);
82                output
83            }
84        }
85    };
86}
87
88impl_abs_block!(f32);
89impl_abs_block!(f64);
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use crate::testing::StubContext;
95    use num_traits::One;
96    use paste::paste;
97
98    macro_rules! test_abs_block {
99        ($name:ident, $type:ty) => {
100            paste! {
101                #[test]
102                fn [<test_abs_block_scalar_ $name>]()
103                {
104                    let mut block = AbsBlock::<$type>::default();
105                    let context = StubContext::default();
106
107                    let output = block.process(&Parameter::new(), &context, <$type>::one());
108                    assert_eq!(output, <$type>::one());
109                    assert_eq!(block.data, OldBlockData::from_scalar(<$type>::one().into()));
110
111                    let output = block.process(&Parameter::new(), &context, -<$type>::one());
112                    assert_eq!(output, <$type>::one());
113                    assert_eq!(block.data, OldBlockData::from_scalar(1.0));
114                }
115
116                #[test]
117                fn [<test_abs_block_vector_1x2_ $name>]() {
118                    let mut block = AbsBlock::<Matrix<1, 2, $type>>::default();
119                    let context = StubContext::default();
120                    let mut input = Matrix::<1, 2, $type>::zeroed();
121                    input.data[0][0] = <$type>::one();
122                    input.data[1][0] = -<$type>::one();
123
124                    let output = block.process(&Parameter::new(), &context, &input);
125                    assert_eq!(output.data[0][0], <$type>::one());
126                    assert_eq!(output.data[1][0], <$type>::one());
127                    assert_eq!(block.data, OldBlockData::from_matrix(&[&[<$type>::one().into(), <$type>::one().into()]]));
128                }
129
130                #[test]
131                fn [<test_abs_block_vector_2x1_ $name>]() {
132                    let mut block = AbsBlock::<Matrix<2, 1, $type>>::default();
133                    let context = StubContext::default();
134                    let mut input = Matrix::<2, 1, $type>::zeroed();
135                    input.data[0][0] = <$type>::one();
136                    input.data[0][1] = -<$type>::one();
137
138                    let output = block.process(&Parameter::new(), &context, &input);
139                    assert_eq!(output.data[0][0], <$type>::one());
140                    assert_eq!(output.data[0][1], <$type>::one());
141                    assert_eq!(block.data, OldBlockData::from_matrix(&[&[<$type>::one().into()], &[<$type>::one().into()]]));
142                }
143
144                #[test]
145                fn [<test_abs_block_matrix_ $name>]() {
146                    let mut block = AbsBlock::<Matrix<2, 2, $type>>::default();
147                    let context = StubContext::default();
148                    let mut input = Matrix::<2, 2, $type>::zeroed();
149                    input.data[0][0] = <$type>::one();
150                    input.data[0][1] = -<$type>::one();
151                    input.data[1][0] = <$type>::one();
152                    input.data[1][1] = -<$type>::one();
153
154                    let output = block.process(&Parameter::new(), &context, &input);
155                    assert_eq!(output.data[0][0], <$type>::one());
156                    assert_eq!(output.data[0][1], <$type>::one());
157                    assert_eq!(output.data[1][0], <$type>::one());
158                    assert_eq!(output.data[1][1], <$type>::one());
159                    assert_eq!(block.data, OldBlockData::from_matrix(&[&[<$type>::one().into(), <$type>::one().into()], &[<$type>::one().into(), <$type>::one().into()]]));
160                }
161            }
162        }
163    }
164
165    test_abs_block!(f32, f32);
166    test_abs_block!(f64, f64);
167}