1#![no_std]
14use core::cmp::{Ord, Ordering};
15use core::fmt;
16
17use fixed::{
18 traits::Fixed,
19 types::extra::{LeEqU128, LeEqU16, LeEqU32, LeEqU64, LeEqU8},
20 FixedI128, FixedI16, FixedI32, FixedI64, FixedI8, FixedU128, FixedU16, FixedU32, FixedU64,
21 FixedU8,
22};
23use num_traits::{One, PrimInt, Zero};
24use typenum::{
25 Bit, IsLessOrEqual, LeEq, True, Unsigned, U126, U127, U14, U15, U30, U31, U6, U62, U63, U7,
26};
27
28pub trait FixedPow: Fixed + FixedPowI + FixedPowF {}
30impl<T: Fixed + FixedPowI + FixedPowF> FixedPow for T {}
31
32pub trait FixedPowI: Fixed {
34 fn powi(self, n: i32) -> Self;
52}
53
54pub trait FixedPowF: Fixed {
58 fn powf(self, n: Self) -> Self;
74}
75
76#[cfg(not(feature = "fast-powi"))]
77fn powi_positive<T: Fixed>(x: T, n: i32) -> T {
78 assert!(n > 0, "exponent should be positive");
79 let mut res = x;
80 for _ in 1..n {
81 res = res * x;
82 }
83 res
84}
85
86#[cfg(feature = "fast-powi")]
89fn powi_positive<T: Fixed>(mut x: T, mut n: i32) -> T {
90 assert!(n > 0, "exponent should be positive");
91 let mut acc = x;
92 n -= 1;
93
94 while n > 0 {
95 if n & 1 == 1 {
96 acc *= x;
97 }
98 x *= x;
99 n >>= 1;
100 }
101
102 acc
103}
104
105fn sqrt<T>(x: T) -> T
106where
107 T: Fixed + Helper,
108 T::Bits: PrimInt,
109{
110 if x.is_zero() || x.is_one() {
111 return x;
112 }
113
114 let mut pow2 = T::one();
115 let mut result;
116
117 if x < T::one() {
118 while x <= pow2 * pow2 {
119 pow2 >>= 1;
120 }
121
122 result = pow2;
123 } else {
124 while pow2 * pow2 <= x {
126 pow2 <<= 1;
127 }
128
129 result = pow2 >> 1;
130 }
131
132 for _ in 0..T::NUM_BITS {
133 pow2 >>= 1;
134 let next_result = result + pow2;
135 if next_result * next_result <= x {
136 result = next_result;
137 }
138 }
139
140 result
141}
142
143fn powf_01<T>(mut x: T, n: T) -> T
144where
145 T: Fixed + Helper,
146 T::Bits: PrimInt + fmt::Debug,
147{
148 let mut n = n.to_bits();
149 if n.is_zero() {
150 panic!("fractional exponent should not be zero");
151 }
152
153 let top = T::Bits::one() << ((T::Frac::U32 - 1) as usize);
154 let mask = !(T::Bits::one() << ((T::Frac::U32) as usize));
155 let mut acc = None;
156
157 while !n.is_zero() {
158 x = sqrt(x);
159 if !(n & top).is_zero() {
160 acc = match acc {
161 Some(acc) => Some(acc * x),
162 None => Some(x),
163 };
164 }
165 n = (n << 1) & mask;
166 }
167
168 acc.unwrap()
169}
170
171fn powf_positive<T>(x: T, n: T) -> T
172where
173 T: Fixed + Helper,
174 T::Bits: PrimInt + fmt::Debug,
175{
176 assert!(Helper::is_positive(n), "exponent should be positive");
177 let ni = n.int().to_num();
178 let powi = if ni == 0 {
179 T::from_num(1)
180 } else {
181 powi_positive(x, ni)
182 };
183 let frac = n.frac();
184
185 if frac.is_zero() {
186 powi
187 } else {
188 assert!(Helper::is_positive(x), "base should be positive");
189 powi * powf_01(x, frac)
190 }
191}
192
193macro_rules! impl_fixed_pow {
194 ($fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
195 impl<Frac> FixedPowI for $fixed<Frac>
196 where
197 Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
198 {
199 fn powi(self, n: i32) -> Self {
200 if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= 0 {
201 panic!(
202 "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
203 self, n
204 );
205 }
206
207 match n.cmp(&0) {
208 Ordering::Greater => powi_positive(self, n),
209 Ordering::Equal => Self::from_bits(1 << Frac::U32),
210 Ordering::Less => powi_positive(Self::from_bits(1 << Frac::U32) / self, -n),
211 }
212 }
213 }
214
215 impl<Frac> FixedPowF for $fixed<Frac>
216 where
217 Frac: $le_eq + IsLessOrEqual<$le_eq_one, Output = True>,
218 {
219 fn powf(self, n: Self) -> Self {
220 let zero = Self::from_bits(0);
221
222 if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= zero {
223 panic!(
224 "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
225 self, n
226 );
227 }
228
229 match n.cmp(&zero) {
230 Ordering::Greater => powf_positive(self, n),
231 Ordering::Equal => Self::from_bits(1 << Frac::U32),
232 Ordering::Less => {
233 powf_positive(Self::from_bits(1 << Frac::U32) / self, Helper::neg(n))
234 },
235 }
236 }
237 }
238 };
239}
240
241impl_fixed_pow!(FixedI8, LeEqU8, U6);
242impl_fixed_pow!(FixedI16, LeEqU16, U14);
243impl_fixed_pow!(FixedI32, LeEqU32, U30);
244impl_fixed_pow!(FixedI64, LeEqU64, U62);
245impl_fixed_pow!(FixedI128, LeEqU128, U126);
246
247impl_fixed_pow!(FixedU8, LeEqU8, U7);
248impl_fixed_pow!(FixedU16, LeEqU16, U15);
249impl_fixed_pow!(FixedU32, LeEqU32, U31);
250impl_fixed_pow!(FixedU64, LeEqU64, U63);
251impl_fixed_pow!(FixedU128, LeEqU128, U127);
252
253trait Helper {
254 const NUM_BITS: u32;
255 fn is_positive(self) -> bool;
256 fn is_one(self) -> bool;
257 fn one() -> Self;
258 fn neg(self) -> Self;
259}
260
261macro_rules! impl_sign_helper {
262 (signed, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
263 impl<Frac: $le_eq> Helper for $fixed<Frac>
264 where
265 Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
266 {
267 const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
268 fn is_positive(self) -> bool {
269 $fixed::is_positive(self)
270 }
271 fn is_one(self) -> bool {
272 <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
273 }
274 fn one() -> Self {
275 assert!(
276 <LeEq<Frac, $le_eq_one>>::BOOL,
277 "one should be possible to represent"
278 );
279 Self::from_bits(1 << Frac::U32)
280 }
281 fn neg(self) -> Self {
282 -self
283 }
284 }
285 };
286 (unsigned, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
287 impl<Frac: $le_eq> Helper for $fixed<Frac>
288 where
289 Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
290 {
291 const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
292 fn is_positive(self) -> bool {
293 self != Self::from_bits(0)
294 }
295 fn is_one(self) -> bool {
296 <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
297 }
298 fn one() -> Self {
299 assert!(
300 <LeEq<Frac, $le_eq_one>>::BOOL,
301 "one should be possible to represent"
302 );
303 Self::from_bits(1 << Frac::U32)
304 }
305 fn neg(self) -> Self {
306 panic!("cannot negate an unsigned number")
307 }
308 }
309 };
310}
311
312impl_sign_helper!(signed, FixedI8, LeEqU8, U6);
313impl_sign_helper!(signed, FixedI16, LeEqU16, U14);
314impl_sign_helper!(signed, FixedI32, LeEqU32, U30);
315impl_sign_helper!(signed, FixedI64, LeEqU64, U62);
316impl_sign_helper!(signed, FixedI128, LeEqU128, U126);
317
318impl_sign_helper!(unsigned, FixedU8, LeEqU8, U7);
319impl_sign_helper!(unsigned, FixedU16, LeEqU16, U15);
320impl_sign_helper!(unsigned, FixedU32, LeEqU32, U31);
321impl_sign_helper!(unsigned, FixedU64, LeEqU64, U63);
322impl_sign_helper!(unsigned, FixedU128, LeEqU128, U127);
323
324#[cfg(test)]
325mod tests {
326 use fixed::types::{I1F31, I32F32, U0F32, U32F32};
327
328 use super::*;
329
330 fn powi_positive_naive<T: Fixed>(x: T, n: i32) -> T {
331 assert!(n > 0, "exponent should be positive");
332 let mut acc = x;
333 for _ in 1..n {
334 acc *= x;
335 }
336 acc
337 }
338
339 fn delta<T: Fixed>(a: T, b: T) -> T {
340 Ord::max(a, b) - Ord::min(a, b)
341 }
342
343 #[test]
344 fn test_powi_positive() {
345 let epsilon = I32F32::from_num(0.0001);
346
347 let test_cases = &[
348 (I32F32::from_num(1.0), 42),
349 (I32F32::from_num(0.8), 7),
350 (I32F32::from_num(1.2), 11),
351 (I32F32::from_num(2.6), 16),
352 (I32F32::from_num(-2.2), 4),
353 (I32F32::from_num(-2.2), 5),
354 ];
355
356 for &(x, n) in test_cases {
357 assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
358 }
359
360 let epsilon = U32F32::from_num(0.0001);
361
362 let test_cases = &[
363 (U32F32::from_num(1.0), 42),
364 (U32F32::from_num(0.8), 7),
365 (U32F32::from_num(1.2), 11),
366 (U32F32::from_num(2.6), 16),
367 ];
368
369 for &(x, n) in test_cases {
370 assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
371 }
372 }
373
374 #[test]
375 fn test_powi_positive_sub_one() {
376 let epsilon = I1F31::from_num(0.0001);
377
378 let test_cases = &[
379 (I1F31::from_num(0.5), 3),
380 (I1F31::from_num(0.8), 5),
381 (I1F31::from_num(0.2), 7),
382 (I1F31::from_num(0.6), 9),
383 (I1F31::from_num(-0.6), 3),
384 (I1F31::from_num(-0.6), 4),
385 ];
386
387 for &(x, n) in test_cases {
388 assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
389 }
390
391 let epsilon = U0F32::from_num(0.0001);
392
393 let test_cases = &[
394 (U0F32::from_num(0.5), 3),
395 (U0F32::from_num(0.8), 5),
396 (U0F32::from_num(0.2), 7),
397 (U0F32::from_num(0.6), 9),
398 ];
399
400 for &(x, n) in test_cases {
401 assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
402 }
403 }
404
405 #[test]
406 fn test_powi_non_positive() {
407 let epsilon = I32F32::from_num(0.0001);
408
409 let test_cases = &[
410 (I32F32::from_num(1.0), -17),
411 (I32F32::from_num(0.8), -7),
412 (I32F32::from_num(1.2), -9),
413 (I32F32::from_num(2.6), -3),
414 ];
415
416 for &(x, n) in test_cases {
417 assert!((powi_positive_naive(I32F32::from_num(1) / x, -n) - x.powi(n)).abs() < epsilon);
418 }
419
420 assert_eq!(I32F32::from_num(1), I32F32::from_num(42).powi(0));
421 assert_eq!(I32F32::from_num(1), I32F32::from_num(-42).powi(0));
422 assert_eq!(I32F32::from_num(1), I32F32::from_num(0).powi(0));
423 }
424
425 fn powf_float<T: Fixed>(x: T, n: T) -> T {
426 let x: f64 = x.to_num();
427 let n: f64 = n.to_num();
428 T::from_num(x.powf(n))
429 }
430
431 #[test]
432 fn test_powf() {
433 let epsilon = I32F32::from_num(0.0001);
434
435 let test_cases = &[
436 (I32F32::from_num(1.0), I32F32::from_num(7.2)),
437 (I32F32::from_num(0.8), I32F32::from_num(-4.5)),
438 (I32F32::from_num(1.2), I32F32::from_num(5.0)),
439 (I32F32::from_num(2.6), I32F32::from_num(-6.7)),
440 (I32F32::from_num(-1.2), I32F32::from_num(4.0)),
441 (I32F32::from_num(-1.2), I32F32::from_num(-3.0)),
442 ];
443
444 for &(x, n) in test_cases {
445 assert!((powf_float(x, n) - x.powf(n)).abs() < epsilon);
446 }
447
448 let epsilon = U32F32::from_num(0.0001);
449
450 let test_cases = &[
451 (U32F32::from_num(1.0), U32F32::from_num(7.2)),
452 (U32F32::from_num(0.8), U32F32::from_num(4.5)),
453 (U32F32::from_num(1.2), U32F32::from_num(5.0)),
454 (U32F32::from_num(2.6), U32F32::from_num(6.7)),
455 ];
456
457 for &(x, n) in test_cases {
458 assert!(delta(powf_float(x, n), x.powf(n)) < epsilon);
459 }
460 }
461}