pictorus_blocks/core_blocks/
aggregate_block.rs1use crate::nalgebra_interop::MatrixExt;
2use pictorus_block_data::{BlockData as OldBlockData, FromPass};
3use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock, Scalar};
4
5pub struct AggregateBlock<T: Apply> {
7 pub data: OldBlockData,
8 buffer: Option<T::Output>,
9}
10
11impl<T: Apply> Default for AggregateBlock<T>
12where
13 T: Pass + Default,
14 OldBlockData: FromPass<T::Output>,
15{
16 fn default() -> Self {
17 Self {
18 data: <OldBlockData as FromPass<T::Output>>::from_pass(<T::Output>::default().as_by()),
19 buffer: None,
20 }
21 }
22}
23
24impl<T> ProcessBlock for AggregateBlock<T>
25where
26 T: Apply + Default,
27 OldBlockData: FromPass<T::Output>,
28{
29 type Inputs = T;
30 type Output = T::Output;
31 type Parameters = Parameters;
32
33 fn process<'b>(
34 &'b mut self,
35 parameters: &Self::Parameters,
36 _context: &dyn pictorus_traits::Context,
37 inputs: pictorus_traits::PassBy<'_, Self::Inputs>,
38 ) -> pictorus_traits::PassBy<'b, Self::Output> {
39 let output = T::apply(&mut self.buffer, inputs, parameters.method);
40 self.data = OldBlockData::from_pass(output);
41 output
42 }
43}
44
45pub trait Apply: Pass {
46 type Output: Scalar;
47
48 fn apply<'s>(
49 store: &mut Option<Self::Output>,
50 input: PassBy<Self>,
51 method: AggregateMethod,
52 ) -> PassBy<'s, Self::Output>;
53}
54
55macro_rules! scalar_impls {
56 () => {};
57 ($type:ty, $($rest:tt),+) => {
58 scalar_impls!($type);
59 scalar_impls!($($rest),+);
60 };
61 ($type:ty) => {
62 impl Apply for $type {
63 type Output = $type;
64
65 fn apply<'s>(
66 store: &mut Option<Self::Output>,
67 input: PassBy<Self>,
68 _method: AggregateMethod,
69 ) -> PassBy<'s, Self::Output> {
70 *store = Some(input);
71 input
72 }
73 }
74 };
75}
76scalar_impls!(f64, f32); macro_rules! float_matrix_impl {
79 ($type:ty) => {
80 impl<const NROWS: usize, const NCOLS: usize> Apply for Matrix<NROWS, NCOLS, $type> {
81 type Output = $type;
82
83 fn apply<'s>(
84 store: &mut Option<Self::Output>,
85 input: PassBy<Self>,
86 method: AggregateMethod,
87 ) -> PassBy<'s, Self::Output> {
88 let view = input.as_view();
89 let output = match method {
90 AggregateMethod::Sum => view.sum(),
91 AggregateMethod::Mean => view.mean(),
92 AggregateMethod::Median => {
93 let mut data = *input;
95 let data = data.data.as_flattened_mut();
96 view.iter().enumerate().for_each(|(i, &x)| data[i] = x);
97 data.sort_by(|a, b| a.partial_cmp(b).expect("NaNs are not supported"));
98 let mid = data.len() / 2;
99 if data.len() % 2 == 0 {
100 (data[mid - 1] + data[mid]) / Self::Output::from(2u8)
101 } else {
102 data[mid]
103 }
104 }
105 AggregateMethod::Min => view.min(),
106 AggregateMethod::Max => view.max(),
107 };
108 *store = Some(output);
109 output
110 }
111 }
112 };
113}
114
115float_matrix_impl!(f64);
116float_matrix_impl!(f32);
117
118#[derive(Debug, Clone, Copy, PartialEq, strum::EnumString)]
120pub enum AggregateMethod {
121 Sum,
123 Mean,
125 Median,
127 Min,
129 Max,
131}
132
133pub struct Parameters {
134 pub method: AggregateMethod,
135}
136impl Parameters {
137 pub fn new(method: &str) -> Self {
138 Self {
139 method: method.parse().expect("Invalid aggregate method"),
140 }
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::testing::StubContext;
148 use alloc::str::FromStr;
149 use approx::assert_relative_eq;
150
151 #[test]
152 fn test_aggregate_sum_f32() {
153 let mut block = AggregateBlock::<Matrix<4, 7, f32>>::default();
154 let context = StubContext::default();
155 let params = Parameters {
156 method: AggregateMethod::Sum,
157 };
158 let input: Matrix<4, 7, f32> = Matrix {
159 data: [[1.0; 4]; 7],
160 };
161 let output = block.process(¶ms, &context, &input);
162 assert_relative_eq!(output, 28.0);
163 assert_relative_eq!(block.data.scalar(), 28.0);
164 }
165
166 #[test]
167 fn test_aggregate_sum_f64() {
168 let mut block = AggregateBlock::<Matrix<4, 7, f64>>::default();
169 let context = StubContext::default();
170 let params = Parameters {
171 method: AggregateMethod::Sum,
172 };
173 let input: Matrix<4, 7, f64> = Matrix {
174 data: [[1.0; 4]; 7],
175 };
176 let output = block.process(¶ms, &context, &input);
177 assert_relative_eq!(output, 28.0);
178 assert_relative_eq!(block.data.scalar(), 28.0);
179 }
180
181 #[test]
182 fn test_aggregate_max_f64() {
183 let mut block = AggregateBlock::<Matrix<4, 7, f64>>::default();
184 let context = StubContext::default();
185 let params = Parameters {
186 method: AggregateMethod::Max,
187 };
188 let mut input: Matrix<4, 7, f64> = Matrix {
189 data: [[1.0; 4]; 7],
190 };
191 input.data[5][3] = 42.0;
192 let output = block.process(¶ms, &context, &input);
193 assert_relative_eq!(output, 42.0);
194 assert_relative_eq!(block.data.scalar(), 42.0);
195 }
196
197 #[test]
198 fn test_aggregate_min_f64() {
199 let mut block = AggregateBlock::<Matrix<4, 7, f64>>::default();
200 let context = StubContext::default();
201 let params = Parameters {
202 method: AggregateMethod::Min,
203 };
204 let mut input: Matrix<4, 7, f64> = Matrix {
205 data: [[11.0; 4]; 7],
206 };
207 input.data[1][2] = 10.99;
208 let output = block.process(¶ms, &context, &input);
209 assert_relative_eq!(output, 10.99);
210 assert_relative_eq!(block.data.scalar(), 10.99);
211 }
212
213 #[test]
214 fn test_aggregate_mean_f64() {
215 let mut block = AggregateBlock::<Matrix<4, 7, f64>>::default();
216 let context = StubContext::default();
217 let params = Parameters {
218 method: AggregateMethod::Mean,
219 };
220 let mut input: Matrix<4, 7, f64> = Matrix::zeroed();
221 for (idx, elem) in input.data.as_flattened_mut().iter_mut().enumerate() {
222 *elem = idx as f64;
223 }
224
225 let output = block.process(¶ms, &context, &input);
226 assert_relative_eq!(output, 13.5);
227 assert_relative_eq!(block.data.scalar(), 13.5);
228 }
229
230 #[test]
231 fn test_aggregate_median_f64() {
232 let mut block = AggregateBlock::<Matrix<4, 7, f64>>::default();
233 let context = StubContext::default();
234 let params = Parameters {
235 method: AggregateMethod::Median,
236 };
237 let mut input: Matrix<4, 7, f64> = Matrix::zeroed();
238 for (idx, elem) in input.data.as_flattened_mut().iter_mut().enumerate() {
239 *elem = idx as f64;
240 }
241
242 let output = block.process(¶ms, &context, &input);
243 assert_relative_eq!(output, 13.5);
244 assert_relative_eq!(block.data.scalar(), 13.5);
245 }
246
247 #[test]
248 fn test_aggregate_method_from_str() {
249 assert_eq!(
250 AggregateMethod::from_str("Sum").unwrap(),
251 AggregateMethod::Sum
252 );
253 assert_eq!(
254 AggregateMethod::from_str("Mean").unwrap(),
255 AggregateMethod::Mean
256 );
257 assert_eq!(
258 AggregateMethod::from_str("Median").unwrap(),
259 AggregateMethod::Median
260 );
261 assert_eq!(
262 AggregateMethod::from_str("Min").unwrap(),
263 AggregateMethod::Min
264 );
265 assert_eq!(
266 AggregateMethod::from_str("Max").unwrap(),
267 AggregateMethod::Max
268 );
269 assert!(AggregateMethod::from_str("Invalid").is_err());
270 }
271}