Skip to main content

pictorus_blocks/core_blocks/
transpose_block.rs

1use crate::nalgebra_interop::MatrixExt;
2use pictorus_block_data::{BlockData as OldBlockData, FromPass};
3use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
4
5use crate::traits::Scalar;
6
7pub struct Parameters {}
8
9impl Parameters {
10    pub fn new() -> Self {
11        Self {}
12    }
13}
14
15impl Default for Parameters {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20/// Outputs the transpose of the input signal.
21///
22/// For scalar inputs this is just a pass-through
23pub struct TransposeBlock<T: Apply> {
24    pub data: OldBlockData,
25    store: Option<T::Output>,
26}
27
28impl<T: Apply> Default for TransposeBlock<T>
29where
30    OldBlockData: FromPass<T::Output>,
31{
32    fn default() -> Self {
33        Self {
34            data: <OldBlockData as FromPass<T::Output>>::from_pass(<T::Output>::default().as_by()),
35            store: None,
36        }
37    }
38}
39
40impl<T: Apply> ProcessBlock for TransposeBlock<T>
41where
42    OldBlockData: FromPass<T::Output>,
43{
44    type Inputs = T;
45    type Output = T::Output;
46    type Parameters = Parameters;
47
48    fn process(
49        &mut self,
50        _parameters: &Self::Parameters,
51        _context: &dyn pictorus_traits::Context,
52        input: PassBy<Self::Inputs>,
53    ) -> PassBy<Self::Output> {
54        let output = T::apply(&mut self.store, input);
55        self.data = OldBlockData::from_pass(output);
56        output
57    }
58}
59
60pub trait Apply: Pass {
61    type Output: Pass + Default;
62
63    fn apply<'s>(
64        store: &'s mut Option<Self::Output>,
65        input: PassBy<Self>,
66    ) -> PassBy<'s, Self::Output>;
67}
68
69impl<S: Scalar> Apply for S {
70    type Output = S;
71
72    fn apply<'s>(
73        store: &'s mut Option<Self::Output>,
74        input: PassBy<Self>,
75    ) -> PassBy<'s, Self::Output> {
76        let output = store.insert(input);
77        output.as_by()
78    }
79}
80
81impl<const NROWS: usize, const NCOLS: usize, S: Scalar> Apply for Matrix<NROWS, NCOLS, S> {
82    type Output = Matrix<NCOLS, NROWS, S>;
83
84    fn apply<'s>(
85        store: &'s mut Option<Self::Output>,
86        input: PassBy<Self>,
87    ) -> PassBy<'s, Self::Output> {
88        let input = input.as_view();
89        let transposed = input.transpose();
90        let output = store.insert(Matrix::from_view(&transposed.as_view()));
91        output
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use crate::testing::StubContext;
98
99    use super::*;
100
101    #[test]
102    fn test_tranpose_scalar_input() {
103        let ctxt = StubContext::default();
104        let params = Parameters::default();
105        let mut transpose_block = TransposeBlock::<f64>::default();
106
107        let output = transpose_block.process(&params, &ctxt, 1.0);
108        assert_eq!(output, 1.0);
109        assert_eq!(transpose_block.data.scalar(), 1.0);
110
111        let output = transpose_block.process(&params, &ctxt, 42.0);
112        assert_eq!(output, 42.0);
113        assert_eq!(transpose_block.data.scalar(), 42.0);
114    }
115
116    #[test]
117    fn test_tranpose_matrix_input() {
118        let ctxt = StubContext::default();
119        let params = Parameters::default();
120        let mut transpose_block = TransposeBlock::<Matrix<3, 2, f64>>::default();
121
122        let input = Matrix {
123            data: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
124        };
125        let expected = Matrix {
126            data: [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]],
127        };
128        let output = transpose_block.process(&params, &ctxt, &input);
129        assert_eq!(output.data, expected.data);
130        assert_eq!(
131            transpose_block.data.get_data().as_slice(),
132            expected.data.as_flattened()
133        );
134    }
135}