Skip to main content

pictorus_blocks/core_blocks/
logical_block.rs

1use core::ops::Sub;
2use num_traits::One;
3use pictorus_block_data::{BlockData as OldBlockData, FromPass};
4use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
5
6use crate::traits::{Apply, ApplyInto, MatrixOps, Scalar};
7
8/// Performs logical operations on inputs.
9///
10/// Currently supports the following methods:
11/// - And
12/// - Or
13/// - Nor
14/// - Nand
15pub struct LogicalBlock<T>
16where
17    T: Apply<Parameters>,
18    T::Output: Finalize,
19    OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
20{
21    store: Option<T::Output>,
22    pub data: OldBlockData,
23}
24
25impl<T> Default for LogicalBlock<T>
26where
27    T: Apply<Parameters>,
28    T::Output: Finalize,
29    OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
30{
31    fn default() -> Self {
32        Self {
33            store: None,
34            data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::default().as_by()),
35        }
36    }
37}
38
39impl<T> ProcessBlock for LogicalBlock<T>
40where
41    T: Apply<Parameters>,
42    T::Output: Finalize,
43    OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
44{
45    type Inputs = T;
46    type Output = T::Output;
47    type Parameters = Parameters;
48    fn process<'b>(
49        &'b mut self,
50        parameters: &Self::Parameters,
51        _context: &dyn pictorus_traits::Context,
52        inputs: PassBy<'_, Self::Inputs>,
53    ) -> PassBy<'b, Self::Output> {
54        self.store = None;
55        T::apply(inputs, parameters, &mut self.store);
56        let result = T::Output::finalize(parameters.method, &mut self.store);
57        self.data = OldBlockData::from_pass(result);
58        result
59    }
60}
61
62fn perform_op<S: Scalar + From<bool>>(input: S, dest: S, method: LogicalMethod) -> S {
63    let x0 = input.is_truthy();
64    let x1 = dest.is_truthy();
65    let res = match method {
66        LogicalMethod::And => x0 & x1,
67        LogicalMethod::Or => x0 | x1,
68        // NAND and NOR behave the same as OR and AND during
69        // the calculation, but the final result is inverted in the finalize step
70        LogicalMethod::Nand => x0 & x1,
71        LogicalMethod::Nor => x0 | x1,
72    };
73
74    res.into()
75}
76
77// Compare scalar with scalar
78impl<S: Scalar + From<bool>> ApplyInto<S, Parameters> for S {
79    fn apply_into<'a>(
80        input: PassBy<Self>,
81        params: &Parameters,
82        dest: &'a mut Option<S>,
83    ) -> PassBy<'a, S> {
84        match dest {
85            Some(dest) => {
86                *dest = perform_op(input, *dest, params.method);
87            }
88            None => {
89                *dest = Some(input);
90            }
91        }
92
93        dest.as_ref().unwrap().as_by()
94    }
95}
96
97// Compare matrix and matrix
98impl<const R: usize, const C: usize, S: Scalar + From<bool>> ApplyInto<Matrix<R, C, S>, Parameters>
99    for Matrix<R, C, S>
100{
101    fn apply_into<'a>(
102        input: PassBy<Self>,
103        params: &Parameters,
104        dest: &'a mut Option<Matrix<R, C, S>>,
105    ) -> PassBy<'a, Matrix<R, C, S>> {
106        match dest {
107            Some(dest) => {
108                input
109                    .data
110                    .as_flattened()
111                    .iter()
112                    .zip(dest.data.as_flattened_mut().iter_mut())
113                    .for_each(|(input, dest)| {
114                        *dest = perform_op(*input, *dest, params.method);
115                    });
116            }
117            None => {
118                *dest = Some(*input);
119            }
120        }
121
122        dest.as_ref().unwrap().as_by()
123    }
124}
125
126// Compare scalar with matrix
127impl<const R: usize, const C: usize, S: Scalar + From<bool>> ApplyInto<Matrix<R, C, S>, Parameters>
128    for S
129{
130    fn apply_into<'a>(
131        input: PassBy<Self>,
132        params: &Parameters,
133        dest: &'a mut Option<Matrix<R, C, S>>,
134    ) -> PassBy<'a, Matrix<R, C, S>> {
135        match dest {
136            Some(dest) => {
137                dest.data.as_flattened_mut().iter_mut().for_each(|dest| {
138                    *dest = perform_op(input, *dest, params.method);
139                });
140            }
141            None => {
142                *dest = Some(Matrix::<R, C, S>::from_element(input));
143            }
144        }
145
146        dest.as_ref().unwrap().as_by()
147    }
148}
149
150pub trait Finalize: Pass + Default {
151    fn finalize(method: LogicalMethod, dest: &mut Option<Self>) -> PassBy<'_, Self>;
152}
153
154impl<S: Scalar + One + Sub<Output = S>> Finalize for S {
155    fn finalize(method: LogicalMethod, dest: &mut Option<Self>) -> PassBy<'_, Self> {
156        let input = dest.get_or_insert(S::default());
157        let res = match method {
158            LogicalMethod::Nor => S::one() - *input,
159            LogicalMethod::Nand => S::one() - *input,
160            _ => *input,
161        };
162
163        *dest = Some(res);
164        dest.as_ref().unwrap().as_by()
165    }
166}
167
168impl<const R: usize, const C: usize, S: Scalar + One + Sub<Output = S>> Finalize
169    for Matrix<R, C, S>
170{
171    fn finalize(method: LogicalMethod, dest: &mut Option<Self>) -> PassBy<'_, Self> {
172        let dest = dest.get_or_insert(Matrix::<R, C, S>::default());
173        dest.data.as_flattened_mut().iter_mut().for_each(|dest| {
174            *dest = match method {
175                LogicalMethod::Nor => S::one() - *dest,
176                LogicalMethod::Nand => S::one() - *dest,
177                _ => *dest,
178            };
179        });
180
181        dest.as_by()
182    }
183}
184
185#[derive(Debug, Clone, Copy, strum::EnumString)]
186/// Logical methods that can be applied
187pub enum LogicalMethod {
188    /// Logical AND operation
189    And,
190    /// Logical OR operation
191    Or,
192    /// Logical NOR operation (NOT OR)
193    Nor,
194    /// Logical NAND operation (NOT AND)
195    Nand,
196}
197
198/// Parameters for the logical block
199pub struct Parameters {
200    /// The logical method to use
201    method: LogicalMethod,
202}
203
204impl Parameters {
205    pub fn new(method: &str) -> Self {
206        Self {
207            method: method.parse().expect("Failed to parse logical method."),
208        }
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use crate::testing::StubContext;
215
216    use super::*;
217
218    #[test]
219    fn test_logical_and_scalar() {
220        let ctxt = StubContext::default();
221        let params = Parameters::new("And");
222        let mut block = LogicalBlock::<(f64, f64, f64)>::default();
223
224        // All zero aka false inputs = false output
225        let res = block.process(&params, &ctxt, (0.0, 0.0, 0.0));
226        assert_eq!(res, 0.0);
227        assert_eq!(block.data.scalar(), 0.0);
228
229        // Some zero inputs = false output
230        let res = block.process(&params, &ctxt, (1.0, 0.0, 1.0));
231        assert_eq!(res, 0.0);
232        assert_eq!(block.data.scalar(), 0.0);
233
234        // All non-zero inputs = true output
235        let res = block.process(&params, &ctxt, (1.0, 1.0, 1.0));
236        assert_eq!(res, 1.0);
237        assert_eq!(block.data.scalar(), 1.0);
238
239        // Even floats and negative data!
240        let res = block.process(&params, &ctxt, (1.0, -2.0, 3.5));
241        assert_eq!(res, 1.0);
242        assert_eq!(block.data.scalar(), 1.0);
243    }
244
245    #[test]
246    fn test_logical_or_scalar() {
247        let ctxt = StubContext::default();
248        let params = Parameters::new("Or");
249        let mut block = LogicalBlock::<(f64, f64, f64)>::default();
250
251        // All zero aka false inputs = false output
252        let res = block.process(&params, &ctxt, (0.0, 0.0, 0.0));
253        assert_eq!(res, 0.0);
254        assert_eq!(block.data.scalar(), 0.0);
255
256        // Some zero inputs = true output
257        let res = block.process(&params, &ctxt, (1.0, 0.0, 1.0));
258        assert_eq!(res, 1.0);
259        assert_eq!(block.data.scalar(), 1.0);
260
261        // All non-zero inputs = true output
262        let res = block.process(&params, &ctxt, (1.0, 1.0, 1.0));
263        assert_eq!(res, 1.0);
264        assert_eq!(block.data.scalar(), 1.0);
265
266        // Even floats and negative data!
267        let res = block.process(&params, &ctxt, (1.0, -2.0, 3.5));
268        assert_eq!(res, 1.0);
269        assert_eq!(block.data.scalar(), 1.0);
270    }
271
272    #[test]
273    fn test_logical_nor_scalar() {
274        let ctxt = StubContext::default();
275        let params = Parameters::new("Nor");
276        let mut block = LogicalBlock::<(f64, f64, f64)>::default();
277
278        // These tests should be the opposite results of the OR tests
279
280        // All zero aka false inputs = true output
281        let res = block.process(&params, &ctxt, (0.0, 0.0, 0.0));
282        assert_eq!(res, 1.0);
283        assert_eq!(block.data.scalar(), 1.0);
284
285        // Some zero inputs = false output
286        let res = block.process(&params, &ctxt, (1.0, 0.0, 1.0));
287        assert_eq!(res, 0.0);
288        assert_eq!(block.data.scalar(), 0.0);
289
290        // All non-zero inputs = false output
291        let res = block.process(&params, &ctxt, (1.0, 1.0, 1.0));
292        assert_eq!(res, 0.0);
293        assert_eq!(block.data.scalar(), 0.0);
294
295        // Even floats and negative data!
296        let res = block.process(&params, &ctxt, (1.0, -2.0, 3.5));
297        assert_eq!(res, 0.0);
298        assert_eq!(block.data.scalar(), 0.0);
299        assert_eq!(block.data.scalar(), 0.0);
300    }
301
302    #[test]
303    fn test_logical_nand_scalar() {
304        let ctxt = StubContext::default();
305        let params = Parameters::new("Nand");
306        let mut block = LogicalBlock::<(f64, f64, f64)>::default();
307
308        // These tests should be the opposite results of the AND tests
309
310        // All zero aka false inputs = true output
311        let res = block.process(&params, &ctxt, (0.0, 0.0, 0.0));
312        assert_eq!(res, 1.0);
313        assert_eq!(block.data.scalar(), 1.0);
314
315        // Some zero inputs = true output
316        let res = block.process(&params, &ctxt, (1.0, 0.0, 1.0));
317        assert_eq!(res, 1.0);
318        assert_eq!(block.data.scalar(), 1.0);
319
320        // All non-zero inputs = false output
321        let res = block.process(&params, &ctxt, (1.0, 1.0, 1.0));
322        assert_eq!(res, 0.0);
323        assert_eq!(block.data.scalar(), 0.0);
324
325        // Even floats and negative data!
326        let res = block.process(&params, &ctxt, (1.0, -2.0, 3.5));
327        assert_eq!(res, 0.0);
328        assert_eq!(block.data.scalar(), 0.0);
329    }
330
331    #[test]
332    fn test_matrix_ops() {
333        let ctxt = StubContext::default();
334        let mut params = Parameters::new("And");
335        let mut block =
336            LogicalBlock::<(Matrix<2, 2, f64>, Matrix<2, 2, f64>, Matrix<2, 2, f64>)>::default();
337
338        let input = (
339            &Matrix {
340                data: [[1.0, 0.0], [0.0, 1.0]],
341            },
342            &Matrix {
343                data: [[0.0, 1.0], [1.0, 0.0]],
344            },
345            &Matrix {
346                data: [[1.0, 1.0], [1.0, 1.0]],
347            },
348        );
349
350        let res = block.process(&params, &ctxt, input);
351        let expected = Matrix {
352            data: [[0.0, 0.0], [0.0, 0.0]],
353        };
354        assert_eq!(res, &expected);
355        assert_eq!(
356            block.data.get_data().as_slice(),
357            expected.data.as_flattened()
358        );
359
360        params.method = LogicalMethod::Or;
361        let res = block.process(&params, &ctxt, input);
362        let expected = Matrix {
363            data: [[1.0, 1.0], [1.0, 1.0]],
364        };
365        assert_eq!(res, &expected);
366        assert_eq!(
367            block.data.get_data().as_slice(),
368            expected.data.as_flattened()
369        );
370
371        params.method = LogicalMethod::Nor;
372        let res = block.process(&params, &ctxt, input);
373        let expected = Matrix {
374            data: [[0.0, 0.0], [0.0, 0.0]],
375        };
376        assert_eq!(res, &expected);
377        assert_eq!(
378            block.data.get_data().as_slice(),
379            expected.data.as_flattened()
380        );
381
382        params.method = LogicalMethod::Nand;
383        let res = block.process(&params, &ctxt, input);
384        let expected = Matrix {
385            data: [[1.0, 1.0], [1.0, 1.0]],
386        };
387        assert_eq!(res, &expected);
388        assert_eq!(
389            block.data.get_data().as_slice(),
390            expected.data.as_flattened()
391        );
392    }
393
394    #[test]
395    fn test_matrix_scalar_ops() {
396        let ctxt = StubContext::default();
397        let mut params = Parameters::new("And");
398        let mut block = LogicalBlock::<(Matrix<2, 2, f64>, f64)>::default();
399
400        let input = (
401            &Matrix {
402                data: [[1.0, 0.0], [0.0, 1.0]],
403            },
404            1.0,
405        );
406
407        let res = block.process(&params, &ctxt, input);
408        let expected = Matrix {
409            data: [[1.0, 0.0], [0.0, 1.0]],
410        };
411        assert_eq!(res, &expected);
412        assert_eq!(
413            block.data.get_data().as_slice(),
414            expected.data.as_flattened()
415        );
416
417        params.method = LogicalMethod::Or;
418        let res = block.process(&params, &ctxt, input);
419        let expected = Matrix {
420            data: [[1.0, 1.0], [1.0, 1.0]],
421        };
422        assert_eq!(res, &expected);
423        assert_eq!(
424            block.data.get_data().as_slice(),
425            expected.data.as_flattened()
426        );
427
428        params.method = LogicalMethod::Nor;
429        let res = block.process(&params, &ctxt, input);
430        let expected = Matrix {
431            data: [[0.0, 0.0], [0.0, 0.0]],
432        };
433        assert_eq!(res, &expected);
434        assert_eq!(
435            block.data.get_data().as_slice(),
436            expected.data.as_flattened()
437        );
438
439        params.method = LogicalMethod::Nand;
440        let res = block.process(&params, &ctxt, input);
441        let expected = Matrix {
442            data: [[0.0, 1.0], [1.0, 0.0]],
443        };
444        assert_eq!(res, &expected);
445        assert_eq!(
446            block.data.get_data().as_slice(),
447            expected.data.as_flattened()
448        );
449    }
450}