fixed_exp2/
lib.rs

1//! Exponentiation for fixed-point numbers.
2//!
3//! # Usage
4//!
5//! ```rust
6//! use fixed::types::I32F32;
7//! use fixed_exp::{FixedPowI, FixedPowF};
8//!
9//! let x = I32F32::from_num(4.0);
10//! assert_eq!(I32F32::from_num(1024.0), x.powi(5));
11//! assert_eq!(I32F32::from_num(8.0), x.powf(I32F32::from_num(1.5)));
12//! ```
13#![no_std]
14use core::cmp::{Ord, Ordering};
15use core::fmt;
16
17use fixed::{
18    traits::Fixed,
19    types::extra::{LeEqU128, LeEqU16, LeEqU32, LeEqU64, LeEqU8},
20    FixedI128, FixedI16, FixedI32, FixedI64, FixedI8, FixedU128, FixedU16, FixedU32, FixedU64,
21    FixedU8,
22};
23use num_traits::{One, PrimInt, Zero};
24use typenum::{
25    Bit, IsLessOrEqual, LeEq, True, Unsigned, U126, U127, U14, U15, U30, U31, U6, U62, U63, U7,
26};
27
28/// Trait alias for fixed-points numbers that support both integer and fixed-point exponentiation.
29pub trait FixedPow: Fixed + FixedPowI + FixedPowF {}
30impl<T: Fixed + FixedPowI + FixedPowF> FixedPow for T {}
31
32/// Extension trait providing integer exponentiation for fixed-point numbers.
33pub trait FixedPowI: Fixed {
34    /// Raises a number to an integer power, using exponentiation by squaring.
35    ///
36    /// Using this function is generally faster than using `powf`.
37    ///
38    /// # Panics
39    ///
40    /// Panics if `1` cannot be represented in `Self`, and `n` is non-positive.
41    ///
42    /// # Examples
43    ///
44    /// ```rust
45    /// use fixed::types::I32F32;
46    /// use fixed_exp::FixedPowI;
47    ///
48    /// let x = I32F32::from_num(4.0);
49    /// assert_eq!(I32F32::from_num(1024.0), x.powi(5));
50    /// ```
51    fn powi(self, n: i32) -> Self;
52}
53
54/// Extension trait providing fixed-point exponentiation for fixed-point numbers.
55///
56/// This is only implemented for types that can represent numbers larger than `1`.
57pub trait FixedPowF: Fixed {
58    /// Raises a number to a fixed-point power.
59    ///
60    /// # Panics
61    ///
62    /// - If `self` is negative and `n` is fractional.
63    ///
64    /// # Examples
65    ///
66    /// ```rust
67    /// use fixed::types::I32F32;
68    /// use fixed_exp::FixedPowF;
69    ///
70    /// let x = I32F32::from_num(4.0);
71    /// assert_eq!(I32F32::from_num(8.0), x.powf(I32F32::from_num(1.5)));
72    /// ```
73    fn powf(self, n: Self) -> Self;
74}
75
76#[cfg(not(feature = "fast-powi"))]
77fn powi_positive<T: Fixed>(x: T, n: i32) -> T {
78    assert!(n > 0, "exponent should be positive");
79    let mut res = x;
80    for _ in 1..n {
81        res = res * x;
82    }
83    res
84}
85
86// TODO: this overflows super quick. Might still be useful for special case scenarios,
87// but determining those = effort I don't feel like putting in
88#[cfg(feature = "fast-powi")]
89fn powi_positive<T: Fixed>(mut x: T, mut n: i32) -> T {
90    assert!(n > 0, "exponent should be positive");
91    let mut acc = x;
92    n -= 1;
93
94    while n > 0 {
95        if n & 1 == 1 {
96            acc *= x;
97        }
98        x *= x;
99        n >>= 1;
100    }
101
102    acc
103}
104
105fn sqrt<T>(x: T) -> T
106where
107    T: Fixed + Helper,
108    T::Bits: PrimInt,
109{
110    if x.is_zero() || x.is_one() {
111        return x;
112    }
113
114    let mut pow2 = T::one();
115    let mut result;
116
117    if x < T::one() {
118        while x <= pow2 * pow2 {
119            pow2 >>= 1;
120        }
121
122        result = pow2;
123    } else {
124        // x >= T::one()
125        while pow2 * pow2 <= x {
126            pow2 <<= 1;
127        }
128
129        result = pow2 >> 1;
130    }
131
132    for _ in 0..T::NUM_BITS {
133        pow2 >>= 1;
134        let next_result = result + pow2;
135        if next_result * next_result <= x {
136            result = next_result;
137        }
138    }
139
140    result
141}
142
143fn powf_01<T>(mut x: T, n: T) -> T
144where
145    T: Fixed + Helper,
146    T::Bits: PrimInt + fmt::Debug,
147{
148    let mut n = n.to_bits();
149    if n.is_zero() {
150        panic!("fractional exponent should not be zero");
151    }
152
153    let top = T::Bits::one() << ((T::Frac::U32 - 1) as usize);
154    let mask = !(T::Bits::one() << ((T::Frac::U32) as usize));
155    let mut acc = None;
156
157    while !n.is_zero() {
158        x = sqrt(x);
159        if !(n & top).is_zero() {
160            acc = match acc {
161                Some(acc) => Some(acc * x),
162                None => Some(x),
163            };
164        }
165        n = (n << 1) & mask;
166    }
167
168    acc.unwrap()
169}
170
171fn powf_positive<T>(x: T, n: T) -> T
172where
173    T: Fixed + Helper,
174    T::Bits: PrimInt + fmt::Debug,
175{
176    assert!(Helper::is_positive(n), "exponent should be positive");
177    let ni = n.int().to_num();
178    let powi = if ni == 0 {
179        T::from_num(1)
180    } else {
181        powi_positive(x, ni)
182    };
183    let frac = n.frac();
184
185    if frac.is_zero() {
186        powi
187    } else {
188        assert!(Helper::is_positive(x), "base should be positive");
189        powi * powf_01(x, frac)
190    }
191}
192
193macro_rules! impl_fixed_pow {
194    ($fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
195        impl<Frac> FixedPowI for $fixed<Frac>
196        where
197            Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
198        {
199            fn powi(self, n: i32) -> Self {
200                if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= 0 {
201                    panic!(
202                        "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
203                        self, n
204                    );
205                }
206
207                match n.cmp(&0) {
208                    Ordering::Greater => powi_positive(self, n),
209                    Ordering::Equal => Self::from_bits(1 << Frac::U32),
210                    Ordering::Less => powi_positive(Self::from_bits(1 << Frac::U32) / self, -n),
211                }
212            }
213        }
214
215        impl<Frac> FixedPowF for $fixed<Frac>
216        where
217            Frac: $le_eq + IsLessOrEqual<$le_eq_one, Output = True>,
218        {
219            fn powf(self, n: Self) -> Self {
220                let zero = Self::from_bits(0);
221
222                if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= zero {
223                    panic!(
224                        "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
225                        self, n
226                    );
227                }
228
229                match n.cmp(&zero) {
230                    Ordering::Greater => powf_positive(self, n),
231                    Ordering::Equal => Self::from_bits(1 << Frac::U32),
232                    Ordering::Less => {
233                        powf_positive(Self::from_bits(1 << Frac::U32) / self, Helper::neg(n))
234                    },
235                }
236            }
237        }
238    };
239}
240
241impl_fixed_pow!(FixedI8, LeEqU8, U6);
242impl_fixed_pow!(FixedI16, LeEqU16, U14);
243impl_fixed_pow!(FixedI32, LeEqU32, U30);
244impl_fixed_pow!(FixedI64, LeEqU64, U62);
245impl_fixed_pow!(FixedI128, LeEqU128, U126);
246
247impl_fixed_pow!(FixedU8, LeEqU8, U7);
248impl_fixed_pow!(FixedU16, LeEqU16, U15);
249impl_fixed_pow!(FixedU32, LeEqU32, U31);
250impl_fixed_pow!(FixedU64, LeEqU64, U63);
251impl_fixed_pow!(FixedU128, LeEqU128, U127);
252
253trait Helper {
254    const NUM_BITS: u32;
255    fn is_positive(self) -> bool;
256    fn is_one(self) -> bool;
257    fn one() -> Self;
258    fn neg(self) -> Self;
259}
260
261macro_rules! impl_sign_helper {
262    (signed, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
263        impl<Frac: $le_eq> Helper for $fixed<Frac>
264        where
265            Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
266        {
267            const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
268            fn is_positive(self) -> bool {
269                $fixed::is_positive(self)
270            }
271            fn is_one(self) -> bool {
272                <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
273            }
274            fn one() -> Self {
275                assert!(
276                    <LeEq<Frac, $le_eq_one>>::BOOL,
277                    "one should be possible to represent"
278                );
279                Self::from_bits(1 << Frac::U32)
280            }
281            fn neg(self) -> Self {
282                -self
283            }
284        }
285    };
286    (unsigned, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
287        impl<Frac: $le_eq> Helper for $fixed<Frac>
288        where
289            Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
290        {
291            const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
292            fn is_positive(self) -> bool {
293                self != Self::from_bits(0)
294            }
295            fn is_one(self) -> bool {
296                <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
297            }
298            fn one() -> Self {
299                assert!(
300                    <LeEq<Frac, $le_eq_one>>::BOOL,
301                    "one should be possible to represent"
302                );
303                Self::from_bits(1 << Frac::U32)
304            }
305            fn neg(self) -> Self {
306                panic!("cannot negate an unsigned number")
307            }
308        }
309    };
310}
311
312impl_sign_helper!(signed, FixedI8, LeEqU8, U6);
313impl_sign_helper!(signed, FixedI16, LeEqU16, U14);
314impl_sign_helper!(signed, FixedI32, LeEqU32, U30);
315impl_sign_helper!(signed, FixedI64, LeEqU64, U62);
316impl_sign_helper!(signed, FixedI128, LeEqU128, U126);
317
318impl_sign_helper!(unsigned, FixedU8, LeEqU8, U7);
319impl_sign_helper!(unsigned, FixedU16, LeEqU16, U15);
320impl_sign_helper!(unsigned, FixedU32, LeEqU32, U31);
321impl_sign_helper!(unsigned, FixedU64, LeEqU64, U63);
322impl_sign_helper!(unsigned, FixedU128, LeEqU128, U127);
323
324#[cfg(test)]
325mod tests {
326    use fixed::types::{I1F31, I32F32, U0F32, U32F32};
327
328    use super::*;
329
330    fn powi_positive_naive<T: Fixed>(x: T, n: i32) -> T {
331        assert!(n > 0, "exponent should be positive");
332        let mut acc = x;
333        for _ in 1..n {
334            acc *= x;
335        }
336        acc
337    }
338
339    fn delta<T: Fixed>(a: T, b: T) -> T {
340        Ord::max(a, b) - Ord::min(a, b)
341    }
342
343    #[test]
344    fn test_powi_positive() {
345        let epsilon = I32F32::from_num(0.0001);
346
347        let test_cases = &[
348            (I32F32::from_num(1.0), 42),
349            (I32F32::from_num(0.8), 7),
350            (I32F32::from_num(1.2), 11),
351            (I32F32::from_num(2.6), 16),
352            (I32F32::from_num(-2.2), 4),
353            (I32F32::from_num(-2.2), 5),
354        ];
355
356        for &(x, n) in test_cases {
357            assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
358        }
359
360        let epsilon = U32F32::from_num(0.0001);
361
362        let test_cases = &[
363            (U32F32::from_num(1.0), 42),
364            (U32F32::from_num(0.8), 7),
365            (U32F32::from_num(1.2), 11),
366            (U32F32::from_num(2.6), 16),
367        ];
368
369        for &(x, n) in test_cases {
370            assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
371        }
372    }
373
374    #[test]
375    fn test_powi_positive_sub_one() {
376        let epsilon = I1F31::from_num(0.0001);
377
378        let test_cases = &[
379            (I1F31::from_num(0.5), 3),
380            (I1F31::from_num(0.8), 5),
381            (I1F31::from_num(0.2), 7),
382            (I1F31::from_num(0.6), 9),
383            (I1F31::from_num(-0.6), 3),
384            (I1F31::from_num(-0.6), 4),
385        ];
386
387        for &(x, n) in test_cases {
388            assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
389        }
390
391        let epsilon = U0F32::from_num(0.0001);
392
393        let test_cases = &[
394            (U0F32::from_num(0.5), 3),
395            (U0F32::from_num(0.8), 5),
396            (U0F32::from_num(0.2), 7),
397            (U0F32::from_num(0.6), 9),
398        ];
399
400        for &(x, n) in test_cases {
401            assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
402        }
403    }
404
405    #[test]
406    fn test_powi_non_positive() {
407        let epsilon = I32F32::from_num(0.0001);
408
409        let test_cases = &[
410            (I32F32::from_num(1.0), -17),
411            (I32F32::from_num(0.8), -7),
412            (I32F32::from_num(1.2), -9),
413            (I32F32::from_num(2.6), -3),
414        ];
415
416        for &(x, n) in test_cases {
417            assert!((powi_positive_naive(I32F32::from_num(1) / x, -n) - x.powi(n)).abs() < epsilon);
418        }
419
420        assert_eq!(I32F32::from_num(1), I32F32::from_num(42).powi(0));
421        assert_eq!(I32F32::from_num(1), I32F32::from_num(-42).powi(0));
422        assert_eq!(I32F32::from_num(1), I32F32::from_num(0).powi(0));
423    }
424
425    fn powf_float<T: Fixed>(x: T, n: T) -> T {
426        let x: f64 = x.to_num();
427        let n: f64 = n.to_num();
428        T::from_num(x.powf(n))
429    }
430
431    #[test]
432    fn test_powf() {
433        let epsilon = I32F32::from_num(0.0001);
434
435        let test_cases = &[
436            (I32F32::from_num(1.0), I32F32::from_num(7.2)),
437            (I32F32::from_num(0.8), I32F32::from_num(-4.5)),
438            (I32F32::from_num(1.2), I32F32::from_num(5.0)),
439            (I32F32::from_num(2.6), I32F32::from_num(-6.7)),
440            (I32F32::from_num(-1.2), I32F32::from_num(4.0)),
441            (I32F32::from_num(-1.2), I32F32::from_num(-3.0)),
442        ];
443
444        for &(x, n) in test_cases {
445            assert!((powf_float(x, n) - x.powf(n)).abs() < epsilon);
446        }
447
448        let epsilon = U32F32::from_num(0.0001);
449
450        let test_cases = &[
451            (U32F32::from_num(1.0), U32F32::from_num(7.2)),
452            (U32F32::from_num(0.8), U32F32::from_num(4.5)),
453            (U32F32::from_num(1.2), U32F32::from_num(5.0)),
454            (U32F32::from_num(2.6), U32F32::from_num(6.7)),
455        ];
456
457        for &(x, n) in test_cases {
458            assert!(delta(powf_float(x, n), x.powf(n)) < epsilon);
459        }
460    }
461}