use rand::{
Rng,
distr::{Distribution, StandardUniform},
};
trait Layout {
type Bits;
const SIGN_MASK: Self::Bits;
const EXPONENT_MASK: Self::Bits;
const EXPONENT_ZERO: Self::Bits;
const MANTISSA_MASK: Self::Bits;
}
impl Layout for half::f16 {
type Bits = u16;
const SIGN_MASK: u16 = 0x8000;
const EXPONENT_MASK: u16 = 0x7C00;
const EXPONENT_ZERO: u16 = 0x3C00;
const MANTISSA_MASK: u16 = 0x03FF;
}
impl Layout for f32 {
type Bits = u32;
const SIGN_MASK: u32 = 0x8000_0000;
const EXPONENT_MASK: u32 = 0x7F80_0000;
const EXPONENT_ZERO: u32 = 0x3F80_0000;
const MANTISSA_MASK: u32 = 0x007F_FFFF;
}
pub struct Finite;
macro_rules! finite {
($T:ty, $bits:ty) => {
impl Distribution<$T> for Finite {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $T {
let mut value: $bits = StandardUniform {}.sample(rng);
let weight = value % 100;
let (mask, allow_edge_exponent, allow_zero_mantissa) = if weight < 90 {
(<$T>::EXPONENT_MASK | <$T>::MANTISSA_MASK, false, true)
} else if weight < 95 {
(<$T>::MANTISSA_MASK, true, false)
} else {
(0, true, true)
};
value &= <$T>::SIGN_MASK | mask;
let exponent = value & <$T>::EXPONENT_MASK;
if !allow_edge_exponent && (exponent == 0 || exponent == <$T>::EXPONENT_MASK) {
value &= !<$T>::EXPONENT_MASK;
value |= <$T>::EXPONENT_ZERO;
}
if !allow_zero_mantissa && (value & <$T>::MANTISSA_MASK == 0) {
value |= 1;
}
<$T>::from_bits(value)
}
}
};
}
finite!(half::f16, u16);
finite!(f32, u32);
#[cfg(not(miri))]
#[cfg(test)]
mod tests {
use rand::{SeedableRng, distr::Distribution, rngs::StdRng};
use super::*;
#[derive(Debug, Default)]
struct Kinds {
normal: i64,
subnormal: i64,
zero: i64,
}
impl Kinds {
fn sum(&self) -> i64 {
self.normal + self.subnormal + self.zero
}
}
#[derive(Debug, Default)]
struct Counts {
positive: Kinds,
negative: Kinds,
}
impl Counts {
fn sum_accross(&self) -> Kinds {
Kinds {
normal: self.positive.normal + self.negative.normal,
subnormal: self.positive.subnormal + self.negative.subnormal,
zero: self.positive.zero + self.negative.zero,
}
}
}
trait TestDistribution {
fn test_distribution(num_trials: usize, seed: u64) -> Counts;
}
impl TestDistribution for f32 {
fn test_distribution(num_trials: usize, seed: u64) -> Counts {
let mut counts = Counts::default();
let mut rng = StdRng::seed_from_u64(seed);
for _ in 0..num_trials {
let v: f32 = (Finite).sample(&mut rng);
assert!(v.is_finite());
if v.is_sign_positive() {
if v.is_subnormal() {
counts.positive.subnormal += 1;
} else if v == 0.0 {
counts.positive.zero += 1;
} else {
counts.positive.normal += 1;
}
} else if v.is_subnormal() {
counts.negative.subnormal += 1;
} else if v == 0.0 {
counts.negative.zero += 1;
} else {
counts.negative.normal += 1;
}
}
counts
}
}
impl TestDistribution for half::f16 {
fn test_distribution(num_trials: usize, seed: u64) -> Counts {
let mut counts = Counts::default();
let mut rng = StdRng::seed_from_u64(seed);
fn is_subnormal(x: half::f16) -> bool {
let bits = x.to_bits();
(bits & half::f16::EXPONENT_MASK) == 0 && (bits & half::f16::MANTISSA_MASK) != 0
}
for _ in 0..num_trials {
let v: half::f16 = (Finite).sample(&mut rng);
assert!(v.is_finite());
if v.is_sign_positive() {
if is_subnormal(v) {
counts.positive.subnormal += 1;
} else if v == half::f16::default() {
counts.positive.zero += 1;
} else {
counts.positive.normal += 1;
}
} else if is_subnormal(v) {
counts.negative.subnormal += 1;
} else if v == half::f16::default() {
counts.negative.zero += 1;
} else {
counts.negative.normal += 1;
}
}
counts
}
}
fn test_end_to_end<T>(seed: u64)
where
T: TestDistribution,
{
let normal_weight = 90;
let subnormal_weight = 5;
let zero_weight = 5;
let total_weight = normal_weight + subnormal_weight + zero_weight;
let num_trials: i64 = 1_000_000;
let margin = num_trials / 500;
let counts = T::test_distribution(num_trials as usize, seed);
let positive_count = counts.positive.sum();
let negative_count = counts.negative.sum();
println!("Counts = {:?}", counts);
assert!((positive_count - num_trials / 2).abs() < margin);
assert!((negative_count - num_trials / 2).abs() < margin);
assert!((counts.positive.normal - counts.negative.normal).abs() < margin);
assert!((counts.positive.subnormal - counts.negative.subnormal).abs() < margin);
assert!((counts.positive.zero - counts.negative.zero).abs() < margin);
let kinds = counts.sum_accross();
assert!((kinds.normal - num_trials * normal_weight / total_weight).abs() < margin);
assert!((kinds.subnormal - num_trials * subnormal_weight / total_weight).abs() < margin);
assert!((kinds.zero - num_trials * zero_weight / total_weight).abs() < margin);
}
#[test]
fn test_f16_distribution() {
test_end_to_end::<half::f16>(0xb1e3a2096f17ec6d);
}
#[test]
fn test_f32_distribution() {
test_end_to_end::<f32>(0x868602b120b17347);
}
}