pictorus_blocks/core_blocks/
transpose_block.rs1use crate::nalgebra_interop::MatrixExt;
2use pictorus_block_data::{BlockData as OldBlockData, FromPass};
3use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
4
5use crate::traits::Scalar;
6
7pub struct Parameters {}
8
9impl Parameters {
10 pub fn new() -> Self {
11 Self {}
12 }
13}
14
15impl Default for Parameters {
16 fn default() -> Self {
17 Self::new()
18 }
19}
20pub struct TransposeBlock<T: Apply> {
24 pub data: OldBlockData,
25 store: Option<T::Output>,
26}
27
28impl<T: Apply> Default for TransposeBlock<T>
29where
30 OldBlockData: FromPass<T::Output>,
31{
32 fn default() -> Self {
33 Self {
34 data: <OldBlockData as FromPass<T::Output>>::from_pass(<T::Output>::default().as_by()),
35 store: None,
36 }
37 }
38}
39
40impl<T: Apply> ProcessBlock for TransposeBlock<T>
41where
42 OldBlockData: FromPass<T::Output>,
43{
44 type Inputs = T;
45 type Output = T::Output;
46 type Parameters = Parameters;
47
48 fn process(
49 &mut self,
50 _parameters: &Self::Parameters,
51 _context: &dyn pictorus_traits::Context,
52 input: PassBy<Self::Inputs>,
53 ) -> PassBy<Self::Output> {
54 let output = T::apply(&mut self.store, input);
55 self.data = OldBlockData::from_pass(output);
56 output
57 }
58}
59
60pub trait Apply: Pass {
61 type Output: Pass + Default;
62
63 fn apply<'s>(
64 store: &'s mut Option<Self::Output>,
65 input: PassBy<Self>,
66 ) -> PassBy<'s, Self::Output>;
67}
68
69impl<S: Scalar> Apply for S {
70 type Output = S;
71
72 fn apply<'s>(
73 store: &'s mut Option<Self::Output>,
74 input: PassBy<Self>,
75 ) -> PassBy<'s, Self::Output> {
76 let output = store.insert(input);
77 output.as_by()
78 }
79}
80
81impl<const NROWS: usize, const NCOLS: usize, S: Scalar> Apply for Matrix<NROWS, NCOLS, S> {
82 type Output = Matrix<NCOLS, NROWS, S>;
83
84 fn apply<'s>(
85 store: &'s mut Option<Self::Output>,
86 input: PassBy<Self>,
87 ) -> PassBy<'s, Self::Output> {
88 let input = input.as_view();
89 let transposed = input.transpose();
90 let output = store.insert(Matrix::from_view(&transposed.as_view()));
91 output
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use crate::testing::StubContext;
98
99 use super::*;
100
101 #[test]
102 fn test_tranpose_scalar_input() {
103 let ctxt = StubContext::default();
104 let params = Parameters::default();
105 let mut transpose_block = TransposeBlock::<f64>::default();
106
107 let output = transpose_block.process(¶ms, &ctxt, 1.0);
108 assert_eq!(output, 1.0);
109 assert_eq!(transpose_block.data.scalar(), 1.0);
110
111 let output = transpose_block.process(¶ms, &ctxt, 42.0);
112 assert_eq!(output, 42.0);
113 assert_eq!(transpose_block.data.scalar(), 42.0);
114 }
115
116 #[test]
117 fn test_tranpose_matrix_input() {
118 let ctxt = StubContext::default();
119 let params = Parameters::default();
120 let mut transpose_block = TransposeBlock::<Matrix<3, 2, f64>>::default();
121
122 let input = Matrix {
123 data: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
124 };
125 let expected = Matrix {
126 data: [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]],
127 };
128 let output = transpose_block.process(¶ms, &ctxt, &input);
129 assert_eq!(output.data, expected.data);
130 assert_eq!(
131 transpose_block.data.get_data().as_slice(),
132 expected.data.as_flattened()
133 );
134 }
135}