Skip to main content

pictorus_blocks/core_blocks/
constant_block.rs

1use pictorus_block_data::{BlockData as OldBlockData, FromPass};
2use pictorus_traits::{GeneratorBlock, Matrix, Pass, PassBy, Scalar};
3
4pub struct Parameters<T> {
5    pub constant: T,
6}
7
8impl<T> Parameters<T> {
9    pub fn new(constant: T) -> Self {
10        Self { constant }
11    }
12}
13
14/// Outputs a constant numeric value.
15pub struct ConstantBlock<T>
16where
17    T: Apply,
18{
19    pub data: OldBlockData,
20    buffer: Option<T::Output>,
21}
22
23impl<T> Default for ConstantBlock<T>
24where
25    T: Apply,
26    OldBlockData: FromPass<T::Output>,
27{
28    fn default() -> Self {
29        Self {
30            buffer: None,
31            data: <OldBlockData as FromPass<T::Output>>::from_pass(<T::Output>::default().as_by()),
32        }
33    }
34}
35
36impl<T> GeneratorBlock for ConstantBlock<T>
37where
38    T: Apply,
39    OldBlockData: FromPass<T::Output>,
40{
41    type Output = T::Output;
42    type Parameters = Parameters<T>;
43
44    fn generate(
45        &mut self,
46        parameters: &Self::Parameters,
47        _context: &dyn pictorus_traits::Context,
48    ) -> pictorus_traits::PassBy<Self::Output> {
49        let output = T::apply(&mut self.buffer, parameters);
50        self.data = OldBlockData::from_pass(output);
51        output
52    }
53}
54
55pub trait Apply: Pass + Sized {
56    type Output: Pass + Default;
57
58    fn apply<'s>(
59        store: &'s mut Option<Self::Output>,
60        parameters: &Parameters<Self>,
61    ) -> PassBy<'s, Self::Output>;
62}
63
64impl Apply for f64 {
65    type Output = f64;
66
67    fn apply<'s>(
68        store: &'s mut Option<Self::Output>,
69        parameters: &Parameters<Self>,
70    ) -> PassBy<'s, Self::Output> {
71        *store = Some(parameters.constant);
72        parameters.constant
73    }
74}
75
76impl<const NROWS: usize, const NCOLS: usize, T> Apply for Matrix<NROWS, NCOLS, T>
77where
78    T: Scalar,
79{
80    type Output = Matrix<NROWS, NCOLS, T>;
81
82    fn apply<'s>(
83        store: &'s mut Option<Self::Output>,
84        parameters: &Parameters<Self>,
85    ) -> PassBy<'s, Self::Output> {
86        let output = store.insert(Matrix::zeroed());
87        *output = parameters.constant;
88        output
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use crate::testing::StubContext;
96    use pictorus_block_data::{BlockData, ToPass};
97
98    #[test]
99    fn test_constant_scalar() {
100        let mut block = ConstantBlock::<f64>::default();
101        let parameters = Parameters::new(3.0);
102        let context = StubContext::default();
103
104        let output = block.generate(&parameters, &context);
105        assert_eq!(output, 3.0);
106        assert_eq!(block.data, BlockData::from_scalar(3.0));
107    }
108
109    #[test]
110    fn test_constant_vector() {
111        let vector: [f64; 2] = [1.0, 2.0];
112
113        let mut block = ConstantBlock::<Matrix<1, 2, f64>>::default();
114        let parameters = Parameters::new(BlockData::from_vector(&vector).to_pass());
115        let context = StubContext::default();
116
117        let output = block.generate(&parameters, &context); // <-- Converts Vector to Matrix in from_pass
118        assert_eq!(output.data[0][0], 1.0);
119        assert_eq!(output.data[1][0], 2.0);
120
121        assert_eq!(block.data, BlockData::from_matrix(&[&vector]));
122    }
123
124    #[test]
125    fn test_constant_matrix() {
126        let matrix_as_blockdata = BlockData::from_matrix(&[&[1.0, 2.0], &[3.0, 4.0]]);
127
128        let mut block = ConstantBlock::<Matrix<2, 2, f64>>::default();
129        let parameters = Parameters::new(matrix_as_blockdata.to_pass());
130        let context = StubContext::default();
131
132        let output = block.generate(&parameters, &context);
133        assert_eq!(output.data[0][0], 1.0);
134        assert_eq!(output.data[1][0], 2.0);
135        assert_eq!(output.data[0][1], 3.0);
136        assert_eq!(output.data[1][1], 4.0);
137        assert_eq!(block.data, matrix_as_blockdata);
138    }
139}