pictorus_blocks/core_blocks/
quantize_block.rs1use core::ops::MulAssign;
2
3use crate::nalgebra_interop::MatrixExt;
4use crate::traits::{MatrixOps, Scalar};
5use nalgebra::ClosedDivAssign;
6use num_traits::Float;
7use pictorus_block_data::{BlockData as OldBlockData, FromPass};
8use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
9
10pub struct Parameters<I: Scalar + Float> {
11 interval: I,
13}
14
15impl<I: Scalar + Float> Parameters<I> {
16 pub fn new(interval: I) -> Self {
17 Parameters { interval }
18 }
19}
20
21pub struct QuantizeBlock<I, T>
27where
28 I: Scalar + Float,
29 T: Apply<I>,
30 OldBlockData: FromPass<T::Output>,
31{
32 pub data: OldBlockData,
33 buffer: Option<T::Output>,
34}
35
36impl<I, T> Default for QuantizeBlock<I, T>
37where
38 I: Scalar + Float,
39 T: Apply<I>,
40 OldBlockData: FromPass<T::Output>,
41{
42 fn default() -> Self {
43 Self {
44 data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::default().as_by()),
45 buffer: None,
46 }
47 }
48}
49
50impl<I, T> ProcessBlock for QuantizeBlock<I, T>
51where
52 I: Scalar + Float,
53 T: Apply<I>,
54 OldBlockData: FromPass<T::Output>,
55{
56 type Parameters = Parameters<I>;
57 type Inputs = T;
58 type Output = T::Output;
59
60 fn process(
61 &mut self,
62 parameters: &Self::Parameters,
63 _context: &dyn pictorus_traits::Context,
64 inputs: PassBy<'_, Self::Inputs>,
65 ) -> PassBy<Self::Output> {
66 let res = T::apply(inputs, parameters.interval, &mut self.buffer);
67 self.data = OldBlockData::from_pass(res);
68 res
69 }
70}
71
72pub trait Apply<I: Scalar + Float>: Pass + Default {
73 type Output: Pass + Default;
74
75 fn apply<'a>(
76 input: PassBy<Self>,
77 interval: I,
78 dest: &'a mut Option<Self::Output>,
79 ) -> PassBy<'a, Self::Output>;
80}
81
82impl<I: Scalar + Float> Apply<I> for I {
83 type Output = I;
84
85 fn apply<'a>(
86 input: PassBy<Self>,
87 interval: I,
88 dest: &'a mut Option<Self::Output>,
89 ) -> PassBy<'a, Self::Output> {
90 let input_divided_interval = input / interval;
91 let rounded = input_divided_interval.round();
92 let res = rounded * interval;
93 *dest = Some(res);
94 res
95 }
96}
97
98impl<const R: usize, const C: usize, I: Scalar + Float + ClosedDivAssign + MulAssign> Apply<I>
99 for Matrix<R, C, I>
100{
101 type Output = Matrix<R, C, I>;
102
103 fn apply<'a>(
104 input: PassBy<Self>,
105 interval: I,
106 dest: &'a mut Option<Self::Output>,
107 ) -> PassBy<'a, Self::Output> {
108 let interval_matrix = Self::from_element(interval);
109 let input_divided_interval = input.as_view().component_div(&interval_matrix.as_view());
110 let rounded = input_divided_interval.map(Float::round);
111 let res = rounded * interval;
112 let res = Self::from_view(&res.as_view());
113 *dest = Some(res);
114 dest.as_ref().unwrap().as_by()
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use std::vec::Vec;
121
122 use crate::testing::StubContext;
123 use paste::paste;
124
125 use super::*;
126
127 macro_rules! test_quantize_block {
128 ($type:ty) => {
129 paste! {
130 #[test]
131 fn [<test_quantize_block_scalar _$type>]() {
132 let context = StubContext::default();
133 let params = Parameters::new(0.5);
134 let mut block = QuantizeBlock::<$type, $type>::default();
135 let input = 0.51;
136 let res = block.process(¶ms, &context, input);
137
138 assert_eq!(res, 0.5);
139 assert_eq!(block.data.scalar(), 0.5);
140 }
141
142 #[test]
143 fn [<test_quantize_block_matrix _$type>]() {
144 let context = StubContext::default();
145 let params = Parameters::new(0.5);
146 let mut block = QuantizeBlock::<$type, Matrix<4, 1, $type>>::default();
147 let input = Matrix {
148 data: [[0.24, 0.25, 0.51, 0.75]],
149 };
150 let expected = Matrix {
151 data: [[0.0, 0.5, 0.5, 1.0]],
152 };
153 let res = block.process(¶ms, &context, &input);
154
155 assert_eq!(res.data, expected.data);
156 assert_eq!(
157 block.data.get_data().as_slice(),
158 expected
159 .data
160 .as_flattened()
161 .iter()
162 .map(|x| *x as f64)
163 .collect::<Vec<f64>>()
164 );
165 }
166 }
167 };
168 }
169
170 test_quantize_block!(f32);
171 test_quantize_block!(f64);
172}