1use core::ops::Sub;
2use num_traits::One;
3use pictorus_block_data::{BlockData as OldBlockData, FromPass};
4use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
5
6use crate::traits::{Apply, ApplyInto, MatrixOps, Scalar};
7
8pub struct LogicalBlock<T>
16where
17 T: Apply<Parameters>,
18 T::Output: Finalize,
19 OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
20{
21 store: Option<T::Output>,
22 pub data: OldBlockData,
23}
24
25impl<T> Default for LogicalBlock<T>
26where
27 T: Apply<Parameters>,
28 T::Output: Finalize,
29 OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
30{
31 fn default() -> Self {
32 Self {
33 store: None,
34 data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::default().as_by()),
35 }
36 }
37}
38
39impl<T> ProcessBlock for LogicalBlock<T>
40where
41 T: Apply<Parameters>,
42 T::Output: Finalize,
43 OldBlockData: FromPass<<T as Apply<Parameters>>::Output>,
44{
45 type Inputs = T;
46 type Output = T::Output;
47 type Parameters = Parameters;
48 fn process<'b>(
49 &'b mut self,
50 parameters: &Self::Parameters,
51 _context: &dyn pictorus_traits::Context,
52 inputs: PassBy<'_, Self::Inputs>,
53 ) -> PassBy<'b, Self::Output> {
54 self.store = None;
55 T::apply(inputs, parameters, &mut self.store);
56 let result = T::Output::finalize(parameters.method, &mut self.store);
57 self.data = OldBlockData::from_pass(result);
58 result
59 }
60}
61
62fn perform_op<S: Scalar + From<bool>>(input: S, dest: S, method: LogicalMethod) -> S {
63 let x0 = input.is_truthy();
64 let x1 = dest.is_truthy();
65 let res = match method {
66 LogicalMethod::And => x0 & x1,
67 LogicalMethod::Or => x0 | x1,
68 LogicalMethod::Nand => x0 & x1,
71 LogicalMethod::Nor => x0 | x1,
72 };
73
74 res.into()
75}
76
77impl<S: Scalar + From<bool>> ApplyInto<S, Parameters> for S {
79 fn apply_into<'a>(
80 input: PassBy<Self>,
81 params: &Parameters,
82 dest: &'a mut Option<S>,
83 ) -> PassBy<'a, S> {
84 match dest {
85 Some(dest) => {
86 *dest = perform_op(input, *dest, params.method);
87 }
88 None => {
89 *dest = Some(input);
90 }
91 }
92
93 dest.as_ref().unwrap().as_by()
94 }
95}
96
97impl<const R: usize, const C: usize, S: Scalar + From<bool>> ApplyInto<Matrix<R, C, S>, Parameters>
99 for Matrix<R, C, S>
100{
101 fn apply_into<'a>(
102 input: PassBy<Self>,
103 params: &Parameters,
104 dest: &'a mut Option<Matrix<R, C, S>>,
105 ) -> PassBy<'a, Matrix<R, C, S>> {
106 match dest {
107 Some(dest) => {
108 input
109 .data
110 .as_flattened()
111 .iter()
112 .zip(dest.data.as_flattened_mut().iter_mut())
113 .for_each(|(input, dest)| {
114 *dest = perform_op(*input, *dest, params.method);
115 });
116 }
117 None => {
118 *dest = Some(*input);
119 }
120 }
121
122 dest.as_ref().unwrap().as_by()
123 }
124}
125
126impl<const R: usize, const C: usize, S: Scalar + From<bool>> ApplyInto<Matrix<R, C, S>, Parameters>
128 for S
129{
130 fn apply_into<'a>(
131 input: PassBy<Self>,
132 params: &Parameters,
133 dest: &'a mut Option<Matrix<R, C, S>>,
134 ) -> PassBy<'a, Matrix<R, C, S>> {
135 match dest {
136 Some(dest) => {
137 dest.data.as_flattened_mut().iter_mut().for_each(|dest| {
138 *dest = perform_op(input, *dest, params.method);
139 });
140 }
141 None => {
142 *dest = Some(Matrix::<R, C, S>::from_element(input));
143 }
144 }
145
146 dest.as_ref().unwrap().as_by()
147 }
148}
149
150pub trait Finalize: Pass + Default {
151 fn finalize(method: LogicalMethod, dest: &mut Option<Self>) -> PassBy<'_, Self>;
152}
153
154impl<S: Scalar + One + Sub<Output = S>> Finalize for S {
155 fn finalize(method: LogicalMethod, dest: &mut Option<Self>) -> PassBy<'_, Self> {
156 let input = dest.get_or_insert(S::default());
157 let res = match method {
158 LogicalMethod::Nor => S::one() - *input,
159 LogicalMethod::Nand => S::one() - *input,
160 _ => *input,
161 };
162
163 *dest = Some(res);
164 dest.as_ref().unwrap().as_by()
165 }
166}
167
168impl<const R: usize, const C: usize, S: Scalar + One + Sub<Output = S>> Finalize
169 for Matrix<R, C, S>
170{
171 fn finalize(method: LogicalMethod, dest: &mut Option<Self>) -> PassBy<'_, Self> {
172 let dest = dest.get_or_insert(Matrix::<R, C, S>::default());
173 dest.data.as_flattened_mut().iter_mut().for_each(|dest| {
174 *dest = match method {
175 LogicalMethod::Nor => S::one() - *dest,
176 LogicalMethod::Nand => S::one() - *dest,
177 _ => *dest,
178 };
179 });
180
181 dest.as_by()
182 }
183}
184
185#[derive(Debug, Clone, Copy, strum::EnumString)]
186pub enum LogicalMethod {
188 And,
190 Or,
192 Nor,
194 Nand,
196}
197
198pub struct Parameters {
200 method: LogicalMethod,
202}
203
204impl Parameters {
205 pub fn new(method: &str) -> Self {
206 Self {
207 method: method.parse().expect("Failed to parse logical method."),
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use crate::testing::StubContext;
215
216 use super::*;
217
218 #[test]
219 fn test_logical_and_scalar() {
220 let ctxt = StubContext::default();
221 let params = Parameters::new("And");
222 let mut block = LogicalBlock::<(f64, f64, f64)>::default();
223
224 let res = block.process(¶ms, &ctxt, (0.0, 0.0, 0.0));
226 assert_eq!(res, 0.0);
227 assert_eq!(block.data.scalar(), 0.0);
228
229 let res = block.process(¶ms, &ctxt, (1.0, 0.0, 1.0));
231 assert_eq!(res, 0.0);
232 assert_eq!(block.data.scalar(), 0.0);
233
234 let res = block.process(¶ms, &ctxt, (1.0, 1.0, 1.0));
236 assert_eq!(res, 1.0);
237 assert_eq!(block.data.scalar(), 1.0);
238
239 let res = block.process(¶ms, &ctxt, (1.0, -2.0, 3.5));
241 assert_eq!(res, 1.0);
242 assert_eq!(block.data.scalar(), 1.0);
243 }
244
245 #[test]
246 fn test_logical_or_scalar() {
247 let ctxt = StubContext::default();
248 let params = Parameters::new("Or");
249 let mut block = LogicalBlock::<(f64, f64, f64)>::default();
250
251 let res = block.process(¶ms, &ctxt, (0.0, 0.0, 0.0));
253 assert_eq!(res, 0.0);
254 assert_eq!(block.data.scalar(), 0.0);
255
256 let res = block.process(¶ms, &ctxt, (1.0, 0.0, 1.0));
258 assert_eq!(res, 1.0);
259 assert_eq!(block.data.scalar(), 1.0);
260
261 let res = block.process(¶ms, &ctxt, (1.0, 1.0, 1.0));
263 assert_eq!(res, 1.0);
264 assert_eq!(block.data.scalar(), 1.0);
265
266 let res = block.process(¶ms, &ctxt, (1.0, -2.0, 3.5));
268 assert_eq!(res, 1.0);
269 assert_eq!(block.data.scalar(), 1.0);
270 }
271
272 #[test]
273 fn test_logical_nor_scalar() {
274 let ctxt = StubContext::default();
275 let params = Parameters::new("Nor");
276 let mut block = LogicalBlock::<(f64, f64, f64)>::default();
277
278 let res = block.process(¶ms, &ctxt, (0.0, 0.0, 0.0));
282 assert_eq!(res, 1.0);
283 assert_eq!(block.data.scalar(), 1.0);
284
285 let res = block.process(¶ms, &ctxt, (1.0, 0.0, 1.0));
287 assert_eq!(res, 0.0);
288 assert_eq!(block.data.scalar(), 0.0);
289
290 let res = block.process(¶ms, &ctxt, (1.0, 1.0, 1.0));
292 assert_eq!(res, 0.0);
293 assert_eq!(block.data.scalar(), 0.0);
294
295 let res = block.process(¶ms, &ctxt, (1.0, -2.0, 3.5));
297 assert_eq!(res, 0.0);
298 assert_eq!(block.data.scalar(), 0.0);
299 assert_eq!(block.data.scalar(), 0.0);
300 }
301
302 #[test]
303 fn test_logical_nand_scalar() {
304 let ctxt = StubContext::default();
305 let params = Parameters::new("Nand");
306 let mut block = LogicalBlock::<(f64, f64, f64)>::default();
307
308 let res = block.process(¶ms, &ctxt, (0.0, 0.0, 0.0));
312 assert_eq!(res, 1.0);
313 assert_eq!(block.data.scalar(), 1.0);
314
315 let res = block.process(¶ms, &ctxt, (1.0, 0.0, 1.0));
317 assert_eq!(res, 1.0);
318 assert_eq!(block.data.scalar(), 1.0);
319
320 let res = block.process(¶ms, &ctxt, (1.0, 1.0, 1.0));
322 assert_eq!(res, 0.0);
323 assert_eq!(block.data.scalar(), 0.0);
324
325 let res = block.process(¶ms, &ctxt, (1.0, -2.0, 3.5));
327 assert_eq!(res, 0.0);
328 assert_eq!(block.data.scalar(), 0.0);
329 }
330
331 #[test]
332 fn test_matrix_ops() {
333 let ctxt = StubContext::default();
334 let mut params = Parameters::new("And");
335 let mut block =
336 LogicalBlock::<(Matrix<2, 2, f64>, Matrix<2, 2, f64>, Matrix<2, 2, f64>)>::default();
337
338 let input = (
339 &Matrix {
340 data: [[1.0, 0.0], [0.0, 1.0]],
341 },
342 &Matrix {
343 data: [[0.0, 1.0], [1.0, 0.0]],
344 },
345 &Matrix {
346 data: [[1.0, 1.0], [1.0, 1.0]],
347 },
348 );
349
350 let res = block.process(¶ms, &ctxt, input);
351 let expected = Matrix {
352 data: [[0.0, 0.0], [0.0, 0.0]],
353 };
354 assert_eq!(res, &expected);
355 assert_eq!(
356 block.data.get_data().as_slice(),
357 expected.data.as_flattened()
358 );
359
360 params.method = LogicalMethod::Or;
361 let res = block.process(¶ms, &ctxt, input);
362 let expected = Matrix {
363 data: [[1.0, 1.0], [1.0, 1.0]],
364 };
365 assert_eq!(res, &expected);
366 assert_eq!(
367 block.data.get_data().as_slice(),
368 expected.data.as_flattened()
369 );
370
371 params.method = LogicalMethod::Nor;
372 let res = block.process(¶ms, &ctxt, input);
373 let expected = Matrix {
374 data: [[0.0, 0.0], [0.0, 0.0]],
375 };
376 assert_eq!(res, &expected);
377 assert_eq!(
378 block.data.get_data().as_slice(),
379 expected.data.as_flattened()
380 );
381
382 params.method = LogicalMethod::Nand;
383 let res = block.process(¶ms, &ctxt, input);
384 let expected = Matrix {
385 data: [[1.0, 1.0], [1.0, 1.0]],
386 };
387 assert_eq!(res, &expected);
388 assert_eq!(
389 block.data.get_data().as_slice(),
390 expected.data.as_flattened()
391 );
392 }
393
394 #[test]
395 fn test_matrix_scalar_ops() {
396 let ctxt = StubContext::default();
397 let mut params = Parameters::new("And");
398 let mut block = LogicalBlock::<(Matrix<2, 2, f64>, f64)>::default();
399
400 let input = (
401 &Matrix {
402 data: [[1.0, 0.0], [0.0, 1.0]],
403 },
404 1.0,
405 );
406
407 let res = block.process(¶ms, &ctxt, input);
408 let expected = Matrix {
409 data: [[1.0, 0.0], [0.0, 1.0]],
410 };
411 assert_eq!(res, &expected);
412 assert_eq!(
413 block.data.get_data().as_slice(),
414 expected.data.as_flattened()
415 );
416
417 params.method = LogicalMethod::Or;
418 let res = block.process(¶ms, &ctxt, input);
419 let expected = Matrix {
420 data: [[1.0, 1.0], [1.0, 1.0]],
421 };
422 assert_eq!(res, &expected);
423 assert_eq!(
424 block.data.get_data().as_slice(),
425 expected.data.as_flattened()
426 );
427
428 params.method = LogicalMethod::Nor;
429 let res = block.process(¶ms, &ctxt, input);
430 let expected = Matrix {
431 data: [[0.0, 0.0], [0.0, 0.0]],
432 };
433 assert_eq!(res, &expected);
434 assert_eq!(
435 block.data.get_data().as_slice(),
436 expected.data.as_flattened()
437 );
438
439 params.method = LogicalMethod::Nand;
440 let res = block.process(¶ms, &ctxt, input);
441 let expected = Matrix {
442 data: [[0.0, 1.0], [1.0, 0.0]],
443 };
444 assert_eq!(res, &expected);
445 assert_eq!(
446 block.data.get_data().as_slice(),
447 expected.data.as_flattened()
448 );
449 }
450}