pictorus_blocks/core_blocks/
constant_block.rs1use pictorus_block_data::{BlockData as OldBlockData, FromPass};
2use pictorus_traits::{GeneratorBlock, Matrix, Pass, PassBy, Scalar};
3
4pub struct Parameters<T> {
5 pub constant: T,
6}
7
8impl<T> Parameters<T> {
9 pub fn new(constant: T) -> Self {
10 Self { constant }
11 }
12}
13
14pub struct ConstantBlock<T>
16where
17 T: Apply,
18{
19 pub data: OldBlockData,
20 buffer: Option<T::Output>,
21}
22
23impl<T> Default for ConstantBlock<T>
24where
25 T: Apply,
26 OldBlockData: FromPass<T::Output>,
27{
28 fn default() -> Self {
29 Self {
30 buffer: None,
31 data: <OldBlockData as FromPass<T::Output>>::from_pass(<T::Output>::default().as_by()),
32 }
33 }
34}
35
36impl<T> GeneratorBlock for ConstantBlock<T>
37where
38 T: Apply,
39 OldBlockData: FromPass<T::Output>,
40{
41 type Output = T::Output;
42 type Parameters = Parameters<T>;
43
44 fn generate(
45 &mut self,
46 parameters: &Self::Parameters,
47 _context: &dyn pictorus_traits::Context,
48 ) -> pictorus_traits::PassBy<Self::Output> {
49 let output = T::apply(&mut self.buffer, parameters);
50 self.data = OldBlockData::from_pass(output);
51 output
52 }
53}
54
55pub trait Apply: Pass + Sized {
56 type Output: Pass + Default;
57
58 fn apply<'s>(
59 store: &'s mut Option<Self::Output>,
60 parameters: &Parameters<Self>,
61 ) -> PassBy<'s, Self::Output>;
62}
63
64impl Apply for f64 {
65 type Output = f64;
66
67 fn apply<'s>(
68 store: &'s mut Option<Self::Output>,
69 parameters: &Parameters<Self>,
70 ) -> PassBy<'s, Self::Output> {
71 *store = Some(parameters.constant);
72 parameters.constant
73 }
74}
75
76impl<const NROWS: usize, const NCOLS: usize, T> Apply for Matrix<NROWS, NCOLS, T>
77where
78 T: Scalar,
79{
80 type Output = Matrix<NROWS, NCOLS, T>;
81
82 fn apply<'s>(
83 store: &'s mut Option<Self::Output>,
84 parameters: &Parameters<Self>,
85 ) -> PassBy<'s, Self::Output> {
86 let output = store.insert(Matrix::zeroed());
87 *output = parameters.constant;
88 output
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95 use crate::testing::StubContext;
96 use pictorus_block_data::{BlockData, ToPass};
97
98 #[test]
99 fn test_constant_scalar() {
100 let mut block = ConstantBlock::<f64>::default();
101 let parameters = Parameters::new(3.0);
102 let context = StubContext::default();
103
104 let output = block.generate(¶meters, &context);
105 assert_eq!(output, 3.0);
106 assert_eq!(block.data, BlockData::from_scalar(3.0));
107 }
108
109 #[test]
110 fn test_constant_vector() {
111 let vector: [f64; 2] = [1.0, 2.0];
112
113 let mut block = ConstantBlock::<Matrix<1, 2, f64>>::default();
114 let parameters = Parameters::new(BlockData::from_vector(&vector).to_pass());
115 let context = StubContext::default();
116
117 let output = block.generate(¶meters, &context); assert_eq!(output.data[0][0], 1.0);
119 assert_eq!(output.data[1][0], 2.0);
120
121 assert_eq!(block.data, BlockData::from_matrix(&[&vector]));
122 }
123
124 #[test]
125 fn test_constant_matrix() {
126 let matrix_as_blockdata = BlockData::from_matrix(&[&[1.0, 2.0], &[3.0, 4.0]]);
127
128 let mut block = ConstantBlock::<Matrix<2, 2, f64>>::default();
129 let parameters = Parameters::new(matrix_as_blockdata.to_pass());
130 let context = StubContext::default();
131
132 let output = block.generate(¶meters, &context);
133 assert_eq!(output.data[0][0], 1.0);
134 assert_eq!(output.data[1][0], 2.0);
135 assert_eq!(output.data[0][1], 3.0);
136 assert_eq!(output.data[1][1], 4.0);
137 assert_eq!(block.data, matrix_as_blockdata);
138 }
139}