tract_core/ops/nn/softmax/
fixedpoint.rs

1pub use num_traits::{AsPrimitive, PrimInt};
2use std::fmt::{Binary, Debug, LowerHex};
3
4use super::math::*;
5
6macro_rules! impl_fixed_point_func_unary {
7    ($func_name: ident) => {
8        #[allow(dead_code)]
9        pub fn $func_name(&self) -> Self {
10            Self::from_raw($func_name(self.as_raw()))
11        }
12    };
13}
14
15macro_rules! impl_fixed_point_func_binary {
16    ($func_name: ident) => {
17        pub fn $func_name(&self, b: Self) -> Self {
18            Self::from_raw($func_name(self.as_raw(), b.as_raw()))
19        }
20    };
21}
22
23pub type Q0_31 = FixedPoint<i32, 0>;
24pub type Q1_30 = FixedPoint<i32, 1>;
25pub type Q2_29 = FixedPoint<i32, 2>;
26pub type Q5_26 = FixedPoint<i32, 5>;
27
28#[derive(PartialEq, Eq,PartialOrd, Copy, Clone)]
29pub struct FixedPoint<T: PrimInt, const INTEGER_BITS: usize>(T);
30
31impl<T, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
32where
33    T: PrimInt,
34{
35    pub fn from_raw(x: T) -> Self {
36        Self(x)
37    }
38
39    pub fn one() -> Self {
40        if INTEGER_BITS == 0 {
41            Self(T::max_value())
42        } else {
43            Self(T::one() << Self::fractional_bits())
44        }
45    }
46
47    pub fn fractional_bits() -> usize {
48        if Self::is_signed() {
49            std::mem::size_of::<T>() * 8 - 1 - INTEGER_BITS
50        } else {
51            std::mem::size_of::<T>() * 8 - INTEGER_BITS
52        }
53    }
54
55    #[allow(dead_code)]
56    pub fn zero() -> Self {
57        Self(T::zero())
58    }
59
60    pub fn as_raw(&self) -> T {
61        self.0
62    }
63
64    pub fn is_signed() -> bool {
65        is_signed::<T>()
66    }
67}
68
69impl<T: 'static, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
70where
71    T: PrimInt + Debug,
72    usize: AsPrimitive<T>,
73{
74    pub fn constant_pot(exponent: isize) -> Self {
75        let offset = (Self::fractional_bits() as isize + exponent) as usize;
76        assert!(offset < 31);
77        Self(1_usize.as_() << offset)
78    }
79}
80
81impl FixedPoint<i32, 0> {
82    impl_fixed_point_func_unary!(exp_on_interval_between_negative_one_quarter_and_0_excl);
83    impl_fixed_point_func_unary!(one_over_one_plus_x_for_x_in_0_1);
84}
85
86impl FixedPoint<i32, 5> {
87    #[allow(dead_code)]
88    pub fn exp_on_negative_values(&self) -> FixedPoint<i32, 0> {
89        FixedPoint::<i32, 0>::from_raw(exp_on_negative_values(self.as_raw()))
90    }
91}
92
93impl<const INTEGER_BITS: usize> FixedPoint<i32, INTEGER_BITS> {
94    impl_fixed_point_func_unary!(mask_if_non_zero);
95    impl_fixed_point_func_unary!(mask_if_zero);
96    impl_fixed_point_func_binary!(rounding_half_sum);
97
98    pub fn saturating_rounding_multiply_by_pot(&self, exponent: i32) -> Self {
99        Self::from_raw(saturating_rounding_multiply_by_pot(self.as_raw(), exponent))
100    }
101
102    #[allow(dead_code)]
103    pub fn rounding_divide_by_pot(&self, exponent: i32) -> Self {
104        Self::from_raw(rounding_divide_by_pot(self.as_raw(), exponent))
105    }
106
107    pub fn select_using_mask(mask: i32, a: Self, b: Self) -> Self {
108        Self::from_raw(select_using_mask(mask, a.as_raw(), b.as_raw()))
109    }
110
111    pub fn rescale<const DST_INTEGER_BITS: usize>(&self) -> FixedPoint<i32, DST_INTEGER_BITS> {
112        FixedPoint::<i32, DST_INTEGER_BITS>::from_raw(rescale(
113            self.as_raw(),
114            INTEGER_BITS,
115            DST_INTEGER_BITS,
116        ))
117    }
118
119    #[allow(dead_code)]
120    pub fn get_reciprocal(&self) -> (FixedPoint<i32, 0>, usize) {
121        let (raw_res, num_bits_over_units) = get_reciprocal(self.as_raw(), INTEGER_BITS);
122        (FixedPoint::<i32, 0>::from_raw(raw_res), num_bits_over_units)
123    }
124}
125
126impl<T, const INTEGER_BITS: usize> Debug for FixedPoint<T, INTEGER_BITS>
127where
128    T: AsPrimitive<f32> + PrimInt + LowerHex + Debug + Binary,
129    f32: AsPrimitive<T>,
130{
131    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
132        write!(fmt, "{:032b}({:?})({})", self.0, self.0, self.as_f32())
133    }
134}
135
136impl<T, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
137where
138    T: AsPrimitive<f32> + PrimInt,
139{
140    pub fn as_f32(&self) -> f32 {
141        self.0.as_() / 2_f32.powi(Self::fractional_bits() as i32)
142    }
143}
144
145impl<T, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
146where
147    T: AsPrimitive<f32> + PrimInt,
148    f32: AsPrimitive<T>,
149{
150    #[allow(dead_code)]
151    pub fn from_f32(x: f32) -> Self {
152        Self::from_raw(
153            f32::min(
154                f32::max(
155                    f32::round(x * 2f32.powi(Self::fractional_bits().as_())),
156                    T::min_value().as_(),
157                ),
158                T::max_value().as_(),
159            )
160            .as_(),
161        )
162    }
163}
164
165impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Add for FixedPoint<T, INTEGER_BITS> {
166    type Output = FixedPoint<T, INTEGER_BITS>;
167    fn add(self, rhs: Self) -> Self::Output {
168        Self::from_raw(self.0 + rhs.0)
169    }
170}
171
172impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Sub for FixedPoint<T, INTEGER_BITS> {
173    type Output = FixedPoint<T, INTEGER_BITS>;
174    fn sub(self, rhs: Self) -> Self::Output {
175        Self::from_raw(self.0 - rhs.0)
176    }
177}
178
179impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Shl<usize> for FixedPoint<T, INTEGER_BITS> {
180    type Output = FixedPoint<T, INTEGER_BITS>;
181    fn shl(self, rhs: usize) -> Self::Output {
182        Self::from_raw(self.0 << rhs)
183    }
184}
185
186impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Shr<usize> for FixedPoint<T, INTEGER_BITS> {
187    type Output = FixedPoint<T, INTEGER_BITS>;
188    fn shr(self, rhs: usize) -> Self::Output {
189        Self::from_raw(self.0 >> rhs)
190    }
191}
192
193impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::BitAnd for FixedPoint<T, INTEGER_BITS> {
194    type Output = FixedPoint<T, INTEGER_BITS>;
195    fn bitand(self, rhs: Self) -> Self::Output {
196        Self::from_raw(self.0 & rhs.0)
197    }
198}
199
200macro_rules! impl_mul {
201    ($T: ty, $LHS_INTEGER_BITS: literal, $RHS_INTEGER_BITS: literal, $OUT_INTEGER_BITS: literal) => {
202        impl std::ops::Mul<FixedPoint<$T, $RHS_INTEGER_BITS>>
203            for FixedPoint<$T, $LHS_INTEGER_BITS>
204        {
205            type Output = FixedPoint<$T, $OUT_INTEGER_BITS>;
206            fn mul(self, rhs: FixedPoint<$T, $RHS_INTEGER_BITS>) -> Self::Output {
207                Self::Output::from_raw(saturating_rounding_doubling_high_mul(self.0, rhs.0))
208            }
209        }
210    };
211}
212
213impl_mul!(i32, 0, 0, 0);
214impl_mul!(i32, 0, 2, 2);
215impl_mul!(i32, 2, 0, 2);
216impl_mul!(i32, 2, 2, 4);
217impl_mul!(i32, 5, 5, 10);
218
219#[cfg(test)]
220mod test {
221    use super::*;
222    use approx::assert_abs_diff_eq;
223    pub type Q10_21 = FixedPoint<i32, 10>;
224    pub type Q12_19 = FixedPoint<i32, 12>;
225    pub type Q26_5 = FixedPoint<i32, 26>;
226    type Q0_7 = FixedPoint<i8, 0>;
227
228    #[test]
229    fn test_to_f32() {
230        let x = Q26_5::from_raw(32);
231        assert_eq!(x.as_f32(), 1.0);
232    }
233
234    #[test]
235    fn test_to_f32_1() {
236        let x = Q0_7::from_raw(32);
237        assert_eq!(x.as_f32(), 0.25);
238    }
239
240    #[test]
241    fn test_one() {
242        let x = Q26_5::one();
243        assert_eq!(x, Q26_5::from_raw(32));
244    }
245
246    #[test]
247    fn test_one_limit() {
248        let x = Q0_31::one();
249        assert_eq!(x, Q0_31::from_raw(i32::MAX));
250    }
251
252    #[test]
253    fn test_mul_1() {
254        let a = Q5_26::from_f32(8.0); // 00000001
255        let b = Q5_26::from_f32(3.0); // 01000000
256        let product = a * b;
257        let expected = Q10_21::from_f32(24.0);
258
259        assert_eq!(product, expected);
260    }
261
262    #[test]
263    fn test_add() {
264        let a = Q5_26::from_f32(16.0);
265        let b = Q5_26::from_f32(5.0);
266        let sum = a + b;
267        let expected = Q5_26::from_f32(21.0);
268        assert_eq!(sum, expected);
269    }
270
271    #[test]
272    fn test_one_over_one_plus_x_for_x_in_0_1() {
273        let a = Q0_31::from_f32(0.75);
274        let expected_res = Q0_31::from_f32(1.0 / 1.75);
275        let res = a.one_over_one_plus_x_for_x_in_0_1();
276        assert_eq!(res.as_f32(), expected_res.as_f32());
277    }
278
279    #[test]
280    fn test_one_over_one_plus_x_for_x_in_0_1_1() {
281        let a = Q0_31::from_f32(0.0);
282        let expected_res = Q0_31::from_f32(1.0 / 1.0);
283        let res = a.one_over_one_plus_x_for_x_in_0_1();
284        assert_eq!(res.as_f32(), expected_res.as_f32());
285    }
286
287    #[test]
288    fn test_get_reciprocal_1() {
289        let a = Q5_26::from_f32(4.5);
290        let expected_res = Q0_31::from_f32(1.0 / 4.5);
291        let (shifted_res, num_bits_over_unit) = a.get_reciprocal();
292        let res = shifted_res.rounding_divide_by_pot(num_bits_over_unit as i32);
293        assert_eq!(res.as_f32(), expected_res.as_f32());
294        assert_eq!(num_bits_over_unit, 2);
295    }
296
297    #[test]
298    fn test_get_reciprocal_2() {
299        let a = Q5_26::from_f32(4.5);
300        let expected_res = Q0_31::from_f32(1.0 / 4.5);
301        let (shifted_res, num_bits_over_unit) = a.get_reciprocal();
302        let res = shifted_res.rounding_divide_by_pot(num_bits_over_unit as i32);
303        assert_eq!(res.as_f32(), expected_res.as_f32());
304        assert_eq!(num_bits_over_unit, 2);
305    }
306
307    #[test]
308    fn test_get_reciprocal_3() {
309        let a = Q12_19::from_f32(2.0);
310        let expected_res = Q0_31::from_f32(1.0 / 2.0);
311        let (shifted_res, num_bits_over_unit) = a.get_reciprocal();
312        let res = shifted_res.rounding_divide_by_pot(num_bits_over_unit as i32);
313        assert_eq!(res.as_f32(), expected_res.as_f32());
314        assert_eq!(num_bits_over_unit, 1);
315    }
316
317    #[test]
318    fn test_rescale_1() {
319        let a = Q0_31::from_f32(0.75);
320        let expeted_res = Q12_19::from_f32(0.75);
321        let res = a.rescale::<12>();
322        assert_eq!(res, expeted_res);
323    }
324
325    #[test]
326    fn test_exp_on_interval_between_negative_one_quarter_and_0_excl() {
327        let a = Q0_31::from_f32(-0.125);
328        let expected_res = Q0_31::from_f32((-0.125_f32).exp());
329        let res = a.exp_on_interval_between_negative_one_quarter_and_0_excl();
330        assert_eq!(res.as_f32(), expected_res.as_f32());
331    }
332
333    #[test]
334    fn test_exp_on_negative_values_1() {
335        let a = Q5_26::from_f32(-0.125);
336        let expected_res = Q0_31::from_f32((-0.125_f32).exp());
337        let res = a.exp_on_negative_values();
338        assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
339    }
340
341    #[test]
342    fn test_exp_on_negative_values_2() {
343        let a = Q5_26::from_f32(0.0);
344        let expected_res = Q0_31::from_f32((0_f32).exp());
345        let res = a.exp_on_negative_values();
346        assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
347    }
348
349    #[test]
350    fn test_exp_on_negative_values_3() {
351        let a = Q5_26::from_f32(-0.25);
352        let expected_res = Q0_31::from_f32((-0.25_f32).exp());
353        let res = a.exp_on_negative_values();
354        assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
355    }
356
357    #[test]
358    fn test_exp_on_negative_values_4() {
359        let a = Q5_26::from_f32(-1.1875_f32);
360        let expected_res = Q0_31::from_f32((-1.1875_f32).exp());
361        let res = a.exp_on_negative_values();
362        assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
363    }
364}