Skip to main content

unsigned_float/
pow.rs

1#[cfg(feature = "f128")]
2use crate::Uf64;
3use crate::{Uf8, Uf8E5M3, Uf16, Uf16E6M10, Uf32, dispatch};
4
5/// Extension trait for raising native floats to unsigned-float exponents.
6///
7/// This mirrors the shape of native `powf` methods while keeping the crate
8/// dependency-free.
9pub trait PowUf<Rhs> {
10    /// The result type of exponentiation.
11    type Output;
12
13    /// Raises `self` to the unsigned-float exponent `rhs`.
14    fn powuf(self, rhs: Rhs) -> Self::Output;
15}
16
17/// Extension trait for evaluating `(1 - self)^rhs`.
18///
19/// This is useful for probability and interpolation kernels where `self` is
20/// already the complement input `u`, not the power base.
21pub trait Pow1mUf<Rhs> {
22    /// The result type of complement exponentiation.
23    type Output;
24
25    /// Raises `1 - self` to the unsigned-float exponent `rhs`.
26    fn pow1muf(self, rhs: Rhs) -> Self::Output;
27}
28
29macro_rules! impl_powuf_f32 {
30    ($($rhs:ty),* $(,)?) => {
31        $(
32            impl PowUf<$rhs> for f32 {
33                type Output = f32;
34
35                #[inline]
36                fn powuf(self, rhs: $rhs) -> Self::Output {
37                    powuf_f32(self, rhs.to_f32())
38                }
39            }
40        )*
41    };
42}
43
44macro_rules! impl_powuf_f64 {
45    ($($rhs:ty),* $(,)?) => {
46        $(
47            impl PowUf<$rhs> for f64 {
48                type Output = f64;
49
50                #[inline]
51                fn powuf(self, rhs: $rhs) -> Self::Output {
52                    powuf_f64(self, rhs.to_f64())
53                }
54            }
55        )*
56    };
57}
58
59macro_rules! impl_pow1muf_f32 {
60    ($($rhs:ty),* $(,)?) => {
61        $(
62            impl Pow1mUf<$rhs> for f32 {
63                type Output = f32;
64
65                #[inline]
66                fn pow1muf(self, rhs: $rhs) -> Self::Output {
67                    pow1muf_f32(self, rhs.to_f32())
68                }
69            }
70        )*
71    };
72}
73
74macro_rules! impl_pow1muf_f64 {
75    ($($rhs:ty),* $(,)?) => {
76        $(
77            impl Pow1mUf<$rhs> for f64 {
78                type Output = f64;
79
80                #[inline]
81                fn pow1muf(self, rhs: $rhs) -> Self::Output {
82                    pow1muf_f64(self, rhs.to_f64())
83                }
84            }
85        )*
86    };
87}
88
89impl_powuf_f32!(Uf8, Uf8E5M3, Uf16, Uf16E6M10, Uf32);
90impl_powuf_f64!(Uf8, Uf8E5M3, Uf16, Uf16E6M10, Uf32);
91impl_pow1muf_f32!(Uf8, Uf8E5M3, Uf16, Uf16E6M10, Uf32);
92impl_pow1muf_f64!(Uf8, Uf8E5M3, Uf16, Uf16E6M10, Uf32);
93
94impl PowUf<Uf8> for Uf8 {
95    type Output = Uf8;
96
97    #[inline]
98    fn powuf(self, rhs: Uf8) -> Self::Output {
99        Uf8::from_bits(dispatch::pow_uf8(self.to_bits(), rhs.to_bits()))
100    }
101}
102
103impl PowUf<Uf8E5M3> for Uf8 {
104    type Output = Uf8;
105
106    #[inline]
107    fn powuf(self, rhs: Uf8E5M3) -> Self::Output {
108        Uf8::from_bits(dispatch::pow_uf8_by_uf8_e5m3(self.to_bits(), rhs.to_bits()))
109    }
110}
111
112impl PowUf<Uf8> for Uf8E5M3 {
113    type Output = Uf8E5M3;
114
115    #[inline]
116    fn powuf(self, rhs: Uf8) -> Self::Output {
117        Uf8E5M3::from_bits(dispatch::pow_uf8_e5m3_by_uf8(self.to_bits(), rhs.to_bits()))
118    }
119}
120
121impl PowUf<Uf8E5M3> for Uf8E5M3 {
122    type Output = Uf8E5M3;
123
124    #[inline]
125    fn powuf(self, rhs: Uf8E5M3) -> Self::Output {
126        Uf8E5M3::from_bits(dispatch::pow_uf8_e5m3(self.to_bits(), rhs.to_bits()))
127    }
128}
129
130impl Pow1mUf<Uf8> for Uf8 {
131    type Output = Uf8;
132
133    #[inline]
134    fn pow1muf(self, rhs: Uf8) -> Self::Output {
135        Uf8::from_bits(dispatch::pow1m_uf8(self.to_bits(), rhs.to_bits()))
136    }
137}
138
139impl Pow1mUf<Uf8E5M3> for Uf8 {
140    type Output = Uf8;
141
142    #[inline]
143    fn pow1muf(self, rhs: Uf8E5M3) -> Self::Output {
144        Uf8::from_bits(dispatch::pow1m_uf8_by_uf8_e5m3(
145            self.to_bits(),
146            rhs.to_bits(),
147        ))
148    }
149}
150
151impl Pow1mUf<Uf8> for Uf8E5M3 {
152    type Output = Uf8E5M3;
153
154    #[inline]
155    fn pow1muf(self, rhs: Uf8) -> Self::Output {
156        Uf8E5M3::from_bits(dispatch::pow1m_uf8_e5m3_by_uf8(
157            self.to_bits(),
158            rhs.to_bits(),
159        ))
160    }
161}
162
163impl Pow1mUf<Uf8E5M3> for Uf8E5M3 {
164    type Output = Uf8E5M3;
165
166    #[inline]
167    fn pow1muf(self, rhs: Uf8E5M3) -> Self::Output {
168        Uf8E5M3::from_bits(dispatch::pow1m_uf8_e5m3(self.to_bits(), rhs.to_bits()))
169    }
170}
171
172#[cfg(feature = "f128")]
173impl PowUf<Uf64> for f32 {
174    type Output = f32;
175
176    #[inline]
177    fn powuf(self, rhs: Uf64) -> Self::Output {
178        powuf_f32(self, rhs.to_f32())
179    }
180}
181
182#[cfg(feature = "f128")]
183impl PowUf<Uf64> for f64 {
184    type Output = f64;
185
186    #[inline]
187    fn powuf(self, rhs: Uf64) -> Self::Output {
188        powuf_f64(self, rhs.to_f64())
189    }
190}
191
192#[cfg(feature = "f128")]
193impl Pow1mUf<Uf64> for f32 {
194    type Output = f32;
195
196    #[inline]
197    fn pow1muf(self, rhs: Uf64) -> Self::Output {
198        pow1muf_f32(self, rhs.to_f32())
199    }
200}
201
202#[cfg(feature = "f128")]
203impl Pow1mUf<Uf64> for f64 {
204    type Output = f64;
205
206    #[inline]
207    fn pow1muf(self, rhs: Uf64) -> Self::Output {
208        pow1muf_f64(self, rhs.to_f64())
209    }
210}
211
212#[inline]
213fn powuf_f32(base: f32, exponent: f32) -> f32 {
214    if exponent == 0.0 {
215        1.0
216    } else if exponent == 1.0 {
217        base
218    } else if exponent == 0.5 {
219        libm::sqrtf(base)
220    } else if let Some(integer) = small_integer_exponent_f32(exponent) {
221        powi_u32_f32(base, integer)
222    } else {
223        libm::powf(base, exponent)
224    }
225}
226
227#[inline]
228fn powuf_f64(base: f64, exponent: f64) -> f64 {
229    if exponent == 0.0 {
230        1.0
231    } else if exponent == 1.0 {
232        base
233    } else if exponent == 0.5 {
234        libm::sqrt(base)
235    } else if let Some(integer) = small_integer_exponent_f64(exponent) {
236        powi_u32_f64(base, integer)
237    } else {
238        libm::pow(base, exponent)
239    }
240}
241
242#[inline]
243fn pow1muf_f32(u: f32, exponent: f32) -> f32 {
244    if exponent == 0.0 || u == 0.0 {
245        1.0
246    } else if u == 1.0 {
247        0.0
248    } else if !(0.0..=1.0).contains(&u) {
249        powuf_f32(1.0 - u, exponent)
250    } else if exponent == 0.5 {
251        libm::sqrtf(1.0 - u)
252    } else {
253        libm::expm1f(exponent * libm::log1pf(-u)) + 1.0
254    }
255}
256
257#[inline]
258fn pow1muf_f64(u: f64, exponent: f64) -> f64 {
259    if exponent == 0.0 || u == 0.0 {
260        1.0
261    } else if u == 1.0 {
262        0.0
263    } else if !(0.0..=1.0).contains(&u) {
264        powuf_f64(1.0 - u, exponent)
265    } else if exponent == 0.5 {
266        libm::sqrt(1.0 - u)
267    } else {
268        libm::expm1(exponent * libm::log1p(-u)) + 1.0
269    }
270}
271
272#[inline]
273fn small_integer_exponent_f32(exponent: f32) -> Option<u32> {
274    if !(2.0..=32.0).contains(&exponent) {
275        return None;
276    }
277
278    let integer = exponent as u32;
279    if integer as f32 == exponent {
280        Some(integer)
281    } else {
282        None
283    }
284}
285
286#[inline]
287fn small_integer_exponent_f64(exponent: f64) -> Option<u32> {
288    if !(2.0..=32.0).contains(&exponent) {
289        return None;
290    }
291
292    let integer = exponent as u32;
293    if integer as f64 == exponent {
294        Some(integer)
295    } else {
296        None
297    }
298}
299
300#[inline]
301fn powi_u32_f32(mut base: f32, mut exponent: u32) -> f32 {
302    let mut acc = 1.0;
303
304    while exponent != 0 {
305        if exponent & 1 == 1 {
306            acc *= base;
307        }
308
309        exponent >>= 1;
310        if exponent != 0 {
311            base *= base;
312        }
313    }
314
315    acc
316}
317
318#[inline]
319fn powi_u32_f64(mut base: f64, mut exponent: u32) -> f64 {
320    let mut acc = 1.0;
321
322    while exponent != 0 {
323        if exponent & 1 == 1 {
324            acc *= base;
325        }
326
327        exponent >>= 1;
328        if exponent != 0 {
329            base *= base;
330        }
331    }
332
333    acc
334}