Skip to main content

pictorus_blocks/core_blocks/
quantize_block.rs

1use core::ops::MulAssign;
2
3use crate::nalgebra_interop::MatrixExt;
4use crate::traits::{MatrixOps, Scalar};
5use nalgebra::ClosedDivAssign;
6use num_traits::Float;
7use pictorus_block_data::{BlockData as OldBlockData, FromPass};
8use pictorus_traits::{Matrix, Pass, PassBy, ProcessBlock};
9
10pub struct Parameters<I: Scalar + Float> {
11    /// The scalar interval to quantize to
12    interval: I,
13}
14
15impl<I: Scalar + Float> Parameters<I> {
16    pub fn new(interval: I) -> Self {
17        Parameters { interval }
18    }
19}
20
21/// Quantizes the input to the nearest integer multiple of the provided interval.
22///
23/// For example, if the interval is 0.5, the input 0.51 will be quantized to 0.5
24/// If the interval is 0.5, the input 0.75 will be quantized to 1.0
25/// For matrices, the process is applied element-wise
26pub struct QuantizeBlock<I, T>
27where
28    I: Scalar + Float,
29    T: Apply<I>,
30    OldBlockData: FromPass<T::Output>,
31{
32    pub data: OldBlockData,
33    buffer: Option<T::Output>,
34}
35
36impl<I, T> Default for QuantizeBlock<I, T>
37where
38    I: Scalar + Float,
39    T: Apply<I>,
40    OldBlockData: FromPass<T::Output>,
41{
42    fn default() -> Self {
43        Self {
44            data: <OldBlockData as FromPass<T::Output>>::from_pass(T::Output::default().as_by()),
45            buffer: None,
46        }
47    }
48}
49
50impl<I, T> ProcessBlock for QuantizeBlock<I, T>
51where
52    I: Scalar + Float,
53    T: Apply<I>,
54    OldBlockData: FromPass<T::Output>,
55{
56    type Parameters = Parameters<I>;
57    type Inputs = T;
58    type Output = T::Output;
59
60    fn process(
61        &mut self,
62        parameters: &Self::Parameters,
63        _context: &dyn pictorus_traits::Context,
64        inputs: PassBy<'_, Self::Inputs>,
65    ) -> PassBy<Self::Output> {
66        let res = T::apply(inputs, parameters.interval, &mut self.buffer);
67        self.data = OldBlockData::from_pass(res);
68        res
69    }
70}
71
72pub trait Apply<I: Scalar + Float>: Pass + Default {
73    type Output: Pass + Default;
74
75    fn apply<'a>(
76        input: PassBy<Self>,
77        interval: I,
78        dest: &'a mut Option<Self::Output>,
79    ) -> PassBy<'a, Self::Output>;
80}
81
82impl<I: Scalar + Float> Apply<I> for I {
83    type Output = I;
84
85    fn apply<'a>(
86        input: PassBy<Self>,
87        interval: I,
88        dest: &'a mut Option<Self::Output>,
89    ) -> PassBy<'a, Self::Output> {
90        let input_divided_interval = input / interval;
91        let rounded = input_divided_interval.round();
92        let res = rounded * interval;
93        *dest = Some(res);
94        res
95    }
96}
97
98impl<const R: usize, const C: usize, I: Scalar + Float + ClosedDivAssign + MulAssign> Apply<I>
99    for Matrix<R, C, I>
100{
101    type Output = Matrix<R, C, I>;
102
103    fn apply<'a>(
104        input: PassBy<Self>,
105        interval: I,
106        dest: &'a mut Option<Self::Output>,
107    ) -> PassBy<'a, Self::Output> {
108        let interval_matrix = Self::from_element(interval);
109        let input_divided_interval = input.as_view().component_div(&interval_matrix.as_view());
110        let rounded = input_divided_interval.map(Float::round);
111        let res = rounded * interval;
112        let res = Self::from_view(&res.as_view());
113        *dest = Some(res);
114        dest.as_ref().unwrap().as_by()
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use std::vec::Vec;
121
122    use crate::testing::StubContext;
123    use paste::paste;
124
125    use super::*;
126
127    macro_rules! test_quantize_block {
128        ($type:ty) => {
129            paste! {
130                #[test]
131                fn [<test_quantize_block_scalar _$type>]() {
132                    let context = StubContext::default();
133                    let params = Parameters::new(0.5);
134                    let mut block = QuantizeBlock::<$type, $type>::default();
135                    let input = 0.51;
136                    let res = block.process(&params, &context, input);
137
138                    assert_eq!(res, 0.5);
139                    assert_eq!(block.data.scalar(), 0.5);
140                }
141
142                #[test]
143                fn [<test_quantize_block_matrix _$type>]() {
144                    let context = StubContext::default();
145                    let params = Parameters::new(0.5);
146                    let mut block = QuantizeBlock::<$type, Matrix<4, 1, $type>>::default();
147                    let input = Matrix {
148                        data: [[0.24, 0.25, 0.51, 0.75]],
149                    };
150                    let expected = Matrix {
151                        data: [[0.0, 0.5, 0.5, 1.0]],
152                    };
153                    let res = block.process(&params, &context, &input);
154
155                    assert_eq!(res.data, expected.data);
156                    assert_eq!(
157                        block.data.get_data().as_slice(),
158                        expected
159                            .data
160                            .as_flattened()
161                            .iter()
162                            .map(|x| *x as f64)
163                            .collect::<Vec<f64>>()
164                    );
165                }
166            }
167        };
168    }
169
170    test_quantize_block!(f32);
171    test_quantize_block!(f64);
172}