fixed_exp2/
lib.rs

1//! Exponentiation for fixed-point numbers.
2//!
3//! # Usage
4//!
5//! ```rust
6//! use fixed::types::I32F32;
7//! use fixed_exp::{FixedPowI, FixedPowF};
8//!
9//! let x = I32F32::from_num(4.0);
10//! assert_eq!(I32F32::from_num(1024.0), x.powi(5));
11//! assert_eq!(I32F32::from_num(8.0), x.powf(I32F32::from_num(1.5)));
12//! ```
13
14use core::cmp::{Ord, Ordering};
15
16use fixed::{
17    traits::Fixed,
18    types::extra::{LeEqU128, LeEqU16, LeEqU32, LeEqU64, LeEqU8},
19    FixedI128, FixedI16, FixedI32, FixedI64, FixedI8, FixedU128, FixedU16, FixedU32, FixedU64,
20    FixedU8,
21};
22use num_traits::{One, PrimInt, Zero};
23use typenum::{
24    Bit, IsLessOrEqual, LeEq, True, Unsigned, U126, U127, U14, U15, U30, U31, U6, U62, U63, U7,
25};
26
27/// Trait alias for fixed-points numbers that support both integer and fixed-point exponentiation.
28pub trait FixedPow: Fixed + FixedPowI + FixedPowF {}
29impl<T: Fixed + FixedPowI + FixedPowF> FixedPow for T {}
30
31/// Extension trait providing integer exponentiation for fixed-point numbers.
32pub trait FixedPowI: Fixed {
33    /// Raises a number to an integer power, using exponentiation by squaring.
34    ///
35    /// Using this function is generally faster than using `powf`.
36    ///
37    /// # Panics
38    ///
39    /// Panics if `1` cannot be represented in `Self`, and `n` is non-positive.
40    ///
41    /// # Examples
42    ///
43    /// ```rust
44    /// use fixed::types::I32F32;
45    /// use fixed_exp::FixedPowI;
46    ///
47    /// let x = I32F32::from_num(4.0);
48    /// assert_eq!(I32F32::from_num(1024.0), x.powi(5));
49    /// ```
50    fn powi(self, n: i32) -> Self;
51}
52
53/// Extension trait providing fixed-point exponentiation for fixed-point numbers.
54///
55/// This is only implemented for types that can represent numbers larger than `1`.
56pub trait FixedPowF: Fixed {
57    /// Raises a number to a fixed-point power.
58    ///
59    /// # Panics
60    ///
61    /// - If `self` is negative and `n` is fractional.
62    ///
63    /// # Examples
64    ///
65    /// ```rust
66    /// use fixed::types::I32F32;
67    /// use fixed_exp::FixedPowF;
68    ///
69    /// let x = I32F32::from_num(4.0);
70    /// assert_eq!(I32F32::from_num(8.0), x.powf(I32F32::from_num(1.5)));
71    /// ```
72    fn powf(self, n: Self) -> Self;
73}
74
75#[cfg(not(feature = "fast-powi"))]
76fn powi_positive<T: Fixed>(x: T, n: i32) -> T {
77    assert!(n > 0, "exponent should be positive");
78    let mut res = x;
79    for _ in 1..n {
80        res = res * x;
81    }
82    res
83}
84
85// TODO: this overflows super quick. Might still be useful for special case scenarios,
86// but determining those = effort I don't feel like putting in
87#[cfg(feature = "fast-powi")]
88fn powi_positive<T: Fixed>(mut x: T, mut n: i32) -> T {
89    assert!(n > 0, "exponent should be positive");
90    let mut acc = x;
91    n -= 1;
92
93    while n > 0 {
94        if n & 1 == 1 {
95            acc *= x;
96        }
97        x *= x;
98        n >>= 1;
99    }
100
101    acc
102}
103
104fn sqrt<T>(x: T) -> T
105where
106    T: Fixed + Helper,
107    T::Bits: PrimInt,
108{
109    if x.is_zero() || x.is_one() {
110        return x;
111    }
112
113    let mut pow2 = T::one();
114    let mut result;
115
116    if x < T::one() {
117        while x <= pow2 * pow2 {
118            pow2 >>= 1;
119        }
120
121        result = pow2;
122    } else {
123        // x >= T::one()
124        while pow2 * pow2 <= x {
125            pow2 <<= 1;
126        }
127
128        result = pow2 >> 1;
129    }
130
131    for _ in 0..T::NUM_BITS {
132        pow2 >>= 1;
133        let next_result = result + pow2;
134        if next_result * next_result <= x {
135            result = next_result;
136        }
137    }
138
139    result
140}
141
142fn powf_01<T>(mut x: T, n: T) -> T
143where
144    T: Fixed + Helper,
145    T::Bits: PrimInt + std::fmt::Debug,
146{
147    let mut n = n.to_bits();
148    if n.is_zero() {
149        panic!("fractional exponent should not be zero");
150    }
151
152    let top = T::Bits::one() << ((T::Frac::U32 - 1) as usize);
153    let mask = !(T::Bits::one() << ((T::Frac::U32) as usize));
154    let mut acc = None;
155
156    while !n.is_zero() {
157        x = sqrt(x);
158        if !(n & top).is_zero() {
159            acc = match acc {
160                Some(acc) => Some(acc * x),
161                None => Some(x),
162            };
163        }
164        n = (n << 1) & mask;
165    }
166
167    acc.unwrap()
168}
169
170fn powf_positive<T>(x: T, n: T) -> T
171where
172    T: Fixed + Helper,
173    T::Bits: PrimInt + std::fmt::Debug,
174{
175    assert!(Helper::is_positive(n), "exponent should be positive");
176    let ni = n.int().to_num();
177    let powi = if ni == 0 {
178        T::from_num(1)
179    } else {
180        powi_positive(x, ni)
181    };
182    let frac = n.frac();
183
184    if frac.is_zero() {
185        powi
186    } else {
187        assert!(Helper::is_positive(x), "base should be positive");
188        powi * powf_01(x, frac)
189    }
190}
191
192macro_rules! impl_fixed_pow {
193    ($fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
194        impl<Frac> FixedPowI for $fixed<Frac>
195        where
196            Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
197        {
198            fn powi(self, n: i32) -> Self {
199                if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= 0 {
200                    panic!(
201                        "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
202                        self, n
203                    );
204                }
205
206                match n.cmp(&0) {
207                    Ordering::Greater => powi_positive(self, n),
208                    Ordering::Equal => Self::from_bits(1 << Frac::U32),
209                    Ordering::Less => powi_positive(Self::from_bits(1 << Frac::U32) / self, -n),
210                }
211            }
212        }
213
214        impl<Frac> FixedPowF for $fixed<Frac>
215        where
216            Frac: $le_eq + IsLessOrEqual<$le_eq_one, Output = True>,
217        {
218            fn powf(self, n: Self) -> Self {
219                let zero = Self::from_bits(0);
220
221                if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= zero {
222                    panic!(
223                        "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
224                        self, n
225                    );
226                }
227
228                match n.cmp(&zero) {
229                    Ordering::Greater => powf_positive(self, n),
230                    Ordering::Equal => Self::from_bits(1 << Frac::U32),
231                    Ordering::Less => {
232                        powf_positive(Self::from_bits(1 << Frac::U32) / self, Helper::neg(n))
233                    },
234                }
235            }
236        }
237    };
238}
239
240impl_fixed_pow!(FixedI8, LeEqU8, U6);
241impl_fixed_pow!(FixedI16, LeEqU16, U14);
242impl_fixed_pow!(FixedI32, LeEqU32, U30);
243impl_fixed_pow!(FixedI64, LeEqU64, U62);
244impl_fixed_pow!(FixedI128, LeEqU128, U126);
245
246impl_fixed_pow!(FixedU8, LeEqU8, U7);
247impl_fixed_pow!(FixedU16, LeEqU16, U15);
248impl_fixed_pow!(FixedU32, LeEqU32, U31);
249impl_fixed_pow!(FixedU64, LeEqU64, U63);
250impl_fixed_pow!(FixedU128, LeEqU128, U127);
251
252trait Helper {
253    const NUM_BITS: u32;
254    fn is_positive(self) -> bool;
255    fn is_one(self) -> bool;
256    fn one() -> Self;
257    fn neg(self) -> Self;
258}
259
260macro_rules! impl_sign_helper {
261    (signed, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
262        impl<Frac: $le_eq> Helper for $fixed<Frac>
263        where
264            Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
265        {
266            const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
267            fn is_positive(self) -> bool {
268                $fixed::is_positive(self)
269            }
270            fn is_one(self) -> bool {
271                <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
272            }
273            fn one() -> Self {
274                assert!(
275                    <LeEq<Frac, $le_eq_one>>::BOOL,
276                    "one should be possible to represent"
277                );
278                Self::from_bits(1 << Frac::U32)
279            }
280            fn neg(self) -> Self {
281                -self
282            }
283        }
284    };
285    (unsigned, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
286        impl<Frac: $le_eq> Helper for $fixed<Frac>
287        where
288            Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
289        {
290            const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
291            fn is_positive(self) -> bool {
292                self != Self::from_bits(0)
293            }
294            fn is_one(self) -> bool {
295                <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
296            }
297            fn one() -> Self {
298                assert!(
299                    <LeEq<Frac, $le_eq_one>>::BOOL,
300                    "one should be possible to represent"
301                );
302                Self::from_bits(1 << Frac::U32)
303            }
304            fn neg(self) -> Self {
305                panic!("cannot negate an unsigned number")
306            }
307        }
308    };
309}
310
311impl_sign_helper!(signed, FixedI8, LeEqU8, U6);
312impl_sign_helper!(signed, FixedI16, LeEqU16, U14);
313impl_sign_helper!(signed, FixedI32, LeEqU32, U30);
314impl_sign_helper!(signed, FixedI64, LeEqU64, U62);
315impl_sign_helper!(signed, FixedI128, LeEqU128, U126);
316
317impl_sign_helper!(unsigned, FixedU8, LeEqU8, U7);
318impl_sign_helper!(unsigned, FixedU16, LeEqU16, U15);
319impl_sign_helper!(unsigned, FixedU32, LeEqU32, U31);
320impl_sign_helper!(unsigned, FixedU64, LeEqU64, U63);
321impl_sign_helper!(unsigned, FixedU128, LeEqU128, U127);
322
323#[cfg(test)]
324mod tests {
325    use fixed::types::{I1F31, I32F32, U0F32, U32F32};
326
327    use super::*;
328
329    fn powi_positive_naive<T: Fixed>(x: T, n: i32) -> T {
330        assert!(n > 0, "exponent should be positive");
331        let mut acc = x;
332        for _ in 1..n {
333            acc *= x;
334        }
335        acc
336    }
337
338    fn delta<T: Fixed>(a: T, b: T) -> T {
339        Ord::max(a, b) - Ord::min(a, b)
340    }
341
342    #[test]
343    fn test_powi_positive() {
344        let epsilon = I32F32::from_num(0.0001);
345
346        let test_cases = &[
347            (I32F32::from_num(1.0), 42),
348            (I32F32::from_num(0.8), 7),
349            (I32F32::from_num(1.2), 11),
350            (I32F32::from_num(2.6), 16),
351            (I32F32::from_num(-2.2), 4),
352            (I32F32::from_num(-2.2), 5),
353        ];
354
355        for &(x, n) in test_cases {
356            assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
357        }
358
359        let epsilon = U32F32::from_num(0.0001);
360
361        let test_cases = &[
362            (U32F32::from_num(1.0), 42),
363            (U32F32::from_num(0.8), 7),
364            (U32F32::from_num(1.2), 11),
365            (U32F32::from_num(2.6), 16),
366        ];
367
368        for &(x, n) in test_cases {
369            assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
370        }
371    }
372
373    #[test]
374    fn test_powi_positive_sub_one() {
375        let epsilon = I1F31::from_num(0.0001);
376
377        let test_cases = &[
378            (I1F31::from_num(0.5), 3),
379            (I1F31::from_num(0.8), 5),
380            (I1F31::from_num(0.2), 7),
381            (I1F31::from_num(0.6), 9),
382            (I1F31::from_num(-0.6), 3),
383            (I1F31::from_num(-0.6), 4),
384        ];
385
386        for &(x, n) in test_cases {
387            assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
388        }
389
390        let epsilon = U0F32::from_num(0.0001);
391
392        let test_cases = &[
393            (U0F32::from_num(0.5), 3),
394            (U0F32::from_num(0.8), 5),
395            (U0F32::from_num(0.2), 7),
396            (U0F32::from_num(0.6), 9),
397        ];
398
399        for &(x, n) in test_cases {
400            assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
401        }
402    }
403
404    #[test]
405    fn test_powi_non_positive() {
406        let epsilon = I32F32::from_num(0.0001);
407
408        let test_cases = &[
409            (I32F32::from_num(1.0), -17),
410            (I32F32::from_num(0.8), -7),
411            (I32F32::from_num(1.2), -9),
412            (I32F32::from_num(2.6), -3),
413        ];
414
415        for &(x, n) in test_cases {
416            assert!((powi_positive_naive(I32F32::from_num(1) / x, -n) - x.powi(n)).abs() < epsilon);
417        }
418
419        assert_eq!(I32F32::from_num(1), I32F32::from_num(42).powi(0));
420        assert_eq!(I32F32::from_num(1), I32F32::from_num(-42).powi(0));
421        assert_eq!(I32F32::from_num(1), I32F32::from_num(0).powi(0));
422    }
423
424    fn powf_float<T: Fixed>(x: T, n: T) -> T {
425        let x: f64 = x.to_num();
426        let n: f64 = n.to_num();
427        T::from_num(x.powf(n))
428    }
429
430    #[test]
431    fn test_powf() {
432        let epsilon = I32F32::from_num(0.0001);
433
434        let test_cases = &[
435            (I32F32::from_num(1.0), I32F32::from_num(7.2)),
436            (I32F32::from_num(0.8), I32F32::from_num(-4.5)),
437            (I32F32::from_num(1.2), I32F32::from_num(5.0)),
438            (I32F32::from_num(2.6), I32F32::from_num(-6.7)),
439            (I32F32::from_num(-1.2), I32F32::from_num(4.0)),
440            (I32F32::from_num(-1.2), I32F32::from_num(-3.0)),
441        ];
442
443        for &(x, n) in test_cases {
444            assert!((powf_float(x, n) - x.powf(n)).abs() < epsilon);
445        }
446
447        let epsilon = U32F32::from_num(0.0001);
448
449        let test_cases = &[
450            (U32F32::from_num(1.0), U32F32::from_num(7.2)),
451            (U32F32::from_num(0.8), U32F32::from_num(4.5)),
452            (U32F32::from_num(1.2), U32F32::from_num(5.0)),
453            (U32F32::from_num(2.6), U32F32::from_num(6.7)),
454        ];
455
456        for &(x, n) in test_cases {
457            assert!(delta(powf_float(x, n), x.powf(n)) < epsilon);
458        }
459    }
460}