1use crate::traits::{Apply, ApplyInto, MatrixOps, Scalar};
2use pictorus_block_data::{BlockData as OldBlockData, FromPass};
3use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
4
5#[derive(Clone, Copy, Debug, PartialEq, strum::EnumString)]
7pub enum ComparisonType {
8 Equal,
10 NotEqual,
12 GreaterThan,
14 GreaterOrEqual,
16 LessThan,
18 LessOrEqual,
20}
21
22pub struct Parameters {
24 pub comparison_type: ComparisonType,
25}
26
27impl Parameters {
28 pub fn new(comparison_type: &str) -> Self {
29 Self {
30 comparison_type: comparison_type
31 .parse()
32 .expect("Failed to parse comparison method."),
33 }
34 }
35}
36
37pub struct ComparisonBlock<T>
47where
48 T: Apply<Parameters>,
49 OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
50{
51 pub data: OldBlockData,
52 buffer: Option<T::Output>,
53}
54
55impl<T> Default for ComparisonBlock<T>
56where
57 T: Apply<Parameters>,
58 OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
59{
60 fn default() -> Self {
61 Self {
62 data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::default().as_by()),
63 buffer: None,
64 }
65 }
66}
67
68impl<T> ProcessBlock for ComparisonBlock<T>
69where
70 T: Apply<Parameters>,
71 OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
72{
73 type Inputs = T;
74 type Output = T::Output;
75 type Parameters = Parameters;
76
77 fn process<'b>(
78 &'b mut self,
79 parameters: &Self::Parameters,
80 _context: &dyn pictorus_traits::Context,
81 inputs: PassBy<Self::Inputs>,
82 ) -> PassBy<'b, Self::Output> {
83 self.buffer = None;
84 T::apply(inputs, parameters, &mut self.buffer);
85 self.data = OldBlockData::from_pass(self.buffer.as_ref().unwrap().as_by());
86 self.buffer.as_ref().unwrap().as_by()
87 }
88}
89
90fn perform_op<S: Scalar + core::cmp::PartialEq + core::cmp::PartialOrd + From<bool>>(
91 lhs: S,
92 rhs: S,
93 comparison_type: ComparisonType,
94) -> S {
95 let res = match comparison_type {
96 ComparisonType::Equal => rhs == lhs,
97 ComparisonType::NotEqual => rhs != lhs,
98 ComparisonType::GreaterThan => rhs > lhs,
99 ComparisonType::GreaterOrEqual => rhs >= lhs,
100 ComparisonType::LessThan => rhs < lhs,
101 ComparisonType::LessOrEqual => rhs <= lhs,
102 };
103 res.into()
104}
105
106impl<S: Scalar + core::cmp::PartialEq + core::cmp::PartialOrd + From<bool>> ApplyInto<S, Parameters>
108 for S
109{
110 fn apply_into<'a>(
111 input: PassBy<Self>,
112 params: &Parameters,
113 dest: &'a mut Option<S>,
114 ) -> PassBy<'a, S> {
115 match dest {
116 Some(dest) => {
117 *dest = perform_op(input, *dest, params.comparison_type);
118 }
119 None => {
120 *dest = Some(input);
121 }
122 }
123
124 dest.as_ref().unwrap().as_by()
125 }
126}
127
128impl<
130 const R: usize,
131 const C: usize,
132 S: Scalar + core::cmp::PartialEq + core::cmp::PartialOrd + From<bool>,
133 > ApplyInto<Matrix<R, C, S>, Parameters> for Matrix<R, C, S>
134{
135 fn apply_into<'a>(
136 input: PassBy<Self>,
137 params: &Parameters,
138 dest: &'a mut Option<Matrix<R, C, S>>,
139 ) -> PassBy<'a, Matrix<R, C, S>> {
140 match dest {
141 Some(dest) => {
142 input
143 .data
144 .as_flattened()
145 .iter()
146 .zip(dest.data.as_flattened_mut().iter_mut())
147 .for_each(|(input, dest)| {
148 *dest = perform_op(*input, *dest, params.comparison_type);
149 });
150 }
151 None => {
152 *dest = Some(*input);
153 }
154 }
155
156 dest.as_ref().unwrap().as_by()
157 }
158}
159
160impl<
162 const R: usize,
163 const C: usize,
164 S: Scalar + core::cmp::PartialEq + core::cmp::PartialOrd + From<bool>,
165 > ApplyInto<Matrix<R, C, S>, Parameters> for S
166{
167 fn apply_into<'a>(
168 input: PassBy<Self>,
169 params: &Parameters,
170 dest: &'a mut Option<Matrix<R, C, S>>,
171 ) -> PassBy<'a, Matrix<R, C, S>> {
172 match dest {
173 Some(dest) => {
174 dest.data.as_flattened_mut().iter_mut().for_each(|dest| {
175 *dest = perform_op(input, *dest, params.comparison_type);
176 });
177 }
178 None => {
179 *dest = Some(Matrix::<R, C, S>::from_element(input));
180 }
181 }
182
183 dest.as_ref().unwrap().as_by()
184 }
185}
186#[cfg(test)]
187mod tests {
188 use core::str::FromStr;
189
190 use super::*;
191 use crate::testing::StubContext;
192
193 #[test]
194 fn test_comparison_type() {
195 assert_eq!(
196 ComparisonType::from_str("Equal").unwrap(),
197 ComparisonType::Equal
198 );
199 assert_eq!(
200 ComparisonType::from_str("NotEqual").unwrap(),
201 ComparisonType::NotEqual
202 );
203 assert_eq!(
204 ComparisonType::from_str("GreaterThan").unwrap(),
205 ComparisonType::GreaterThan
206 );
207 assert_eq!(
208 ComparisonType::from_str("GreaterOrEqual").unwrap(),
209 ComparisonType::GreaterOrEqual
210 );
211 assert_eq!(
212 ComparisonType::from_str("LessThan").unwrap(),
213 ComparisonType::LessThan
214 );
215 assert_eq!(
216 ComparisonType::from_str("LessOrEqual").unwrap(),
217 ComparisonType::LessOrEqual
218 );
219 }
220
221 #[test]
222 fn test_comparison_block_scalar() {
223 let c = StubContext::default();
224 let mut block = ComparisonBlock::<(f64, f64)>::default();
225 let output = block.process(&Parameters::new("Equal"), &c, (1., 1.));
226 assert_eq!(output, 1.0);
227
228 let output = block.process(&Parameters::new("Equal"), &c, (0., 1.));
229 assert_eq!(output, 0.0);
230
231 let output = block.process(&Parameters::new("NotEqual"), &c, (1., 0.));
232 assert_eq!(output, 1.0);
233
234 let output = block.process(&Parameters::new("NotEqual"), &c, (1., 1.));
235 assert_eq!(output, 0.0);
236
237 let output = block.process(&Parameters::new("GreaterThan"), &c, (1., 0.));
239 assert_eq!(output, 1.0);
240
241 let output = block.process(&Parameters::new("GreaterThan"), &c, (1., 1.));
242 assert_eq!(output, 0.0);
243
244 let output = block.process(&Parameters::new("GreaterThan"), &c, (0., 1.));
245 assert_eq!(output, 0.0);
246
247 let output = block.process(&Parameters::new("GreaterOrEqual"), &c, (1., 0.));
249 assert_eq!(output, 1.0);
250
251 let output = block.process(&Parameters::new("GreaterOrEqual"), &c, (1., 1.));
252 assert_eq!(output, 1.0);
253
254 let output = block.process(&Parameters::new("GreaterOrEqual"), &c, (0., 1.));
255 assert_eq!(output, 0.0);
256
257 let output = block.process(&Parameters::new("LessThan"), &c, (0., 1.));
259 assert_eq!(output, 1.0);
260
261 let output = block.process(&Parameters::new("LessThan"), &c, (1., 1.));
262 assert_eq!(output, 0.0);
263
264 let output = block.process(&Parameters::new("LessThan"), &c, (1., 0.));
265 assert_eq!(output, 0.0);
266
267 let output = block.process(&Parameters::new("LessOrEqual"), &c, (0., 1.));
269 assert_eq!(output, 1.0);
270
271 let output = block.process(&Parameters::new("LessOrEqual"), &c, (1., 1.));
272 assert_eq!(output, 1.0);
273
274 let output = block.process(&Parameters::new("LessOrEqual"), &c, (1., 0.));
275 assert_eq!(output, 0.0);
276 }
277
278 #[test]
279 fn test_comparison_block_matrix() {
280 let c = StubContext::default();
281 let mut block = ComparisonBlock::<(Matrix<1, 3, f64>, Matrix<1, 3, f64>)>::default();
282 let output = block.process(
283 &Parameters::new("Equal"),
284 &c,
285 (
286 &Matrix {
287 data: [[1.], [0.], [-1.]],
288 },
289 &Matrix {
290 data: [[1.], [1.], [1.]],
291 },
292 ),
293 );
294 assert_eq!(
295 output,
296 &Matrix {
297 data: [[1.], [0.], [0.]]
298 }
299 );
300
301 let output = block.process(
302 &Parameters::new("NotEqual"),
303 &c,
304 (
305 &Matrix {
306 data: [[1.], [0.], [-1.]],
307 },
308 &Matrix {
309 data: [[1.], [1.], [1.]],
310 },
311 ),
312 );
313 assert_eq!(
314 output,
315 &Matrix {
316 data: [[0.], [1.], [1.]]
317 }
318 );
319
320 let output = block.process(
321 &Parameters::new("GreaterThan"),
322 &c,
323 (
324 &Matrix {
325 data: [[1.], [1.], [-2.]],
326 },
327 &Matrix {
328 data: [[1.], [0.], [-1.]],
329 },
330 ),
331 );
332 assert_eq!(
333 output,
334 &Matrix {
335 data: [[0.], [1.], [0.]]
336 }
337 );
338
339 let output = block.process(
340 &Parameters::new("GreaterOrEqual"),
341 &c,
342 (
343 &Matrix {
344 data: [[1.], [1.], [-2.]],
345 },
346 &Matrix {
347 data: [[1.], [0.], [-1.]],
348 },
349 ),
350 );
351 assert_eq!(
352 output,
353 &Matrix {
354 data: [[1.], [1.], [0.]]
355 }
356 );
357
358 let output = block.process(
359 &Parameters::new("LessThan"),
360 &c,
361 (
362 &Matrix {
363 data: [[1.], [1.], [-2.]],
364 },
365 &Matrix {
366 data: [[1.], [0.], [-1.]],
367 },
368 ),
369 );
370 assert_eq!(
371 output,
372 &Matrix {
373 data: [[0.], [0.], [1.]]
374 }
375 );
376
377 let output = block.process(
378 &Parameters::new("LessOrEqual"),
379 &c,
380 (
381 &Matrix {
382 data: [[1.], [1.], [-2.]],
383 },
384 &Matrix {
385 data: [[1.], [0.], [-1.]],
386 },
387 ),
388 );
389 assert_eq!(
390 output,
391 &Matrix {
392 data: [[1.], [0.], [1.]]
393 }
394 );
395 }
396
397 #[test]
398 fn test_comparison_block_scalar_matrix() {
399 let c = StubContext::default();
400 let mut block = ComparisonBlock::<(f64, Matrix<1, 3, f64>)>::default();
401 let output = block.process(
402 &Parameters::new("Equal"),
403 &c,
404 (
405 1.,
406 &Matrix {
407 data: [[1.], [0.], [-1.]],
408 },
409 ),
410 );
411 assert_eq!(
412 output,
413 &Matrix {
414 data: [[1.], [0.], [0.]]
415 }
416 );
417
418 let output = block.process(
419 &Parameters::new("NotEqual"),
420 &c,
421 (
422 1.,
423 &Matrix {
424 data: [[1.], [0.], [-1.]],
425 },
426 ),
427 );
428 assert_eq!(
429 output,
430 &Matrix {
431 data: [[0.], [1.], [1.]]
432 }
433 );
434
435 let output = block.process(
436 &Parameters::new("GreaterThan"),
437 &c,
438 (
439 1.,
440 &Matrix {
441 data: [[2.], [1.], [-1.]],
442 },
443 ),
444 );
445 assert_eq!(
446 output,
447 &Matrix {
448 data: [[0.], [0.], [1.]]
449 }
450 );
451
452 let output = block.process(
453 &Parameters::new("GreaterOrEqual"),
454 &c,
455 (
456 1.,
457 &Matrix {
458 data: [[2.], [1.], [-1.]],
459 },
460 ),
461 );
462 assert_eq!(
463 output,
464 &Matrix {
465 data: [[0.], [1.], [1.]]
466 }
467 );
468
469 let output = block.process(
470 &Parameters::new("LessThan"),
471 &c,
472 (
473 1.,
474 &Matrix {
475 data: [[2.], [1.], [-1.]],
476 },
477 ),
478 );
479 assert_eq!(
480 output,
481 &Matrix {
482 data: [[1.], [0.], [0.]]
483 }
484 );
485
486 let output = block.process(
487 &Parameters::new("LessOrEqual"),
488 &c,
489 (
490 1.,
491 &Matrix {
492 data: [[2.], [1.], [-1.]],
493 },
494 ),
495 );
496 assert_eq!(
497 output,
498 &Matrix {
499 data: [[1.], [1.], [0.]]
500 }
501 );
502 }
503}