use std::fmt::Debug;
use dashu::{
float::round::{
Round,
mode::{Down, Up},
},
integer::{IBig, UBig},
rational::RBig,
};
use super::sample_from_uniform_bytes;
use crate::{error::Fallible, traits::RoundCast};
#[cfg(all(feature = "contrib", test))]
mod test;
mod canonical;
pub use canonical::CanonicalRV;
pub trait InverseCDF: Sized {
type Edge: PartialOrd + Debug;
fn inverse_cdf<R: ODPRound>(&self, uniform: RBig, refinements: usize) -> Option<Self::Edge>;
}
pub struct PartialSample<D: InverseCDF> {
randomness: UBig,
refinements: usize,
pub distribution: D,
}
impl<D: InverseCDF> PartialSample<D> {
pub fn new(distribution: D) -> Self {
PartialSample {
randomness: UBig::ZERO,
refinements: 0,
distribution,
}
}
}
impl<D: InverseCDF> PartialSample<D> {
fn edge<R: ODPRound>(&self) -> Option<D::Edge> {
let uniform_edge = RBig::from_parts(
IBig::from(self.randomness.clone() + R::UBIG),
UBig::ONE << self.refinements,
);
self.distribution
.inverse_cdf::<R>(uniform_edge, self.refinements)
}
fn refine(&mut self) -> Fallible<()> {
self.randomness <<= 64;
self.randomness += UBig::from(sample_from_uniform_bytes::<u64, 8>()?);
self.refinements += 64;
Ok(())
}
fn lower(&self) -> Option<D::Edge> {
self.edge::<Down>()
}
fn upper(&self) -> Option<D::Edge> {
self.edge::<Up>()
}
pub fn greater_than(
self: &mut PartialSample<D>,
other: &mut PartialSample<D>,
) -> Fallible<bool> {
Ok(loop {
match self.lower().zip(other.upper()) {
Some((l, r)) if l > r => break true,
_ => (),
}
match self.upper().zip(other.lower()) {
Some((l, r)) if l < r => break false,
_ => (),
}
if self.refinements < other.refinements {
self.refine()?
} else {
other.refine()?
}
})
}
pub fn value<TO: RoundCast<D::Edge> + PartialEq>(&mut self) -> Fallible<TO> {
Ok(loop {
let Some((l, r)) = self.lower().zip(self.upper()) else {
self.refine()?;
continue;
};
let (l, r) = (TO::round_cast(l)?, TO::round_cast(r)?);
if l == r {
break l;
}
self.refine()?;
})
}
}
pub trait ODPRound: Round {
const UBIG: UBig;
type C: ODPRound<C = Self>;
}
impl ODPRound for Down {
const UBIG: UBig = UBig::ZERO;
type C = Up;
}
impl ODPRound for Up {
const UBIG: UBig = UBig::ONE;
type C = Down;
}