float8 0.7.0

8-bit floating point types for Rust
Documentation
use crate::{F8E4M3, F8E5M2};

use rand::{distr::Distribution, Rng};
use rand_distr::uniform::UniformFloat;

macro_rules! impl_distribution_via_f32 {
    ($Ty:ty, $Distr:ty) => {
        impl Distribution<$Ty> for $Distr {
            fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $Ty {
                <$Ty>::from_f32(<Self as Distribution<f32>>::sample(self, rng))
            }
        }
    };
}

impl_distribution_via_f32!(F8E4M3, rand_distr::StandardUniform);
impl_distribution_via_f32!(F8E4M3, rand_distr::StandardNormal);
impl_distribution_via_f32!(F8E4M3, rand_distr::Exp1);
impl_distribution_via_f32!(F8E4M3, rand_distr::Open01);
impl_distribution_via_f32!(F8E4M3, rand_distr::OpenClosed01);

impl_distribution_via_f32!(F8E5M2, rand_distr::StandardUniform);
impl_distribution_via_f32!(F8E5M2, rand_distr::StandardNormal);
impl_distribution_via_f32!(F8E5M2, rand_distr::Exp1);
impl_distribution_via_f32!(F8E5M2, rand_distr::Open01);
impl_distribution_via_f32!(F8E5M2, rand_distr::OpenClosed01);

#[derive(Debug, Clone, Copy)]
pub struct Float16Sampler(UniformFloat<f32>);

impl rand_distr::uniform::SampleUniform for F8E5M2 {
    type Sampler = Float16Sampler;
}

impl rand_distr::uniform::UniformSampler for Float16Sampler {
    type X = F8E5M2;
    fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
    where
        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
    {
        Ok(Self(UniformFloat::new(
            low.borrow().to_f32(),
            high.borrow().to_f32(),
        )?))
    }
    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
    where
        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
    {
        Ok(Self(UniformFloat::new_inclusive(
            low.borrow().to_f32(),
            high.borrow().to_f32(),
        )?))
    }
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
        F8E5M2::from_f32(self.0.sample(rng))
    }
}

#[derive(Debug, Clone, Copy)]
pub struct BFloat16Sampler(UniformFloat<f32>);

impl rand_distr::uniform::SampleUniform for F8E4M3 {
    type Sampler = BFloat16Sampler;
}

impl rand_distr::uniform::UniformSampler for BFloat16Sampler {
    type X = F8E4M3;
    fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
    where
        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
    {
        Ok(Self(UniformFloat::new(
            low.borrow().to_f32(),
            high.borrow().to_f32(),
        )?))
    }
    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
    where
        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
    {
        Ok(Self(UniformFloat::new_inclusive(
            low.borrow().to_f32(),
            high.borrow().to_f32(),
        )?))
    }
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
        F8E4M3::from_f32(self.0.sample(rng))
    }
}