Skip to main content

pictorus_blocks/core_blocks/
counter_block.rs

1use pictorus_block_data::{BlockData as OldBlockData, FromPass};
2use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
3
4use crate::traits::Scalar;
5
6pub struct Parameters {}
7
8impl Default for Parameters {
9    fn default() -> Self {
10        Self::new()
11    }
12}
13
14impl Parameters {
15    pub fn new() -> Self {
16        Self {}
17    }
18}
19
20/// Increments a counter every time the count input is truthy.
21///
22/// The counters can be reset using non-zero values of either a single scalar to
23/// to reset all counters or a vector/matrix of values that is the same size as the input to
24/// reset individual counters.
25///
26/// The block is generic over a type 'T'. This is expected to be a tuple of two types, the first
27/// is the input type and the second is the reset type. For both types they accepts either a scalar
28/// or a matrix of scalars. However they are interpreted as bools or matrices of bools, where true or
29/// false is determined by whether the value is non-zero or zero respectively. See the [`Scalar::is_truthy`]
30/// function for more details.
31pub struct CounterBlock<T: Apply>
32where
33    OldBlockData: FromPass<T::Counter>,
34{
35    pub data: OldBlockData,
36    counter: T::Counter,
37}
38
39impl<T: Apply> ProcessBlock for CounterBlock<T>
40where
41    OldBlockData: FromPass<T::Counter>,
42{
43    type Inputs = T;
44    type Output = T::Counter;
45    type Parameters = Parameters;
46
47    fn process<'b>(
48        &'b mut self,
49        _parameters: &Self::Parameters,
50        _context: &dyn pictorus_traits::Context,
51        inputs: PassBy<'_, Self::Inputs>,
52    ) -> PassBy<'b, Self::Output> {
53        T::apply(&mut self.counter, inputs)
54    }
55}
56
57impl<T: Apply> Default for CounterBlock<T>
58where
59    OldBlockData: FromPass<T::Counter>,
60{
61    fn default() -> Self {
62        let counter = T::Counter::default();
63        Self {
64            data: OldBlockData::from_pass(counter.as_by()),
65            counter,
66        }
67    }
68}
69
70pub trait Apply: Pass {
71    type Counter: Default + Pass;
72    fn apply<'a>(count: &'a mut Self::Counter, input: PassBy<Self>) -> PassBy<'a, Self::Counter>;
73}
74
75impl<I: Scalar, R: Scalar> Apply for (I, R) {
76    type Counter = f64;
77    fn apply<'a>(count: &'a mut Self::Counter, input: PassBy<Self>) -> PassBy<'a, Self::Counter> {
78        if input.1.is_truthy() {
79            *count = 0.0;
80        } else if input.0.is_truthy() {
81            *count += 1.0;
82        }
83        count.as_by()
84    }
85}
86
87impl<I: Scalar, R: Scalar, const NROWS: usize, const NCOLS: usize> Apply
88    for (Matrix<NROWS, NCOLS, I>, R)
89{
90    type Counter = Matrix<NROWS, NCOLS, f64>;
91    fn apply<'a>(count: &'a mut Self::Counter, input: PassBy<Self>) -> PassBy<'a, Self::Counter> {
92        for i in 0..NROWS {
93            for j in 0..NCOLS {
94                if input.1.is_truthy() {
95                    count.data[j][i] = 0.0;
96                } else if input.0.data[j][i].is_truthy() {
97                    count.data[j][i] += 1.0;
98                }
99            }
100        }
101        count.as_by()
102    }
103}
104
105impl<I: Scalar, R: Scalar, const NROWS: usize, const NCOLS: usize> Apply
106    for (Matrix<NROWS, NCOLS, I>, Matrix<NROWS, NCOLS, R>)
107{
108    type Counter = Matrix<NROWS, NCOLS, f64>;
109    fn apply<'a>(count: &'a mut Self::Counter, input: PassBy<Self>) -> PassBy<'a, Self::Counter> {
110        for i in 0..NROWS {
111            for j in 0..NCOLS {
112                if input.1.data[j][i].is_truthy() {
113                    count.data[j][i] = 0.0;
114                } else if input.0.data[j][i].is_truthy() {
115                    count.data[j][i] += 1.0;
116                }
117            }
118        }
119        count.as_by()
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use crate::testing::StubContext;
126
127    use super::*;
128
129    #[test]
130    fn test_counter_block_simple_f64() {
131        let p = Parameters::new();
132        let mut block = CounterBlock::<(Matrix<1, 1, bool>, Matrix<1, 1, bool>)>::default();
133        let c = StubContext::default();
134
135        let mut increment = Matrix::<1, 1, bool>::zeroed();
136        increment.data[0][0] = true;
137
138        let mut reset = Matrix::<1, 1, bool>::zeroed();
139        reset.data[0][0] = false;
140
141        let output = block.process(&p, &c, (&increment, &reset));
142        assert!(output.data[0][0] == 1.0);
143
144        let output = block.process(&p, &c, (&increment, &reset));
145        assert!(output.data[0][0] == 2.0);
146
147        reset.data[0][0] = true;
148        let output = block.process(&p, &c, (&increment, &reset));
149        assert!(output.data[0][0] == 0.0);
150    }
151
152    #[test]
153    fn test_counter_block_1x2_f64() {
154        let p = Parameters::new();
155        let mut block = CounterBlock::<(Matrix<1, 2, bool>, Matrix<1, 2, bool>)>::default();
156        let c = StubContext::default();
157
158        let mut increment = Matrix::<1, 2, bool>::zeroed();
159        increment.data[0][0] = true;
160
161        let mut reset = Matrix::<1, 2, bool>::zeroed();
162        reset.data[0][0] = false;
163
164        let output = block.process(&p, &c, (&increment, &reset));
165        assert_eq!(output.data[0][0], 1.0);
166        assert_eq!(output.data[1][0], 0.0);
167
168        let output = block.process(&p, &c, (&increment, &reset));
169        assert_eq!(output.data[0][0], 2.0);
170        assert_eq!(output.data[1][0], 0.0);
171
172        reset.data[0][0] = true;
173        let output = block.process(&p, &c, (&increment, &reset));
174        assert_eq!(output.data[0][0], 0.0);
175        assert_eq!(output.data[1][0], 0.0);
176    }
177
178    #[test]
179    fn test_counter_block_2x2_f64() {
180        let p = Parameters::new();
181        let mut block = CounterBlock::<(Matrix<2, 2, f64>, Matrix<2, 2, bool>)>::default();
182        let c = StubContext::default();
183
184        let mut increment = Matrix::<2, 2, f64>::zeroed();
185        increment.data[0][0] = 1.0;
186        increment.data[1][0] = 1.0;
187        increment.data[0][1] = 1.0;
188        increment.data[1][1] = 1.0;
189
190        let mut reset = Matrix::<2, 2, bool>::zeroed();
191
192        let output = block.process(&p, &c, (&increment, &reset));
193        assert_eq!(output.data[0][0], 1.0);
194        assert_eq!(output.data[1][0], 1.0);
195        assert_eq!(output.data[0][1], 1.0);
196        assert_eq!(output.data[1][1], 1.0);
197
198        let output = block.process(&p, &c, (&increment, &reset));
199        assert_eq!(output.data[0][0], 2.0);
200        assert_eq!(output.data[1][0], 2.0);
201        assert_eq!(output.data[0][1], 2.0);
202        assert_eq!(output.data[1][1], 2.0);
203
204        reset.data[0][0] = true;
205        let output = block.process(&p, &c, (&increment, &reset));
206        assert_eq!(output.data[0][0], 0.0);
207        assert_eq!(output.data[1][0], 3.0);
208        assert_eq!(output.data[0][1], 3.0);
209        assert_eq!(output.data[1][1], 3.0);
210
211        reset.data[0][0] = false;
212        reset.data[1][0] = true;
213        let output = block.process(&p, &c, (&increment, &reset));
214        assert_eq!(output.data[0][0], 1.0);
215        assert_eq!(output.data[1][0], 0.0);
216        assert_eq!(output.data[0][1], 4.0);
217        assert_eq!(output.data[1][1], 4.0);
218
219        reset.data[0][0] = false;
220        reset.data[1][0] = false;
221        reset.data[0][1] = true;
222        let output = block.process(&p, &c, (&increment, &reset));
223        assert_eq!(output.data[0][0], 2.0);
224        assert_eq!(output.data[1][0], 1.0);
225        assert_eq!(output.data[0][1], 0.0);
226        assert_eq!(output.data[1][1], 5.0);
227    }
228
229    #[test]
230    fn test_counter_block_2x2_single_reset_f64() {
231        let p = Parameters::new();
232        let mut block = CounterBlock::<(Matrix<2, 2, f64>, bool)>::default();
233        let c = StubContext::default();
234
235        let mut increment = Matrix::<2, 2, f64>::zeroed();
236        increment.data[0][0] = 1.0;
237        increment.data[1][0] = 1.0;
238        increment.data[0][1] = 1.0;
239        increment.data[1][1] = 1.0;
240
241        let mut reset = false;
242
243        let output = block.process(&p, &c, (&increment, reset));
244        assert_eq!(output.data[0][0], 1.0);
245        assert_eq!(output.data[1][0], 1.0);
246        assert_eq!(output.data[0][1], 1.0);
247        assert_eq!(output.data[1][1], 1.0);
248
249        let output = block.process(&p, &c, (&increment, reset));
250        assert_eq!(output.data[0][0], 2.0);
251        assert_eq!(output.data[1][0], 2.0);
252        assert_eq!(output.data[0][1], 2.0);
253        assert_eq!(output.data[1][1], 2.0);
254
255        reset = true;
256        let output = block.process(&p, &c, (&increment, reset));
257        assert_eq!(output.data[0][0], 0.0);
258        assert_eq!(output.data[1][0], 0.0);
259        assert_eq!(output.data[0][1], 0.0);
260        assert_eq!(output.data[1][1], 0.0);
261
262        reset = false;
263        let output = block.process(&p, &c, (&increment, reset));
264        assert_eq!(output.data[0][0], 1.0);
265        assert_eq!(output.data[1][0], 1.0);
266        assert_eq!(output.data[0][1], 1.0);
267        assert_eq!(output.data[1][1], 1.0);
268
269        let output = block.process(&p, &c, (&increment, reset));
270        assert_eq!(output.data[0][0], 2.0);
271        assert_eq!(output.data[1][0], 2.0);
272        assert_eq!(output.data[0][1], 2.0);
273        assert_eq!(output.data[1][1], 2.0);
274    }
275
276    #[test]
277    fn test_counter_block_2x2_u8() {
278        let p = Parameters::new();
279        let mut block = CounterBlock::<(Matrix<2, 2, u8>, Matrix<2, 2, bool>)>::default();
280        let c = StubContext::default();
281
282        let mut increment = Matrix::<2, 2, u8>::zeroed();
283        increment.data[0][0] = 1;
284        increment.data[1][0] = 1;
285        increment.data[0][1] = 1;
286        increment.data[1][1] = 1;
287
288        let mut reset = Matrix::<2, 2, bool>::zeroed();
289
290        let output = block.process(&p, &c, (&increment, &reset));
291        assert_eq!(output.data[0][0], 1.0);
292        assert_eq!(output.data[1][0], 1.0);
293        assert_eq!(output.data[0][1], 1.0);
294        assert_eq!(output.data[1][1], 1.0);
295
296        let output = block.process(&p, &c, (&increment, &reset));
297        assert_eq!(output.data[0][0], 2.0);
298        assert_eq!(output.data[1][0], 2.0);
299        assert_eq!(output.data[0][1], 2.0);
300        assert_eq!(output.data[1][1], 2.0);
301
302        reset.data[0][0] = true;
303        let output = block.process(&p, &c, (&increment, &reset));
304        assert_eq!(output.data[0][0], 0.0);
305        assert_eq!(output.data[1][0], 3.0);
306        assert_eq!(output.data[0][1], 3.0);
307        assert_eq!(output.data[1][1], 3.0);
308
309        reset.data[0][0] = false;
310        reset.data[1][0] = true;
311        let output = block.process(&p, &c, (&increment, &reset));
312        assert_eq!(output.data[0][0], 1.0);
313        assert_eq!(output.data[1][0], 0.0);
314        assert_eq!(output.data[0][1], 4.0);
315        assert_eq!(output.data[1][1], 4.0);
316
317        reset.data[0][0] = false;
318        reset.data[1][0] = false;
319        reset.data[0][1] = true;
320        let output = block.process(&p, &c, (&increment, &reset));
321        assert_eq!(output.data[0][0], 2.0);
322        assert_eq!(output.data[1][0], 1.0);
323        assert_eq!(output.data[0][1], 0.0);
324        assert_eq!(output.data[1][1], 5.0);
325    }
326}