1use crate::nalgebra_interop::MatrixExt;
2use pictorus_block_data::{BlockData as OldBlockData, FromPass};
3use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock, Scalar};
4
5pub struct SumBlock<T: Summable>
7where
8 pictorus_block_data::BlockData: FromPass<<T as Summable>::Output>,
9{
10 store: Option<T::Output>,
11 pub data: OldBlockData,
12}
13
14impl<T: Summable> Default for SumBlock<T>
15where
16 pictorus_block_data::BlockData: FromPass<<T as Summable>::Output>,
17{
18 fn default() -> Self {
19 Self {
20 store: None,
21 data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::default().as_by()),
22 }
23 }
24}
25
26impl<T> ProcessBlock for SumBlock<T>
27where
28 T: Summable,
29 OldBlockData: FromPass<T::Output>,
30{
31 type Inputs = T;
32 type Output = T::Output;
33 type Parameters = T::Parameters;
34
35 fn process(
36 &mut self,
37 parameters: &Self::Parameters,
38 _context: &dyn pictorus_traits::Context,
39 input: PassBy<Self::Inputs>,
40 ) -> PassBy<Self::Output> {
41 self.store = None;
42 let result = T::get_sum(input, *parameters, &mut self.store);
43 self.data = OldBlockData::from_pass(result);
44 result
45 }
46}
47
48trait SumScalar:
49 Scalar
50 + nalgebra::Scalar
51 + core::ops::Neg<Output = Self>
52 + core::ops::Add<Output = Self>
53 + core::ops::Sub<Output = Self>
54 + core::ops::AddAssign
55 + core::ops::SubAssign
56{
57}
58impl SumScalar for f32 {}
59impl SumScalar for f64 {}
60
61pub trait TypePromotion<RHS> {
67 type Output: Pass + Default;
68}
69
70impl<S: SumScalar> TypePromotion<S> for S {
72 type Output = S;
73}
74
75impl<const R: usize, const C: usize, S: SumScalar> TypePromotion<S> for Matrix<R, C, S> {
77 type Output = Matrix<R, C, S>;
78}
79
80impl<const R: usize, const C: usize, S: SumScalar> TypePromotion<Matrix<R, C, S>> for S {
82 type Output = Matrix<R, C, S>;
83}
84
85impl<const R: usize, const C: usize, S: SumScalar> TypePromotion<Matrix<R, C, S>>
87 for Matrix<R, C, S>
88{
89 type Output = Matrix<R, C, S>;
90}
91
92impl<A, B, C> TypePromotion<(B, C)> for A
94where
95 B: TypePromotion<C>,
96 A: TypePromotion<<B as TypePromotion<C>>::Output>,
97{
98 type Output = <A as TypePromotion<B::Output>>::Output;
99}
100
101impl<A, B, C, D> TypePromotion<(B, C, D)> for A
103where
104 B: TypePromotion<(C, D)>,
105 A: TypePromotion<B::Output>,
106{
107 type Output = <A as TypePromotion<B::Output>>::Output;
108}
109
110impl<A, B, C, D, E> TypePromotion<(B, C, D, E)> for A
112where
113 B: TypePromotion<(C, D, E)>,
114 A: TypePromotion<B::Output>,
115{
116 type Output = <A as TypePromotion<B::Output>>::Output;
117}
118
119impl<A, B, C, D, E, F> TypePromotion<(B, C, D, E, F)> for A
121where
122 B: TypePromotion<(C, D, E, F)>,
123 A: TypePromotion<B::Output>,
124{
125 type Output = <A as TypePromotion<B::Output>>::Output;
126}
127
128impl<A, B, C, D, E, F, G> TypePromotion<(B, C, D, E, F, G)> for A
130where
131 B: TypePromotion<(C, D, E, F, G)>,
132 A: TypePromotion<B::Output>,
133{
134 type Output = <A as TypePromotion<B::Output>>::Output;
135}
136
137impl<A, B, C, D, E, F, G, H> TypePromotion<(B, C, D, E, F, G, H)> for A
139where
140 B: TypePromotion<(C, D, E, F, G, H)>,
141 A: TypePromotion<B::Output>,
142{
143 type Output = <A as TypePromotion<B::Output>>::Output;
144}
145
146pub trait SumInto<DEST: Pass>: Pass {
150 fn sum_into<'a>(
151 input: PassBy<Self>,
152 sum_type: SumType,
153 dest: &'a mut Option<DEST>,
154 ) -> PassBy<'a, DEST>;
155}
156
157impl<S: SumScalar> SumInto<S> for S {
159 fn sum_into<'a>(
160 input: PassBy<Self>,
161 sum_type: SumType,
162 dest: &'a mut Option<S>,
163 ) -> PassBy<'a, S> {
164 let dest = dest.get_or_insert(S::default());
165 match sum_type {
166 SumType::Addition => {
167 *dest += input;
168 }
169 SumType::Subtraction => {
170 *dest -= input;
171 }
172 }
173 *dest
174 }
175}
176
177impl<const R: usize, const C: usize, S: SumScalar> SumInto<Matrix<R, C, S>> for Matrix<R, C, S> {
179 fn sum_into<'a>(
180 input: PassBy<Self>,
181 sum_type: SumType,
182 dest: &'a mut Option<Matrix<R, C, S>>,
183 ) -> PassBy<'a, Matrix<R, C, S>> {
184 let dest = dest.get_or_insert(Matrix::<R, C, S>::zeroed());
185 let orig_dest = dest.as_view().clone_owned();
186 match sum_type {
187 SumType::Addition => {
188 orig_dest.add_to(&input.as_view(), &mut dest.as_view_mut());
189 }
190 SumType::Subtraction => {
191 orig_dest.sub_to(&input.as_view(), &mut dest.as_view_mut());
192 }
193 }
194 dest
195 }
196}
197
198impl<const R: usize, const C: usize, S: SumScalar> SumInto<Matrix<R, C, S>> for S {
200 fn sum_into<'a>(
201 input: PassBy<Self>,
202 sum_type: SumType,
203 dest: &'a mut Option<Matrix<R, C, S>>,
204 ) -> PassBy<'a, Matrix<R, C, S>> {
205 let dest = dest.get_or_insert(Matrix::<R, C, S>::zeroed());
206 let mut orig_dest = dest.as_view().clone_owned();
207 match sum_type {
208 SumType::Addition => {
209 orig_dest = orig_dest.add_scalar(input);
210 }
211 SumType::Subtraction => {
212 orig_dest = orig_dest.add_scalar(-input);
213 }
214 }
215 dest.as_view_mut().copy_from(&orig_dest);
216 dest
217 }
218}
219
220pub trait Summable: Pass {
222 type Output: Pass + Default;
223 type Parameters: Copy;
224
225 fn get_sum<'a>(
226 input: PassBy<Self>,
227 parameters: Self::Parameters,
228 dest: &'a mut Option<Self::Output>,
229 ) -> PassBy<'a, Self::Output>;
230}
231
232impl<S: SumScalar> Summable for S {
234 type Output = S;
235 type Parameters = Parameters<1>;
236
237 fn get_sum<'a>(
238 input: PassBy<Self>,
239 parameters: Self::Parameters,
240 dest: &'a mut Option<Self::Output>,
241 ) -> PassBy<'a, Self::Output> {
242 Self::sum_into(input, parameters.operations[0], dest);
243 dest.unwrap()
244 }
245}
246
247impl<const R: usize, const C: usize, S: SumScalar> Summable for Matrix<R, C, S> {
249 type Output = Matrix<R, C, S>;
250 type Parameters = Parameters<1>;
251
252 fn get_sum<'a>(
253 input: PassBy<Self>,
254 parameters: Self::Parameters,
255 dest: &'a mut Option<Self::Output>,
256 ) -> PassBy<'a, Self::Output> {
257 Self::sum_into(input, parameters.operations[0], dest);
258 dest.as_ref().unwrap()
259 }
260}
261
262impl<A, B> Summable for (A, B)
263where
264 A: TypePromotion<B>,
265 A: SumInto<A::Output>,
266 B: SumInto<A::Output>,
267{
268 type Output = A::Output;
269 type Parameters = Parameters<2>;
270
271 fn get_sum<'a>(
272 input: PassBy<Self>,
273 parameters: Self::Parameters,
274 dest: &'a mut Option<Self::Output>,
275 ) -> PassBy<'a, Self::Output> {
276 let (a, b) = input;
277 A::sum_into(a, parameters.operations[0], dest);
278 B::sum_into(b, parameters.operations[1], dest)
279 }
280}
281
282impl<A, B, C> Summable for (A, B, C)
283where
284 A: TypePromotion<(B, C)>,
285 A: SumInto<A::Output>,
286 B: SumInto<A::Output>,
287 C: SumInto<A::Output>,
288{
289 type Output = A::Output;
290 type Parameters = Parameters<3>;
291
292 fn get_sum<'a>(
293 input: PassBy<Self>,
294 parameters: Self::Parameters,
295 dest: &'a mut Option<Self::Output>,
296 ) -> PassBy<'a, Self::Output> {
297 let (a, b, c) = input;
298 A::sum_into(a, parameters.operations[0], dest);
299 B::sum_into(b, parameters.operations[1], dest);
300 C::sum_into(c, parameters.operations[2], dest)
301 }
302}
303
304impl<A, B, C, D> Summable for (A, B, C, D)
305where
306 A: TypePromotion<(B, C, D)>,
307 A: SumInto<A::Output>,
308 B: SumInto<A::Output>,
309 C: SumInto<A::Output>,
310 D: SumInto<A::Output>,
311{
312 type Output = A::Output;
313 type Parameters = Parameters<4>;
314
315 fn get_sum<'a>(
316 input: PassBy<Self>,
317 parameters: Self::Parameters,
318 dest: &'a mut Option<Self::Output>,
319 ) -> PassBy<'a, Self::Output> {
320 let (a, b, c, d) = input;
321 A::sum_into(a, parameters.operations[0], dest);
322 B::sum_into(b, parameters.operations[1], dest);
323 C::sum_into(c, parameters.operations[2], dest);
324 D::sum_into(d, parameters.operations[3], dest)
325 }
326}
327
328impl<A, B, C, D, E> Summable for (A, B, C, D, E)
329where
330 A: TypePromotion<(B, C, D, E)>,
331 A: SumInto<A::Output>,
332 B: SumInto<A::Output>,
333 C: SumInto<A::Output>,
334 D: SumInto<A::Output>,
335 E: SumInto<A::Output>,
336{
337 type Output = A::Output;
338 type Parameters = Parameters<5>;
339
340 fn get_sum<'a>(
341 input: PassBy<Self>,
342 parameters: Self::Parameters,
343 dest: &'a mut Option<Self::Output>,
344 ) -> PassBy<'a, Self::Output> {
345 let (a, b, c, d, e) = input;
346 A::sum_into(a, parameters.operations[0], dest);
347 B::sum_into(b, parameters.operations[1], dest);
348 C::sum_into(c, parameters.operations[2], dest);
349 D::sum_into(d, parameters.operations[3], dest);
350 E::sum_into(e, parameters.operations[4], dest)
351 }
352}
353
354impl<A, B, C, D, E, F> Summable for (A, B, C, D, E, F)
355where
356 A: TypePromotion<(B, C, D, E, F)>,
357 A: SumInto<A::Output>,
358 B: SumInto<A::Output>,
359 C: SumInto<A::Output>,
360 D: SumInto<A::Output>,
361 E: SumInto<A::Output>,
362 F: SumInto<A::Output>,
363{
364 type Output = A::Output;
365 type Parameters = Parameters<6>;
366
367 fn get_sum<'a>(
368 input: PassBy<Self>,
369 parameters: Self::Parameters,
370 dest: &'a mut Option<Self::Output>,
371 ) -> PassBy<'a, Self::Output> {
372 let (a, b, c, d, e, f) = input;
373 A::sum_into(a, parameters.operations[0], dest);
374 B::sum_into(b, parameters.operations[1], dest);
375 C::sum_into(c, parameters.operations[2], dest);
376 D::sum_into(d, parameters.operations[3], dest);
377 E::sum_into(e, parameters.operations[4], dest);
378 F::sum_into(f, parameters.operations[5], dest)
379 }
380}
381
382impl<A, B, C, D, E, F, G> Summable for (A, B, C, D, E, F, G)
383where
384 A: TypePromotion<(B, C, D, E, F, G)>,
385 A: SumInto<A::Output>,
386 B: SumInto<A::Output>,
387 C: SumInto<A::Output>,
388 D: SumInto<A::Output>,
389 E: SumInto<A::Output>,
390 F: SumInto<A::Output>,
391 G: SumInto<A::Output>,
392{
393 type Output = A::Output;
394 type Parameters = Parameters<7>;
395
396 fn get_sum<'a>(
397 input: PassBy<Self>,
398 parameters: Self::Parameters,
399 dest: &'a mut Option<Self::Output>,
400 ) -> PassBy<'a, Self::Output> {
401 let (a, b, c, d, e, f, g) = input;
402 A::sum_into(a, parameters.operations[0], dest);
403 B::sum_into(b, parameters.operations[1], dest);
404 C::sum_into(c, parameters.operations[2], dest);
405 D::sum_into(d, parameters.operations[3], dest);
406 E::sum_into(e, parameters.operations[4], dest);
407 F::sum_into(f, parameters.operations[5], dest);
408 G::sum_into(g, parameters.operations[6], dest)
409 }
410}
411
412impl<A, B, C, D, E, F, G, H> Summable for (A, B, C, D, E, F, G, H)
413where
414 A: TypePromotion<(B, C, D, E, F, G, H)>,
415 A: SumInto<A::Output>,
416 B: SumInto<A::Output>,
417 C: SumInto<A::Output>,
418 D: SumInto<A::Output>,
419 E: SumInto<A::Output>,
420 F: SumInto<A::Output>,
421 G: SumInto<A::Output>,
422 H: SumInto<A::Output>,
423{
424 type Output = A::Output;
425 type Parameters = Parameters<8>;
426
427 fn get_sum<'a>(
428 input: PassBy<Self>,
429 parameters: Self::Parameters,
430 dest: &'a mut Option<Self::Output>,
431 ) -> PassBy<'a, Self::Output> {
432 let (a, b, c, d, e, f, g, h) = input;
433 A::sum_into(a, parameters.operations[0], dest);
434 B::sum_into(b, parameters.operations[1], dest);
435 C::sum_into(c, parameters.operations[2], dest);
436 D::sum_into(d, parameters.operations[3], dest);
437 E::sum_into(e, parameters.operations[4], dest);
438 F::sum_into(f, parameters.operations[5], dest);
439 G::sum_into(g, parameters.operations[6], dest);
440 H::sum_into(h, parameters.operations[7], dest)
441 }
442}
443
444#[derive(Clone, Copy, Debug, PartialEq)]
446pub enum SumType {
447 Addition,
448 Subtraction,
449}
450
451#[derive(Clone, Copy, Debug, PartialEq)]
453pub struct Parameters<const NUM_INPUTS: usize> {
454 pub operations: [SumType; NUM_INPUTS],
455}
456
457impl<const NUM_INPUTS: usize> Parameters<NUM_INPUTS> {
458 pub fn new(input: [f64; NUM_INPUTS]) -> Self {
461 let mut operations = [SumType::Addition; NUM_INPUTS];
462 for (i, &val) in input.iter().enumerate() {
463 if val < 0.0 {
464 operations[i] = SumType::Subtraction;
465 }
466 }
467 Self { operations }
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use crate::testing::StubContext;
475 use approx::assert_relative_eq;
476
477 #[test]
478 fn test_one_scalar() {
479 let mut block = SumBlock::<f64>::default();
480 let input = 3.0;
481 let stub_context = StubContext::default();
482 let parameters = Parameters {
483 operations: [SumType::Addition],
484 };
485 let result = block.process(¶meters, &stub_context, input);
486 assert_relative_eq!(result, 3.0);
487 }
488
489 #[test]
490 fn test_one_matrix() {
491 let mut block = SumBlock::<Matrix<2, 2, f64>>::default();
492 let input = Matrix {
493 data: [[1.0, 2.0], [3.0, 4.0]],
494 };
495 let stub_context = StubContext::default();
496 let parameters = Parameters {
497 operations: [SumType::Addition],
498 };
499 let result = block.process(¶meters, &stub_context, &input);
500 assert_relative_eq!(
501 result.data.as_flattened(),
502 [[1.0, 2.0], [3.0, 4.0]].as_flattened()
503 );
504 }
505
506 #[test]
507 fn test_multiple_scalars() {
508 let stub_context = StubContext::default();
509
510 let mut two_block = SumBlock::<(f64, f64)>::default();
512 let input = (3.0, 4.0);
513 let parameters = Parameters {
514 operations: [SumType::Addition, SumType::Addition],
515 };
516 let result = two_block.process(¶meters, &stub_context, input);
517 assert_relative_eq!(result, 7.0);
518
519 let parameters = Parameters {
520 operations: [SumType::Addition, SumType::Subtraction],
521 };
522 let result = two_block.process(¶meters, &stub_context, input);
523 assert_relative_eq!(result, -1.0);
524
525 let parameters = Parameters {
526 operations: [SumType::Subtraction, SumType::Addition],
527 };
528 let result = two_block.process(¶meters, &stub_context, input);
529 assert_relative_eq!(result, 1.0);
530
531 let parameters = Parameters {
532 operations: [SumType::Subtraction, SumType::Subtraction],
533 };
534 let result = two_block.process(¶meters, &stub_context, input);
535 assert_relative_eq!(result, -7.0);
536
537 let mut three_block = SumBlock::<(f64, f64, f64)>::default();
539 let input = (3.0, 4.0, 5.0);
540 let parameters = Parameters {
541 operations: [SumType::Addition, SumType::Addition, SumType::Addition],
542 };
543 let result = three_block.process(¶meters, &stub_context, input);
544 assert_relative_eq!(result, 12.0);
545
546 let parameters = Parameters {
547 operations: [SumType::Addition, SumType::Addition, SumType::Subtraction],
548 };
549 let result = three_block.process(¶meters, &stub_context, input);
550 assert_relative_eq!(result, 2.0);
551
552 let mut four_block = SumBlock::<(f64, f64, f64, f64)>::default();
554 let input = (3.0, 4.0, 5.0, 6.0);
555 let parameters = Parameters {
556 operations: [
557 SumType::Addition,
558 SumType::Addition,
559 SumType::Addition,
560 SumType::Addition,
561 ],
562 };
563 let result = four_block.process(¶meters, &stub_context, input);
564 assert_relative_eq!(result, 18.0);
565
566 let mut five_block = SumBlock::<(f64, f64, f64, f64, f64)>::default();
568 let input = (3.0, 4.0, 5.0, 6.0, 7.0);
569 let parameters = Parameters {
570 operations: [
571 SumType::Addition,
572 SumType::Addition,
573 SumType::Addition,
574 SumType::Addition,
575 SumType::Addition,
576 ],
577 };
578 let result = five_block.process(¶meters, &stub_context, input);
579 assert_relative_eq!(result, 25.0);
580
581 let mut six_block = SumBlock::<(f64, f64, f64, f64, f64, f64)>::default();
583 let input = (3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
584 let parameters = Parameters {
585 operations: [
586 SumType::Addition,
587 SumType::Addition,
588 SumType::Addition,
589 SumType::Addition,
590 SumType::Addition,
591 SumType::Addition,
592 ],
593 };
594 let result = six_block.process(¶meters, &stub_context, input);
595 assert_relative_eq!(result, 33.0);
596
597 let mut seven_block = SumBlock::<(f64, f64, f64, f64, f64, f64, f64)>::default();
599 let input = (3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
600 let parameters = Parameters {
601 operations: [
602 SumType::Addition,
603 SumType::Addition,
604 SumType::Addition,
605 SumType::Addition,
606 SumType::Addition,
607 SumType::Addition,
608 SumType::Addition,
609 ],
610 };
611 let result = seven_block.process(¶meters, &stub_context, input);
612 assert_relative_eq!(result, 42.0);
613
614 let mut eight_block = SumBlock::<(f64, f64, f64, f64, f64, f64, f64, f64)>::default();
616 let input = (3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0);
617 let parameters = Parameters {
618 operations: [
619 SumType::Addition,
620 SumType::Addition,
621 SumType::Addition,
622 SumType::Addition,
623 SumType::Addition,
624 SumType::Addition,
625 SumType::Addition,
626 SumType::Addition,
627 ],
628 };
629 let result = eight_block.process(¶meters, &stub_context, input);
630 assert_relative_eq!(result, 52.0);
631 }
632
633 #[test]
634 fn test_multiple_matrices() {
635 let stub_context = StubContext::default();
636
637 let mut two_block = SumBlock::<(Matrix<2, 2, f64>, Matrix<2, 2, f64>)>::default();
639 let input = (
640 &Matrix {
641 data: [[1.0, 2.0], [3.0, 4.0]],
642 },
643 &Matrix {
644 data: [[5.0, 6.0], [7.0, 8.0]],
645 },
646 );
647 let parameters = Parameters {
648 operations: [SumType::Addition, SumType::Addition],
649 };
650 let result = two_block.process(¶meters, &stub_context, input);
651 assert_relative_eq!(
652 result.data.as_flattened(),
653 [[6.0, 8.0], [10.0, 12.0]].as_flattened()
654 );
655
656 let parameters = Parameters {
657 operations: [SumType::Addition, SumType::Subtraction],
658 };
659 let result = two_block.process(¶meters, &stub_context, input);
660 assert_relative_eq!(
661 result.data.as_flattened(),
662 [[-4.0, -4.0], [-4.0, -4.0]].as_flattened()
663 );
664
665 let parameters = Parameters {
666 operations: [SumType::Subtraction, SumType::Addition],
667 };
668 let result = two_block.process(¶meters, &stub_context, input);
669 assert_relative_eq!(
670 result.data.as_flattened(),
671 [[4.0, 4.0], [4.0, 4.0]].as_flattened()
672 );
673
674 let parameters = Parameters {
675 operations: [SumType::Subtraction, SumType::Subtraction],
676 };
677 let result = two_block.process(¶meters, &stub_context, input);
678 assert_relative_eq!(
679 result.data.as_flattened(),
680 [[-6.0, -8.0], [-10.0, -12.0]].as_flattened()
681 );
682
683 let mut three_block =
685 SumBlock::<(Matrix<2, 2, f64>, Matrix<2, 2, f64>, Matrix<2, 2, f64>)>::default();
686 let input = (
687 &Matrix {
688 data: [[1.0, 2.0], [3.0, 4.0]],
689 },
690 &Matrix {
691 data: [[5.0, 6.0], [7.0, 8.0]],
692 },
693 &Matrix {
694 data: [[9.0, 10.0], [11.0, 12.0]],
695 },
696 );
697 let parameters = Parameters {
698 operations: [SumType::Addition, SumType::Addition, SumType::Addition],
699 };
700 let result = three_block.process(¶meters, &stub_context, input);
701 assert_relative_eq!(
702 result.data.as_flattened(),
703 [[15.0, 18.0], [21.0, 24.0]].as_flattened()
704 );
705
706 let mut four_block = SumBlock::<(
708 Matrix<2, 2, f64>,
709 Matrix<2, 2, f64>,
710 Matrix<2, 2, f64>,
711 Matrix<2, 2, f64>,
712 )>::default();
713 let input = (
714 &Matrix {
715 data: [[1.0, 2.0], [3.0, 4.0]],
716 },
717 &Matrix {
718 data: [[5.0, 6.0], [7.0, 8.0]],
719 },
720 &Matrix {
721 data: [[9.0, 10.0], [11.0, 12.0]],
722 },
723 &Matrix {
724 data: [[13.0, 14.0], [15.0, 16.0]],
725 },
726 );
727 let parameters = Parameters {
728 operations: [
729 SumType::Addition,
730 SumType::Addition,
731 SumType::Addition,
732 SumType::Addition,
733 ],
734 };
735 let result = four_block.process(¶meters, &stub_context, input);
736 assert_relative_eq!(
737 result.data.as_flattened(),
738 [[28.0, 32.0], [36.0, 40.0]].as_flattened()
739 );
740 }
741
742 #[test]
743 fn test_mixed_scalars_and_matrices() {
744 let stub_context = StubContext::default();
745
746 let mut two_block = SumBlock::<(f64, Matrix<2, 2, f64>)>::default();
748 let input = (
749 3.0,
750 &Matrix {
751 data: [[1.0, 2.0], [3.0, 4.0]],
752 },
753 );
754 let parameters = Parameters {
755 operations: [SumType::Addition, SumType::Addition],
756 };
757 let result = two_block.process(¶meters, &stub_context, input);
758 assert_relative_eq!(
759 result.data.as_flattened(),
760 [[4.0, 5.0], [6.0, 7.0]].as_flattened()
761 );
762
763 let mut three_block_1 = SumBlock::<(f64, Matrix<2, 2, f64>, f64)>::default();
765 let input = (
766 3.0,
767 &Matrix {
768 data: [[1.0, 2.0], [3.0, 4.0]],
769 },
770 5.0,
771 );
772 let parameters = Parameters {
773 operations: [SumType::Addition, SumType::Addition, SumType::Addition],
774 };
775 let result = three_block_1.process(¶meters, &stub_context, input);
776 assert_relative_eq!(
777 result.data.as_flattened(),
778 [[9.0, 10.0], [11.0, 12.0]].as_flattened()
779 );
780
781 let mut three_block_2 = SumBlock::<(Matrix<2, 2, f64>, f64, Matrix<2, 2, f64>)>::default();
782 let input = (
783 &Matrix {
784 data: [[1.0, 2.0], [3.0, 4.0]],
785 },
786 5.0,
787 &Matrix {
788 data: [[5.0, 6.0], [7.0, 8.0]],
789 },
790 );
791 let parameters = Parameters {
792 operations: [SumType::Addition, SumType::Addition, SumType::Addition],
793 };
794 let result = three_block_2.process(¶meters, &stub_context, input);
795 assert_relative_eq!(
796 result.data.as_flattened(),
797 [[11.0, 13.0], [15.0, 17.0]].as_flattened()
798 );
799 }
800}