1use std::cmp::{Ord, Ordering};
15
16use fixed::traits::Fixed;
17use fixed::types::extra::{LeEqU128, LeEqU16, LeEqU32, LeEqU64, LeEqU8};
18use fixed::{
19 FixedI128, FixedI16, FixedI32, FixedI64, FixedI8, FixedU128, FixedU16, FixedU32, FixedU64,
20 FixedU8,
21};
22use num_traits::{One, PrimInt, Zero};
23use typenum::{
24 Bit, IsLessOrEqual, LeEq, True, Unsigned, U126, U127, U14, U15, U30, U31, U6, U62, U63, U7,
25};
26
27pub trait FixedPow: Fixed + FixedPowI + FixedPowF {}
29impl<T: Fixed + FixedPowI + FixedPowF> FixedPow for T {}
30
31pub trait FixedPowI: Fixed {
33 fn powi(self, n: i32) -> Self;
51}
52
53pub trait FixedPowF: Fixed {
57 fn powf(self, n: Self) -> Self;
73}
74
75fn powi_positive<T: Fixed>(mut x: T, mut n: i32) -> T {
76 assert!(n > 0, "exponent should be positive");
77
78 let mut acc = x;
79 n -= 1;
80
81 while n > 0 {
82 if n & 1 == 1 {
83 acc *= x;
84 }
85 x *= x;
86 n >>= 1;
87 }
88
89 acc
90}
91
92fn sqrt<T>(x: T) -> T
93where
94 T: Fixed + Helper,
95 T::Bits: PrimInt,
96{
97 if x.is_zero() || x.is_one() {
98 return x;
99 }
100
101 let mut pow2 = T::one();
102 let mut result;
103
104 if x < T::one() {
105 while x <= pow2 * pow2 {
106 pow2 >>= 1;
107 }
108
109 result = pow2;
110 } else {
111 while pow2 * pow2 <= x {
113 pow2 <<= 1;
114 }
115
116 result = pow2 >> 1;
117 }
118
119 for _ in 0..T::NUM_BITS {
120 pow2 >>= 1;
121 let next_result = result + pow2;
122 if next_result * next_result <= x {
123 result = next_result;
124 }
125 }
126
127 result
128}
129
130fn powf_01<T>(mut x: T, n: T) -> T
131where
132 T: Fixed + Helper,
133 T::Bits: PrimInt + std::fmt::Debug,
134{
135 let mut n = n.to_bits();
136 if n.is_zero() {
137 panic!("fractional exponent should not be zero");
138 }
139
140 let top = T::Bits::one() << ((T::Frac::U32 - 1) as usize);
141 let mask = !(T::Bits::one() << ((T::Frac::U32) as usize));
142 let mut acc = None;
143
144 while !n.is_zero() {
145 x = sqrt(x);
146 if !(n & top).is_zero() {
147 acc = match acc {
148 Some(acc) => Some(acc * x),
149 None => Some(x),
150 };
151 }
152 n = (n << 1) & mask;
153 }
154
155 acc.unwrap()
156}
157
158fn powf_positive<T>(x: T, n: T) -> T
159where
160 T: Fixed + Helper,
161 T::Bits: PrimInt + std::fmt::Debug,
162{
163 assert!(Helper::is_positive(n), "exponent should be positive");
164
165 let powi = powi_positive(x, n.int().to_num());
166 let frac = n.frac();
167
168 if frac.is_zero() {
169 powi
170 } else {
171 assert!(Helper::is_positive(x), "base should be positive");
172 powi * powf_01(x, frac)
173 }
174}
175
176macro_rules! impl_fixed_pow {
177 ($fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
178 impl<Frac> FixedPowI for $fixed<Frac>
179 where
180 Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
181 {
182 fn powi(self, n: i32) -> Self {
183 if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= 0 {
184 panic!(
185 "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
186 self, n
187 );
188 }
189
190 match n.cmp(&0) {
191 Ordering::Greater => powi_positive(self, n),
192 Ordering::Equal => Self::from_bits(1 << Frac::U32),
193 Ordering::Less => powi_positive(Self::from_bits(1 << Frac::U32) / self, -n),
194 }
195 }
196 }
197
198 impl<Frac> FixedPowF for $fixed<Frac>
199 where
200 Frac: $le_eq + IsLessOrEqual<$le_eq_one, Output = True>,
201 {
202 fn powf(self, n: Self) -> Self {
203 let zero = Self::from_bits(0);
204
205 if !<LeEq<Frac, $le_eq_one>>::BOOL && n <= zero {
206 panic!(
207 "cannot raise `{}` to the power of `{}` because numbers larger than or equal to `1` are not representable",
208 self, n
209 );
210 }
211
212 match n.cmp(&zero) {
213 Ordering::Greater => powf_positive(self, n),
214 Ordering::Equal => Self::from_bits(1 << Frac::U32),
215 Ordering::Less => powf_positive(Self::from_bits(1 << Frac::U32) / self, Helper::neg(n)),
216 }
217 }
218 }
219 };
220}
221
222impl_fixed_pow!(FixedI8, LeEqU8, U6);
223impl_fixed_pow!(FixedI16, LeEqU16, U14);
224impl_fixed_pow!(FixedI32, LeEqU32, U30);
225impl_fixed_pow!(FixedI64, LeEqU64, U62);
226impl_fixed_pow!(FixedI128, LeEqU128, U126);
227
228impl_fixed_pow!(FixedU8, LeEqU8, U7);
229impl_fixed_pow!(FixedU16, LeEqU16, U15);
230impl_fixed_pow!(FixedU32, LeEqU32, U31);
231impl_fixed_pow!(FixedU64, LeEqU64, U63);
232impl_fixed_pow!(FixedU128, LeEqU128, U127);
233
234trait Helper {
235 const NUM_BITS: u32;
236 fn is_positive(self) -> bool;
237 fn is_zero(self) -> bool;
238 fn is_one(self) -> bool;
239 fn one() -> Self;
240 fn neg(self) -> Self;
241}
242
243macro_rules! impl_sign_helper {
244 (signed, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
245 impl<Frac: $le_eq> Helper for $fixed<Frac>
246 where
247 Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
248 {
249 const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
250 fn is_positive(self) -> bool {
251 $fixed::is_positive(self)
252 }
253 fn is_zero(self) -> bool {
254 self.to_bits() == 0
255 }
256 fn is_one(self) -> bool {
257 <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
258 }
259 fn one() -> Self {
260 assert!(
261 <LeEq<Frac, $le_eq_one>>::BOOL,
262 "one should be possible to represent"
263 );
264 Self::from_bits(1 << Frac::U32)
265 }
266 fn neg(self) -> Self {
267 -self
268 }
269 }
270 };
271 (unsigned, $fixed:ident, $le_eq:ident, $le_eq_one:ident) => {
272 impl<Frac: $le_eq> Helper for $fixed<Frac>
273 where
274 Frac: $le_eq + IsLessOrEqual<$le_eq_one>,
275 {
276 const NUM_BITS: u32 = <Self as Fixed>::INT_NBITS + <Self as Fixed>::FRAC_NBITS;
277 fn is_positive(self) -> bool {
278 self != Self::from_bits(0)
279 }
280 fn is_zero(self) -> bool {
281 self.to_bits() == 0
282 }
283 fn is_one(self) -> bool {
284 <LeEq<Frac, $le_eq_one>>::BOOL && self.to_bits() == 1 << Frac::U32
285 }
286 fn one() -> Self {
287 assert!(
288 <LeEq<Frac, $le_eq_one>>::BOOL,
289 "one should be possible to represent"
290 );
291 Self::from_bits(1 << Frac::U32)
292 }
293 fn neg(self) -> Self {
294 panic!("cannot negate an unsigned number")
295 }
296 }
297 };
298}
299
300impl_sign_helper!(signed, FixedI8, LeEqU8, U6);
301impl_sign_helper!(signed, FixedI16, LeEqU16, U14);
302impl_sign_helper!(signed, FixedI32, LeEqU32, U30);
303impl_sign_helper!(signed, FixedI64, LeEqU64, U62);
304impl_sign_helper!(signed, FixedI128, LeEqU128, U126);
305
306impl_sign_helper!(unsigned, FixedU8, LeEqU8, U7);
307impl_sign_helper!(unsigned, FixedU16, LeEqU16, U15);
308impl_sign_helper!(unsigned, FixedU32, LeEqU32, U31);
309impl_sign_helper!(unsigned, FixedU64, LeEqU64, U63);
310impl_sign_helper!(unsigned, FixedU128, LeEqU128, U127);
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 use fixed::types::{I1F31, I32F32, U0F32, U32F32};
317
318 fn powi_positive_naive<T: Fixed>(x: T, n: i32) -> T {
319 assert!(n > 0, "exponent should be positive");
320 let mut acc = x;
321 for _ in 1..n {
322 acc *= x;
323 }
324 acc
325 }
326
327 fn delta<T: Fixed>(a: T, b: T) -> T {
328 Ord::max(a, b) - Ord::min(a, b)
329 }
330
331 #[test]
332 fn test_powi_positive() {
333 let epsilon = I32F32::from_num(0.0001);
334
335 let test_cases = &[
336 (I32F32::from_num(1.0), 42),
337 (I32F32::from_num(0.8), 7),
338 (I32F32::from_num(1.2), 11),
339 (I32F32::from_num(2.6), 16),
340 (I32F32::from_num(-2.2), 4),
341 (I32F32::from_num(-2.2), 5),
342 ];
343
344 for &(x, n) in test_cases {
345 assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
346 }
347
348 let epsilon = U32F32::from_num(0.0001);
349
350 let test_cases = &[
351 (U32F32::from_num(1.0), 42),
352 (U32F32::from_num(0.8), 7),
353 (U32F32::from_num(1.2), 11),
354 (U32F32::from_num(2.6), 16),
355 ];
356
357 for &(x, n) in test_cases {
358 assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
359 }
360 }
361
362 #[test]
363 fn test_powi_positive_sub_one() {
364 let epsilon = I1F31::from_num(0.0001);
365
366 let test_cases = &[
367 (I1F31::from_num(0.5), 3),
368 (I1F31::from_num(0.8), 5),
369 (I1F31::from_num(0.2), 7),
370 (I1F31::from_num(0.6), 9),
371 (I1F31::from_num(-0.6), 3),
372 (I1F31::from_num(-0.6), 4),
373 ];
374
375 for &(x, n) in test_cases {
376 assert!((powi_positive_naive(x, n) - x.powi(n)).abs() < epsilon);
377 }
378
379 let epsilon = U0F32::from_num(0.0001);
380
381 let test_cases = &[
382 (U0F32::from_num(0.5), 3),
383 (U0F32::from_num(0.8), 5),
384 (U0F32::from_num(0.2), 7),
385 (U0F32::from_num(0.6), 9),
386 ];
387
388 for &(x, n) in test_cases {
389 assert!(delta(powi_positive_naive(x, n), x.powi(n)) < epsilon);
390 }
391 }
392
393 #[test]
394 fn test_powi_non_positive() {
395 let epsilon = I32F32::from_num(0.0001);
396
397 let test_cases = &[
398 (I32F32::from_num(1.0), -17),
399 (I32F32::from_num(0.8), -7),
400 (I32F32::from_num(1.2), -9),
401 (I32F32::from_num(2.6), -3),
402 ];
403
404 for &(x, n) in test_cases {
405 assert!((powi_positive_naive(I32F32::from_num(1) / x, -n) - x.powi(n)).abs() < epsilon);
406 }
407
408 assert_eq!(I32F32::from_num(1), I32F32::from_num(42).powi(0));
409 assert_eq!(I32F32::from_num(1), I32F32::from_num(-42).powi(0));
410 assert_eq!(I32F32::from_num(1), I32F32::from_num(0).powi(0));
411 }
412
413 fn powf_float<T: Fixed>(x: T, n: T) -> T {
414 let x: f64 = x.to_num();
415 let n: f64 = n.to_num();
416 T::from_num(x.powf(n))
417 }
418
419 #[test]
420 fn test_powf() {
421 let epsilon = I32F32::from_num(0.0001);
422
423 let test_cases = &[
424 (I32F32::from_num(1.0), I32F32::from_num(7.2)),
425 (I32F32::from_num(0.8), I32F32::from_num(-4.5)),
426 (I32F32::from_num(1.2), I32F32::from_num(5.0)),
427 (I32F32::from_num(2.6), I32F32::from_num(-6.7)),
428 (I32F32::from_num(-1.2), I32F32::from_num(4.0)),
429 (I32F32::from_num(-1.2), I32F32::from_num(-3.0)),
430 ];
431
432 for &(x, n) in test_cases {
433 assert!((powf_float(x, n) - x.powf(n)).abs() < epsilon);
434 }
435
436 let epsilon = U32F32::from_num(0.0001);
437
438 let test_cases = &[
439 (U32F32::from_num(1.0), U32F32::from_num(7.2)),
440 (U32F32::from_num(0.8), U32F32::from_num(4.5)),
441 (U32F32::from_num(1.2), U32F32::from_num(5.0)),
442 (U32F32::from_num(2.6), U32F32::from_num(6.7)),
443 ];
444
445 for &(x, n) in test_cases {
446 assert!(delta(powf_float(x, n), x.powf(n)) < epsilon);
447 }
448 }
449}