Skip to main content

pictorus_blocks/core_blocks/
switch_block.rs

1extern crate alloc;
2use alloc::vec::Vec;
3use pictorus_block_data::{BlockData as OldBlockData, FromPass};
4use pictorus_traits::{ByteSliceSignal, Matrix, Pass, PassBy, ProcessBlock};
5
6use crate::traits::{CopyInto, DefaultStorage, Scalar};
7
8/// Switches between multiple input signals based on a condition.
9///
10/// The condition is the first input, and the rest are the signals to switch between.
11/// The block will output the signal that corresponds to the index of the `cases`` parameter
12/// that matches the condition input. If no matches are found, it will output the last input.
13/// For example:
14/// ```
15/// use core::time::Duration;
16/// use pictorus_blocks::SwitchBlock;
17/// use pictorus_traits::ProcessBlock;
18/// use pictorus_block_data::BlockData as OldBlockData;
19/// use pictorus_traits::Context;
20///
21/// #[derive(Default)]
22/// struct StubContext {}
23///
24/// impl Context for StubContext {
25///     fn time(&self) -> Duration {
26///         Duration::from_secs(0)
27///     }
28///
29///     fn timestep(&self) -> Option<Duration> {
30///         None
31///     }
32///
33///     fn fundamental_timestep(&self) -> Duration {
34///         Duration::from_millis(100)
35///     }
36/// }
37///
38/// let ctxt = StubContext::default();
39/// let mut block = SwitchBlock::<(f64, f64, f64)>::default();
40/// // If condition is 0, output the signal at index 0
41/// // If condition is 1, output the signal at index 1
42/// // If condition is anything else, output the signal at index 1
43/// let cases = OldBlockData::from_vector(&[0.0, 1.0]);
44/// let parameters = <SwitchBlock<(f64, f64, f64)> as ProcessBlock>::Parameters::new(&cases);
45/// // Here we have a condition of 0.0, and inputs of [1.0, 2.0]
46/// // Since condition matches case 0, the output will be 1.0
47/// let input = (0.0, 1.0, 2.0);
48/// let output = block.process(&parameters, &ctxt, input);
49/// assert_eq!(output, 1.0);
50///
51pub struct SwitchBlock<T: Apply>
52where
53    T::Output: DefaultStorage,
54    OldBlockData: FromPass<T::Output>,
55{
56    pub data: OldBlockData,
57    buffer: <T::Output as DefaultStorage>::Storage,
58}
59
60impl<T: Apply> Default for SwitchBlock<T>
61where
62    T::Output: DefaultStorage,
63    OldBlockData: FromPass<T::Output>,
64{
65    fn default() -> Self {
66        Self {
67            data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::from_storage(
68                &T::Output::default_storage(),
69            )),
70            buffer: T::Output::default_storage(),
71        }
72    }
73}
74
75impl<T: Apply> ProcessBlock for SwitchBlock<T>
76where
77    T::Output: DefaultStorage,
78    OldBlockData: FromPass<T::Output>,
79{
80    type Inputs = T;
81    type Output = T::Output;
82    type Parameters = T::Parameters;
83
84    fn process<'b>(
85        &'b mut self,
86        parameters: &Self::Parameters,
87        _context: &dyn pictorus_traits::Context,
88        inputs: PassBy<'_, Self::Inputs>,
89    ) -> PassBy<'b, Self::Output> {
90        T::apply(inputs, parameters, &mut self.buffer);
91        let res = T::Output::from_storage(&self.buffer);
92        self.data = <OldBlockData as FromPass<T::Output>>::from_pass(res);
93        res
94    }
95}
96
97/// Parameters for the SwitchBlock
98pub struct Parameters<C: Scalar, const N: usize> {
99    /// The cases to compare the input condition against
100    /// The cases array must be exactly the same length as the number of inputs
101    /// The last case is the default value
102    pub cases: [C; N],
103}
104
105// TODO: This is currently only implemented for f64 and is constructed from OldBlockData.
106// In the future this should either accept an array of [C; N] or a &[C]
107impl<const N: usize> Parameters<f64, N> {
108    pub fn new(cases: &OldBlockData) -> Self {
109        assert!(cases.len() == N, "Invalid number of switch cases");
110
111        let mut case_arr: [f64; N] = [0.0; N];
112        for (idx, case) in cases.iter().enumerate() {
113            case_arr[idx] = *case;
114        }
115        Self { cases: case_arr }
116    }
117}
118
119pub trait ApplyInto<C: Scalar, const N: usize>: Pass + DefaultStorage {
120    fn apply_into(
121        condition: C,
122        cases: &[C; N],
123        inputs: &[PassBy<Self>; N],
124        dest: &mut Self::Storage,
125    );
126}
127
128impl<C: Scalar, const N: usize> ApplyInto<C, N> for C {
129    fn apply_into(condition: C, cases: &[C; N], inputs: &[PassBy<C>; N], dest: &mut C) {
130        for (idx, case) in cases.iter().enumerate() {
131            if condition == *case {
132                let res = inputs[idx];
133                *dest = res;
134                return;
135            }
136        }
137        let res = inputs[inputs.len() - 1];
138        *dest = res;
139    }
140}
141
142impl<C: Scalar, const NROWS: usize, const NCOLS: usize, const N: usize> ApplyInto<C, N>
143    for Matrix<NROWS, NCOLS, C>
144{
145    fn apply_into(
146        condition: C,
147        cases: &[C; N],
148        inputs: &[PassBy<Matrix<NROWS, NCOLS, C>>; N],
149        dest: &mut Matrix<NROWS, NCOLS, C>,
150    ) {
151        for (idx, case) in cases.iter().enumerate() {
152            if condition == *case {
153                let res = inputs[idx];
154                Matrix::copy_into(res, dest);
155                return;
156            }
157        }
158        let res = inputs[inputs.len() - 1];
159        Matrix::copy_into(res, dest);
160    }
161}
162
163impl<C: Scalar, const N: usize> ApplyInto<C, N> for ByteSliceSignal {
164    fn apply_into(
165        condition: C,
166        cases: &[C; N],
167        inputs: &[PassBy<ByteSliceSignal>; N],
168        dest: &mut Vec<u8>,
169    ) {
170        for (idx, case) in cases.iter().enumerate() {
171            if condition == *case {
172                let res = inputs[idx];
173                dest.clear();
174                dest.extend_from_slice(res);
175                return;
176            }
177        }
178        let res = inputs[inputs.len() - 1];
179        // We use clear and extend rather than copy_from_slice because
180        // copy_from_slice requires the destination to be the same length as the source
181        dest.clear();
182        dest.extend_from_slice(res);
183    }
184}
185
186pub trait Apply: Pass {
187    type Parameters;
188    type Output: Pass + DefaultStorage;
189
190    fn apply(
191        input: PassBy<Self>,
192        params: &Self::Parameters,
193        buffer: &mut <Self::Output as DefaultStorage>::Storage,
194    );
195}
196
197// SwitchBlock requires at least 3 inputs. The first is the condition,
198// the rest are inputs to maybe pass through
199
200// 1 condition + 2 inputs
201impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 2>> Apply for (C, T, T) {
202    type Output = T;
203    type Parameters = Parameters<C, 2>;
204
205    fn apply(
206        input: PassBy<Self>,
207        params: &Self::Parameters,
208        buffer: &mut <Self::Output as DefaultStorage>::Storage,
209    ) {
210        let condition = input.0;
211        T::apply_into(condition, &params.cases, &[input.1, input.2], buffer);
212    }
213}
214
215// 1 condition + 3 inputs
216impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 3>> Apply for (C, T, T, T) {
217    type Output = T;
218    type Parameters = Parameters<C, 3>;
219
220    fn apply(
221        input: PassBy<Self>,
222        params: &Self::Parameters,
223        buffer: &mut <Self::Output as DefaultStorage>::Storage,
224    ) {
225        let condition = input.0;
226        T::apply_into(
227            condition,
228            &params.cases,
229            &[input.1, input.2, input.3],
230            buffer,
231        );
232    }
233}
234
235// 1 condition + 4 inputs
236impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 4>> Apply for (C, T, T, T, T) {
237    type Output = T;
238    type Parameters = Parameters<C, 4>;
239
240    fn apply(
241        input: PassBy<Self>,
242        params: &Self::Parameters,
243        buffer: &mut <Self::Output as DefaultStorage>::Storage,
244    ) {
245        let condition = input.0;
246        T::apply_into(
247            condition,
248            &params.cases,
249            &[input.1, input.2, input.3, input.4],
250            buffer,
251        );
252    }
253}
254
255// 1 condition + 5 inputs
256impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 5>> Apply for (C, T, T, T, T, T) {
257    type Output = T;
258    type Parameters = Parameters<C, 5>;
259
260    fn apply(
261        input: PassBy<Self>,
262        params: &Self::Parameters,
263        buffer: &mut <Self::Output as DefaultStorage>::Storage,
264    ) {
265        let condition = input.0;
266        T::apply_into(
267            condition,
268            &params.cases,
269            &[input.1, input.2, input.3, input.4, input.5],
270            buffer,
271        );
272    }
273}
274
275// 1 condition + 6 inputs
276impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 6>> Apply for (C, T, T, T, T, T, T) {
277    type Output = T;
278    type Parameters = Parameters<C, 6>;
279
280    fn apply(
281        input: PassBy<Self>,
282        params: &Self::Parameters,
283        buffer: &mut <Self::Output as DefaultStorage>::Storage,
284    ) {
285        let condition = input.0;
286        T::apply_into(
287            condition,
288            &params.cases,
289            &[input.1, input.2, input.3, input.4, input.5, input.6],
290            buffer,
291        );
292    }
293}
294
295// 1 condition + 7 inputs
296impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 7>> Apply for (C, T, T, T, T, T, T, T) {
297    type Output = T;
298    type Parameters = Parameters<C, 7>;
299
300    fn apply(
301        input: PassBy<Self>,
302        params: &Self::Parameters,
303        buffer: &mut <Self::Output as DefaultStorage>::Storage,
304    ) {
305        let condition = input.0;
306        T::apply_into(
307            condition,
308            &params.cases,
309            &[
310                input.1, input.2, input.3, input.4, input.5, input.6, input.7,
311            ],
312            buffer,
313        );
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use crate::traits::MatrixOps;
320
321    use super::*;
322    use crate::testing::StubContext;
323
324    #[test]
325    fn test_switch_block_2_scalars() {
326        let ctxt = StubContext::default();
327
328        let mut block = SwitchBlock::<(f64, f64, f64)>::default();
329        let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
330
331        let input = (0.0, 1.0, 2.0);
332        let output = block.process(&parameters, &ctxt, input);
333        assert_eq!(output, 1.0);
334        assert_eq!(block.data.scalar(), 1.0);
335    }
336
337    #[test]
338    fn test_switch_block_7_scalars() {
339        let ctxt = StubContext::default();
340
341        let mut block = SwitchBlock::<(f64, f64, f64, f64, f64, f64, f64, f64)>::default();
342        let parameters = Parameters::new(&OldBlockData::from_vector(&[
343            0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
344        ]));
345
346        let input = (6.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0);
347        let output = block.process(&parameters, &ctxt, input);
348        assert_eq!(output, 7.0);
349        assert_eq!(block.data.scalar(), 7.0);
350    }
351
352    #[test]
353    fn test_switch_block_scalar_default() {
354        let ctxt = StubContext::default();
355
356        let mut block = SwitchBlock::<(f64, f64, f64)>::default();
357        let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
358
359        // Should use the last value by default
360        let input = (1.2345, 1.0, 2.0);
361        let output = block.process(&parameters, &ctxt, input);
362        assert_eq!(output, 2.0);
363        assert_eq!(block.data.scalar(), 2.0);
364    }
365
366    #[test]
367    fn test_switch_block_2_matrices() {
368        let ctxt = StubContext::default();
369
370        let mut block = SwitchBlock::<(f64, Matrix<3, 3, f64>, Matrix<3, 3, f64>)>::default();
371        let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
372
373        let input = (0.0, &Matrix::from_element(1.0), &Matrix::from_element(2.0));
374        let output = block.process(&parameters, &ctxt, input);
375        let expected = Matrix::from_element(1.0);
376        assert_eq!(output, &expected);
377        assert_eq!(
378            block.data.get_data().as_slice(),
379            expected.data.as_flattened()
380        );
381    }
382
383    #[test]
384    fn test_switch_block_7_matrices() {
385        let ctxt = StubContext::default();
386
387        let mut block = SwitchBlock::<(
388            f64,
389            Matrix<3, 3, f64>,
390            Matrix<3, 3, f64>,
391            Matrix<3, 3, f64>,
392            Matrix<3, 3, f64>,
393            Matrix<3, 3, f64>,
394            Matrix<3, 3, f64>,
395            Matrix<3, 3, f64>,
396        )>::default();
397        let parameters = Parameters::new(&OldBlockData::from_vector(&[
398            0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
399        ]));
400
401        let input = (
402            6.0,
403            &Matrix::from_element(1.0),
404            &Matrix::from_element(2.0),
405            &Matrix::from_element(3.0),
406            &Matrix::from_element(4.0),
407            &Matrix::from_element(5.0),
408            &Matrix::from_element(6.0),
409            &Matrix::from_element(7.0),
410        );
411        let output = block.process(&parameters, &ctxt, input);
412        let expected = Matrix::from_element(7.0);
413        assert_eq!(output, &expected);
414        assert_eq!(
415            block.data.get_data().as_slice(),
416            expected.data.as_flattened()
417        );
418    }
419
420    #[test]
421    fn test_switch_block_matrix_default() {
422        let ctxt = StubContext::default();
423
424        let mut block = SwitchBlock::<(f64, Matrix<3, 3, f64>, Matrix<3, 3, f64>)>::default();
425        let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
426
427        // Should use the last value by default
428        let input = (
429            1.2345,
430            &Matrix::from_element(1.0),
431            &Matrix::from_element(2.0),
432        );
433        let output = block.process(&parameters, &ctxt, input);
434        let expected = Matrix::from_element(2.0);
435        assert_eq!(output, &expected);
436        assert_eq!(
437            block.data.get_data().as_slice(),
438            expected.data.as_flattened()
439        );
440    }
441
442    #[test]
443    fn test_switch_block_2_bytes() {
444        let ctxt = StubContext::default();
445
446        let mut block = SwitchBlock::<(f64, ByteSliceSignal, ByteSliceSignal)>::default();
447        let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
448
449        let input = (0.0, b"foo".as_slice(), b"bar".as_slice());
450        let output = block.process(&parameters, &ctxt, input);
451        assert_eq!(output, b"foo");
452        assert_eq!(block.data.raw_string().as_bytes(), b"foo".as_slice());
453    }
454
455    #[test]
456    fn test_switch_block_2_bytes_default() {
457        let ctxt = StubContext::default();
458
459        let mut block = SwitchBlock::<(f64, ByteSliceSignal, ByteSliceSignal)>::default();
460        let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
461
462        // Should use the last value by default
463        let input = (1.2345, b"foo".as_slice(), b"bar".as_slice());
464        let output = block.process(&parameters, &ctxt, input);
465        assert_eq!(output, b"bar");
466        assert_eq!(block.data.raw_string().as_bytes(), b"bar".as_slice());
467    }
468
469    #[test]
470    fn test_switch_block_7_bytes() {
471        let ctxt = StubContext::default();
472
473        let mut block = SwitchBlock::<(
474            f64,
475            ByteSliceSignal,
476            ByteSliceSignal,
477            ByteSliceSignal,
478            ByteSliceSignal,
479            ByteSliceSignal,
480            ByteSliceSignal,
481            ByteSliceSignal,
482        )>::default();
483        let parameters = Parameters::new(&OldBlockData::from_vector(&[
484            0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
485        ]));
486
487        let input = (
488            6.0,
489            b"foo".as_slice(),
490            b"bar".as_slice(),
491            b"baz".as_slice(),
492            b"qux".as_slice(),
493            b"quux".as_slice(),
494            b"corge".as_slice(),
495            b"grault".as_slice(),
496        );
497        let output = block.process(&parameters, &ctxt, input);
498        assert_eq!(output, b"grault");
499        assert_eq!(block.data.raw_string().as_bytes(), b"grault".as_slice());
500    }
501}