1extern crate alloc;
2use alloc::vec::Vec;
3use pictorus_block_data::{BlockData as OldBlockData, FromPass};
4use pictorus_traits::{ByteSliceSignal, Matrix, Pass, PassBy, ProcessBlock};
5
6use crate::traits::{CopyInto, DefaultStorage, Scalar};
7
8pub struct SwitchBlock<T: Apply>
52where
53 T::Output: DefaultStorage,
54 OldBlockData: FromPass<T::Output>,
55{
56 pub data: OldBlockData,
57 buffer: <T::Output as DefaultStorage>::Storage,
58}
59
60impl<T: Apply> Default for SwitchBlock<T>
61where
62 T::Output: DefaultStorage,
63 OldBlockData: FromPass<T::Output>,
64{
65 fn default() -> Self {
66 Self {
67 data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::from_storage(
68 &T::Output::default_storage(),
69 )),
70 buffer: T::Output::default_storage(),
71 }
72 }
73}
74
75impl<T: Apply> ProcessBlock for SwitchBlock<T>
76where
77 T::Output: DefaultStorage,
78 OldBlockData: FromPass<T::Output>,
79{
80 type Inputs = T;
81 type Output = T::Output;
82 type Parameters = T::Parameters;
83
84 fn process<'b>(
85 &'b mut self,
86 parameters: &Self::Parameters,
87 _context: &dyn pictorus_traits::Context,
88 inputs: PassBy<'_, Self::Inputs>,
89 ) -> PassBy<'b, Self::Output> {
90 T::apply(inputs, parameters, &mut self.buffer);
91 let res = T::Output::from_storage(&self.buffer);
92 self.data = <OldBlockData as FromPass<T::Output>>::from_pass(res);
93 res
94 }
95}
96
97pub struct Parameters<C: Scalar, const N: usize> {
99 pub cases: [C; N],
103}
104
105impl<const N: usize> Parameters<f64, N> {
108 pub fn new(cases: &OldBlockData) -> Self {
109 assert!(cases.len() == N, "Invalid number of switch cases");
110
111 let mut case_arr: [f64; N] = [0.0; N];
112 for (idx, case) in cases.iter().enumerate() {
113 case_arr[idx] = *case;
114 }
115 Self { cases: case_arr }
116 }
117}
118
119pub trait ApplyInto<C: Scalar, const N: usize>: Pass + DefaultStorage {
120 fn apply_into(
121 condition: C,
122 cases: &[C; N],
123 inputs: &[PassBy<Self>; N],
124 dest: &mut Self::Storage,
125 );
126}
127
128impl<C: Scalar, const N: usize> ApplyInto<C, N> for C {
129 fn apply_into(condition: C, cases: &[C; N], inputs: &[PassBy<C>; N], dest: &mut C) {
130 for (idx, case) in cases.iter().enumerate() {
131 if condition == *case {
132 let res = inputs[idx];
133 *dest = res;
134 return;
135 }
136 }
137 let res = inputs[inputs.len() - 1];
138 *dest = res;
139 }
140}
141
142impl<C: Scalar, const NROWS: usize, const NCOLS: usize, const N: usize> ApplyInto<C, N>
143 for Matrix<NROWS, NCOLS, C>
144{
145 fn apply_into(
146 condition: C,
147 cases: &[C; N],
148 inputs: &[PassBy<Matrix<NROWS, NCOLS, C>>; N],
149 dest: &mut Matrix<NROWS, NCOLS, C>,
150 ) {
151 for (idx, case) in cases.iter().enumerate() {
152 if condition == *case {
153 let res = inputs[idx];
154 Matrix::copy_into(res, dest);
155 return;
156 }
157 }
158 let res = inputs[inputs.len() - 1];
159 Matrix::copy_into(res, dest);
160 }
161}
162
163impl<C: Scalar, const N: usize> ApplyInto<C, N> for ByteSliceSignal {
164 fn apply_into(
165 condition: C,
166 cases: &[C; N],
167 inputs: &[PassBy<ByteSliceSignal>; N],
168 dest: &mut Vec<u8>,
169 ) {
170 for (idx, case) in cases.iter().enumerate() {
171 if condition == *case {
172 let res = inputs[idx];
173 dest.clear();
174 dest.extend_from_slice(res);
175 return;
176 }
177 }
178 let res = inputs[inputs.len() - 1];
179 dest.clear();
182 dest.extend_from_slice(res);
183 }
184}
185
186pub trait Apply: Pass {
187 type Parameters;
188 type Output: Pass + DefaultStorage;
189
190 fn apply(
191 input: PassBy<Self>,
192 params: &Self::Parameters,
193 buffer: &mut <Self::Output as DefaultStorage>::Storage,
194 );
195}
196
197impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 2>> Apply for (C, T, T) {
202 type Output = T;
203 type Parameters = Parameters<C, 2>;
204
205 fn apply(
206 input: PassBy<Self>,
207 params: &Self::Parameters,
208 buffer: &mut <Self::Output as DefaultStorage>::Storage,
209 ) {
210 let condition = input.0;
211 T::apply_into(condition, ¶ms.cases, &[input.1, input.2], buffer);
212 }
213}
214
215impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 3>> Apply for (C, T, T, T) {
217 type Output = T;
218 type Parameters = Parameters<C, 3>;
219
220 fn apply(
221 input: PassBy<Self>,
222 params: &Self::Parameters,
223 buffer: &mut <Self::Output as DefaultStorage>::Storage,
224 ) {
225 let condition = input.0;
226 T::apply_into(
227 condition,
228 ¶ms.cases,
229 &[input.1, input.2, input.3],
230 buffer,
231 );
232 }
233}
234
235impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 4>> Apply for (C, T, T, T, T) {
237 type Output = T;
238 type Parameters = Parameters<C, 4>;
239
240 fn apply(
241 input: PassBy<Self>,
242 params: &Self::Parameters,
243 buffer: &mut <Self::Output as DefaultStorage>::Storage,
244 ) {
245 let condition = input.0;
246 T::apply_into(
247 condition,
248 ¶ms.cases,
249 &[input.1, input.2, input.3, input.4],
250 buffer,
251 );
252 }
253}
254
255impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 5>> Apply for (C, T, T, T, T, T) {
257 type Output = T;
258 type Parameters = Parameters<C, 5>;
259
260 fn apply(
261 input: PassBy<Self>,
262 params: &Self::Parameters,
263 buffer: &mut <Self::Output as DefaultStorage>::Storage,
264 ) {
265 let condition = input.0;
266 T::apply_into(
267 condition,
268 ¶ms.cases,
269 &[input.1, input.2, input.3, input.4, input.5],
270 buffer,
271 );
272 }
273}
274
275impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 6>> Apply for (C, T, T, T, T, T, T) {
277 type Output = T;
278 type Parameters = Parameters<C, 6>;
279
280 fn apply(
281 input: PassBy<Self>,
282 params: &Self::Parameters,
283 buffer: &mut <Self::Output as DefaultStorage>::Storage,
284 ) {
285 let condition = input.0;
286 T::apply_into(
287 condition,
288 ¶ms.cases,
289 &[input.1, input.2, input.3, input.4, input.5, input.6],
290 buffer,
291 );
292 }
293}
294
295impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 7>> Apply for (C, T, T, T, T, T, T, T) {
297 type Output = T;
298 type Parameters = Parameters<C, 7>;
299
300 fn apply(
301 input: PassBy<Self>,
302 params: &Self::Parameters,
303 buffer: &mut <Self::Output as DefaultStorage>::Storage,
304 ) {
305 let condition = input.0;
306 T::apply_into(
307 condition,
308 ¶ms.cases,
309 &[
310 input.1, input.2, input.3, input.4, input.5, input.6, input.7,
311 ],
312 buffer,
313 );
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use crate::traits::MatrixOps;
320
321 use super::*;
322 use crate::testing::StubContext;
323
324 #[test]
325 fn test_switch_block_2_scalars() {
326 let ctxt = StubContext::default();
327
328 let mut block = SwitchBlock::<(f64, f64, f64)>::default();
329 let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
330
331 let input = (0.0, 1.0, 2.0);
332 let output = block.process(¶meters, &ctxt, input);
333 assert_eq!(output, 1.0);
334 assert_eq!(block.data.scalar(), 1.0);
335 }
336
337 #[test]
338 fn test_switch_block_7_scalars() {
339 let ctxt = StubContext::default();
340
341 let mut block = SwitchBlock::<(f64, f64, f64, f64, f64, f64, f64, f64)>::default();
342 let parameters = Parameters::new(&OldBlockData::from_vector(&[
343 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
344 ]));
345
346 let input = (6.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0);
347 let output = block.process(¶meters, &ctxt, input);
348 assert_eq!(output, 7.0);
349 assert_eq!(block.data.scalar(), 7.0);
350 }
351
352 #[test]
353 fn test_switch_block_scalar_default() {
354 let ctxt = StubContext::default();
355
356 let mut block = SwitchBlock::<(f64, f64, f64)>::default();
357 let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
358
359 let input = (1.2345, 1.0, 2.0);
361 let output = block.process(¶meters, &ctxt, input);
362 assert_eq!(output, 2.0);
363 assert_eq!(block.data.scalar(), 2.0);
364 }
365
366 #[test]
367 fn test_switch_block_2_matrices() {
368 let ctxt = StubContext::default();
369
370 let mut block = SwitchBlock::<(f64, Matrix<3, 3, f64>, Matrix<3, 3, f64>)>::default();
371 let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
372
373 let input = (0.0, &Matrix::from_element(1.0), &Matrix::from_element(2.0));
374 let output = block.process(¶meters, &ctxt, input);
375 let expected = Matrix::from_element(1.0);
376 assert_eq!(output, &expected);
377 assert_eq!(
378 block.data.get_data().as_slice(),
379 expected.data.as_flattened()
380 );
381 }
382
383 #[test]
384 fn test_switch_block_7_matrices() {
385 let ctxt = StubContext::default();
386
387 let mut block = SwitchBlock::<(
388 f64,
389 Matrix<3, 3, f64>,
390 Matrix<3, 3, f64>,
391 Matrix<3, 3, f64>,
392 Matrix<3, 3, f64>,
393 Matrix<3, 3, f64>,
394 Matrix<3, 3, f64>,
395 Matrix<3, 3, f64>,
396 )>::default();
397 let parameters = Parameters::new(&OldBlockData::from_vector(&[
398 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
399 ]));
400
401 let input = (
402 6.0,
403 &Matrix::from_element(1.0),
404 &Matrix::from_element(2.0),
405 &Matrix::from_element(3.0),
406 &Matrix::from_element(4.0),
407 &Matrix::from_element(5.0),
408 &Matrix::from_element(6.0),
409 &Matrix::from_element(7.0),
410 );
411 let output = block.process(¶meters, &ctxt, input);
412 let expected = Matrix::from_element(7.0);
413 assert_eq!(output, &expected);
414 assert_eq!(
415 block.data.get_data().as_slice(),
416 expected.data.as_flattened()
417 );
418 }
419
420 #[test]
421 fn test_switch_block_matrix_default() {
422 let ctxt = StubContext::default();
423
424 let mut block = SwitchBlock::<(f64, Matrix<3, 3, f64>, Matrix<3, 3, f64>)>::default();
425 let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
426
427 let input = (
429 1.2345,
430 &Matrix::from_element(1.0),
431 &Matrix::from_element(2.0),
432 );
433 let output = block.process(¶meters, &ctxt, input);
434 let expected = Matrix::from_element(2.0);
435 assert_eq!(output, &expected);
436 assert_eq!(
437 block.data.get_data().as_slice(),
438 expected.data.as_flattened()
439 );
440 }
441
442 #[test]
443 fn test_switch_block_2_bytes() {
444 let ctxt = StubContext::default();
445
446 let mut block = SwitchBlock::<(f64, ByteSliceSignal, ByteSliceSignal)>::default();
447 let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
448
449 let input = (0.0, b"foo".as_slice(), b"bar".as_slice());
450 let output = block.process(¶meters, &ctxt, input);
451 assert_eq!(output, b"foo");
452 assert_eq!(block.data.raw_string().as_bytes(), b"foo".as_slice());
453 }
454
455 #[test]
456 fn test_switch_block_2_bytes_default() {
457 let ctxt = StubContext::default();
458
459 let mut block = SwitchBlock::<(f64, ByteSliceSignal, ByteSliceSignal)>::default();
460 let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
461
462 let input = (1.2345, b"foo".as_slice(), b"bar".as_slice());
464 let output = block.process(¶meters, &ctxt, input);
465 assert_eq!(output, b"bar");
466 assert_eq!(block.data.raw_string().as_bytes(), b"bar".as_slice());
467 }
468
469 #[test]
470 fn test_switch_block_7_bytes() {
471 let ctxt = StubContext::default();
472
473 let mut block = SwitchBlock::<(
474 f64,
475 ByteSliceSignal,
476 ByteSliceSignal,
477 ByteSliceSignal,
478 ByteSliceSignal,
479 ByteSliceSignal,
480 ByteSliceSignal,
481 ByteSliceSignal,
482 )>::default();
483 let parameters = Parameters::new(&OldBlockData::from_vector(&[
484 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
485 ]));
486
487 let input = (
488 6.0,
489 b"foo".as_slice(),
490 b"bar".as_slice(),
491 b"baz".as_slice(),
492 b"qux".as_slice(),
493 b"quux".as_slice(),
494 b"corge".as_slice(),
495 b"grault".as_slice(),
496 );
497 let output = block.process(¶meters, &ctxt, input);
498 assert_eq!(output, b"grault");
499 assert_eq!(block.data.raw_string().as_bytes(), b"grault".as_slice());
500 }
501}