Skip to main content

pictorus_blocks/core_blocks/
sum_block.rs

1use crate::nalgebra_interop::MatrixExt;
2use pictorus_block_data::{BlockData as OldBlockData, FromPass};
3use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock, Scalar};
4
5/// Sums (adds or subtracts) all inputs together.
6pub struct SumBlock<T: Summable>
7where
8    pictorus_block_data::BlockData: FromPass<<T as Summable>::Output>,
9{
10    store: Option<T::Output>,
11    pub data: OldBlockData,
12}
13
14impl<T: Summable> Default for SumBlock<T>
15where
16    pictorus_block_data::BlockData: FromPass<<T as Summable>::Output>,
17{
18    fn default() -> Self {
19        Self {
20            store: None,
21            data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::default().as_by()),
22        }
23    }
24}
25
26impl<T> ProcessBlock for SumBlock<T>
27where
28    T: Summable,
29    OldBlockData: FromPass<T::Output>,
30{
31    type Inputs = T;
32    type Output = T::Output;
33    type Parameters = T::Parameters;
34
35    fn process(
36        &mut self,
37        parameters: &Self::Parameters,
38        _context: &dyn pictorus_traits::Context,
39        input: PassBy<Self::Inputs>,
40    ) -> PassBy<Self::Output> {
41        self.store = None;
42        let result = T::get_sum(input, *parameters, &mut self.store);
43        self.data = OldBlockData::from_pass(result);
44        result
45    }
46}
47
48trait SumScalar:
49    Scalar
50    + nalgebra::Scalar
51    + core::ops::Neg<Output = Self>
52    + core::ops::Add<Output = Self>
53    + core::ops::Sub<Output = Self>
54    + core::ops::AddAssign
55    + core::ops::SubAssign
56{
57}
58impl SumScalar for f32 {}
59impl SumScalar for f64 {}
60
61/// This trait is used to determine the output type of a sum operation
62/// between two types, most importantly it can be used recursively. To get the output type for
63/// a tuple of inputs. For an input of all scalars the output is scalar. For all inputs being a
64/// single size of matrix, or a mix of scalars and a single size of matrix the output is a matrix
65/// of that size.
66pub trait TypePromotion<RHS> {
67    type Output: Pass + Default;
68}
69
70/// A Scalar and a scalar outputs a scalar
71impl<S: SumScalar> TypePromotion<S> for S {
72    type Output = S;
73}
74
75/// A Scalar and a Matrix outputs a Matrix
76impl<const R: usize, const C: usize, S: SumScalar> TypePromotion<S> for Matrix<R, C, S> {
77    type Output = Matrix<R, C, S>;
78}
79
80/// A Matrix and a Scalar outputs a Matrix
81impl<const R: usize, const C: usize, S: SumScalar> TypePromotion<Matrix<R, C, S>> for S {
82    type Output = Matrix<R, C, S>;
83}
84
85/// A Matrix and a Matrix outputs a Matrix
86impl<const R: usize, const C: usize, S: SumScalar> TypePromotion<Matrix<R, C, S>>
87    for Matrix<R, C, S>
88{
89    type Output = Matrix<R, C, S>;
90}
91
92/// Recursive definition for 3 inputs
93impl<A, B, C> TypePromotion<(B, C)> for A
94where
95    B: TypePromotion<C>,
96    A: TypePromotion<<B as TypePromotion<C>>::Output>,
97{
98    type Output = <A as TypePromotion<B::Output>>::Output;
99}
100
101/// Recursive definition for 4 inputs
102impl<A, B, C, D> TypePromotion<(B, C, D)> for A
103where
104    B: TypePromotion<(C, D)>,
105    A: TypePromotion<B::Output>,
106{
107    type Output = <A as TypePromotion<B::Output>>::Output;
108}
109
110/// Recursive definition for 5 inputs
111impl<A, B, C, D, E> TypePromotion<(B, C, D, E)> for A
112where
113    B: TypePromotion<(C, D, E)>,
114    A: TypePromotion<B::Output>,
115{
116    type Output = <A as TypePromotion<B::Output>>::Output;
117}
118
119/// Recursive definition for 6 inputs
120impl<A, B, C, D, E, F> TypePromotion<(B, C, D, E, F)> for A
121where
122    B: TypePromotion<(C, D, E, F)>,
123    A: TypePromotion<B::Output>,
124{
125    type Output = <A as TypePromotion<B::Output>>::Output;
126}
127
128/// Recursive definition for 7 inputs
129impl<A, B, C, D, E, F, G> TypePromotion<(B, C, D, E, F, G)> for A
130where
131    B: TypePromotion<(C, D, E, F, G)>,
132    A: TypePromotion<B::Output>,
133{
134    type Output = <A as TypePromotion<B::Output>>::Output;
135}
136
137/// Recursive definition for 8 inputs
138impl<A, B, C, D, E, F, G, H> TypePromotion<(B, C, D, E, F, G, H)> for A
139where
140    B: TypePromotion<(C, D, E, F, G, H)>,
141    A: TypePromotion<B::Output>,
142{
143    type Output = <A as TypePromotion<B::Output>>::Output;
144}
145
146/// This trait allow the implementor to be "summed into" a destination type
147/// A matrix can only be summed into a matrix of the same size, a scalar can be summed into
148/// a matrix or another scalar
149pub trait SumInto<DEST: Pass>: Pass {
150    fn sum_into<'a>(
151        input: PassBy<Self>,
152        sum_type: SumType,
153        dest: &'a mut Option<DEST>,
154    ) -> PassBy<'a, DEST>;
155}
156
157/// Scalar summing into a scalar
158impl<S: SumScalar> SumInto<S> for S {
159    fn sum_into<'a>(
160        input: PassBy<Self>,
161        sum_type: SumType,
162        dest: &'a mut Option<S>,
163    ) -> PassBy<'a, S> {
164        let dest = dest.get_or_insert(S::default());
165        match sum_type {
166            SumType::Addition => {
167                *dest += input;
168            }
169            SumType::Subtraction => {
170                *dest -= input;
171            }
172        }
173        *dest
174    }
175}
176
177/// Matrix summing into a matrix
178impl<const R: usize, const C: usize, S: SumScalar> SumInto<Matrix<R, C, S>> for Matrix<R, C, S> {
179    fn sum_into<'a>(
180        input: PassBy<Self>,
181        sum_type: SumType,
182        dest: &'a mut Option<Matrix<R, C, S>>,
183    ) -> PassBy<'a, Matrix<R, C, S>> {
184        let dest = dest.get_or_insert(Matrix::<R, C, S>::zeroed());
185        let orig_dest = dest.as_view().clone_owned();
186        match sum_type {
187            SumType::Addition => {
188                orig_dest.add_to(&input.as_view(), &mut dest.as_view_mut());
189            }
190            SumType::Subtraction => {
191                orig_dest.sub_to(&input.as_view(), &mut dest.as_view_mut());
192            }
193        }
194        dest
195    }
196}
197
198/// Scalar summing into a matrix
199impl<const R: usize, const C: usize, S: SumScalar> SumInto<Matrix<R, C, S>> for S {
200    fn sum_into<'a>(
201        input: PassBy<Self>,
202        sum_type: SumType,
203        dest: &'a mut Option<Matrix<R, C, S>>,
204    ) -> PassBy<'a, Matrix<R, C, S>> {
205        let dest = dest.get_or_insert(Matrix::<R, C, S>::zeroed());
206        let mut orig_dest = dest.as_view().clone_owned();
207        match sum_type {
208            SumType::Addition => {
209                orig_dest = orig_dest.add_scalar(input);
210            }
211            SumType::Subtraction => {
212                orig_dest = orig_dest.add_scalar(-input);
213            }
214        }
215        dest.as_view_mut().copy_from(&orig_dest);
216        dest
217    }
218}
219
220/// This trait makes use of the two above , `SumInto` and `TypePromotion` to sum a tuple of inputs (or a single input)
221pub trait Summable: Pass {
222    type Output: Pass + Default;
223    type Parameters: Copy;
224
225    fn get_sum<'a>(
226        input: PassBy<Self>,
227        parameters: Self::Parameters,
228        dest: &'a mut Option<Self::Output>,
229    ) -> PassBy<'a, Self::Output>;
230}
231
232/// Single scalar input
233impl<S: SumScalar> Summable for S {
234    type Output = S;
235    type Parameters = Parameters<1>;
236
237    fn get_sum<'a>(
238        input: PassBy<Self>,
239        parameters: Self::Parameters,
240        dest: &'a mut Option<Self::Output>,
241    ) -> PassBy<'a, Self::Output> {
242        Self::sum_into(input, parameters.operations[0], dest);
243        dest.unwrap()
244    }
245}
246
247/// Single matrix input
248impl<const R: usize, const C: usize, S: SumScalar> Summable for Matrix<R, C, S> {
249    type Output = Matrix<R, C, S>;
250    type Parameters = Parameters<1>;
251
252    fn get_sum<'a>(
253        input: PassBy<Self>,
254        parameters: Self::Parameters,
255        dest: &'a mut Option<Self::Output>,
256    ) -> PassBy<'a, Self::Output> {
257        Self::sum_into(input, parameters.operations[0], dest);
258        dest.as_ref().unwrap()
259    }
260}
261
262impl<A, B> Summable for (A, B)
263where
264    A: TypePromotion<B>,
265    A: SumInto<A::Output>,
266    B: SumInto<A::Output>,
267{
268    type Output = A::Output;
269    type Parameters = Parameters<2>;
270
271    fn get_sum<'a>(
272        input: PassBy<Self>,
273        parameters: Self::Parameters,
274        dest: &'a mut Option<Self::Output>,
275    ) -> PassBy<'a, Self::Output> {
276        let (a, b) = input;
277        A::sum_into(a, parameters.operations[0], dest);
278        B::sum_into(b, parameters.operations[1], dest)
279    }
280}
281
282impl<A, B, C> Summable for (A, B, C)
283where
284    A: TypePromotion<(B, C)>,
285    A: SumInto<A::Output>,
286    B: SumInto<A::Output>,
287    C: SumInto<A::Output>,
288{
289    type Output = A::Output;
290    type Parameters = Parameters<3>;
291
292    fn get_sum<'a>(
293        input: PassBy<Self>,
294        parameters: Self::Parameters,
295        dest: &'a mut Option<Self::Output>,
296    ) -> PassBy<'a, Self::Output> {
297        let (a, b, c) = input;
298        A::sum_into(a, parameters.operations[0], dest);
299        B::sum_into(b, parameters.operations[1], dest);
300        C::sum_into(c, parameters.operations[2], dest)
301    }
302}
303
304impl<A, B, C, D> Summable for (A, B, C, D)
305where
306    A: TypePromotion<(B, C, D)>,
307    A: SumInto<A::Output>,
308    B: SumInto<A::Output>,
309    C: SumInto<A::Output>,
310    D: SumInto<A::Output>,
311{
312    type Output = A::Output;
313    type Parameters = Parameters<4>;
314
315    fn get_sum<'a>(
316        input: PassBy<Self>,
317        parameters: Self::Parameters,
318        dest: &'a mut Option<Self::Output>,
319    ) -> PassBy<'a, Self::Output> {
320        let (a, b, c, d) = input;
321        A::sum_into(a, parameters.operations[0], dest);
322        B::sum_into(b, parameters.operations[1], dest);
323        C::sum_into(c, parameters.operations[2], dest);
324        D::sum_into(d, parameters.operations[3], dest)
325    }
326}
327
328impl<A, B, C, D, E> Summable for (A, B, C, D, E)
329where
330    A: TypePromotion<(B, C, D, E)>,
331    A: SumInto<A::Output>,
332    B: SumInto<A::Output>,
333    C: SumInto<A::Output>,
334    D: SumInto<A::Output>,
335    E: SumInto<A::Output>,
336{
337    type Output = A::Output;
338    type Parameters = Parameters<5>;
339
340    fn get_sum<'a>(
341        input: PassBy<Self>,
342        parameters: Self::Parameters,
343        dest: &'a mut Option<Self::Output>,
344    ) -> PassBy<'a, Self::Output> {
345        let (a, b, c, d, e) = input;
346        A::sum_into(a, parameters.operations[0], dest);
347        B::sum_into(b, parameters.operations[1], dest);
348        C::sum_into(c, parameters.operations[2], dest);
349        D::sum_into(d, parameters.operations[3], dest);
350        E::sum_into(e, parameters.operations[4], dest)
351    }
352}
353
354impl<A, B, C, D, E, F> Summable for (A, B, C, D, E, F)
355where
356    A: TypePromotion<(B, C, D, E, F)>,
357    A: SumInto<A::Output>,
358    B: SumInto<A::Output>,
359    C: SumInto<A::Output>,
360    D: SumInto<A::Output>,
361    E: SumInto<A::Output>,
362    F: SumInto<A::Output>,
363{
364    type Output = A::Output;
365    type Parameters = Parameters<6>;
366
367    fn get_sum<'a>(
368        input: PassBy<Self>,
369        parameters: Self::Parameters,
370        dest: &'a mut Option<Self::Output>,
371    ) -> PassBy<'a, Self::Output> {
372        let (a, b, c, d, e, f) = input;
373        A::sum_into(a, parameters.operations[0], dest);
374        B::sum_into(b, parameters.operations[1], dest);
375        C::sum_into(c, parameters.operations[2], dest);
376        D::sum_into(d, parameters.operations[3], dest);
377        E::sum_into(e, parameters.operations[4], dest);
378        F::sum_into(f, parameters.operations[5], dest)
379    }
380}
381
382impl<A, B, C, D, E, F, G> Summable for (A, B, C, D, E, F, G)
383where
384    A: TypePromotion<(B, C, D, E, F, G)>,
385    A: SumInto<A::Output>,
386    B: SumInto<A::Output>,
387    C: SumInto<A::Output>,
388    D: SumInto<A::Output>,
389    E: SumInto<A::Output>,
390    F: SumInto<A::Output>,
391    G: SumInto<A::Output>,
392{
393    type Output = A::Output;
394    type Parameters = Parameters<7>;
395
396    fn get_sum<'a>(
397        input: PassBy<Self>,
398        parameters: Self::Parameters,
399        dest: &'a mut Option<Self::Output>,
400    ) -> PassBy<'a, Self::Output> {
401        let (a, b, c, d, e, f, g) = input;
402        A::sum_into(a, parameters.operations[0], dest);
403        B::sum_into(b, parameters.operations[1], dest);
404        C::sum_into(c, parameters.operations[2], dest);
405        D::sum_into(d, parameters.operations[3], dest);
406        E::sum_into(e, parameters.operations[4], dest);
407        F::sum_into(f, parameters.operations[5], dest);
408        G::sum_into(g, parameters.operations[6], dest)
409    }
410}
411
412impl<A, B, C, D, E, F, G, H> Summable for (A, B, C, D, E, F, G, H)
413where
414    A: TypePromotion<(B, C, D, E, F, G, H)>,
415    A: SumInto<A::Output>,
416    B: SumInto<A::Output>,
417    C: SumInto<A::Output>,
418    D: SumInto<A::Output>,
419    E: SumInto<A::Output>,
420    F: SumInto<A::Output>,
421    G: SumInto<A::Output>,
422    H: SumInto<A::Output>,
423{
424    type Output = A::Output;
425    type Parameters = Parameters<8>;
426
427    fn get_sum<'a>(
428        input: PassBy<Self>,
429        parameters: Self::Parameters,
430        dest: &'a mut Option<Self::Output>,
431    ) -> PassBy<'a, Self::Output> {
432        let (a, b, c, d, e, f, g, h) = input;
433        A::sum_into(a, parameters.operations[0], dest);
434        B::sum_into(b, parameters.operations[1], dest);
435        C::sum_into(c, parameters.operations[2], dest);
436        D::sum_into(d, parameters.operations[3], dest);
437        E::sum_into(e, parameters.operations[4], dest);
438        F::sum_into(f, parameters.operations[5], dest);
439        G::sum_into(g, parameters.operations[6], dest);
440        H::sum_into(h, parameters.operations[7], dest)
441    }
442}
443
444/// The type of sum to perform
445#[derive(Clone, Copy, Debug, PartialEq)]
446pub enum SumType {
447    Addition,
448    Subtraction,
449}
450
451/// The parameters for the sum block
452#[derive(Clone, Copy, Debug, PartialEq)]
453pub struct Parameters<const NUM_INPUTS: usize> {
454    pub operations: [SumType; NUM_INPUTS],
455}
456
457impl<const NUM_INPUTS: usize> Parameters<NUM_INPUTS> {
458    /// This new function accepts a fixed size arrays of f64 because that is what codgen hands it currently
459    /// It should be revisited when we tackle codegen changes
460    pub fn new(input: [f64; NUM_INPUTS]) -> Self {
461        let mut operations = [SumType::Addition; NUM_INPUTS];
462        for (i, &val) in input.iter().enumerate() {
463            if val < 0.0 {
464                operations[i] = SumType::Subtraction;
465            }
466        }
467        Self { operations }
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use crate::testing::StubContext;
475    use approx::assert_relative_eq;
476
477    #[test]
478    fn test_one_scalar() {
479        let mut block = SumBlock::<f64>::default();
480        let input = 3.0;
481        let stub_context = StubContext::default();
482        let parameters = Parameters {
483            operations: [SumType::Addition],
484        };
485        let result = block.process(&parameters, &stub_context, input);
486        assert_relative_eq!(result, 3.0);
487    }
488
489    #[test]
490    fn test_one_matrix() {
491        let mut block = SumBlock::<Matrix<2, 2, f64>>::default();
492        let input = Matrix {
493            data: [[1.0, 2.0], [3.0, 4.0]],
494        };
495        let stub_context = StubContext::default();
496        let parameters = Parameters {
497            operations: [SumType::Addition],
498        };
499        let result = block.process(&parameters, &stub_context, &input);
500        assert_relative_eq!(
501            result.data.as_flattened(),
502            [[1.0, 2.0], [3.0, 4.0]].as_flattened()
503        );
504    }
505
506    #[test]
507    fn test_multiple_scalars() {
508        let stub_context = StubContext::default();
509
510        // Two Inputs
511        let mut two_block = SumBlock::<(f64, f64)>::default();
512        let input = (3.0, 4.0);
513        let parameters = Parameters {
514            operations: [SumType::Addition, SumType::Addition],
515        };
516        let result = two_block.process(&parameters, &stub_context, input);
517        assert_relative_eq!(result, 7.0);
518
519        let parameters = Parameters {
520            operations: [SumType::Addition, SumType::Subtraction],
521        };
522        let result = two_block.process(&parameters, &stub_context, input);
523        assert_relative_eq!(result, -1.0);
524
525        let parameters = Parameters {
526            operations: [SumType::Subtraction, SumType::Addition],
527        };
528        let result = two_block.process(&parameters, &stub_context, input);
529        assert_relative_eq!(result, 1.0);
530
531        let parameters = Parameters {
532            operations: [SumType::Subtraction, SumType::Subtraction],
533        };
534        let result = two_block.process(&parameters, &stub_context, input);
535        assert_relative_eq!(result, -7.0);
536
537        // Three Inputs
538        let mut three_block = SumBlock::<(f64, f64, f64)>::default();
539        let input = (3.0, 4.0, 5.0);
540        let parameters = Parameters {
541            operations: [SumType::Addition, SumType::Addition, SumType::Addition],
542        };
543        let result = three_block.process(&parameters, &stub_context, input);
544        assert_relative_eq!(result, 12.0);
545
546        let parameters = Parameters {
547            operations: [SumType::Addition, SumType::Addition, SumType::Subtraction],
548        };
549        let result = three_block.process(&parameters, &stub_context, input);
550        assert_relative_eq!(result, 2.0);
551
552        // Four Inputs
553        let mut four_block = SumBlock::<(f64, f64, f64, f64)>::default();
554        let input = (3.0, 4.0, 5.0, 6.0);
555        let parameters = Parameters {
556            operations: [
557                SumType::Addition,
558                SumType::Addition,
559                SumType::Addition,
560                SumType::Addition,
561            ],
562        };
563        let result = four_block.process(&parameters, &stub_context, input);
564        assert_relative_eq!(result, 18.0);
565
566        // Five Inputs
567        let mut five_block = SumBlock::<(f64, f64, f64, f64, f64)>::default();
568        let input = (3.0, 4.0, 5.0, 6.0, 7.0);
569        let parameters = Parameters {
570            operations: [
571                SumType::Addition,
572                SumType::Addition,
573                SumType::Addition,
574                SumType::Addition,
575                SumType::Addition,
576            ],
577        };
578        let result = five_block.process(&parameters, &stub_context, input);
579        assert_relative_eq!(result, 25.0);
580
581        // Six Inputs
582        let mut six_block = SumBlock::<(f64, f64, f64, f64, f64, f64)>::default();
583        let input = (3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
584        let parameters = Parameters {
585            operations: [
586                SumType::Addition,
587                SumType::Addition,
588                SumType::Addition,
589                SumType::Addition,
590                SumType::Addition,
591                SumType::Addition,
592            ],
593        };
594        let result = six_block.process(&parameters, &stub_context, input);
595        assert_relative_eq!(result, 33.0);
596
597        // Seven Inputs
598        let mut seven_block = SumBlock::<(f64, f64, f64, f64, f64, f64, f64)>::default();
599        let input = (3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
600        let parameters = Parameters {
601            operations: [
602                SumType::Addition,
603                SumType::Addition,
604                SumType::Addition,
605                SumType::Addition,
606                SumType::Addition,
607                SumType::Addition,
608                SumType::Addition,
609            ],
610        };
611        let result = seven_block.process(&parameters, &stub_context, input);
612        assert_relative_eq!(result, 42.0);
613
614        // Eight Inputs
615        let mut eight_block = SumBlock::<(f64, f64, f64, f64, f64, f64, f64, f64)>::default();
616        let input = (3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0);
617        let parameters = Parameters {
618            operations: [
619                SumType::Addition,
620                SumType::Addition,
621                SumType::Addition,
622                SumType::Addition,
623                SumType::Addition,
624                SumType::Addition,
625                SumType::Addition,
626                SumType::Addition,
627            ],
628        };
629        let result = eight_block.process(&parameters, &stub_context, input);
630        assert_relative_eq!(result, 52.0);
631    }
632
633    #[test]
634    fn test_multiple_matrices() {
635        let stub_context = StubContext::default();
636
637        // Two Inputs
638        let mut two_block = SumBlock::<(Matrix<2, 2, f64>, Matrix<2, 2, f64>)>::default();
639        let input = (
640            &Matrix {
641                data: [[1.0, 2.0], [3.0, 4.0]],
642            },
643            &Matrix {
644                data: [[5.0, 6.0], [7.0, 8.0]],
645            },
646        );
647        let parameters = Parameters {
648            operations: [SumType::Addition, SumType::Addition],
649        };
650        let result = two_block.process(&parameters, &stub_context, input);
651        assert_relative_eq!(
652            result.data.as_flattened(),
653            [[6.0, 8.0], [10.0, 12.0]].as_flattened()
654        );
655
656        let parameters = Parameters {
657            operations: [SumType::Addition, SumType::Subtraction],
658        };
659        let result = two_block.process(&parameters, &stub_context, input);
660        assert_relative_eq!(
661            result.data.as_flattened(),
662            [[-4.0, -4.0], [-4.0, -4.0]].as_flattened()
663        );
664
665        let parameters = Parameters {
666            operations: [SumType::Subtraction, SumType::Addition],
667        };
668        let result = two_block.process(&parameters, &stub_context, input);
669        assert_relative_eq!(
670            result.data.as_flattened(),
671            [[4.0, 4.0], [4.0, 4.0]].as_flattened()
672        );
673
674        let parameters = Parameters {
675            operations: [SumType::Subtraction, SumType::Subtraction],
676        };
677        let result = two_block.process(&parameters, &stub_context, input);
678        assert_relative_eq!(
679            result.data.as_flattened(),
680            [[-6.0, -8.0], [-10.0, -12.0]].as_flattened()
681        );
682
683        // Three Inputs
684        let mut three_block =
685            SumBlock::<(Matrix<2, 2, f64>, Matrix<2, 2, f64>, Matrix<2, 2, f64>)>::default();
686        let input = (
687            &Matrix {
688                data: [[1.0, 2.0], [3.0, 4.0]],
689            },
690            &Matrix {
691                data: [[5.0, 6.0], [7.0, 8.0]],
692            },
693            &Matrix {
694                data: [[9.0, 10.0], [11.0, 12.0]],
695            },
696        );
697        let parameters = Parameters {
698            operations: [SumType::Addition, SumType::Addition, SumType::Addition],
699        };
700        let result = three_block.process(&parameters, &stub_context, input);
701        assert_relative_eq!(
702            result.data.as_flattened(),
703            [[15.0, 18.0], [21.0, 24.0]].as_flattened()
704        );
705
706        // Four Inputs
707        let mut four_block = SumBlock::<(
708            Matrix<2, 2, f64>,
709            Matrix<2, 2, f64>,
710            Matrix<2, 2, f64>,
711            Matrix<2, 2, f64>,
712        )>::default();
713        let input = (
714            &Matrix {
715                data: [[1.0, 2.0], [3.0, 4.0]],
716            },
717            &Matrix {
718                data: [[5.0, 6.0], [7.0, 8.0]],
719            },
720            &Matrix {
721                data: [[9.0, 10.0], [11.0, 12.0]],
722            },
723            &Matrix {
724                data: [[13.0, 14.0], [15.0, 16.0]],
725            },
726        );
727        let parameters = Parameters {
728            operations: [
729                SumType::Addition,
730                SumType::Addition,
731                SumType::Addition,
732                SumType::Addition,
733            ],
734        };
735        let result = four_block.process(&parameters, &stub_context, input);
736        assert_relative_eq!(
737            result.data.as_flattened(),
738            [[28.0, 32.0], [36.0, 40.0]].as_flattened()
739        );
740    }
741
742    #[test]
743    fn test_mixed_scalars_and_matrices() {
744        let stub_context = StubContext::default();
745
746        // Two Inputs
747        let mut two_block = SumBlock::<(f64, Matrix<2, 2, f64>)>::default();
748        let input = (
749            3.0,
750            &Matrix {
751                data: [[1.0, 2.0], [3.0, 4.0]],
752            },
753        );
754        let parameters = Parameters {
755            operations: [SumType::Addition, SumType::Addition],
756        };
757        let result = two_block.process(&parameters, &stub_context, input);
758        assert_relative_eq!(
759            result.data.as_flattened(),
760            [[4.0, 5.0], [6.0, 7.0]].as_flattened()
761        );
762
763        // Three Inputs
764        let mut three_block_1 = SumBlock::<(f64, Matrix<2, 2, f64>, f64)>::default();
765        let input = (
766            3.0,
767            &Matrix {
768                data: [[1.0, 2.0], [3.0, 4.0]],
769            },
770            5.0,
771        );
772        let parameters = Parameters {
773            operations: [SumType::Addition, SumType::Addition, SumType::Addition],
774        };
775        let result = three_block_1.process(&parameters, &stub_context, input);
776        assert_relative_eq!(
777            result.data.as_flattened(),
778            [[9.0, 10.0], [11.0, 12.0]].as_flattened()
779        );
780
781        let mut three_block_2 = SumBlock::<(Matrix<2, 2, f64>, f64, Matrix<2, 2, f64>)>::default();
782        let input = (
783            &Matrix {
784                data: [[1.0, 2.0], [3.0, 4.0]],
785            },
786            5.0,
787            &Matrix {
788                data: [[5.0, 6.0], [7.0, 8.0]],
789            },
790        );
791        let parameters = Parameters {
792            operations: [SumType::Addition, SumType::Addition, SumType::Addition],
793        };
794        let result = three_block_2.process(&parameters, &stub_context, input);
795        assert_relative_eq!(
796            result.data.as_flattened(),
797            [[11.0, 13.0], [15.0, 17.0]].as_flattened()
798        );
799    }
800}