opendp/transformations/count_cdf/
mod.rs

1use std::ops::Sub;
2
3use num::Zero;
4use opendp_derive::bootstrap;
5
6use crate::{
7    core::Function,
8    error::Fallible,
9    traits::{Float, Number, RoundCast},
10};
11
12#[cfg(feature = "ffi")]
13mod ffi;
14
15#[bootstrap(features("contrib"), generics(TA(default = "float")))]
16/// Postprocess a noisy array of float summary counts into a cumulative distribution.
17///
18/// # Generics
19/// * `TA` - Atomic Type. One of `f32` or `f64`
20pub fn make_cdf<TA: Float>() -> Fallible<Function<Vec<TA>, Vec<TA>>> {
21    Ok(Function::new_fallible(|arg: &Vec<TA>| {
22        let cumsum = arg
23            .iter()
24            .scan(TA::zero(), |acc, v| {
25                *acc += *v;
26                Some(*acc)
27            })
28            .collect::<Vec<TA>>();
29        let sum = cumsum[cumsum.len() - 1];
30        Ok(cumsum.into_iter().map(|v| v / sum).collect())
31    }))
32}
33
34#[doc(hidden)]
35pub enum Interpolation {
36    Nearest,
37    Linear,
38}
39
40#[bootstrap(
41    features("contrib"),
42    arguments(interpolation(c_type = "char *", rust_type = "String", default = "linear")),
43    generics(F(default = "float"))
44)]
45/// Postprocess a noisy array of summary counts into quantiles.
46///
47/// # Arguments
48/// * `bin_edges` - The edges that the input data was binned into before counting.
49/// * `alphas` - Return all specified `alpha`-quantiles.
50/// * `interpolation` - Must be one of `linear` or `nearest`
51///
52/// # Generics
53/// * `TA` - Atomic Type of the bin edges and data.
54/// * `F` - Float type of the alpha argument. One of `f32` or `f64`
55pub fn make_quantiles_from_counts<TA, F>(
56    bin_edges: Vec<TA>,
57    alphas: Vec<F>,
58    interpolation: Interpolation,
59) -> Fallible<Function<Vec<TA>, Vec<TA>>>
60where
61    TA: Number + RoundCast<F>,
62    F: Float + RoundCast<TA>,
63{
64    if bin_edges.len().is_zero() {
65        return fallible!(MakeTransformation, "bin_edges.len() must be positive");
66    }
67    if bin_edges.windows(2).any(|w| w[0] >= w[1]) {
68        return fallible!(MakeTransformation, "bin_edges must be increasing");
69    }
70    if alphas.windows(2).any(|w| w[0] >= w[1]) {
71        return fallible!(MakeTransformation, "alphas must be increasing");
72    }
73    if let Some(lower) = alphas.first() {
74        if lower.is_sign_negative() {
75            return fallible!(
76                MakeTransformation,
77                "alphas must be greater than or equal to zero"
78            );
79        }
80    }
81    if let Some(upper) = alphas.last() {
82        if upper > &F::one() {
83            return fallible!(
84                MakeTransformation,
85                "alphas must be less than or equal to one"
86            );
87        }
88    }
89
90    Ok(Function::new_fallible(move |arg: &Vec<TA>| {
91        // one fewer args than bin edges, or one greater args than bin edges are allowed
92        if abs_diff(bin_edges.len(), arg.len()) != 1 {
93            return fallible!(
94                FailedFunction,
95                "there must be one more bin edge than there are counts"
96            );
97        }
98        if arg.is_empty() {
99            return Ok(vec![bin_edges[0].clone(); alphas.len()]);
100        }
101        // if args includes extremal bins for (-inf, edge_0] and [edge_n, inf), discard them
102        let arg = if bin_edges.len() + 1 == arg.len() {
103            &arg[1..arg.len() - 1]
104        } else {
105            &arg[..]
106        };
107        // compute the cumulative sum of the input counts
108        let cumsum = (arg.iter())
109            .scan(TA::zero(), |acc, v| {
110                *acc += v.clone();
111                Some(acc.clone())
112            })
113            .map(F::round_cast)
114            .collect::<Fallible<Vec<F>>>()?;
115
116        // reuse the last element of the cumsum
117        let sum = cumsum[cumsum.len() - 1];
118
119        let cdf: Vec<F> = cumsum.into_iter().map(|v| v / sum).collect();
120
121        // each index is the number of bins whose combined mass is less than the alpha_edge mass
122        let mut indices = vec![0; alphas.len()];
123        count_lt_recursive(indices.as_mut_slice(), alphas.as_slice(), cdf.as_slice(), 0);
124
125        indices
126            .into_iter()
127            .zip(&alphas)
128            .map(|(idx, &alpha)| {
129                // Want to find the cumulative values to the left and right of edge
130                // When no elements less than edge, consider cumulative value to be zero
131                let left_cdf = if idx == 0 { F::zero() } else { cdf[idx - 1] };
132                let right_cdf = cdf[idx];
133
134                // println!("x's {:?}, {:?}", edge, (left.clone(), right.clone()));
135                // println!("y's {:?}", (&bin_edges[idx], &bin_edges[idx + 1]));
136                match interpolation {
137                    Interpolation::Nearest => {
138                        // if edge nearer to right than to left, then increment index
139                        Ok(bin_edges[idx + (alpha - left_cdf > right_cdf - alpha) as usize])
140                    }
141                    Interpolation::Linear => {
142                        let left_edge = F::round_cast(bin_edges[idx])?;
143                        let right_edge = F::round_cast(bin_edges[idx + 1])?;
144
145                        // find the interpolant between the bin edges.
146                        // denominator is never zero because bin edges is strictly increasing
147                        let t = (alpha - left_cdf) / (right_cdf - left_cdf);
148                        let v = (F::one() - t) * left_edge + t * right_edge;
149                        TA::round_cast(v)
150                    }
151                }
152            })
153            .collect()
154    }))
155}
156
157fn abs_diff<T: PartialOrd + Sub<Output = T>>(a: T, b: T) -> T {
158    if a < b { b - a } else { a - b }
159}
160
161/// Compute number of elements less than each edge
162/// Formula is #(`x` <= e) for each e in `edges`.
163///
164/// # Arguments
165/// * `counts` - location to write the result
166/// * `edges` - edges to collect counts for. Must be sorted
167/// * `x` - dataset to count against
168/// * `x_start_idx` - value to add to the count. Useful for recursion on subslices
169fn count_lt_recursive<TI: PartialOrd>(
170    counts: &mut [usize],
171    edges: &[TI],
172    x: &[TI],
173    x_start_idx: usize,
174) {
175    if edges.is_empty() {
176        return;
177    }
178    if edges.len() == 1 {
179        counts[0] = x_start_idx + count_lt(x, &edges[0]);
180        return;
181    }
182    // use binary search to find |{i; x[i] < middle edge}|
183    let mid_edge_idx = (edges.len() + 1) / 2;
184    let mid_edge = &edges[mid_edge_idx];
185    let mid_x_idx = count_lt(x, mid_edge);
186    counts[mid_edge_idx] = x_start_idx + mid_x_idx;
187
188    count_lt_recursive(
189        &mut counts[..mid_edge_idx],
190        &edges[..mid_edge_idx],
191        &x[..mid_x_idx],
192        x_start_idx,
193    );
194
195    count_lt_recursive(
196        &mut counts[mid_edge_idx + 1..],
197        &edges[mid_edge_idx + 1..],
198        &x[mid_x_idx..],
199        x_start_idx + mid_x_idx,
200    );
201}
202
203/// Find the number of elements in `x` lt `target`.
204/// Formula is #(`x` < `target`)
205///
206/// # Arguments
207/// * `x` - dataset to count against
208/// * `target` - value to compare against
209fn count_lt<TI: PartialOrd>(x: &[TI], target: &TI) -> usize {
210    if x.is_empty() {
211        return 0;
212    }
213    let (mut lower, mut upper) = (0, x.len());
214
215    while upper - lower > 1 {
216        let middle = lower + (upper - lower) / 2;
217
218        if &x[middle] < target {
219            lower = middle;
220        } else {
221            upper = middle;
222        }
223    }
224    if &x[lower] < target { upper } else { lower }
225}
226
227#[cfg(test)]
228mod test;