use std::ops::Sub;
use num::Zero;
use opendp_derive::bootstrap;
use crate::{
core::Function,
error::Fallible,
traits::{Float, Number, RoundCast},
};
#[cfg(feature = "ffi")]
mod ffi;
#[bootstrap(features("contrib"), generics(TA(default = "float")))]
pub fn make_cdf<TA: Float>() -> Fallible<Function<Vec<TA>, Vec<TA>>> {
Ok(Function::new_fallible(|arg: &Vec<TA>| {
let cumsum = arg
.iter()
.scan(TA::zero(), |acc, v| {
*acc += *v;
Some(*acc)
})
.collect::<Vec<TA>>();
let sum = cumsum[cumsum.len() - 1];
Ok(cumsum.into_iter().map(|v| v / sum).collect())
}))
}
#[doc(hidden)]
pub enum Interpolation {
Nearest,
Linear,
}
#[bootstrap(
features("contrib"),
arguments(interpolation(c_type = "char *", rust_type = "String", default = "linear")),
generics(F(default = "float"))
)]
pub fn make_quantiles_from_counts<TA, F>(
bin_edges: Vec<TA>,
alphas: Vec<F>,
interpolation: Interpolation,
) -> Fallible<Function<Vec<TA>, Vec<TA>>>
where
TA: Number + RoundCast<F>,
F: Float + RoundCast<TA>,
{
if bin_edges.len().is_zero() {
return fallible!(MakeTransformation, "bin_edges.len() must be positive");
}
if bin_edges.windows(2).any(|w| w[0] >= w[1]) {
return fallible!(MakeTransformation, "bin_edges must be increasing");
}
if alphas.windows(2).any(|w| w[0] >= w[1]) {
return fallible!(MakeTransformation, "alphas must be increasing");
}
if let Some(lower) = alphas.first() {
if lower.is_sign_negative() {
return fallible!(
MakeTransformation,
"alphas must be greater than or equal to zero"
);
}
}
if let Some(upper) = alphas.last() {
if upper > &F::one() {
return fallible!(
MakeTransformation,
"alphas must be less than or equal to one"
);
}
}
Ok(Function::new_fallible(move |arg: &Vec<TA>| {
if abs_diff(bin_edges.len(), arg.len()) != 1 {
return fallible!(
FailedFunction,
"there must be one more bin edge than there are counts"
);
}
if arg.is_empty() {
return Ok(vec![bin_edges[0].clone(); alphas.len()]);
}
let arg = if bin_edges.len() + 1 == arg.len() {
&arg[1..arg.len() - 1]
} else {
&arg[..]
};
let cumsum = (arg.iter())
.scan(TA::zero(), |acc, v| {
*acc += v.clone();
Some(acc.clone())
})
.map(F::round_cast)
.collect::<Fallible<Vec<F>>>()?;
let sum = cumsum[cumsum.len() - 1];
let cdf: Vec<F> = cumsum.into_iter().map(|v| v / sum).collect();
let mut indices = vec![0; alphas.len()];
count_lt_recursive(indices.as_mut_slice(), alphas.as_slice(), cdf.as_slice(), 0);
indices
.into_iter()
.zip(&alphas)
.map(|(idx, &alpha)| {
let left_cdf = if idx == 0 { F::zero() } else { cdf[idx - 1] };
let right_cdf = cdf[idx];
match interpolation {
Interpolation::Nearest => {
Ok(bin_edges[idx + (alpha - left_cdf > right_cdf - alpha) as usize])
}
Interpolation::Linear => {
let left_edge = F::round_cast(bin_edges[idx])?;
let right_edge = F::round_cast(bin_edges[idx + 1])?;
let t = (alpha - left_cdf) / (right_cdf - left_cdf);
let v = (F::one() - t) * left_edge + t * right_edge;
TA::round_cast(v)
}
}
})
.collect()
}))
}
fn abs_diff<T: PartialOrd + Sub<Output = T>>(a: T, b: T) -> T {
if a < b { b - a } else { a - b }
}
fn count_lt_recursive<TI: PartialOrd>(
counts: &mut [usize],
edges: &[TI],
x: &[TI],
x_start_idx: usize,
) {
if edges.is_empty() {
return;
}
if edges.len() == 1 {
counts[0] = x_start_idx + count_lt(x, &edges[0]);
return;
}
let mid_edge_idx = (edges.len() + 1) / 2;
let mid_edge = &edges[mid_edge_idx];
let mid_x_idx = count_lt(x, mid_edge);
counts[mid_edge_idx] = x_start_idx + mid_x_idx;
count_lt_recursive(
&mut counts[..mid_edge_idx],
&edges[..mid_edge_idx],
&x[..mid_x_idx],
x_start_idx,
);
count_lt_recursive(
&mut counts[mid_edge_idx + 1..],
&edges[mid_edge_idx + 1..],
&x[mid_x_idx..],
x_start_idx + mid_x_idx,
);
}
fn count_lt<TI: PartialOrd>(x: &[TI], target: &TI) -> usize {
if x.is_empty() {
return 0;
}
let (mut lower, mut upper) = (0, x.len());
while upper - lower > 1 {
let middle = lower + (upper - lower) / 2;
if &x[middle] < target {
lower = middle;
} else {
upper = middle;
}
}
if &x[lower] < target { upper } else { lower }
}
#[cfg(test)]
mod test;