closure_core/
lib.rs

1//! CLOSURE: complete listing of original samples of underlying raw evidence
2//! 
3//! Crate closure-core implements the CLOSURE technique for efficiently reconstructing
4//! all possible distributions of raw data from summary statistics. It is not
5//! about the Rust feature called closure.
6//! 
7//! The crate is mostly meant to serve as a backend for the R package [unsum](https://lhdjung.github.io/unsum/).
8//! The only API users are likely to need is `dfs_parallel()`.
9//! 
10//! Most of the code was written by Claude 3.5, translating Python code by Nathanael Larigaldie.
11
12
13use num::{Float, FromPrimitive, Integer, NumCast, ToPrimitive};
14use std::collections::VecDeque;
15use rayon::prelude::*;
16
17/// Implements range over Rint-friendly generic integer type U
18struct IntegerRange<U>
19where
20    U: Integer + Copy
21{
22    current: U,
23    end: U,
24}
25
26impl<U> Iterator for IntegerRange<U>
27where 
28    U: Integer + Copy
29{
30    type Item = U;
31
32    /// Increment over U type integers
33    fn next(&mut self) -> Option<U> {
34        if self.current < self.end {
35            let next = self.current;
36            self.current = self.current + U::one();
37            Some(next) 
38        } else {
39            None
40        }
41    }
42}
43
44/// Creates an iterator over the space of U type integers 
45fn range_u<U: Integer + Copy>(start: U, end: U) -> IntegerRange<U> {
46    IntegerRange {current: start, end}
47}
48
49// Define the Combination struct
50#[derive(Clone)]
51struct Combination<T, U> {
52    values: Vec<U>,
53    running_sum: T,
54    running_m2: T,
55}
56
57/// Count first set of integers
58/// 
59/// The first set of integers that can be formed
60/// given a range defined by `scale_min` and `scale_max`.
61/// This function calculates the number of unique pairs (i, j) where i and j are integers
62/// within the specified range, and i <= j.
63/// # Arguments
64/// * `scale_min` - The minimum value of the scale.
65/// * `scale_max` - The maximum value of the scale.
66/// # Returns
67/// The total number of unique combinations of integers within the specified range.
68pub fn count_initial_combinations(scale_min: i32, scale_max: i32) -> i32 {
69    let range_size = scale_max - scale_min + 1;
70    (range_size * (range_size + 1)) / 2
71}
72
73
74/// Generate all valid combinations
75/// 
76/// `dfs_parallel()` computes all valid combinations of integers that
77/// match the given summary statistics.
78///
79/// # Arguments
80/// * `mean` - The mean of the target distribution.
81/// * `sd` - The standard deviation of the target distribution.
82/// * `n` - The number of elements in the target distribution.
83/// * `scale_min` - The minimum value of the scale.
84/// * `scale_max` - The maximum value of the scale.
85/// * `rounding_error_mean` - The rounding error for the mean.
86/// * `rounding_error_sd` - The rounding error for the standard deviation.
87/// # Returns
88/// A vector of vectors, where each inner vector represents a valid combination of integers
89/// that matches the given summary statistics.
90pub fn dfs_parallel<T, U>(
91    mean: T,
92    sd: T,
93    n: U,
94    scale_min: U,
95    scale_max: U,
96    rounding_error_mean: T,
97    rounding_error_sd: T,
98) -> Vec<Vec<U>>
99where
100    T: Float + FromPrimitive + Send + Sync, // suggest renaming to F to indicate float type?
101    U: Integer + NumCast + ToPrimitive + Copy + Send + Sync,
102{
103    // Convert integer `n` to float to enable multiplication with other floats
104    let n_float = T::from(U::to_i32(&n).unwrap()).unwrap();
105    
106    // Target sum calculations
107    let target_sum = mean * n_float;
108    let rounding_error_sum = rounding_error_mean * n_float;
109    
110    let target_sum_upper = target_sum + rounding_error_sum;
111    let target_sum_lower = target_sum - rounding_error_sum;
112    let sd_upper = sd + rounding_error_sd;
113    let sd_lower = sd - rounding_error_sd;
114
115    // Convert to usize for range operations
116    let n_usize = U::to_usize(&n).unwrap();
117    
118    // Precomputing scale sums directly on T types 
119    let scale_min_sum_t: Vec<T> = (0..n_usize)
120        .map(|x| T::from(scale_min).unwrap() * T::from(x).unwrap())
121        .collect();
122    
123    let scale_max_sum_t: Vec<T> = (0..n_usize)
124        .map(|x| T::from(scale_max).unwrap() * T::from(x).unwrap())
125        .collect();
126    
127    let n_minus_1 = n - U::one();
128    let scale_max_plus_1 = scale_max + U::one();
129
130   // instead of generating the initial combinations using concrete types, we're keeping them in U
131    // and T using the iterator for U 
132    let combinations = range_u(scale_min, scale_max_plus_1)
133    .flat_map(|i| {
134        range_u(i, scale_max_plus_1).map(move |j| {
135            let initial_combination = vec![i, j];
136
137            // turn the integer type into the float type
138            // again, might be good for readability to rename T to F
139            let i_float = T::from(i).unwrap();
140            let j_float = T::from(j).unwrap();
141            let sum = i_float + j_float;
142            let current_mean = sum / T::from(2).unwrap();
143
144            let diff_i = i_float - current_mean;
145            let diff_j = j_float - current_mean;
146            let current_m2 = diff_i * diff_i + diff_j * diff_j;
147
148            (initial_combination, sum, current_m2)
149        })
150    })
151    .collect::<Vec<_>>();
152
153    // Process combinations in parallel
154    combinations.par_iter()
155        .flat_map(|(combo, running_sum, running_m2)| {
156            dfs_branch(
157                combo.clone(),
158                *running_sum,
159                *running_m2,
160                n_usize,
161                target_sum_upper,
162                target_sum_lower,
163                sd_upper,
164                sd_lower,
165                &scale_min_sum_t,
166                &scale_max_sum_t,
167                n_minus_1,
168                scale_max_plus_1,
169            )
170        })
171        .collect()
172}
173
174// Collect all valid combinations from a starting point
175#[inline]
176#[allow(clippy::too_many_arguments)]
177fn dfs_branch<T, U>(
178    start_combination: Vec<U>,
179    running_sum_init: T,
180    running_m2_init: T,
181    n: usize,  // Use usize for the length
182    target_sum_upper: T,
183    target_sum_lower: T,
184    sd_upper: T,
185    sd_lower: T,
186    scale_min_sum_t: &[T],
187    scale_max_sum_t: &[T],
188    _n_minus_1: U,
189    scale_max_plus_1: U,
190) -> Vec<Vec<U>>
191where
192    T: Float + FromPrimitive + Send + Sync,
193    U: Integer + NumCast + ToPrimitive + Copy + Send + Sync,
194{
195    let mut stack = VecDeque::with_capacity(n * 2); // Preallocate with reasonable capacity
196    let mut results = Vec::new();
197    
198    stack.push_back(Combination {
199        values: start_combination.clone(),
200        running_sum: running_sum_init,
201        running_m2: running_m2_init,
202    });
203    
204    while let Some(current) = stack.pop_back() {
205        if current.values.len() >= n {
206            let n_minus_1_float = T::from(n - 1).unwrap();
207            let current_std = (current.running_m2 / n_minus_1_float).sqrt();
208            if current_std >= sd_lower {
209                results.push(current.values);
210            }
211            continue;
212        }
213
214        // Calculate remaining items to add
215        let current_len = current.values.len();
216        let n_left = n - current_len - 1; // How many more items after the next one
217        let next_n = current_len + 1;
218
219        // Get current mean
220        let current_mean = current.running_sum / T::from(current_len).unwrap();
221
222        // Get the last value        
223        let last_value = current.values[current_len - 1];
224
225        for next_value in range_u(last_value, scale_max_plus_1) {
226            let next_value_as_t = T::from(next_value).unwrap();
227            let next_sum = current.running_sum + next_value_as_t;
228            
229            // Safe indexing with bounds check (using usize for indexing)
230            if n_left < scale_min_sum_t.len() {
231                let minmean = next_sum + scale_min_sum_t[n_left];
232                if minmean > target_sum_upper {
233                    break; // Early termination - better than take_while!
234                }
235                
236                // Safe indexing with bounds check (using usize for indexing)
237                if n_left < scale_max_sum_t.len() {
238                    let maxmean = next_sum + scale_max_sum_t[n_left];
239                    if maxmean < target_sum_lower {
240                        continue;
241                    }
242                    
243                    let next_mean = next_sum / T::from(next_n).unwrap();
244                    let delta = next_value_as_t - current_mean;
245                    let delta2 = next_value_as_t - next_mean;
246                    let next_m2 = current.running_m2 + delta * delta2;
247                    
248                    let min_sd = (next_m2 / T::from(n - 1).unwrap()).sqrt();
249                    if min_sd <= sd_upper {
250                        let mut new_values = current.values.clone();
251                        new_values.push(next_value);
252                        stack.push_back(Combination {
253                            values: new_values,
254                            running_sum: next_sum,
255                            running_m2: next_m2,
256                        });
257                    }
258                }
259            }
260        }
261    }
262    results
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn test_count_initial_combinations() {
271        assert_eq!(count_initial_combinations(1, 3), 6);
272        assert_eq!(count_initial_combinations(1, 4), 10);
273    }
274}
275