use crate::{error::Fallible, traits::Integer};
use super::fill_bytes;
use dashu::{base::BitTest, integer::UBig};
use num::Unsigned;
use opendp_derive::proven;
#[cfg(test)]
mod test;
pub trait FromBytes<const N: usize> {
fn from_ne_bytes(bytes: [u8; N]) -> Self;
}
macro_rules! impl_from_bytes {
($($ty:ty)+) => ($(impl FromBytes<{size_of::<$ty>()}> for $ty {
fn from_ne_bytes(bytes: [u8; size_of::<$ty>()]) -> Self {
<$ty>::from_ne_bytes(bytes)
}
})+)
}
impl_from_bytes!(u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize);
pub fn sample_from_uniform_bytes<T: FromBytes<N>, const N: usize>() -> Fallible<T> {
let mut buffer = [0; N];
fill_bytes(&mut buffer)?;
Ok(T::from_ne_bytes(buffer))
}
#[proven]
pub fn sample_uniform_uint_below<T: Integer + Unsigned + FromBytes<N>, const N: usize>(
upper: T,
) -> Fallible<T> {
let threshold = T::MAX_FINITE - T::MAX_FINITE % upper;
Ok(loop {
let sample = sample_from_uniform_bytes::<T, N>()?;
if sample < threshold {
break sample % upper;
}
})
}
#[proven]
pub fn sample_uniform_ubig_below(upper: UBig) -> Fallible<UBig> {
let byte_len = upper.bit_len().div_ceil(8);
let max = UBig::from_be_bytes(&vec![u8::MAX; byte_len]);
let threshold = &max - &max % &upper;
let mut buffer = vec![0; byte_len];
Ok(loop {
fill_bytes(&mut buffer)?;
let sample = UBig::from_be_bytes(&buffer);
if sample < threshold {
break sample % &upper;
}
})
}