use core::ops::{Add, Div, Mul, Neg, Sub};
#[repr(transparent)]
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
pub struct DetF32(f32);
pub unsafe trait Deterministic {}
unsafe impl Deterministic for DetF32 {}
impl DetF32 {
#[inline]
pub fn from_bits(bits: u32) -> Self {
Self(f32::from_bits(bits))
}
#[inline]
pub fn to_bits(self) -> u32 {
self.0.to_bits()
}
#[inline]
pub fn from_f32(value: f32) -> Self {
Self(value)
}
#[inline]
pub fn to_f32(self) -> f32 {
self.0
}
pub const ZERO: Self = Self(0.0);
pub const ONE: Self = Self(1.0);
pub const TWO: Self = Self(2.0);
pub const HALF: Self = Self(0.5);
pub const PI: Self = Self(std::f32::consts::PI);
#[inline(never)] pub fn abs(self) -> Self {
Self::from_bits(self.to_bits() & 0x7fffffff)
}
#[inline(never)]
pub fn sqrt(self) -> Self {
if self.to_f32() <= 0.0 {
return Self::ZERO;
}
let mut guess = self;
for _ in 0..4 {
guess = Self::HALF * (guess + self / guess);
}
guess
}
#[inline(never)]
pub fn rsqrt(self) -> Self {
Self::ONE / self.sqrt()
}
}
impl Add for DetF32 {
type Output = Self;
#[inline(never)]
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}
impl Sub for DetF32 {
type Output = Self;
#[inline(never)]
fn sub(self, rhs: Self) -> Self {
Self(self.0 - rhs.0)
}
}
impl Mul for DetF32 {
type Output = Self;
#[inline(never)]
fn mul(self, rhs: Self) -> Self {
Self(self.0 * rhs.0)
}
}
impl Div for DetF32 {
type Output = Self;
#[inline(never)]
fn div(self, rhs: Self) -> Self {
Self(self.0 / rhs.0)
}
}
impl Neg for DetF32 {
type Output = Self;
#[inline]
fn neg(self) -> Self {
Self(-self.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_ops_exact_bits() {
let a = DetF32::from_bits(0x3f800000); let b = DetF32::from_bits(0x40000000);
assert_eq!((a + b).to_bits(), 0x40400000); assert_eq!((a * b).to_bits(), 0x40000000); assert_eq!((b - a).to_bits(), 0x3f800000); assert_eq!((b / b).to_bits(), 0x3f800000); }
#[test]
fn test_sqrt_determinism() {
let x = DetF32::from_bits(0x40000000); let sqrt_2 = x.sqrt();
assert_eq!(sqrt_2.to_bits(), sqrt_2.to_bits());
let expected = DetF32::from_f32(std::f32::consts::SQRT_2);
let error = (sqrt_2 - expected).abs();
assert!(error < DetF32::from_f32(1e-6));
}
}