use std::ops::{AddAssign, SubAssign};
use num::{One, Zero};
use opendp_derive::proven;
use crate::{
error::Fallible,
traits::{ExactIntCast, FiniteBounds, Float, Integer},
};
use super::{fill_bytes, sample_bernoulli_float, sample_standard_bernoulli};
pub fn sample_geometric_linear<T, P>(
mut shift: T,
positive: bool,
prob: P,
mut trials: Option<usize>,
) -> Fallible<T>
where
T: Clone + Zero + One + PartialEq + AddAssign + SubAssign + FiniteBounds,
P: Float,
usize: ExactIntCast<P::Bits>,
P::Bits: ExactIntCast<usize>,
{
if !(P::zero()..=P::one()).contains(&prob) {
return fallible!(FailedFunction, "probability is not within [0, 1]");
}
let bound = if positive {
T::MAX_FINITE
} else {
T::MIN_FINITE
};
let mut success: bool = false;
loop {
success |= sample_bernoulli_float(prob, trials.is_some())?;
if !success && shift != bound {
if positive {
shift += T::one()
} else {
shift -= T::one()
}
}
if let Some(trials) = trials.as_mut() {
if trials.is_zero() {
break;
}
*trials -= 1;
} else if success {
break;
}
}
Ok(shift)
}
pub fn sample_discrete_laplace_linear<T, P>(
mut shift: T,
scale: P,
(lower, upper): (T, T),
) -> Fallible<T>
where
T: Integer,
P: Float,
usize: ExactIntCast<P::Bits> + ExactIntCast<T>,
P::Bits: ExactIntCast<usize>,
{
if scale.is_zero() {
return Ok(shift);
}
if lower == upper {
return Ok(lower);
}
let trials: Option<usize> = Some(usize::exact_int_cast(
upper.alerting_sub(&lower)?.alerting_sub(&T::one())?,
)?);
let prob = P::one().neg_inf_sub(&(-scale.recip()).inf_exp()?)?;
shift = shift.total_clamp(lower, upper)?;
let noised = loop {
let direction = sample_standard_bernoulli()?;
let sample = sample_geometric_linear(shift, direction, prob, trials)?;
if direction || sample != shift {
break sample;
}
};
noised.total_clamp(lower, upper)
}
#[proven]
pub(super) fn sample_geometric_buffer(
buffer_len: usize,
constant_time: bool,
) -> Fallible<Option<usize>> {
Ok(if constant_time {
let mut buffer = vec![0_u8; buffer_len];
fill_bytes(&mut buffer)?;
(buffer.iter())
.enumerate()
.filter(|(_, sample)| **sample > 0)
.map(|(i, sample)| 8 * i + sample.leading_zeros() as usize)
.min()
} else {
let mut buffer = vec![0_u8; 1];
for i in 0..buffer_len {
fill_bytes(&mut buffer)?;
if buffer[0] > 0 {
return Ok(Some(i * 8 + buffer[0].leading_zeros() as usize));
}
}
None
})
}