Skip to main content

pictorus_blocks/core_blocks/
comparison_block.rs

1use crate::traits::{Apply, ApplyInto, MatrixOps, Scalar};
2use pictorus_block_data::{BlockData as OldBlockData, FromPass};
3use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
4
5/// The type of comparison operation to perform
6#[derive(Clone, Copy, Debug, PartialEq, strum::EnumString)]
7pub enum ComparisonType {
8    /// Check if the two inputs are equal
9    Equal,
10    /// Check if the two inputs are not equal
11    NotEqual,
12    /// Check if the first input is greater than the second
13    GreaterThan,
14    /// Check if the first input is greater than or equal to the second
15    GreaterOrEqual,
16    /// Check if the first input is less than the second
17    LessThan,
18    /// Check if the first input is less than or equal to the second
19    LessOrEqual,
20}
21
22/// Parameters for the comparison operator block
23pub struct Parameters {
24    pub comparison_type: ComparisonType,
25}
26
27impl Parameters {
28    pub fn new(comparison_type: &str) -> Self {
29        Self {
30            comparison_type: comparison_type
31                .parse()
32                .expect("Failed to parse comparison method."),
33        }
34    }
35}
36
37/// Performs an element-wise comparison operation on two inputs.
38///
39/// Currently supports the following comparison methods:
40/// - Equal
41/// - NotEqual
42/// - GreaterThan
43/// - GreaterOrEqual
44/// - LessThan
45/// - LessOrEqual
46pub struct ComparisonBlock<T>
47where
48    T: Apply<Parameters>,
49    OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
50{
51    pub data: OldBlockData,
52    buffer: Option<T::Output>,
53}
54
55impl<T> Default for ComparisonBlock<T>
56where
57    T: Apply<Parameters>,
58    OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
59{
60    fn default() -> Self {
61        Self {
62            data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::default().as_by()),
63            buffer: None,
64        }
65    }
66}
67
68impl<T> ProcessBlock for ComparisonBlock<T>
69where
70    T: Apply<Parameters>,
71    OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
72{
73    type Inputs = T;
74    type Output = T::Output;
75    type Parameters = Parameters;
76
77    fn process<'b>(
78        &'b mut self,
79        parameters: &Self::Parameters,
80        _context: &dyn pictorus_traits::Context,
81        inputs: PassBy<Self::Inputs>,
82    ) -> PassBy<'b, Self::Output> {
83        self.buffer = None;
84        T::apply(inputs, parameters, &mut self.buffer);
85        self.data = OldBlockData::from_pass(self.buffer.as_ref().unwrap().as_by());
86        self.buffer.as_ref().unwrap().as_by()
87    }
88}
89
90fn perform_op<S: Scalar + core::cmp::PartialEq + core::cmp::PartialOrd + From<bool>>(
91    lhs: S,
92    rhs: S,
93    comparison_type: ComparisonType,
94) -> S {
95    let res = match comparison_type {
96        ComparisonType::Equal => rhs == lhs,
97        ComparisonType::NotEqual => rhs != lhs,
98        ComparisonType::GreaterThan => rhs > lhs,
99        ComparisonType::GreaterOrEqual => rhs >= lhs,
100        ComparisonType::LessThan => rhs < lhs,
101        ComparisonType::LessOrEqual => rhs <= lhs,
102    };
103    res.into()
104}
105
106// Compare scalar with scalar
107impl<S: Scalar + core::cmp::PartialEq + core::cmp::PartialOrd + From<bool>> ApplyInto<S, Parameters>
108    for S
109{
110    fn apply_into<'a>(
111        input: PassBy<Self>,
112        params: &Parameters,
113        dest: &'a mut Option<S>,
114    ) -> PassBy<'a, S> {
115        match dest {
116            Some(dest) => {
117                *dest = perform_op(input, *dest, params.comparison_type);
118            }
119            None => {
120                *dest = Some(input);
121            }
122        }
123
124        dest.as_ref().unwrap().as_by()
125    }
126}
127
128// Compare matrix and matrix
129impl<
130        const R: usize,
131        const C: usize,
132        S: Scalar + core::cmp::PartialEq + core::cmp::PartialOrd + From<bool>,
133    > ApplyInto<Matrix<R, C, S>, Parameters> for Matrix<R, C, S>
134{
135    fn apply_into<'a>(
136        input: PassBy<Self>,
137        params: &Parameters,
138        dest: &'a mut Option<Matrix<R, C, S>>,
139    ) -> PassBy<'a, Matrix<R, C, S>> {
140        match dest {
141            Some(dest) => {
142                input
143                    .data
144                    .as_flattened()
145                    .iter()
146                    .zip(dest.data.as_flattened_mut().iter_mut())
147                    .for_each(|(input, dest)| {
148                        *dest = perform_op(*input, *dest, params.comparison_type);
149                    });
150            }
151            None => {
152                *dest = Some(*input);
153            }
154        }
155
156        dest.as_ref().unwrap().as_by()
157    }
158}
159
160// Compare scalar with matrix
161impl<
162        const R: usize,
163        const C: usize,
164        S: Scalar + core::cmp::PartialEq + core::cmp::PartialOrd + From<bool>,
165    > ApplyInto<Matrix<R, C, S>, Parameters> for S
166{
167    fn apply_into<'a>(
168        input: PassBy<Self>,
169        params: &Parameters,
170        dest: &'a mut Option<Matrix<R, C, S>>,
171    ) -> PassBy<'a, Matrix<R, C, S>> {
172        match dest {
173            Some(dest) => {
174                dest.data.as_flattened_mut().iter_mut().for_each(|dest| {
175                    *dest = perform_op(input, *dest, params.comparison_type);
176                });
177            }
178            None => {
179                *dest = Some(Matrix::<R, C, S>::from_element(input));
180            }
181        }
182
183        dest.as_ref().unwrap().as_by()
184    }
185}
186#[cfg(test)]
187mod tests {
188    use core::str::FromStr;
189
190    use super::*;
191    use crate::testing::StubContext;
192
193    #[test]
194    fn test_comparison_type() {
195        assert_eq!(
196            ComparisonType::from_str("Equal").unwrap(),
197            ComparisonType::Equal
198        );
199        assert_eq!(
200            ComparisonType::from_str("NotEqual").unwrap(),
201            ComparisonType::NotEqual
202        );
203        assert_eq!(
204            ComparisonType::from_str("GreaterThan").unwrap(),
205            ComparisonType::GreaterThan
206        );
207        assert_eq!(
208            ComparisonType::from_str("GreaterOrEqual").unwrap(),
209            ComparisonType::GreaterOrEqual
210        );
211        assert_eq!(
212            ComparisonType::from_str("LessThan").unwrap(),
213            ComparisonType::LessThan
214        );
215        assert_eq!(
216            ComparisonType::from_str("LessOrEqual").unwrap(),
217            ComparisonType::LessOrEqual
218        );
219    }
220
221    #[test]
222    fn test_comparison_block_scalar() {
223        let c = StubContext::default();
224        let mut block = ComparisonBlock::<(f64, f64)>::default();
225        let output = block.process(&Parameters::new("Equal"), &c, (1., 1.));
226        assert_eq!(output, 1.0);
227
228        let output = block.process(&Parameters::new("Equal"), &c, (0., 1.));
229        assert_eq!(output, 0.0);
230
231        let output = block.process(&Parameters::new("NotEqual"), &c, (1., 0.));
232        assert_eq!(output, 1.0);
233
234        let output = block.process(&Parameters::new("NotEqual"), &c, (1., 1.));
235        assert_eq!(output, 0.0);
236
237        // GreaterThan
238        let output = block.process(&Parameters::new("GreaterThan"), &c, (1., 0.));
239        assert_eq!(output, 1.0);
240
241        let output = block.process(&Parameters::new("GreaterThan"), &c, (1., 1.));
242        assert_eq!(output, 0.0);
243
244        let output = block.process(&Parameters::new("GreaterThan"), &c, (0., 1.));
245        assert_eq!(output, 0.0);
246
247        // GreaterOrEqual
248        let output = block.process(&Parameters::new("GreaterOrEqual"), &c, (1., 0.));
249        assert_eq!(output, 1.0);
250
251        let output = block.process(&Parameters::new("GreaterOrEqual"), &c, (1., 1.));
252        assert_eq!(output, 1.0);
253
254        let output = block.process(&Parameters::new("GreaterOrEqual"), &c, (0., 1.));
255        assert_eq!(output, 0.0);
256
257        // LessThan
258        let output = block.process(&Parameters::new("LessThan"), &c, (0., 1.));
259        assert_eq!(output, 1.0);
260
261        let output = block.process(&Parameters::new("LessThan"), &c, (1., 1.));
262        assert_eq!(output, 0.0);
263
264        let output = block.process(&Parameters::new("LessThan"), &c, (1., 0.));
265        assert_eq!(output, 0.0);
266
267        // LessOrEqual
268        let output = block.process(&Parameters::new("LessOrEqual"), &c, (0., 1.));
269        assert_eq!(output, 1.0);
270
271        let output = block.process(&Parameters::new("LessOrEqual"), &c, (1., 1.));
272        assert_eq!(output, 1.0);
273
274        let output = block.process(&Parameters::new("LessOrEqual"), &c, (1., 0.));
275        assert_eq!(output, 0.0);
276    }
277
278    #[test]
279    fn test_comparison_block_matrix() {
280        let c = StubContext::default();
281        let mut block = ComparisonBlock::<(Matrix<1, 3, f64>, Matrix<1, 3, f64>)>::default();
282        let output = block.process(
283            &Parameters::new("Equal"),
284            &c,
285            (
286                &Matrix {
287                    data: [[1.], [0.], [-1.]],
288                },
289                &Matrix {
290                    data: [[1.], [1.], [1.]],
291                },
292            ),
293        );
294        assert_eq!(
295            output,
296            &Matrix {
297                data: [[1.], [0.], [0.]]
298            }
299        );
300
301        let output = block.process(
302            &Parameters::new("NotEqual"),
303            &c,
304            (
305                &Matrix {
306                    data: [[1.], [0.], [-1.]],
307                },
308                &Matrix {
309                    data: [[1.], [1.], [1.]],
310                },
311            ),
312        );
313        assert_eq!(
314            output,
315            &Matrix {
316                data: [[0.], [1.], [1.]]
317            }
318        );
319
320        let output = block.process(
321            &Parameters::new("GreaterThan"),
322            &c,
323            (
324                &Matrix {
325                    data: [[1.], [1.], [-2.]],
326                },
327                &Matrix {
328                    data: [[1.], [0.], [-1.]],
329                },
330            ),
331        );
332        assert_eq!(
333            output,
334            &Matrix {
335                data: [[0.], [1.], [0.]]
336            }
337        );
338
339        let output = block.process(
340            &Parameters::new("GreaterOrEqual"),
341            &c,
342            (
343                &Matrix {
344                    data: [[1.], [1.], [-2.]],
345                },
346                &Matrix {
347                    data: [[1.], [0.], [-1.]],
348                },
349            ),
350        );
351        assert_eq!(
352            output,
353            &Matrix {
354                data: [[1.], [1.], [0.]]
355            }
356        );
357
358        let output = block.process(
359            &Parameters::new("LessThan"),
360            &c,
361            (
362                &Matrix {
363                    data: [[1.], [1.], [-2.]],
364                },
365                &Matrix {
366                    data: [[1.], [0.], [-1.]],
367                },
368            ),
369        );
370        assert_eq!(
371            output,
372            &Matrix {
373                data: [[0.], [0.], [1.]]
374            }
375        );
376
377        let output = block.process(
378            &Parameters::new("LessOrEqual"),
379            &c,
380            (
381                &Matrix {
382                    data: [[1.], [1.], [-2.]],
383                },
384                &Matrix {
385                    data: [[1.], [0.], [-1.]],
386                },
387            ),
388        );
389        assert_eq!(
390            output,
391            &Matrix {
392                data: [[1.], [0.], [1.]]
393            }
394        );
395    }
396
397    #[test]
398    fn test_comparison_block_scalar_matrix() {
399        let c = StubContext::default();
400        let mut block = ComparisonBlock::<(f64, Matrix<1, 3, f64>)>::default();
401        let output = block.process(
402            &Parameters::new("Equal"),
403            &c,
404            (
405                1.,
406                &Matrix {
407                    data: [[1.], [0.], [-1.]],
408                },
409            ),
410        );
411        assert_eq!(
412            output,
413            &Matrix {
414                data: [[1.], [0.], [0.]]
415            }
416        );
417
418        let output = block.process(
419            &Parameters::new("NotEqual"),
420            &c,
421            (
422                1.,
423                &Matrix {
424                    data: [[1.], [0.], [-1.]],
425                },
426            ),
427        );
428        assert_eq!(
429            output,
430            &Matrix {
431                data: [[0.], [1.], [1.]]
432            }
433        );
434
435        let output = block.process(
436            &Parameters::new("GreaterThan"),
437            &c,
438            (
439                1.,
440                &Matrix {
441                    data: [[2.], [1.], [-1.]],
442                },
443            ),
444        );
445        assert_eq!(
446            output,
447            &Matrix {
448                data: [[0.], [0.], [1.]]
449            }
450        );
451
452        let output = block.process(
453            &Parameters::new("GreaterOrEqual"),
454            &c,
455            (
456                1.,
457                &Matrix {
458                    data: [[2.], [1.], [-1.]],
459                },
460            ),
461        );
462        assert_eq!(
463            output,
464            &Matrix {
465                data: [[0.], [1.], [1.]]
466            }
467        );
468
469        let output = block.process(
470            &Parameters::new("LessThan"),
471            &c,
472            (
473                1.,
474                &Matrix {
475                    data: [[2.], [1.], [-1.]],
476                },
477            ),
478        );
479        assert_eq!(
480            output,
481            &Matrix {
482                data: [[1.], [0.], [0.]]
483            }
484        );
485
486        let output = block.process(
487            &Parameters::new("LessOrEqual"),
488            &c,
489            (
490                1.,
491                &Matrix {
492                    data: [[2.], [1.], [-1.]],
493                },
494            ),
495        );
496        assert_eq!(
497            output,
498            &Matrix {
499                data: [[1.], [1.], [0.]]
500            }
501        );
502    }
503}