opendp/transformations/count_cdf/
mod.rs1use std::ops::Sub;
2
3use num::Zero;
4use opendp_derive::bootstrap;
5
6use crate::{
7 core::Function,
8 error::Fallible,
9 traits::{Float, Number, RoundCast},
10};
11
12#[cfg(feature = "ffi")]
13mod ffi;
14
15#[bootstrap(features("contrib"), generics(TA(default = "float")))]
16pub fn make_cdf<TA: Float>() -> Fallible<Function<Vec<TA>, Vec<TA>>> {
21 Ok(Function::new_fallible(|arg: &Vec<TA>| {
22 let cumsum = arg
23 .iter()
24 .scan(TA::zero(), |acc, v| {
25 *acc += *v;
26 Some(*acc)
27 })
28 .collect::<Vec<TA>>();
29 let sum = cumsum[cumsum.len() - 1];
30 Ok(cumsum.into_iter().map(|v| v / sum).collect())
31 }))
32}
33
34#[doc(hidden)]
35pub enum Interpolation {
36 Nearest,
37 Linear,
38}
39
40#[bootstrap(
41 features("contrib"),
42 arguments(interpolation(c_type = "char *", rust_type = "String", default = "linear")),
43 generics(F(default = "float"))
44)]
45pub fn make_quantiles_from_counts<TA, F>(
56 bin_edges: Vec<TA>,
57 alphas: Vec<F>,
58 interpolation: Interpolation,
59) -> Fallible<Function<Vec<TA>, Vec<TA>>>
60where
61 TA: Number + RoundCast<F>,
62 F: Float + RoundCast<TA>,
63{
64 if bin_edges.len().is_zero() {
65 return fallible!(MakeTransformation, "bin_edges.len() must be positive");
66 }
67 if bin_edges.windows(2).any(|w| w[0] >= w[1]) {
68 return fallible!(MakeTransformation, "bin_edges must be increasing");
69 }
70 if alphas.windows(2).any(|w| w[0] >= w[1]) {
71 return fallible!(MakeTransformation, "alphas must be increasing");
72 }
73 if let Some(lower) = alphas.first() {
74 if lower.is_sign_negative() {
75 return fallible!(
76 MakeTransformation,
77 "alphas must be greater than or equal to zero"
78 );
79 }
80 }
81 if let Some(upper) = alphas.last() {
82 if upper > &F::one() {
83 return fallible!(
84 MakeTransformation,
85 "alphas must be less than or equal to one"
86 );
87 }
88 }
89
90 Ok(Function::new_fallible(move |arg: &Vec<TA>| {
91 if abs_diff(bin_edges.len(), arg.len()) != 1 {
93 return fallible!(
94 FailedFunction,
95 "there must be one more bin edge than there are counts"
96 );
97 }
98 if arg.is_empty() {
99 return Ok(vec![bin_edges[0].clone(); alphas.len()]);
100 }
101 let arg = if bin_edges.len() + 1 == arg.len() {
103 &arg[1..arg.len() - 1]
104 } else {
105 &arg[..]
106 };
107 let cumsum = (arg.iter())
109 .scan(TA::zero(), |acc, v| {
110 *acc += v.clone();
111 Some(acc.clone())
112 })
113 .map(F::round_cast)
114 .collect::<Fallible<Vec<F>>>()?;
115
116 let sum = cumsum[cumsum.len() - 1];
118
119 let cdf: Vec<F> = cumsum.into_iter().map(|v| v / sum).collect();
120
121 let mut indices = vec![0; alphas.len()];
123 count_lt_recursive(indices.as_mut_slice(), alphas.as_slice(), cdf.as_slice(), 0);
124
125 indices
126 .into_iter()
127 .zip(&alphas)
128 .map(|(idx, &alpha)| {
129 let left_cdf = if idx == 0 { F::zero() } else { cdf[idx - 1] };
132 let right_cdf = cdf[idx];
133
134 match interpolation {
137 Interpolation::Nearest => {
138 Ok(bin_edges[idx + (alpha - left_cdf > right_cdf - alpha) as usize])
140 }
141 Interpolation::Linear => {
142 let left_edge = F::round_cast(bin_edges[idx])?;
143 let right_edge = F::round_cast(bin_edges[idx + 1])?;
144
145 let t = (alpha - left_cdf) / (right_cdf - left_cdf);
148 let v = (F::one() - t) * left_edge + t * right_edge;
149 TA::round_cast(v)
150 }
151 }
152 })
153 .collect()
154 }))
155}
156
157fn abs_diff<T: PartialOrd + Sub<Output = T>>(a: T, b: T) -> T {
158 if a < b { b - a } else { a - b }
159}
160
161fn count_lt_recursive<TI: PartialOrd>(
170 counts: &mut [usize],
171 edges: &[TI],
172 x: &[TI],
173 x_start_idx: usize,
174) {
175 if edges.is_empty() {
176 return;
177 }
178 if edges.len() == 1 {
179 counts[0] = x_start_idx + count_lt(x, &edges[0]);
180 return;
181 }
182 let mid_edge_idx = (edges.len() + 1) / 2;
184 let mid_edge = &edges[mid_edge_idx];
185 let mid_x_idx = count_lt(x, mid_edge);
186 counts[mid_edge_idx] = x_start_idx + mid_x_idx;
187
188 count_lt_recursive(
189 &mut counts[..mid_edge_idx],
190 &edges[..mid_edge_idx],
191 &x[..mid_x_idx],
192 x_start_idx,
193 );
194
195 count_lt_recursive(
196 &mut counts[mid_edge_idx + 1..],
197 &edges[mid_edge_idx + 1..],
198 &x[mid_x_idx..],
199 x_start_idx + mid_x_idx,
200 );
201}
202
203fn count_lt<TI: PartialOrd>(x: &[TI], target: &TI) -> usize {
210 if x.is_empty() {
211 return 0;
212 }
213 let (mut lower, mut upper) = (0, x.len());
214
215 while upper - lower > 1 {
216 let middle = lower + (upper - lower) / 2;
217
218 if &x[middle] < target {
219 lower = middle;
220 } else {
221 upper = middle;
222 }
223 }
224 if &x[lower] < target { upper } else { lower }
225}
226
227#[cfg(test)]
228mod test;