fixed_exp/
lib.rs

1//! Exponentiation for fixed-point numbers.
2//!
3//! # Usage
4//!
5//! ```rust
6//! use fixed::types::I32F32;
7//! use fixed_exp::{FixedPowI, FixedPowF};
8//!
9//! let x = I32F32::from_num(4.0);
10//! assert_eq!(I32F32::from_num(1024.0), x.powi(5));
11//! assert_eq!(I32F32::from_num(8.0), x.powf(I32F32::from_num(1.5)));
12//! ```
13
14use std::cmp::{Ord, Ordering};
15
16use fixed::traits::Fixed;
17use fixed::types::extra::{LeEqU128, LeEqU16, LeEqU32, LeEqU64, LeEqU8};
18use fixed::{
19    FixedI128, FixedI16, FixedI32, FixedI64, FixedI8, FixedU128, FixedU16, FixedU32, FixedU64,
20    FixedU8,
21};
22use num_traits::{One, PrimInt, Zero};
23use typenum::{
24    Bit, IsLessOrEqual, LeEq, True, Unsigned, U126, U127, U14, U15, U30, U31, U6, U62, U63, U7,
25};
26
27/// Trait alias for fixed-points numbers that support both integer and fixed-point exponentiation.
28pub trait FixedPow: Fixed + FixedPowI + FixedPowF {}
29impl<T: Fixed + FixedPowI + FixedPowF> FixedPow for T {}
30
31/// Extension trait providing integer exponentiation for fixed-point numbers.
32pub trait FixedPowI: Fixed {
33    /// Raises a number to an integer power, using exponentiation by squaring.
34    ///
35    /// Using this function is generally faster than using `powf`.
36    ///
37    /// # Panics
38    ///
39    /// Panics if `1` cannot be represented in `Self`, and `n` is non-positive.
40    ///
41    /// # Examples
42    ///
43    /// ```rust
44    /// use fixed::types::I32F32;
45    /// use fixed_exp::FixedPowI;
46    ///
47    /// let x = I32F32::from_num(4.0);
48    /// assert_eq!(I32F32::from_num(1024.0), x.powi(5));
49    /// ```
50    fn powi(self, n: i32) -> Self;
51}
52
53/// Extension trait providing fixed-point exponentiation for fixed-point numbers.
54///
55/// This is only implemented for types that can represent numbers larger than `1`.
56pub trait FixedPowF: Fixed {
57    /// Raises a number to a fixed-point power.
58    ///
59    /// # Panics
60    ///
61    /// - If `self` is negative and `n` is fractional.
62    ///
63    /// # Examples
64    ///
65    /// ```rust
66    /// use fixed::types::I32F32;
67    /// use fixed_exp::FixedPowF;
68    ///
69    /// let x = I32F32::from_num(4.0);
70    /// assert_eq!(I32F32::from_num(8.0), x.powf(I32F32::from_num(1.5)));
71    /// ```
72    fn powf(self, n: Self) -> Self;
73}
74
75fn powi_positive<T: Fixed>(mut x: T, mut n: i32) -> T {
76    assert!(n > 0, "exponent should be positive");
77
78    let mut acc = x;
79    n -= 1;
80
81    while n > 0 {
82        if n & 1 == 1 {
83            acc *= x;
84        }
85        x *= x;
86        n >>= 1;
87    }
88
89    acc
90}
91
92fn sqrt<T>(x: T) -> T
93where
94    T: Fixed + Helper,
95    T::Bits: PrimInt,
96{
97    if x.is_zero() || x.is_one() {
98        return x;
99    }
100
101    let mut pow2 = T::one();
102    let mut result;
103
104    if x < T::one() {
105        while x <= pow2 * pow2 {
106            pow2 >>= 1;
107        }
108
109        result = pow2;
110    } else {
111        // x >= T::one()
112        while pow2 * pow2 <= x {
113            pow2 <<= 1;
114        }
115
116        result = pow2 >> 1;
117    }
118
119    for _ in 0..T::NUM_BITS {
120        pow2 >>= 1;
121        let next_result = result + pow2;
122        if next_result * next_result <= x {
123            result = next_result;
124        }
125    }
126
127    result
128}
129
130fn powf_01<T>(mut x: T, n: T) -> T
131where
132    T: Fixed + Helper,
133    T::Bits: PrimInt + std::fmt::Debug,
134{
135    let mut n = n.to_bits();
136    if n.is_zero() {
137        panic!("fractional exponent should not be zero");
138    }
139
140    let top = T::Bits::one() << ((T::Frac::U32 - 1) as usize);
141    let mask = !(T::Bits::one() << ((T::Frac::U32) as usize));
142    let mut acc = None;
143
144    while !n.is_zero() {
145        x = sqrt(x);
146        if !(n & top).is_zero() {
147            acc = match acc {
148                Some(acc) => Some(acc * x),
149                None => Some(x),
150            };
151        }
152        n = (n << 1) & mask;
153    }
154
155    acc.unwrap()
156}
157
158fn powf_positive<T>(x: T, n: T) -> T
159where
160    T: Fixed + Helper,
161    T::Bits: PrimInt + std::fmt::Debug,
162{
163    assert!(Helper::is_positive(n), "exponent should be positive");
164
165    let powi = powi_positive(x, n.int().to_num());
166    let frac = n.frac();
167
168    if frac.is_zero() {
169        powi
170    } else {
171        assert!(Helper::is_positive(x), "base should be positive");
172        powi * powf_01(x, frac)
173    }
174}
175
176macro_rules! impl_fixed_pow {
177    ($fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
178        impl<Frac> FixedPowI for $fixed<Frac>
179        where
180            Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
181        {
182            fn powi(self, n: i32) -> Self {
183                if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= 0 {
184                    panic!(
185                        "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
186                        self, n
187                    );
188                }
189
190                match n.cmp(&0) {
191                    Ordering::Greater => powi_positive(self, n),
192                    Ordering::Equal => Self::from_bits(1 << Frac::U32),
193                    Ordering::Less => powi_positive(Self::from_bits(1 << Frac::U32) / self, -n),
194                }
195            }
196        }
197
198        impl<Frac> FixedPowF for $fixed<Frac>
199        where
200            Frac: $le_eq + IsLessOrEqual<$le_eq_one, Output = True>,
201        {
202            fn powf(self, n: Self) -> Self {
203                let zero = Self::from_bits(0);
204
205                if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= zero {
206                    panic!(
207                        "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
208                        self, n
209                    );
210                }
211
212                match n.cmp(&zero) {
213                    Ordering::Greater => powf_positive(self, n),
214                    Ordering::Equal => Self::from_bits(1 << Frac::U32),
215                    Ordering::Less => powf_positive(Self::from_bits(1 << Frac::U32) / self, Helper::neg(n)),
216                }
217            }
218        }
219    };
220}
221
222impl_fixed_pow!(FixedI8, LeEqU8, U6);
223impl_fixed_pow!(FixedI16, LeEqU16, U14);
224impl_fixed_pow!(FixedI32, LeEqU32, U30);
225impl_fixed_pow!(FixedI64, LeEqU64, U62);
226impl_fixed_pow!(FixedI128, LeEqU128, U126);
227
228impl_fixed_pow!(FixedU8, LeEqU8, U7);
229impl_fixed_pow!(FixedU16, LeEqU16, U15);
230impl_fixed_pow!(FixedU32, LeEqU32, U31);
231impl_fixed_pow!(FixedU64, LeEqU64, U63);
232impl_fixed_pow!(FixedU128, LeEqU128, U127);
233
234trait Helper {
235    const NUM_BITS: u32;
236    fn is_positive(self) -> bool;
237    fn is_zero(self) -> bool;
238    fn is_one(self) -> bool;
239    fn one() -> Self;
240    fn neg(self) -> Self;
241}
242
243macro_rules! impl_sign_helper {
244    (signed, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
245        impl<Frac: $le_eq> Helper for $fixed<Frac>
246        where
247            Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
248        {
249            const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
250            fn is_positive(self) -> bool {
251                $fixed::is_positive(self)
252            }
253            fn is_zero(self) -> bool {
254                self.to_bits() == 0
255            }
256            fn is_one(self) -> bool {
257                <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
258            }
259            fn one() -> Self {
260                assert!(
261                    <LeEq<Frac, $le_eq_one>>::BOOL,
262                    "one should be possible to represent"
263                );
264                Self::from_bits(1 << Frac::U32)
265            }
266            fn neg(self) -> Self {
267                -self
268            }
269        }
270    };
271    (unsigned, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
272        impl<Frac: $le_eq> Helper for $fixed<Frac>
273        where
274            Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
275        {
276            const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
277            fn is_positive(self) -> bool {
278                self != Self::from_bits(0)
279            }
280            fn is_zero(self) -> bool {
281                self.to_bits() == 0
282            }
283            fn is_one(self) -> bool {
284                <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
285            }
286            fn one() -> Self {
287                assert!(
288                    <LeEq<Frac, $le_eq_one>>::BOOL,
289                    "one should be possible to represent"
290                );
291                Self::from_bits(1 << Frac::U32)
292            }
293            fn neg(self) -> Self {
294                panic!("cannot negate an unsigned number")
295            }
296        }
297    };
298}
299
300impl_sign_helper!(signed, FixedI8, LeEqU8, U6);
301impl_sign_helper!(signed, FixedI16, LeEqU16, U14);
302impl_sign_helper!(signed, FixedI32, LeEqU32, U30);
303impl_sign_helper!(signed, FixedI64, LeEqU64, U62);
304impl_sign_helper!(signed, FixedI128, LeEqU128, U126);
305
306impl_sign_helper!(unsigned, FixedU8, LeEqU8, U7);
307impl_sign_helper!(unsigned, FixedU16, LeEqU16, U15);
308impl_sign_helper!(unsigned, FixedU32, LeEqU32, U31);
309impl_sign_helper!(unsigned, FixedU64, LeEqU64, U63);
310impl_sign_helper!(unsigned, FixedU128, LeEqU128, U127);
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    use fixed::types::{I1F31, I32F32, U0F32, U32F32};
317
318    fn powi_positive_naive<T: Fixed>(x: T, n: i32) -> T {
319        assert!(n > 0, "exponent should be positive");
320        let mut acc = x;
321        for _ in 1..n {
322            acc *= x;
323        }
324        acc
325    }
326
327    fn delta<T: Fixed>(a: T, b: T) -> T {
328        Ord::max(a, b) - Ord::min(a, b)
329    }
330
331    #[test]
332    fn test_powi_positive() {
333        let epsilon = I32F32::from_num(0.0001);
334
335        let test_cases = &[
336            (I32F32::from_num(1.0), 42),
337            (I32F32::from_num(0.8), 7),
338            (I32F32::from_num(1.2), 11),
339            (I32F32::from_num(2.6), 16),
340            (I32F32::from_num(-2.2), 4),
341            (I32F32::from_num(-2.2), 5),
342        ];
343
344        for &(x, n) in test_cases {
345            assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
346        }
347
348        let epsilon = U32F32::from_num(0.0001);
349
350        let test_cases = &[
351            (U32F32::from_num(1.0), 42),
352            (U32F32::from_num(0.8), 7),
353            (U32F32::from_num(1.2), 11),
354            (U32F32::from_num(2.6), 16),
355        ];
356
357        for &(x, n) in test_cases {
358            assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
359        }
360    }
361
362    #[test]
363    fn test_powi_positive_sub_one() {
364        let epsilon = I1F31::from_num(0.0001);
365
366        let test_cases = &[
367            (I1F31::from_num(0.5), 3),
368            (I1F31::from_num(0.8), 5),
369            (I1F31::from_num(0.2), 7),
370            (I1F31::from_num(0.6), 9),
371            (I1F31::from_num(-0.6), 3),
372            (I1F31::from_num(-0.6), 4),
373        ];
374
375        for &(x, n) in test_cases {
376            assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
377        }
378
379        let epsilon = U0F32::from_num(0.0001);
380
381        let test_cases = &[
382            (U0F32::from_num(0.5), 3),
383            (U0F32::from_num(0.8), 5),
384            (U0F32::from_num(0.2), 7),
385            (U0F32::from_num(0.6), 9),
386        ];
387
388        for &(x, n) in test_cases {
389            assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
390        }
391    }
392
393    #[test]
394    fn test_powi_non_positive() {
395        let epsilon = I32F32::from_num(0.0001);
396
397        let test_cases = &[
398            (I32F32::from_num(1.0), -17),
399            (I32F32::from_num(0.8), -7),
400            (I32F32::from_num(1.2), -9),
401            (I32F32::from_num(2.6), -3),
402        ];
403
404        for &(x, n) in test_cases {
405            assert!((powi_positive_naive(I32F32::from_num(1) / x, -n) - x.powi(n)).abs() < epsilon);
406        }
407
408        assert_eq!(I32F32::from_num(1), I32F32::from_num(42).powi(0));
409        assert_eq!(I32F32::from_num(1), I32F32::from_num(-42).powi(0));
410        assert_eq!(I32F32::from_num(1), I32F32::from_num(0).powi(0));
411    }
412
413    fn powf_float<T: Fixed>(x: T, n: T) -> T {
414        let x: f64 = x.to_num();
415        let n: f64 = n.to_num();
416        T::from_num(x.powf(n))
417    }
418
419    #[test]
420    fn test_powf() {
421        let epsilon = I32F32::from_num(0.0001);
422
423        let test_cases = &[
424            (I32F32::from_num(1.0), I32F32::from_num(7.2)),
425            (I32F32::from_num(0.8), I32F32::from_num(-4.5)),
426            (I32F32::from_num(1.2), I32F32::from_num(5.0)),
427            (I32F32::from_num(2.6), I32F32::from_num(-6.7)),
428            (I32F32::from_num(-1.2), I32F32::from_num(4.0)),
429            (I32F32::from_num(-1.2), I32F32::from_num(-3.0)),
430        ];
431
432        for &(x, n) in test_cases {
433            assert!((powf_float(x, n) - x.powf(n)).abs() < epsilon);
434        }
435
436        let epsilon = U32F32::from_num(0.0001);
437
438        let test_cases = &[
439            (U32F32::from_num(1.0), U32F32::from_num(7.2)),
440            (U32F32::from_num(0.8), U32F32::from_num(4.5)),
441            (U32F32::from_num(1.2), U32F32::from_num(5.0)),
442            (U32F32::from_num(2.6), U32F32::from_num(6.7)),
443        ];
444
445        for &(x, n) in test_cases {
446            assert!(delta(powf_float(x, n), x.powf(n)) < epsilon);
447        }
448    }
449}