extern crate alloc;
use alloc::vec::Vec;
use pictorus_block_data::{BlockData as OldBlockData, FromPass};
use pictorus_traits::{ByteSliceSignal, Matrix, Pass, PassBy, ProcessBlock};
use crate::traits::{CopyInto, DefaultStorage, Scalar};
pub struct SwitchBlock<T: Apply>
where
T::Output: DefaultStorage,
OldBlockData: FromPass<T::Output>,
{
pub data: OldBlockData,
buffer: <T::Output as DefaultStorage>::Storage,
}
impl<T: Apply> Default for SwitchBlock<T>
where
T::Output: DefaultStorage,
OldBlockData: FromPass<T::Output>,
{
fn default() -> Self {
Self {
data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::from_storage(
&T::Output::default_storage(),
)),
buffer: T::Output::default_storage(),
}
}
}
impl<T: Apply> ProcessBlock for SwitchBlock<T>
where
T::Output: DefaultStorage,
OldBlockData: FromPass<T::Output>,
{
type Inputs = T;
type Output = T::Output;
type Parameters = T::Parameters;
fn process<'b>(
&'b mut self,
parameters: &Self::Parameters,
_context: &dyn pictorus_traits::Context,
inputs: PassBy<'_, Self::Inputs>,
) -> PassBy<'b, Self::Output> {
T::apply(inputs, parameters, &mut self.buffer);
let res = T::Output::from_storage(&self.buffer);
self.data = <OldBlockData as FromPass<T::Output>>::from_pass(res);
res
}
}
pub struct Parameters<C: Scalar, const N: usize> {
pub cases: [C; N],
}
impl<const N: usize> Parameters<f64, N> {
pub fn new(cases: &OldBlockData) -> Self {
assert!(cases.len() == N, "Invalid number of switch cases");
let mut case_arr: [f64; N] = [0.0; N];
for (idx, case) in cases.iter().enumerate() {
case_arr[idx] = *case;
}
Self { cases: case_arr }
}
}
pub trait ApplyInto<C: Scalar, const N: usize>: Pass + DefaultStorage {
fn apply_into(
condition: C,
cases: &[C; N],
inputs: &[PassBy<Self>; N],
dest: &mut Self::Storage,
);
}
impl<C: Scalar, const N: usize> ApplyInto<C, N> for C {
fn apply_into(condition: C, cases: &[C; N], inputs: &[PassBy<C>; N], dest: &mut C) {
for (idx, case) in cases.iter().enumerate() {
if condition == *case {
let res = inputs[idx];
*dest = res;
return;
}
}
let res = inputs[inputs.len() - 1];
*dest = res;
}
}
impl<C: Scalar, const NROWS: usize, const NCOLS: usize, const N: usize> ApplyInto<C, N>
for Matrix<NROWS, NCOLS, C>
{
fn apply_into(
condition: C,
cases: &[C; N],
inputs: &[PassBy<Matrix<NROWS, NCOLS, C>>; N],
dest: &mut Matrix<NROWS, NCOLS, C>,
) {
for (idx, case) in cases.iter().enumerate() {
if condition == *case {
let res = inputs[idx];
Matrix::copy_into(res, dest);
return;
}
}
let res = inputs[inputs.len() - 1];
Matrix::copy_into(res, dest);
}
}
impl<C: Scalar, const N: usize> ApplyInto<C, N> for ByteSliceSignal {
fn apply_into(
condition: C,
cases: &[C; N],
inputs: &[PassBy<ByteSliceSignal>; N],
dest: &mut Vec<u8>,
) {
for (idx, case) in cases.iter().enumerate() {
if condition == *case {
let res = inputs[idx];
dest.clear();
dest.extend_from_slice(res);
return;
}
}
let res = inputs[inputs.len() - 1];
dest.clear();
dest.extend_from_slice(res);
}
}
pub trait Apply: Pass {
type Parameters;
type Output: Pass + DefaultStorage;
fn apply(
input: PassBy<Self>,
params: &Self::Parameters,
buffer: &mut <Self::Output as DefaultStorage>::Storage,
);
}
impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 2>> Apply for (C, T, T) {
type Output = T;
type Parameters = Parameters<C, 2>;
fn apply(
input: PassBy<Self>,
params: &Self::Parameters,
buffer: &mut <Self::Output as DefaultStorage>::Storage,
) {
let condition = input.0;
T::apply_into(condition, ¶ms.cases, &[input.1, input.2], buffer);
}
}
impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 3>> Apply for (C, T, T, T) {
type Output = T;
type Parameters = Parameters<C, 3>;
fn apply(
input: PassBy<Self>,
params: &Self::Parameters,
buffer: &mut <Self::Output as DefaultStorage>::Storage,
) {
let condition = input.0;
T::apply_into(
condition,
¶ms.cases,
&[input.1, input.2, input.3],
buffer,
);
}
}
impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 4>> Apply for (C, T, T, T, T) {
type Output = T;
type Parameters = Parameters<C, 4>;
fn apply(
input: PassBy<Self>,
params: &Self::Parameters,
buffer: &mut <Self::Output as DefaultStorage>::Storage,
) {
let condition = input.0;
T::apply_into(
condition,
¶ms.cases,
&[input.1, input.2, input.3, input.4],
buffer,
);
}
}
impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 5>> Apply for (C, T, T, T, T, T) {
type Output = T;
type Parameters = Parameters<C, 5>;
fn apply(
input: PassBy<Self>,
params: &Self::Parameters,
buffer: &mut <Self::Output as DefaultStorage>::Storage,
) {
let condition = input.0;
T::apply_into(
condition,
¶ms.cases,
&[input.1, input.2, input.3, input.4, input.5],
buffer,
);
}
}
impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 6>> Apply for (C, T, T, T, T, T, T) {
type Output = T;
type Parameters = Parameters<C, 6>;
fn apply(
input: PassBy<Self>,
params: &Self::Parameters,
buffer: &mut <Self::Output as DefaultStorage>::Storage,
) {
let condition = input.0;
T::apply_into(
condition,
¶ms.cases,
&[input.1, input.2, input.3, input.4, input.5, input.6],
buffer,
);
}
}
impl<C: Scalar, T: Pass + DefaultStorage + ApplyInto<C, 7>> Apply for (C, T, T, T, T, T, T, T) {
type Output = T;
type Parameters = Parameters<C, 7>;
fn apply(
input: PassBy<Self>,
params: &Self::Parameters,
buffer: &mut <Self::Output as DefaultStorage>::Storage,
) {
let condition = input.0;
T::apply_into(
condition,
¶ms.cases,
&[
input.1, input.2, input.3, input.4, input.5, input.6, input.7,
],
buffer,
);
}
}
#[cfg(test)]
mod tests {
use crate::traits::MatrixOps;
use super::*;
use crate::testing::StubContext;
#[test]
fn test_switch_block_2_scalars() {
let ctxt = StubContext::default();
let mut block = SwitchBlock::<(f64, f64, f64)>::default();
let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
let input = (0.0, 1.0, 2.0);
let output = block.process(¶meters, &ctxt, input);
assert_eq!(output, 1.0);
assert_eq!(block.data.scalar(), 1.0);
}
#[test]
fn test_switch_block_7_scalars() {
let ctxt = StubContext::default();
let mut block = SwitchBlock::<(f64, f64, f64, f64, f64, f64, f64, f64)>::default();
let parameters = Parameters::new(&OldBlockData::from_vector(&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
]));
let input = (6.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0);
let output = block.process(¶meters, &ctxt, input);
assert_eq!(output, 7.0);
assert_eq!(block.data.scalar(), 7.0);
}
#[test]
fn test_switch_block_scalar_default() {
let ctxt = StubContext::default();
let mut block = SwitchBlock::<(f64, f64, f64)>::default();
let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
let input = (1.2345, 1.0, 2.0);
let output = block.process(¶meters, &ctxt, input);
assert_eq!(output, 2.0);
assert_eq!(block.data.scalar(), 2.0);
}
#[test]
fn test_switch_block_2_matrices() {
let ctxt = StubContext::default();
let mut block = SwitchBlock::<(f64, Matrix<3, 3, f64>, Matrix<3, 3, f64>)>::default();
let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
let input = (0.0, &Matrix::from_element(1.0), &Matrix::from_element(2.0));
let output = block.process(¶meters, &ctxt, input);
let expected = Matrix::from_element(1.0);
assert_eq!(output, &expected);
assert_eq!(
block.data.get_data().as_slice(),
expected.data.as_flattened()
);
}
#[test]
fn test_switch_block_7_matrices() {
let ctxt = StubContext::default();
let mut block = SwitchBlock::<(
f64,
Matrix<3, 3, f64>,
Matrix<3, 3, f64>,
Matrix<3, 3, f64>,
Matrix<3, 3, f64>,
Matrix<3, 3, f64>,
Matrix<3, 3, f64>,
Matrix<3, 3, f64>,
)>::default();
let parameters = Parameters::new(&OldBlockData::from_vector(&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
]));
let input = (
6.0,
&Matrix::from_element(1.0),
&Matrix::from_element(2.0),
&Matrix::from_element(3.0),
&Matrix::from_element(4.0),
&Matrix::from_element(5.0),
&Matrix::from_element(6.0),
&Matrix::from_element(7.0),
);
let output = block.process(¶meters, &ctxt, input);
let expected = Matrix::from_element(7.0);
assert_eq!(output, &expected);
assert_eq!(
block.data.get_data().as_slice(),
expected.data.as_flattened()
);
}
#[test]
fn test_switch_block_matrix_default() {
let ctxt = StubContext::default();
let mut block = SwitchBlock::<(f64, Matrix<3, 3, f64>, Matrix<3, 3, f64>)>::default();
let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
let input = (
1.2345,
&Matrix::from_element(1.0),
&Matrix::from_element(2.0),
);
let output = block.process(¶meters, &ctxt, input);
let expected = Matrix::from_element(2.0);
assert_eq!(output, &expected);
assert_eq!(
block.data.get_data().as_slice(),
expected.data.as_flattened()
);
}
#[test]
fn test_switch_block_2_bytes() {
let ctxt = StubContext::default();
let mut block = SwitchBlock::<(f64, ByteSliceSignal, ByteSliceSignal)>::default();
let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
let input = (0.0, b"foo".as_slice(), b"bar".as_slice());
let output = block.process(¶meters, &ctxt, input);
assert_eq!(output, b"foo");
assert_eq!(block.data.raw_string().as_bytes(), b"foo".as_slice());
}
#[test]
fn test_switch_block_2_bytes_default() {
let ctxt = StubContext::default();
let mut block = SwitchBlock::<(f64, ByteSliceSignal, ByteSliceSignal)>::default();
let parameters = Parameters::new(&OldBlockData::from_vector(&[0.0, 1.0]));
let input = (1.2345, b"foo".as_slice(), b"bar".as_slice());
let output = block.process(¶meters, &ctxt, input);
assert_eq!(output, b"bar");
assert_eq!(block.data.raw_string().as_bytes(), b"bar".as_slice());
}
#[test]
fn test_switch_block_7_bytes() {
let ctxt = StubContext::default();
let mut block = SwitchBlock::<(
f64,
ByteSliceSignal,
ByteSliceSignal,
ByteSliceSignal,
ByteSliceSignal,
ByteSliceSignal,
ByteSliceSignal,
ByteSliceSignal,
)>::default();
let parameters = Parameters::new(&OldBlockData::from_vector(&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
]));
let input = (
6.0,
b"foo".as_slice(),
b"bar".as_slice(),
b"baz".as_slice(),
b"qux".as_slice(),
b"quux".as_slice(),
b"corge".as_slice(),
b"grault".as_slice(),
);
let output = block.process(¶meters, &ctxt, input);
assert_eq!(output, b"grault");
assert_eq!(block.data.raw_string().as_bytes(), b"grault".as_slice());
}
}