Skip to main content

pictorus_blocks/core_blocks/
not_block.rs

1use pictorus_block_data::{BlockData as OldBlockData, FromPass};
2use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
3
4#[derive(strum::EnumString, Clone, Copy)]
5pub enum NotMethod {
6    Logical,
7    Bitwise,
8}
9
10/// A block that performs a logical or bitwise NOT operation on the input.
11pub struct NotBlock<T>
12where
13    T: Apply,
14    OldBlockData: FromPass<T::Output>,
15{
16    pub data: OldBlockData,
17    buffer: Option<T::Output>,
18}
19
20impl<T> Default for NotBlock<T>
21where
22    T: Apply,
23    OldBlockData: FromPass<T::Output>,
24{
25    fn default() -> Self {
26        Self {
27            data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::default().as_by()),
28            buffer: None,
29        }
30    }
31}
32
33impl<T> ProcessBlock for NotBlock<T>
34where
35    T: Apply,
36    OldBlockData: FromPass<T::Output>,
37{
38    type Inputs = T;
39    type Output = T::Output;
40    type Parameters = Parameters;
41
42    fn process(
43        &mut self,
44        parameters: &Self::Parameters,
45        _context: &dyn pictorus_traits::Context,
46        input: PassBy<Self::Inputs>,
47    ) -> PassBy<Self::Output> {
48        let output = T::apply(&mut self.buffer, input, parameters.method);
49        self.data = OldBlockData::from_pass(output);
50        output
51    }
52}
53
54pub trait Apply: Pass {
55    type Output: Pass + Default;
56
57    fn apply<'s>(
58        store: &'s mut Option<Self::Output>,
59        input: PassBy<Self>,
60        method: NotMethod,
61    ) -> PassBy<'s, Self::Output>;
62}
63
64impl Apply for bool {
65    type Output = bool;
66
67    fn apply<'s>(
68        store: &'s mut Option<Self::Output>,
69        input: PassBy<Self>,
70        method: NotMethod,
71    ) -> PassBy<'s, Self::Output> {
72        let output = match method {
73            NotMethod::Logical => !input,
74            NotMethod::Bitwise => !input,
75        };
76        *store = Some(output);
77        output
78    }
79}
80
81impl<const NROWS: usize, const NCOLS: usize> Apply for Matrix<NROWS, NCOLS, bool> {
82    type Output = Matrix<NROWS, NCOLS, bool>;
83
84    fn apply<'s>(
85        store: &'s mut Option<Self::Output>,
86        input: PassBy<Self>,
87        method: NotMethod,
88    ) -> PassBy<'s, Self::Output> {
89        let output = store.insert(Matrix::zeroed());
90        output
91            .data
92            .as_flattened_mut()
93            .iter_mut()
94            .enumerate()
95            .for_each(|(i, lhs)| {
96                let input_val = input.data.as_flattened()[i];
97                *lhs = match method {
98                    NotMethod::Logical => !input_val,
99                    NotMethod::Bitwise => !input_val,
100                };
101            });
102        output
103    }
104}
105
106macro_rules! impl_not_apply {
107    ($type:ty, $cast_type:ty) => {
108        impl Apply for $type {
109            type Output = $type;
110
111            fn apply<'s>(
112                store: &'s mut Option<Self::Output>,
113                input: PassBy<Self>,
114                method: NotMethod,
115            ) -> PassBy<'s, Self::Output> {
116                let output = match method {
117                    NotMethod::Logical => {
118                        if input == 0.0 {
119                            1.0
120                        } else {
121                            0.0
122                        }
123                    }
124                    NotMethod::Bitwise => !(input as $cast_type) as $type,
125                };
126                *store = Some(output);
127                output
128            }
129        }
130
131        impl<const NROWS: usize, const NCOLS: usize> Apply for Matrix<NROWS, NCOLS, $type> {
132            type Output = Matrix<NROWS, NCOLS, $type>;
133
134            fn apply<'s>(
135                store: &'s mut Option<Self::Output>,
136                input: PassBy<Self>,
137                method: NotMethod,
138            ) -> PassBy<'s, Self::Output> {
139                let output = store.insert(Matrix::zeroed());
140                output
141                    .data
142                    .as_flattened_mut()
143                    .iter_mut()
144                    .enumerate()
145                    .for_each(|(i, lhs)| {
146                        let input_val = input.data.as_flattened()[i];
147                        *lhs = match method {
148                            NotMethod::Logical => {
149                                if input_val == 0.0 {
150                                    1.0
151                                } else {
152                                    0.0
153                                }
154                            }
155                            NotMethod::Bitwise => !(input_val as $cast_type) as $type,
156                        };
157                    });
158                output
159            }
160        }
161    };
162}
163
164pub struct Parameters {
165    // The method to use for the NOT operation. Either 'Logical' or 'Bitwise'.
166    pub method: NotMethod,
167}
168
169impl Parameters {
170    pub fn new(method: &str) -> Self {
171        Self {
172            method: method
173                .parse()
174                .expect("Failed to parse NotMethod, expected 'Logical' or 'Bitwise'"),
175        }
176    }
177}
178
179impl_not_apply!(f32, i32);
180impl_not_apply!(f64, i64);
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::testing::StubContext;
186    use paste::paste;
187
188    macro_rules! test_not_block {
189        ($type:ty) => {
190            paste! {
191                #[test]
192                fn [<test_not_block_logical_scalar_ $type>]() {
193                    let mut block = NotBlock::<$type>::default();
194                    let context = StubContext::default();
195                    let parameters = Parameters::new("Logical");
196
197                    let res = block.process(&parameters, &context, 1.0);
198                    assert_eq!(res, 0.0);
199                    assert_eq!(block.data.scalar(), 0.0);
200
201                    let res = block.process(&parameters, &context, 0.0);
202                    assert_eq!(res, 1.0);
203                    assert_eq!(block.data.scalar(), 1.0);
204
205                    let res = block.process(&parameters, &context, -1.2);
206                    assert_eq!(res, 0.0);
207                    assert_eq!(block.data.scalar(), 0.0);
208
209                    let res = block.process(&parameters, &context, 1.2);
210                    assert_eq!(res, 0.0);
211                    assert_eq!(block.data.scalar(), 0.0);
212                }
213
214                #[test]
215                fn [<test_not_block_logical_matrix_ $type>]() {
216                    let mut block = NotBlock::<Matrix<4, 1, $type>>::default();
217                    let context = StubContext::default();
218                    let parameters = Parameters::new("Logical");
219
220                    let input = Matrix {
221                        data: [[1.0, 0.0, -1.2, 1.2]],
222                    };
223                    let res = block.process(&parameters, &context, &input);
224                    assert_eq!(res.data, [[0.0, 1.0, 0.0, 0.0]]);
225                    assert_eq!(block.data.get_data().as_slice(), [[0.0, 1.0, 0.0, 0.0]].as_flattened());
226                }
227
228                #[test]
229                fn [<test_not_block_bitwise_scalar_ $type>]() {
230                    let mut block = NotBlock::<$type>::default();
231                    let context = StubContext::default();
232                    let parameters = Parameters::new("Bitwise");
233
234                    let res = block.process(&parameters, &context, 1.0);
235                    assert_eq!(res, -2.0);
236                    assert_eq!(block.data.scalar(), -2.0);
237
238                    let res = block.process(&parameters, &context, 42.0);
239                    assert_eq!(res, -43.0);
240                    assert_eq!(block.data.scalar(), -43.0);
241
242                    let res = block.process(&parameters, &context, -1.2);
243                    assert_eq!(res, 0.0);
244                    assert_eq!(block.data.scalar(), 0.0);
245
246                    let res = block.process(&parameters, &context, 1.2);
247                    assert_eq!(res, -2.0);
248                    assert_eq!(block.data.scalar(), -2.0);
249                }
250
251                #[test]
252                fn [<test_not_block_bitwise_matrix_ $type>]() {
253                    let mut block = NotBlock::<Matrix<2, 2, $type>>::default();
254                    let context = StubContext::default();
255                    let parameters = Parameters::new("Bitwise");
256
257                    let input = Matrix {
258                        data: [[1.0, 42.0], [-1.2, 1.2]],
259                    };
260                    let res = block.process(&parameters, &context, &input);
261                    assert_eq!(res.data, [[-2.0, -43.0], [0.0, -2.0]]);
262                    assert_eq!(block.data.get_data().as_slice(), [[-2.0, -43.0], [0.0, -2.0]].as_flattened());
263                }
264            }
265        };
266    }
267
268    test_not_block!(f32);
269    test_not_block!(f64);
270
271    #[test]
272    fn test_scalar_bool() {
273        let mut block = NotBlock::<bool>::default();
274        let context = StubContext::default();
275        let parameters = Parameters::new("Logical");
276
277        let res = block.process(&parameters, &context, true);
278        assert!(!res);
279        assert_eq!(block.data.scalar(), 0.0);
280
281        let res = block.process(&parameters, &context, false);
282        assert!(res);
283        assert_eq!(block.data.scalar(), 1.0);
284
285        let parameters = Parameters::new("Bitwise");
286        let res = block.process(&parameters, &context, true);
287        assert!(!res);
288        assert_eq!(block.data.scalar(), 0.0);
289
290        let res = block.process(&parameters, &context, false);
291        assert!(res);
292        assert_eq!(block.data.scalar(), 1.0);
293    }
294
295    #[test]
296    fn test_matrix_bool() {
297        let mut block = NotBlock::<Matrix<2, 2, bool>>::default();
298        let context = StubContext::default();
299        let parameters = Parameters::new("Logical");
300
301        let input = Matrix {
302            data: [[true, false], [false, true]],
303        };
304        let res = block.process(&parameters, &context, &input);
305        assert_eq!(res.data, [[false, true], [true, false]]);
306        assert_eq!(
307            block.data.get_data().as_slice(),
308            [[0.0, 1.0], [1.0, 0.0]].as_flattened()
309        );
310    }
311}