use dashu::{integer::Sign, rational::RBig};
use num::{One, Zero};
use opendp_derive::proven;
use crate::{
error::Fallible,
traits::{ExactIntCast, Float, InfDiv},
};
use super::{fill_bytes, sample_geometric_buffer, sample_uniform_ubig_below};
#[cfg(test)]
mod test;
pub fn sample_standard_bernoulli() -> Fallible<bool> {
let mut buffer = [0u8; 1];
fill_bytes(&mut buffer)?;
Ok(buffer[0] & 1 == 1)
}
#[proven]
pub fn sample_bernoulli_float<T>(prob: T, constant_time: bool) -> Fallible<bool>
where
T: Float,
T::Bits: PartialOrd + ExactIntCast<usize>,
usize: ExactIntCast<T::Bits>,
{
if prob.is_one() {
return Ok(true);
}
let first_heads_index = {
let max_coin_flips =
usize::exact_int_cast(T::EXPONENT_BIAS)? + usize::exact_int_cast(T::MANTISSA_BITS)?;
let buffer_len = max_coin_flips.inf_div(&8)?;
match sample_geometric_buffer(buffer_len, constant_time)? {
Some(i) => T::Bits::exact_int_cast(i)?,
None => return Ok(false),
}
};
let leading_zeros = T::EXPONENT_BIAS - T::Bits::one() - prob.raw_exponent();
Ok(match first_heads_index {
i if i < leading_zeros => false,
i if i == leading_zeros => !prob.raw_exponent().is_zero(),
i if i > leading_zeros + T::MANTISSA_BITS => false,
i => !(prob.to_bits() & T::Bits::one() << (leading_zeros + T::MANTISSA_BITS - i)).is_zero(),
})
}
#[proven]
pub fn sample_bernoulli_rational(prob: RBig) -> Fallible<bool> {
let (numer, denom) = prob.into_parts();
let (Sign::Positive, numer) = numer.into_parts() else {
return fallible!(FailedFunction, "numerator must not be negative");
};
if numer > denom {
return fallible!(FailedFunction, "prob must not be greater than one");
}
sample_uniform_ubig_below(denom).map(|s| numer > s)
}