1use core::cmp::{Ord, Ordering};
15
16use fixed::{
17 traits::Fixed,
18 types::extra::{LeEqU128, LeEqU16, LeEqU32, LeEqU64, LeEqU8},
19 FixedI128, FixedI16, FixedI32, FixedI64, FixedI8, FixedU128, FixedU16, FixedU32, FixedU64,
20 FixedU8,
21};
22use num_traits::{One, PrimInt, Zero};
23use typenum::{
24 Bit, IsLessOrEqual, LeEq, True, Unsigned, U126, U127, U14, U15, U30, U31, U6, U62, U63, U7,
25};
26
27pub trait FixedPow: Fixed + FixedPowI + FixedPowF {}
29impl<T: Fixed + FixedPowI + FixedPowF> FixedPow for T {}
30
31pub trait FixedPowI: Fixed {
33 fn powi(self, n: i32) -> Self;
51}
52
53pub trait FixedPowF: Fixed {
57 fn powf(self, n: Self) -> Self;
73}
74
75#[cfg(not(feature = "fast-powi"))]
76fn powi_positive<T: Fixed>(x: T, n: i32) -> T {
77 assert!(n > 0, "exponent should be positive");
78 let mut res = x;
79 for _ in 1..n {
80 res = res * x;
81 }
82 res
83}
84
85#[cfg(feature = "fast-powi")]
88fn powi_positive<T: Fixed>(mut x: T, mut n: i32) -> T {
89 assert!(n > 0, "exponent should be positive");
90 let mut acc = x;
91 n -= 1;
92
93 while n > 0 {
94 if n & 1 == 1 {
95 acc *= x;
96 }
97 x *= x;
98 n >>= 1;
99 }
100
101 acc
102}
103
104fn sqrt<T>(x: T) -> T
105where
106 T: Fixed + Helper,
107 T::Bits: PrimInt,
108{
109 if x.is_zero() || x.is_one() {
110 return x;
111 }
112
113 let mut pow2 = T::one();
114 let mut result;
115
116 if x < T::one() {
117 while x <= pow2 * pow2 {
118 pow2 >>= 1;
119 }
120
121 result = pow2;
122 } else {
123 while pow2 * pow2 <= x {
125 pow2 <<= 1;
126 }
127
128 result = pow2 >> 1;
129 }
130
131 for _ in 0..T::NUM_BITS {
132 pow2 >>= 1;
133 let next_result = result + pow2;
134 if next_result * next_result <= x {
135 result = next_result;
136 }
137 }
138
139 result
140}
141
142fn powf_01<T>(mut x: T, n: T) -> T
143where
144 T: Fixed + Helper,
145 T::Bits: PrimInt + std::fmt::Debug,
146{
147 let mut n = n.to_bits();
148 if n.is_zero() {
149 panic!("fractional exponent should not be zero");
150 }
151
152 let top = T::Bits::one() << ((T::Frac::U32 - 1) as usize);
153 let mask = !(T::Bits::one() << ((T::Frac::U32) as usize));
154 let mut acc = None;
155
156 while !n.is_zero() {
157 x = sqrt(x);
158 if !(n & top).is_zero() {
159 acc = match acc {
160 Some(acc) => Some(acc * x),
161 None => Some(x),
162 };
163 }
164 n = (n << 1) & mask;
165 }
166
167 acc.unwrap()
168}
169
170fn powf_positive<T>(x: T, n: T) -> T
171where
172 T: Fixed + Helper,
173 T::Bits: PrimInt + std::fmt::Debug,
174{
175 assert!(Helper::is_positive(n), "exponent should be positive");
176 let ni = n.int().to_num();
177 let powi = if ni == 0 {
178 T::from_num(1)
179 } else {
180 powi_positive(x, ni)
181 };
182 let frac = n.frac();
183
184 if frac.is_zero() {
185 powi
186 } else {
187 assert!(Helper::is_positive(x), "base should be positive");
188 powi * powf_01(x, frac)
189 }
190}
191
192macro_rules! impl_fixed_pow {
193 ($fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
194 impl<Frac> FixedPowI for $fixed<Frac>
195 where
196 Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
197 {
198 fn powi(self, n: i32) -> Self {
199 if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= 0 {
200 panic!(
201 "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
202 self, n
203 );
204 }
205
206 match n.cmp(&0) {
207 Ordering::Greater => powi_positive(self, n),
208 Ordering::Equal => Self::from_bits(1 << Frac::U32),
209 Ordering::Less => powi_positive(Self::from_bits(1 << Frac::U32) / self, -n),
210 }
211 }
212 }
213
214 impl<Frac> FixedPowF for $fixed<Frac>
215 where
216 Frac: $le_eq + IsLessOrEqual<$le_eq_one, Output = True>,
217 {
218 fn powf(self, n: Self) -> Self {
219 let zero = Self::from_bits(0);
220
221 if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= zero {
222 panic!(
223 "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
224 self, n
225 );
226 }
227
228 match n.cmp(&zero) {
229 Ordering::Greater => powf_positive(self, n),
230 Ordering::Equal => Self::from_bits(1 << Frac::U32),
231 Ordering::Less => {
232 powf_positive(Self::from_bits(1 << Frac::U32) / self, Helper::neg(n))
233 },
234 }
235 }
236 }
237 };
238}
239
240impl_fixed_pow!(FixedI8, LeEqU8, U6);
241impl_fixed_pow!(FixedI16, LeEqU16, U14);
242impl_fixed_pow!(FixedI32, LeEqU32, U30);
243impl_fixed_pow!(FixedI64, LeEqU64, U62);
244impl_fixed_pow!(FixedI128, LeEqU128, U126);
245
246impl_fixed_pow!(FixedU8, LeEqU8, U7);
247impl_fixed_pow!(FixedU16, LeEqU16, U15);
248impl_fixed_pow!(FixedU32, LeEqU32, U31);
249impl_fixed_pow!(FixedU64, LeEqU64, U63);
250impl_fixed_pow!(FixedU128, LeEqU128, U127);
251
252trait Helper {
253 const NUM_BITS: u32;
254 fn is_positive(self) -> bool;
255 fn is_one(self) -> bool;
256 fn one() -> Self;
257 fn neg(self) -> Self;
258}
259
260macro_rules! impl_sign_helper {
261 (signed, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
262 impl<Frac: $le_eq> Helper for $fixed<Frac>
263 where
264 Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
265 {
266 const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
267 fn is_positive(self) -> bool {
268 $fixed::is_positive(self)
269 }
270 fn is_one(self) -> bool {
271 <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
272 }
273 fn one() -> Self {
274 assert!(
275 <LeEq<Frac, $le_eq_one>>::BOOL,
276 "one should be possible to represent"
277 );
278 Self::from_bits(1 << Frac::U32)
279 }
280 fn neg(self) -> Self {
281 -self
282 }
283 }
284 };
285 (unsigned, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
286 impl<Frac: $le_eq> Helper for $fixed<Frac>
287 where
288 Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
289 {
290 const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
291 fn is_positive(self) -> bool {
292 self != Self::from_bits(0)
293 }
294 fn is_one(self) -> bool {
295 <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
296 }
297 fn one() -> Self {
298 assert!(
299 <LeEq<Frac, $le_eq_one>>::BOOL,
300 "one should be possible to represent"
301 );
302 Self::from_bits(1 << Frac::U32)
303 }
304 fn neg(self) -> Self {
305 panic!("cannot negate an unsigned number")
306 }
307 }
308 };
309}
310
311impl_sign_helper!(signed, FixedI8, LeEqU8, U6);
312impl_sign_helper!(signed, FixedI16, LeEqU16, U14);
313impl_sign_helper!(signed, FixedI32, LeEqU32, U30);
314impl_sign_helper!(signed, FixedI64, LeEqU64, U62);
315impl_sign_helper!(signed, FixedI128, LeEqU128, U126);
316
317impl_sign_helper!(unsigned, FixedU8, LeEqU8, U7);
318impl_sign_helper!(unsigned, FixedU16, LeEqU16, U15);
319impl_sign_helper!(unsigned, FixedU32, LeEqU32, U31);
320impl_sign_helper!(unsigned, FixedU64, LeEqU64, U63);
321impl_sign_helper!(unsigned, FixedU128, LeEqU128, U127);
322
323#[cfg(test)]
324mod tests {
325 use fixed::types::{I1F31, I32F32, U0F32, U32F32};
326
327 use super::*;
328
329 fn powi_positive_naive<T: Fixed>(x: T, n: i32) -> T {
330 assert!(n > 0, "exponent should be positive");
331 let mut acc = x;
332 for _ in 1..n {
333 acc *= x;
334 }
335 acc
336 }
337
338 fn delta<T: Fixed>(a: T, b: T) -> T {
339 Ord::max(a, b) - Ord::min(a, b)
340 }
341
342 #[test]
343 fn test_powi_positive() {
344 let epsilon = I32F32::from_num(0.0001);
345
346 let test_cases = &[
347 (I32F32::from_num(1.0), 42),
348 (I32F32::from_num(0.8), 7),
349 (I32F32::from_num(1.2), 11),
350 (I32F32::from_num(2.6), 16),
351 (I32F32::from_num(-2.2), 4),
352 (I32F32::from_num(-2.2), 5),
353 ];
354
355 for &(x, n) in test_cases {
356 assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
357 }
358
359 let epsilon = U32F32::from_num(0.0001);
360
361 let test_cases = &[
362 (U32F32::from_num(1.0), 42),
363 (U32F32::from_num(0.8), 7),
364 (U32F32::from_num(1.2), 11),
365 (U32F32::from_num(2.6), 16),
366 ];
367
368 for &(x, n) in test_cases {
369 assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
370 }
371 }
372
373 #[test]
374 fn test_powi_positive_sub_one() {
375 let epsilon = I1F31::from_num(0.0001);
376
377 let test_cases = &[
378 (I1F31::from_num(0.5), 3),
379 (I1F31::from_num(0.8), 5),
380 (I1F31::from_num(0.2), 7),
381 (I1F31::from_num(0.6), 9),
382 (I1F31::from_num(-0.6), 3),
383 (I1F31::from_num(-0.6), 4),
384 ];
385
386 for &(x, n) in test_cases {
387 assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
388 }
389
390 let epsilon = U0F32::from_num(0.0001);
391
392 let test_cases = &[
393 (U0F32::from_num(0.5), 3),
394 (U0F32::from_num(0.8), 5),
395 (U0F32::from_num(0.2), 7),
396 (U0F32::from_num(0.6), 9),
397 ];
398
399 for &(x, n) in test_cases {
400 assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
401 }
402 }
403
404 #[test]
405 fn test_powi_non_positive() {
406 let epsilon = I32F32::from_num(0.0001);
407
408 let test_cases = &[
409 (I32F32::from_num(1.0), -17),
410 (I32F32::from_num(0.8), -7),
411 (I32F32::from_num(1.2), -9),
412 (I32F32::from_num(2.6), -3),
413 ];
414
415 for &(x, n) in test_cases {
416 assert!((powi_positive_naive(I32F32::from_num(1) / x, -n) - x.powi(n)).abs() < epsilon);
417 }
418
419 assert_eq!(I32F32::from_num(1), I32F32::from_num(42).powi(0));
420 assert_eq!(I32F32::from_num(1), I32F32::from_num(-42).powi(0));
421 assert_eq!(I32F32::from_num(1), I32F32::from_num(0).powi(0));
422 }
423
424 fn powf_float<T: Fixed>(x: T, n: T) -> T {
425 let x: f64 = x.to_num();
426 let n: f64 = n.to_num();
427 T::from_num(x.powf(n))
428 }
429
430 #[test]
431 fn test_powf() {
432 let epsilon = I32F32::from_num(0.0001);
433
434 let test_cases = &[
435 (I32F32::from_num(1.0), I32F32::from_num(7.2)),
436 (I32F32::from_num(0.8), I32F32::from_num(-4.5)),
437 (I32F32::from_num(1.2), I32F32::from_num(5.0)),
438 (I32F32::from_num(2.6), I32F32::from_num(-6.7)),
439 (I32F32::from_num(-1.2), I32F32::from_num(4.0)),
440 (I32F32::from_num(-1.2), I32F32::from_num(-3.0)),
441 ];
442
443 for &(x, n) in test_cases {
444 assert!((powf_float(x, n) - x.powf(n)).abs() < epsilon);
445 }
446
447 let epsilon = U32F32::from_num(0.0001);
448
449 let test_cases = &[
450 (U32F32::from_num(1.0), U32F32::from_num(7.2)),
451 (U32F32::from_num(0.8), U32F32::from_num(4.5)),
452 (U32F32::from_num(1.2), U32F32::from_num(5.0)),
453 (U32F32::from_num(2.6), U32F32::from_num(6.7)),
454 ];
455
456 for &(x, n) in test_cases {
457 assert!(delta(powf_float(x, n), x.powf(n)) < epsilon);
458 }
459 }
460}