1use pictorus_block_data::{BlockData as OldBlockData, FromPass};
2use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
3
4use crate::traits::Scalar;
5
6pub struct Parameters {}
7
8impl Default for Parameters {
9 fn default() -> Self {
10 Self::new()
11 }
12}
13
14impl Parameters {
15 pub fn new() -> Self {
16 Self {}
17 }
18}
19
20pub struct CounterBlock<T: Apply>
32where
33 OldBlockData: FromPass<T::Counter>,
34{
35 pub data: OldBlockData,
36 counter: T::Counter,
37}
38
39impl<T: Apply> ProcessBlock for CounterBlock<T>
40where
41 OldBlockData: FromPass<T::Counter>,
42{
43 type Inputs = T;
44 type Output = T::Counter;
45 type Parameters = Parameters;
46
47 fn process<'b>(
48 &'b mut self,
49 _parameters: &Self::Parameters,
50 _context: &dyn pictorus_traits::Context,
51 inputs: PassBy<'_, Self::Inputs>,
52 ) -> PassBy<'b, Self::Output> {
53 T::apply(&mut self.counter, inputs)
54 }
55}
56
57impl<T: Apply> Default for CounterBlock<T>
58where
59 OldBlockData: FromPass<T::Counter>,
60{
61 fn default() -> Self {
62 let counter = T::Counter::default();
63 Self {
64 data: OldBlockData::from_pass(counter.as_by()),
65 counter,
66 }
67 }
68}
69
70pub trait Apply: Pass {
71 type Counter: Default + Pass;
72 fn apply<'a>(count: &'a mut Self::Counter, input: PassBy<Self>) -> PassBy<'a, Self::Counter>;
73}
74
75impl<I: Scalar, R: Scalar> Apply for (I, R) {
76 type Counter = f64;
77 fn apply<'a>(count: &'a mut Self::Counter, input: PassBy<Self>) -> PassBy<'a, Self::Counter> {
78 if input.1.is_truthy() {
79 *count = 0.0;
80 } else if input.0.is_truthy() {
81 *count += 1.0;
82 }
83 count.as_by()
84 }
85}
86
87impl<I: Scalar, R: Scalar, const NROWS: usize, const NCOLS: usize> Apply
88 for (Matrix<NROWS, NCOLS, I>, R)
89{
90 type Counter = Matrix<NROWS, NCOLS, f64>;
91 fn apply<'a>(count: &'a mut Self::Counter, input: PassBy<Self>) -> PassBy<'a, Self::Counter> {
92 for i in 0..NROWS {
93 for j in 0..NCOLS {
94 if input.1.is_truthy() {
95 count.data[j][i] = 0.0;
96 } else if input.0.data[j][i].is_truthy() {
97 count.data[j][i] += 1.0;
98 }
99 }
100 }
101 count.as_by()
102 }
103}
104
105impl<I: Scalar, R: Scalar, const NROWS: usize, const NCOLS: usize> Apply
106 for (Matrix<NROWS, NCOLS, I>, Matrix<NROWS, NCOLS, R>)
107{
108 type Counter = Matrix<NROWS, NCOLS, f64>;
109 fn apply<'a>(count: &'a mut Self::Counter, input: PassBy<Self>) -> PassBy<'a, Self::Counter> {
110 for i in 0..NROWS {
111 for j in 0..NCOLS {
112 if input.1.data[j][i].is_truthy() {
113 count.data[j][i] = 0.0;
114 } else if input.0.data[j][i].is_truthy() {
115 count.data[j][i] += 1.0;
116 }
117 }
118 }
119 count.as_by()
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use crate::testing::StubContext;
126
127 use super::*;
128
129 #[test]
130 fn test_counter_block_simple_f64() {
131 let p = Parameters::new();
132 let mut block = CounterBlock::<(Matrix<1, 1, bool>, Matrix<1, 1, bool>)>::default();
133 let c = StubContext::default();
134
135 let mut increment = Matrix::<1, 1, bool>::zeroed();
136 increment.data[0][0] = true;
137
138 let mut reset = Matrix::<1, 1, bool>::zeroed();
139 reset.data[0][0] = false;
140
141 let output = block.process(&p, &c, (&increment, &reset));
142 assert!(output.data[0][0] == 1.0);
143
144 let output = block.process(&p, &c, (&increment, &reset));
145 assert!(output.data[0][0] == 2.0);
146
147 reset.data[0][0] = true;
148 let output = block.process(&p, &c, (&increment, &reset));
149 assert!(output.data[0][0] == 0.0);
150 }
151
152 #[test]
153 fn test_counter_block_1x2_f64() {
154 let p = Parameters::new();
155 let mut block = CounterBlock::<(Matrix<1, 2, bool>, Matrix<1, 2, bool>)>::default();
156 let c = StubContext::default();
157
158 let mut increment = Matrix::<1, 2, bool>::zeroed();
159 increment.data[0][0] = true;
160
161 let mut reset = Matrix::<1, 2, bool>::zeroed();
162 reset.data[0][0] = false;
163
164 let output = block.process(&p, &c, (&increment, &reset));
165 assert_eq!(output.data[0][0], 1.0);
166 assert_eq!(output.data[1][0], 0.0);
167
168 let output = block.process(&p, &c, (&increment, &reset));
169 assert_eq!(output.data[0][0], 2.0);
170 assert_eq!(output.data[1][0], 0.0);
171
172 reset.data[0][0] = true;
173 let output = block.process(&p, &c, (&increment, &reset));
174 assert_eq!(output.data[0][0], 0.0);
175 assert_eq!(output.data[1][0], 0.0);
176 }
177
178 #[test]
179 fn test_counter_block_2x2_f64() {
180 let p = Parameters::new();
181 let mut block = CounterBlock::<(Matrix<2, 2, f64>, Matrix<2, 2, bool>)>::default();
182 let c = StubContext::default();
183
184 let mut increment = Matrix::<2, 2, f64>::zeroed();
185 increment.data[0][0] = 1.0;
186 increment.data[1][0] = 1.0;
187 increment.data[0][1] = 1.0;
188 increment.data[1][1] = 1.0;
189
190 let mut reset = Matrix::<2, 2, bool>::zeroed();
191
192 let output = block.process(&p, &c, (&increment, &reset));
193 assert_eq!(output.data[0][0], 1.0);
194 assert_eq!(output.data[1][0], 1.0);
195 assert_eq!(output.data[0][1], 1.0);
196 assert_eq!(output.data[1][1], 1.0);
197
198 let output = block.process(&p, &c, (&increment, &reset));
199 assert_eq!(output.data[0][0], 2.0);
200 assert_eq!(output.data[1][0], 2.0);
201 assert_eq!(output.data[0][1], 2.0);
202 assert_eq!(output.data[1][1], 2.0);
203
204 reset.data[0][0] = true;
205 let output = block.process(&p, &c, (&increment, &reset));
206 assert_eq!(output.data[0][0], 0.0);
207 assert_eq!(output.data[1][0], 3.0);
208 assert_eq!(output.data[0][1], 3.0);
209 assert_eq!(output.data[1][1], 3.0);
210
211 reset.data[0][0] = false;
212 reset.data[1][0] = true;
213 let output = block.process(&p, &c, (&increment, &reset));
214 assert_eq!(output.data[0][0], 1.0);
215 assert_eq!(output.data[1][0], 0.0);
216 assert_eq!(output.data[0][1], 4.0);
217 assert_eq!(output.data[1][1], 4.0);
218
219 reset.data[0][0] = false;
220 reset.data[1][0] = false;
221 reset.data[0][1] = true;
222 let output = block.process(&p, &c, (&increment, &reset));
223 assert_eq!(output.data[0][0], 2.0);
224 assert_eq!(output.data[1][0], 1.0);
225 assert_eq!(output.data[0][1], 0.0);
226 assert_eq!(output.data[1][1], 5.0);
227 }
228
229 #[test]
230 fn test_counter_block_2x2_single_reset_f64() {
231 let p = Parameters::new();
232 let mut block = CounterBlock::<(Matrix<2, 2, f64>, bool)>::default();
233 let c = StubContext::default();
234
235 let mut increment = Matrix::<2, 2, f64>::zeroed();
236 increment.data[0][0] = 1.0;
237 increment.data[1][0] = 1.0;
238 increment.data[0][1] = 1.0;
239 increment.data[1][1] = 1.0;
240
241 let mut reset = false;
242
243 let output = block.process(&p, &c, (&increment, reset));
244 assert_eq!(output.data[0][0], 1.0);
245 assert_eq!(output.data[1][0], 1.0);
246 assert_eq!(output.data[0][1], 1.0);
247 assert_eq!(output.data[1][1], 1.0);
248
249 let output = block.process(&p, &c, (&increment, reset));
250 assert_eq!(output.data[0][0], 2.0);
251 assert_eq!(output.data[1][0], 2.0);
252 assert_eq!(output.data[0][1], 2.0);
253 assert_eq!(output.data[1][1], 2.0);
254
255 reset = true;
256 let output = block.process(&p, &c, (&increment, reset));
257 assert_eq!(output.data[0][0], 0.0);
258 assert_eq!(output.data[1][0], 0.0);
259 assert_eq!(output.data[0][1], 0.0);
260 assert_eq!(output.data[1][1], 0.0);
261
262 reset = false;
263 let output = block.process(&p, &c, (&increment, reset));
264 assert_eq!(output.data[0][0], 1.0);
265 assert_eq!(output.data[1][0], 1.0);
266 assert_eq!(output.data[0][1], 1.0);
267 assert_eq!(output.data[1][1], 1.0);
268
269 let output = block.process(&p, &c, (&increment, reset));
270 assert_eq!(output.data[0][0], 2.0);
271 assert_eq!(output.data[1][0], 2.0);
272 assert_eq!(output.data[0][1], 2.0);
273 assert_eq!(output.data[1][1], 2.0);
274 }
275
276 #[test]
277 fn test_counter_block_2x2_u8() {
278 let p = Parameters::new();
279 let mut block = CounterBlock::<(Matrix<2, 2, u8>, Matrix<2, 2, bool>)>::default();
280 let c = StubContext::default();
281
282 let mut increment = Matrix::<2, 2, u8>::zeroed();
283 increment.data[0][0] = 1;
284 increment.data[1][0] = 1;
285 increment.data[0][1] = 1;
286 increment.data[1][1] = 1;
287
288 let mut reset = Matrix::<2, 2, bool>::zeroed();
289
290 let output = block.process(&p, &c, (&increment, &reset));
291 assert_eq!(output.data[0][0], 1.0);
292 assert_eq!(output.data[1][0], 1.0);
293 assert_eq!(output.data[0][1], 1.0);
294 assert_eq!(output.data[1][1], 1.0);
295
296 let output = block.process(&p, &c, (&increment, &reset));
297 assert_eq!(output.data[0][0], 2.0);
298 assert_eq!(output.data[1][0], 2.0);
299 assert_eq!(output.data[0][1], 2.0);
300 assert_eq!(output.data[1][1], 2.0);
301
302 reset.data[0][0] = true;
303 let output = block.process(&p, &c, (&increment, &reset));
304 assert_eq!(output.data[0][0], 0.0);
305 assert_eq!(output.data[1][0], 3.0);
306 assert_eq!(output.data[0][1], 3.0);
307 assert_eq!(output.data[1][1], 3.0);
308
309 reset.data[0][0] = false;
310 reset.data[1][0] = true;
311 let output = block.process(&p, &c, (&increment, &reset));
312 assert_eq!(output.data[0][0], 1.0);
313 assert_eq!(output.data[1][0], 0.0);
314 assert_eq!(output.data[0][1], 4.0);
315 assert_eq!(output.data[1][1], 4.0);
316
317 reset.data[0][0] = false;
318 reset.data[1][0] = false;
319 reset.data[0][1] = true;
320 let output = block.process(&p, &c, (&increment, &reset));
321 assert_eq!(output.data[0][0], 2.0);
322 assert_eq!(output.data[1][0], 1.0);
323 assert_eq!(output.data[0][1], 0.0);
324 assert_eq!(output.data[1][1], 5.0);
325 }
326}