mod kernel;
use core::num::FpCategory;
pub const EXP_SHIFT: u32 = f32::MANTISSA_DIGITS - 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Magnitude {
Nan,
Infinite,
Zero,
Normalized(i32),
}
#[allow(clippy::cast_possible_wrap)]
fn normalize(x: f32) -> (bool, Magnitude) {
let sign = x.is_sign_negative();
let magnitude = x.abs().to_bits() as i32;
match x.classify() {
FpCategory::Nan => (sign, Magnitude::Nan),
FpCategory::Infinite => (sign, Magnitude::Infinite),
FpCategory::Zero => (sign, Magnitude::Zero),
FpCategory::Normal => (sign, Magnitude::Normalized(magnitude)),
FpCategory::Subnormal => {
let shift = magnitude.leading_zeros() as i32 - 8;
(
sign,
Magnitude::Normalized((magnitude << shift) - (shift << EXP_SHIFT)),
)
}
}
}
#[must_use]
pub fn next_up(x: f32) -> f32 {
if x.is_nan() || x == f32::INFINITY {
x
} else if x == 0.0 {
f32::from_bits(1)
} else if x.is_sign_negative() {
f32::from_bits(x.to_bits() - 1)
} else {
f32::from_bits(x.to_bits() + 1)
}
}
#[must_use]
pub fn next_down(x: f32) -> f32 {
if x.is_nan() || x == f32::NEG_INFINITY {
x
} else if x == 0.0 {
f32::from_bits(0x8000_0001)
} else if x.is_sign_negative() {
f32::from_bits(x.to_bits() + 1)
} else {
f32::from_bits(x.to_bits() - 1)
}
}
#[must_use]
pub fn cbrt(x: f32) -> f32 {
let (sign, Magnitude::Normalized(magnitude)) = normalize(x) else {
return x;
};
#[allow(clippy::cast_sign_loss)]
let magnitude = (0x2A51_2CE3 + magnitude / 3) as u32;
let iter = |y: f32| 3.0f32.recip().mul_add(x / (y * y) - y, y);
iter(iter(iter(f32::from_bits(
u32::from(sign) << 31 | magnitude,
))))
}
#[must_use]
pub fn exp(x: f32) -> f32 {
use core::f32::consts::LN_2;
use core::f64::consts;
#[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap)]
if x < (f32::MIN_EXP - f32::MANTISSA_DIGITS as i32 - 1) as f32 * LN_2 {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
if x > f32::MAX_EXP as f32 * LN_2 {
return f32::INFINITY;
}
let x = f64::from(x);
let n = (x * consts::LOG2_E).round_ties_even();
let x = n.mul_add(-consts::LN_2, x);
let y = kernel::exp(x).mul_add(x, 1.0);
#[allow(clippy::cast_possible_truncation)]
return kernel::fast_ldexp(y, n as i64) as f32;
}
#[must_use]
pub fn exp2(x: f32) -> f32 {
const P: [f32; 6] = [
6.931_472e-1,
2.402_265e-1,
5.550_357e-2,
9.618_031e-3,
1.339_086_7e-3,
1.546_973_5e-4,
];
#[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap)]
if x < (f32::MIN_EXP - f32::MANTISSA_DIGITS as i32 - 1) as f32 {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
if x > f32::MAX_EXP as f32 {
return f32::INFINITY;
}
let n = x.round_ties_even();
let x = x - n;
let x = P[5]
.mul_add(x, P[4])
.mul_add(x, P[3])
.mul_add(x, P[2])
.mul_add(x, P[1])
.mul_add(x, P[0])
.mul_add(x, 1.0);
#[allow(clippy::cast_possible_truncation)]
return kernel::fast_ldexp(f64::from(x), n as i64) as f32;
}
#[must_use]
pub fn exp_m1(x: f32) -> f32 {
use core::f32::consts::LN_2;
use core::f64::consts;
#[allow(clippy::cast_precision_loss)]
if x < f32::MANTISSA_DIGITS as f32 * -LN_2 {
return -1.0;
}
#[allow(clippy::cast_precision_loss)]
if x > f32::MAX_EXP as f32 * LN_2 {
return f32::INFINITY;
}
let x = f64::from(x);
let n = (x * consts::LOG2_E).round_ties_even() + 0.0;
let x = n.mul_add(-consts::LN_2, x);
let y = kernel::exp(x);
if n == 0.0 {
#[allow(clippy::cast_possible_truncation)]
return (x * y) as f32;
}
#[allow(clippy::cast_possible_truncation)]
return (kernel::fast_ldexp(y.mul_add(x, 1.0), n as i64) - 1.0) as f32;
}
#[must_use]
pub fn ldexp(x: f32, n: i32) -> f32 {
const MIN_EXP: i32 = f64::MIN_EXP - 1;
const MAX_EXP: i32 = f64::MAX_EXP;
let coefficient = match n {
..MIN_EXP => 0.5 * f64::MIN_POSITIVE,
#[allow(clippy::cast_sign_loss)]
n @ MIN_EXP..MAX_EXP => f64::from_bits(((MAX_EXP - 1 + n) as u64) << crate::f64::EXP_SHIFT),
MAX_EXP.. => f64::MAX,
};
#[allow(clippy::cast_possible_truncation)]
return (f64::from(x) * coefficient) as f32;
}
#[must_use]
pub fn frexp(x: f32) -> (f32, i32) {
let (sign, Magnitude::Normalized(magnitude)) = normalize(x) else {
return (x, 0);
};
let mask = f32::MIN_POSITIVE.to_bits() - 1;
#[allow(clippy::cast_sign_loss)]
let significand = magnitude as u32 & mask | 0.5f32.to_bits();
(
f32::from_bits(u32::from(sign) << 31 | significand),
f32::MIN_EXP - 1 + (magnitude >> EXP_SHIFT),
)
}