Skip to main content

pictorus_blocks/core_blocks/
aggregate_block.rs

1use crate::nalgebra_interop::MatrixExt;
2use pictorus_block_data::{BlockData as OldBlockData, FromPass};
3use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock, Scalar};
4
5/// Block for performing an aggregation operation (i.e. sum, min, max) on input data.
6pub struct AggregateBlock<T: Apply> {
7    pub data: OldBlockData,
8    buffer: Option<T::Output>,
9}
10
11impl<T: Apply> Default for AggregateBlock<T>
12where
13    T: Pass + Default,
14    OldBlockData: FromPass<T::Output>,
15{
16    fn default() -> Self {
17        Self {
18            data: <OldBlockData as FromPass<T::Output>>::from_pass(<T::Output>::default().as_by()),
19            buffer: None,
20        }
21    }
22}
23
24impl<T> ProcessBlock for AggregateBlock<T>
25where
26    T: Apply + Default,
27    OldBlockData: FromPass<T::Output>,
28{
29    type Inputs = T;
30    type Output = T::Output;
31    type Parameters = Parameters;
32
33    fn process<'b>(
34        &'b mut self,
35        parameters: &Self::Parameters,
36        _context: &dyn pictorus_traits::Context,
37        inputs: pictorus_traits::PassBy<'_, Self::Inputs>,
38    ) -> pictorus_traits::PassBy<'b, Self::Output> {
39        let output = T::apply(&mut self.buffer, inputs, parameters.method);
40        self.data = OldBlockData::from_pass(output);
41        output
42    }
43}
44
45pub trait Apply: Pass {
46    type Output: Scalar;
47
48    fn apply<'s>(
49        store: &mut Option<Self::Output>,
50        input: PassBy<Self>,
51        method: AggregateMethod,
52    ) -> PassBy<'s, Self::Output>;
53}
54
55macro_rules! scalar_impls {
56    () => {};
57    ($type:ty, $($rest:tt),+) => {
58        scalar_impls!($type);
59        scalar_impls!($($rest),+);
60    };
61    ($type:ty) => {
62        impl Apply for $type {
63            type Output = $type;
64
65            fn apply<'s>(
66                store: &mut Option<Self::Output>,
67                input: PassBy<Self>,
68                _method: AggregateMethod,
69            ) -> PassBy<'s, Self::Output> {
70                *store = Some(input);
71                input
72            }
73        }
74    };
75}
76scalar_impls!(f64, f32); // We could also just easily add u8, u16 and bool here but they wouldn't have equivalent matrix impls
77
78macro_rules! float_matrix_impl {
79    ($type:ty) => {
80        impl<const NROWS: usize, const NCOLS: usize> Apply for Matrix<NROWS, NCOLS, $type> {
81            type Output = $type;
82
83            fn apply<'s>(
84                store: &mut Option<Self::Output>,
85                input: PassBy<Self>,
86                method: AggregateMethod,
87            ) -> PassBy<'s, Self::Output> {
88                let view = input.as_view();
89                let output = match method {
90                    AggregateMethod::Sum => view.sum(),
91                    AggregateMethod::Mean => view.mean(),
92                    AggregateMethod::Median => {
93                        // Have to copy the data to the stack so we can sort it
94                        let mut data = *input;
95                        let data = data.data.as_flattened_mut();
96                        view.iter().enumerate().for_each(|(i, &x)| data[i] = x);
97                        data.sort_by(|a, b| a.partial_cmp(b).expect("NaNs are not supported"));
98                        let mid = data.len() / 2;
99                        if data.len() % 2 == 0 {
100                            (data[mid - 1] + data[mid]) / Self::Output::from(2u8)
101                        } else {
102                            data[mid]
103                        }
104                    }
105                    AggregateMethod::Min => view.min(),
106                    AggregateMethod::Max => view.max(),
107                };
108                *store = Some(output);
109                output
110            }
111        }
112    };
113}
114
115float_matrix_impl!(f64);
116float_matrix_impl!(f32);
117
118/// Represents the method of aggregation to be performed.
119#[derive(Debug, Clone, Copy, PartialEq, strum::EnumString)]
120pub enum AggregateMethod {
121    /// Sum of all elements.
122    Sum,
123    /// Mean (average) of all elements.
124    Mean,
125    /// Median of all elements.
126    Median,
127    /// Minimum value among all elements.
128    Min,
129    /// Maximum value among all elements.
130    Max,
131}
132
133pub struct Parameters {
134    pub method: AggregateMethod,
135}
136impl Parameters {
137    pub fn new(method: &str) -> Self {
138        Self {
139            method: method.parse().expect("Invalid aggregate method"),
140        }
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::testing::StubContext;
148    use alloc::str::FromStr;
149    use approx::assert_relative_eq;
150
151    #[test]
152    fn test_aggregate_sum_f32() {
153        let mut block = AggregateBlock::<Matrix<4, 7, f32>>::default();
154        let context = StubContext::default();
155        let params = Parameters {
156            method: AggregateMethod::Sum,
157        };
158        let input: Matrix<4, 7, f32> = Matrix {
159            data: [[1.0; 4]; 7],
160        };
161        let output = block.process(&params, &context, &input);
162        assert_relative_eq!(output, 28.0);
163        assert_relative_eq!(block.data.scalar(), 28.0);
164    }
165
166    #[test]
167    fn test_aggregate_sum_f64() {
168        let mut block = AggregateBlock::<Matrix<4, 7, f64>>::default();
169        let context = StubContext::default();
170        let params = Parameters {
171            method: AggregateMethod::Sum,
172        };
173        let input: Matrix<4, 7, f64> = Matrix {
174            data: [[1.0; 4]; 7],
175        };
176        let output = block.process(&params, &context, &input);
177        assert_relative_eq!(output, 28.0);
178        assert_relative_eq!(block.data.scalar(), 28.0);
179    }
180
181    #[test]
182    fn test_aggregate_max_f64() {
183        let mut block = AggregateBlock::<Matrix<4, 7, f64>>::default();
184        let context = StubContext::default();
185        let params = Parameters {
186            method: AggregateMethod::Max,
187        };
188        let mut input: Matrix<4, 7, f64> = Matrix {
189            data: [[1.0; 4]; 7],
190        };
191        input.data[5][3] = 42.0;
192        let output = block.process(&params, &context, &input);
193        assert_relative_eq!(output, 42.0);
194        assert_relative_eq!(block.data.scalar(), 42.0);
195    }
196
197    #[test]
198    fn test_aggregate_min_f64() {
199        let mut block = AggregateBlock::<Matrix<4, 7, f64>>::default();
200        let context = StubContext::default();
201        let params = Parameters {
202            method: AggregateMethod::Min,
203        };
204        let mut input: Matrix<4, 7, f64> = Matrix {
205            data: [[11.0; 4]; 7],
206        };
207        input.data[1][2] = 10.99;
208        let output = block.process(&params, &context, &input);
209        assert_relative_eq!(output, 10.99);
210        assert_relative_eq!(block.data.scalar(), 10.99);
211    }
212
213    #[test]
214    fn test_aggregate_mean_f64() {
215        let mut block = AggregateBlock::<Matrix<4, 7, f64>>::default();
216        let context = StubContext::default();
217        let params = Parameters {
218            method: AggregateMethod::Mean,
219        };
220        let mut input: Matrix<4, 7, f64> = Matrix::zeroed();
221        for (idx, elem) in input.data.as_flattened_mut().iter_mut().enumerate() {
222            *elem = idx as f64;
223        }
224
225        let output = block.process(&params, &context, &input);
226        assert_relative_eq!(output, 13.5);
227        assert_relative_eq!(block.data.scalar(), 13.5);
228    }
229
230    #[test]
231    fn test_aggregate_median_f64() {
232        let mut block = AggregateBlock::<Matrix<4, 7, f64>>::default();
233        let context = StubContext::default();
234        let params = Parameters {
235            method: AggregateMethod::Median,
236        };
237        let mut input: Matrix<4, 7, f64> = Matrix::zeroed();
238        for (idx, elem) in input.data.as_flattened_mut().iter_mut().enumerate() {
239            *elem = idx as f64;
240        }
241
242        let output = block.process(&params, &context, &input);
243        assert_relative_eq!(output, 13.5);
244        assert_relative_eq!(block.data.scalar(), 13.5);
245    }
246
247    #[test]
248    fn test_aggregate_method_from_str() {
249        assert_eq!(
250            AggregateMethod::from_str("Sum").unwrap(),
251            AggregateMethod::Sum
252        );
253        assert_eq!(
254            AggregateMethod::from_str("Mean").unwrap(),
255            AggregateMethod::Mean
256        );
257        assert_eq!(
258            AggregateMethod::from_str("Median").unwrap(),
259            AggregateMethod::Median
260        );
261        assert_eq!(
262            AggregateMethod::from_str("Min").unwrap(),
263            AggregateMethod::Min
264        );
265        assert_eq!(
266            AggregateMethod::from_str("Max").unwrap(),
267            AggregateMethod::Max
268        );
269        assert!(AggregateMethod::from_str("Invalid").is_err());
270    }
271}