1use super::EstimatedLog2;
2
3#[cfg(not(feature = "std"))]
5const LOG2_TAB: [u8; 128] = [
6 0x00, 0x02, 0x05, 0x08, 0x0b, 0x0e, 0x10, 0x13, 0x16, 0x19, 0x1b, 0x1e, 0x21, 0x23, 0x26, 0x28,
7 0x2b, 0x2e, 0x30, 0x33, 0x35, 0x38, 0x3a, 0x3d, 0x3f, 0x41, 0x44, 0x46, 0x49, 0x4b, 0x4d, 0x50,
8 0x52, 0x54, 0x57, 0x59, 0x5b, 0x5d, 0x60, 0x62, 0x64, 0x66, 0x68, 0x6a, 0x6d, 0x6f, 0x71, 0x73,
9 0x75, 0x77, 0x79, 0x7b, 0x7d, 0x7f, 0x81, 0x84, 0x86, 0x88, 0x8a, 0x8c, 0x8d, 0x8f, 0x91, 0x93,
10 0x95, 0x97, 0x99, 0x9b, 0x9d, 0x9f, 0xa1, 0xa2, 0xa4, 0xa6, 0xa8, 0xaa, 0xac, 0xad, 0xaf, 0xb1,
11 0xb3, 0xb5, 0xb6, 0xb8, 0xba, 0xbc, 0xbd, 0xbf, 0xc1, 0xc2, 0xc4, 0xc6, 0xc8, 0xc9, 0xcb, 0xcd,
12 0xce, 0xd0, 0xd1, 0xd3, 0xd5, 0xd6, 0xd8, 0xda, 0xdb, 0xdd, 0xde, 0xe0, 0xe1, 0xe3, 0xe5, 0xe6,
13 0xe8, 0xe9, 0xeb, 0xec, 0xee, 0xef, 0xf1, 0xf2, 0xf4, 0xf5, 0xf7, 0xf8, 0xfa, 0xfb, 0xfd, 0xfe,
14];
15
16#[cfg(not(feature = "std"))]
19const fn log2_fp8(n: u16) -> u16 {
20 debug_assert!(n > 0xff); let nbits = (u16::BITS - n.leading_zeros()) as u16;
23 if n < 0x200 {
24 let lookup = LOG2_TAB[(n >> 1) as usize - 0x80];
26 let est = lookup as u16 + (7 + 1) * 256;
27 est + (n < 354 && n & 1 > 0) as u16
28 } else if n < (0x4000 + 0x80) {
29 let shift = nbits - 8;
31 let mask = n >> (shift - 2);
32 let lookup = LOG2_TAB[(mask >> 2) as usize - 0x80];
33 let est = lookup as u16 + (7 + shift) * 256;
34
35 est + (mask & 3 == 3) as u16
37 } else {
38 let shift = nbits - 8;
40 let mask = n >> (shift - 7);
41 let top_est = LOG2_TAB[(mask >> 7) as usize - 0x80];
42 let est = top_est as u16 + (7 + shift) * 256;
43
44 est + (mask & 127 >= 80) as u16
46 }
47}
48
49#[cfg(not(feature = "std"))]
57const fn ceil_log2_fp8(n: u16) -> u16 {
58 debug_assert!(n > 0xff); debug_assert!(!n.is_power_of_two());
60
61 let nbits = (u16::BITS - n.leading_zeros()) as u16;
62 if n < 0x80 {
63 let shift = 8 - nbits;
65 let top_est = LOG2_TAB[(n << shift) as usize - 0x80];
66 top_est as u16 + (7 - shift) * 256 + 1
67 } else if n < 0x200 {
68 let shift = nbits - 8;
70 let top_est = LOG2_TAB[(n >> shift) as usize - 0x80];
71 let est = top_est as u16 + (7 + shift) * 256 + 1;
72
73 if n > 0x100 && n & 1 == 1 {
74 est + 2
75 } else {
76 est
77 }
78 } else {
79 let shift = nbits - 8;
81 let mask10 = n >> (shift - 2);
82 let mask8 = mask10 >> 2;
83 if mask8 == 255 {
84 0x100 + (7 + shift) * 256
85 } else {
86 let top_est = LOG2_TAB[mask8 as usize + 1 - 0x80];
88 let est = top_est as u16 + (7 + shift) * 256 + 1;
89 est - (mask10 & 3 == 0) as u16
90 }
91 }
92}
93
94#[inline]
98pub fn next_up(f: f32) -> f32 {
99 assert!(!f.is_nan() && !f.is_infinite());
100
101 const TINY_BITS: u32 = 0x1; const CLEAR_SIGN_MASK: u32 = 0x7fff_ffff;
103
104 let bits = f.to_bits();
105 let abs = bits & CLEAR_SIGN_MASK;
106 let next_bits = if abs == 0 {
107 TINY_BITS
108 } else if bits == abs {
109 bits + 1
110 } else {
111 bits - 1
112 };
113 f32::from_bits(next_bits)
114}
115
116#[inline]
120pub fn next_down(f: f32) -> f32 {
121 assert!(!f.is_nan() && !f.is_infinite());
122
123 const NEG_TINY_BITS: u32 = 0x8000_0001; const CLEAR_SIGN_MASK: u32 = 0x7fff_ffff;
125
126 let bits = f.to_bits();
127 let abs = bits & CLEAR_SIGN_MASK;
128 let next_bits = if abs == 0 {
129 NEG_TINY_BITS
130 } else if bits == abs {
131 bits - 1
132 } else {
133 bits + 1
134 };
135 f32::from_bits(next_bits)
136}
137
138#[cfg(not(feature = "std"))]
139impl EstimatedLog2 for u8 {
140 #[inline]
141 fn log2_bounds(&self) -> (f32, f32) {
142 match *self {
143 0 => (f32::NEG_INFINITY, f32::NEG_INFINITY),
144 1 => (0., 0.),
145 i if i.is_power_of_two() => {
146 let log = self.trailing_zeros() as f32;
147 (log, log)
148 }
149 3 => (1.5849625, 1.5849626),
150 i if i < 16 => {
151 let pow = (i as u16).pow(4);
152 let lb = log2_fp8(pow) as f32 / 256.0;
153 let ub = ceil_log2_fp8(pow) as f32 / 256.0;
154 (lb / 4., ub / 4.)
155 }
156 i => {
157 let pow = (i as u16).pow(2);
158 let lb = log2_fp8(pow) as f32 / 256.0;
159 let ub = ceil_log2_fp8(pow) as f32 / 256.0;
160 (lb / 2., ub / 2.)
161 }
162 }
163 }
164}
165
166#[cfg(not(feature = "std"))]
167impl EstimatedLog2 for u16 {
168 #[inline]
169 fn log2_bounds(&self) -> (f32, f32) {
170 if *self <= 0xff {
171 return (*self as u8).log2_bounds();
172 } else if self.is_power_of_two() {
173 let log = self.trailing_zeros() as f32;
174 return (log, log);
175 }
176
177 let lb = log2_fp8(*self) as f32 / 256.0;
178 let ub = ceil_log2_fp8(*self) as f32 / 256.0;
179 (lb, ub)
180 }
181}
182
183#[cfg(not(feature = "std"))]
184macro_rules! impl_log2_bounds_for_uint {
185 ($($t:ty)*) => {$(
186 impl EstimatedLog2 for $t {
187 #[inline]
188 fn log2_bounds(&self) -> (f32, f32) {
189 if *self <= 0xff {
190 return (*self as u8).log2_bounds();
191 } else if self.is_power_of_two() {
192 let log = self.trailing_zeros() as f32;
193 return (log, log);
194 }
195
196 let bits = <$t>::BITS - self.leading_zeros();
197 if bits <= u16::BITS {
198 let lb = log2_fp8(*self as u16) as f32 / 256.0;
199 let ub = ceil_log2_fp8(*self as u16) as f32 / 256.0;
200 (lb, ub)
201 } else {
202 let shift = bits - u16::BITS;
203 let hi = (*self >> shift) as u16;
204 let lb = log2_fp8(hi) as f32 / 256.0;
205 let ub = if hi == 1 << (u16::BITS - 1) {
206 (u16::BITS as u16 - 1) * 256 + 1
208 } else {
209 ceil_log2_fp8(hi)
212 };
213 let ub = ub as f32 / 256.0;
214 (next_down(lb + shift as f32), next_up(ub + shift as f32))
215 }
216 }
217 }
218 )*};
219}
220
221#[cfg(not(feature = "std"))]
222impl_log2_bounds_for_uint!(u32 u64 u128 usize);
223
224#[cfg(feature = "std")]
225macro_rules! impl_log2_bounds_for_uint {
226 ($($t:ty)*) => {$(
227 impl EstimatedLog2 for $t {
228 fn log2_bounds(&self) -> (f32, f32) {
229 if *self == 0 {
230 return (f32::NEG_INFINITY, f32::NEG_INFINITY);
231 }
232
233 if self.is_power_of_two() {
234 let log = self.trailing_zeros() as f32;
235 (log, log)
236 } else {
237 let nbits = Self::BITS - self.leading_zeros();
238 if nbits <= 24 {
239 let log = (*self as f32).log2();
241 (next_down(log), next_up(log))
242 } else {
243 let shifted = (self >> (nbits - 24)) as f32;
244 let est_lb = shifted.log2();
245 let est_ub = (shifted + 1.).log2();
246
247 let shift = (nbits - 24) as f32;
248 (next_down(est_lb + shift), next_up(est_ub + shift))
249 }
250 }
251 }
252
253 #[inline]
254 fn log2_est(&self) -> f32 {
255 (*self as f32).log2()
256 }
257 }
258 )*}
259}
260
261#[cfg(feature = "std")]
262impl_log2_bounds_for_uint!(u8 u16 u32 u64 u128 usize);
263
264macro_rules! impl_log2_bounds_for_int {
265 ($($t:ty)*) => {$(
266 impl EstimatedLog2 for $t {
267 fn log2_bounds(&self) -> (f32, f32) {
268 self.unsigned_abs().log2_bounds()
269 }
270 }
271 )*};
272}
273impl_log2_bounds_for_int!(i8 i16 i32 i64 i128 isize);
274
275#[cfg(not(feature = "std"))]
276macro_rules! impl_log2_bounds_for_float {
277 ($($t:ty)*) => {$(
278 impl EstimatedLog2 for $t {
279 fn log2_bounds(&self) -> (f32, f32) {
280 use crate::FloatEncoding;
281 use core::num::FpCategory::*;
282
283 if *self == 0. {
284 (f32::NEG_INFINITY, f32::NEG_INFINITY)
285 } else {
286 match self.decode() {
287 Ok((man, exp)) => {
288 let (est_lb, est_ub) = man.log2_bounds();
289 (est_lb + exp as f32, est_ub + exp as f32)
290 },
291 Err(Nan) => panic!("calling log2 on nans is forbidden!"),
292 Err(Infinite) => (f32::INFINITY, f32::INFINITY),
293 _ => unreachable!()
294 }
295 }
296 }
297 }
298 )*};
299}
300#[cfg(not(feature = "std"))]
301impl_log2_bounds_for_float!(f32 f64);
302
303#[cfg(feature = "std")]
304macro_rules! impl_log2_bounds_for_float {
305 ($($t:ty)*) => {$(
306 impl EstimatedLog2 for $t {
307 #[inline]
308 fn log2_bounds(&self) -> (f32, f32) {
309 assert!(!self.is_nan());
310
311 if *self == 0. {
312 (f32::NEG_INFINITY, f32::NEG_INFINITY)
313 } else if self.is_infinite() {
314 (f32::INFINITY, f32::INFINITY)
315 } else {
316 let log2 = self.abs().log2() as f32;
317 (next_down(log2), next_up(log2))
318 }
319 }
320
321 #[inline]
322 fn log2_est(&self) -> f32 {
323 assert!(!self.is_nan());
324
325 if *self == 0. {
326 f32::NEG_INFINITY
327 } else if self.is_infinite() {
328 f32::INFINITY
329 } else {
330 self.abs().log2() as f32
331 }
332 }
333 }
334 )*};
335}
336
337#[cfg(feature = "std")]
338impl_log2_bounds_for_float!(f32 f64);
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 #[cfg(not(feature = "std"))]
346 fn test_log2_fp8() {
347 assert_eq!(log2_fp8(1234), 2628); assert_eq!(log2_fp8(12345), 3478); assert_eq!(log2_fp8(0x100), 2048); assert_eq!(log2_fp8(0x101), 2049); assert_eq!(log2_fp8(0xff00), 4094); assert_eq!(log2_fp8(0xffff), 4095); assert_eq!(ceil_log2_fp8(1234), 2631); assert_eq!(ceil_log2_fp8(12345), 3480); assert_eq!(ceil_log2_fp8(0x101), 2051); assert_eq!(ceil_log2_fp8(0xff00), 4096); assert_eq!(ceil_log2_fp8(0xffff), 4096); }
360
361 #[test]
362 fn test_log2_bounds() {
363 assert_eq!(0u8.log2_bounds(), (f32::NEG_INFINITY, f32::NEG_INFINITY));
364 assert_eq!(0i8.log2_bounds(), (f32::NEG_INFINITY, f32::NEG_INFINITY));
365 assert_eq!(0f32.log2_bounds(), (f32::NEG_INFINITY, f32::NEG_INFINITY));
366
367 for i in 1..1000u16 {
369 let (lb, ub) = i.log2_bounds();
370 assert!(2f64.powf(lb as f64) <= i as f64);
371 assert!(2f64.powf(ub as f64) >= i as f64);
372 assert_eq!((-(i as i16)).log2_bounds(), (lb, ub));
373
374 let (lb, ub) = (i as f32).log2_bounds();
375 assert!(2f64.powf(lb as f64) <= i as f64);
376 assert!(2f64.powf(ub as f64) >= i as f64);
377
378 let (lb, ub) = (i as f64).log2_bounds();
379 assert!(2f64.powf(lb as f64) <= i as f64);
380 assert!(2f64.powf(ub as f64) >= i as f64);
381 }
382
383 for i in (0x4000..0x400000u32).step_by(0x1001) {
385 let (lb, ub) = i.log2_bounds();
386 assert!(2f64.powf(lb as f64) <= i as f64);
387 assert!(2f64.powf(ub as f64) >= i as f64);
388 }
389
390 let (lb, ub) = 1e20f32.log2_bounds();
391 assert!(2f64.powf(lb as f64) <= 1e20);
392 assert!(2f64.powf(ub as f64) >= 1e20);
393 assert_eq!((-1e20f32).log2_bounds(), (lb, ub));
394
395 let (lb, ub) = 1e40f64.log2_bounds();
396 assert!(2f64.powf(lb as f64) <= 1e40);
397 assert!(2f64.powf(ub as f64) >= 1e40);
398 assert_eq!((-1e40f64).log2_bounds(), (lb, ub));
399 }
400}