1use num::{Float, FromPrimitive, Integer, NumCast, ToPrimitive};
14use std::collections::VecDeque;
15use rayon::prelude::*;
16
17struct IntegerRange<U>
19where
20 U: Integer + Copy
21{
22 current: U,
23 end: U,
24}
25
26impl<U> Iterator for IntegerRange<U>
27where
28 U: Integer + Copy
29{
30 type Item = U;
31
32 fn next(&mut self) -> Option<U> {
34 if self.current < self.end {
35 let next = self.current;
36 self.current = self.current + U::one();
37 Some(next)
38 } else {
39 None
40 }
41 }
42}
43
44fn range_u<U: Integer + Copy>(start: U, end: U) -> IntegerRange<U> {
46 IntegerRange {current: start, end}
47}
48
49#[derive(Clone)]
51struct Combination<T, U> {
52 values: Vec<U>,
53 running_sum: T,
54 running_m2: T,
55}
56
57pub fn count_initial_combinations(scale_min: i32, scale_max: i32) -> i32 {
69 let range_size = scale_max - scale_min + 1;
70 (range_size * (range_size + 1)) / 2
71}
72
73
74pub fn dfs_parallel<T, U>(
91 mean: T,
92 sd: T,
93 n: U,
94 scale_min: U,
95 scale_max: U,
96 rounding_error_mean: T,
97 rounding_error_sd: T,
98) -> Vec<Vec<U>>
99where
100 T: Float + FromPrimitive + Send + Sync, U: Integer + NumCast + ToPrimitive + Copy + Send + Sync,
102{
103 let n_float = T::from(U::to_i32(&n).unwrap()).unwrap();
105
106 let target_sum = mean * n_float;
108 let rounding_error_sum = rounding_error_mean * n_float;
109
110 let target_sum_upper = target_sum + rounding_error_sum;
111 let target_sum_lower = target_sum - rounding_error_sum;
112 let sd_upper = sd + rounding_error_sd;
113 let sd_lower = sd - rounding_error_sd;
114
115 let n_usize = U::to_usize(&n).unwrap();
117
118 let scale_min_sum_t: Vec<T> = (0..n_usize)
120 .map(|x| T::from(scale_min).unwrap() * T::from(x).unwrap())
121 .collect();
122
123 let scale_max_sum_t: Vec<T> = (0..n_usize)
124 .map(|x| T::from(scale_max).unwrap() * T::from(x).unwrap())
125 .collect();
126
127 let n_minus_1 = n - U::one();
128 let scale_max_plus_1 = scale_max + U::one();
129
130 let combinations = range_u(scale_min, scale_max_plus_1)
133 .flat_map(|i| {
134 range_u(i, scale_max_plus_1).map(move |j| {
135 let initial_combination = vec![i, j];
136
137 let i_float = T::from(i).unwrap();
140 let j_float = T::from(j).unwrap();
141 let sum = i_float + j_float;
142 let current_mean = sum / T::from(2).unwrap();
143
144 let diff_i = i_float - current_mean;
145 let diff_j = j_float - current_mean;
146 let current_m2 = diff_i * diff_i + diff_j * diff_j;
147
148 (initial_combination, sum, current_m2)
149 })
150 })
151 .collect::<Vec<_>>();
152
153 combinations.par_iter()
155 .flat_map(|(combo, running_sum, running_m2)| {
156 dfs_branch(
157 combo.clone(),
158 *running_sum,
159 *running_m2,
160 n_usize,
161 target_sum_upper,
162 target_sum_lower,
163 sd_upper,
164 sd_lower,
165 &scale_min_sum_t,
166 &scale_max_sum_t,
167 n_minus_1,
168 scale_max_plus_1,
169 )
170 })
171 .collect()
172}
173
174#[inline]
176#[allow(clippy::too_many_arguments)]
177fn dfs_branch<T, U>(
178 start_combination: Vec<U>,
179 running_sum_init: T,
180 running_m2_init: T,
181 n: usize, target_sum_upper: T,
183 target_sum_lower: T,
184 sd_upper: T,
185 sd_lower: T,
186 scale_min_sum_t: &[T],
187 scale_max_sum_t: &[T],
188 _n_minus_1: U,
189 scale_max_plus_1: U,
190) -> Vec<Vec<U>>
191where
192 T: Float + FromPrimitive + Send + Sync,
193 U: Integer + NumCast + ToPrimitive + Copy + Send + Sync,
194{
195 let mut stack = VecDeque::with_capacity(n * 2); let mut results = Vec::new();
197
198 stack.push_back(Combination {
199 values: start_combination.clone(),
200 running_sum: running_sum_init,
201 running_m2: running_m2_init,
202 });
203
204 while let Some(current) = stack.pop_back() {
205 if current.values.len() >= n {
206 let n_minus_1_float = T::from(n - 1).unwrap();
207 let current_std = (current.running_m2 / n_minus_1_float).sqrt();
208 if current_std >= sd_lower {
209 results.push(current.values);
210 }
211 continue;
212 }
213
214 let current_len = current.values.len();
216 let n_left = n - current_len - 1; let next_n = current_len + 1;
218
219 let current_mean = current.running_sum / T::from(current_len).unwrap();
221
222 let last_value = current.values[current_len - 1];
224
225 for next_value in range_u(last_value, scale_max_plus_1) {
226 let next_value_as_t = T::from(next_value).unwrap();
227 let next_sum = current.running_sum + next_value_as_t;
228
229 if n_left < scale_min_sum_t.len() {
231 let minmean = next_sum + scale_min_sum_t[n_left];
232 if minmean > target_sum_upper {
233 break; }
235
236 if n_left < scale_max_sum_t.len() {
238 let maxmean = next_sum + scale_max_sum_t[n_left];
239 if maxmean < target_sum_lower {
240 continue;
241 }
242
243 let next_mean = next_sum / T::from(next_n).unwrap();
244 let delta = next_value_as_t - current_mean;
245 let delta2 = next_value_as_t - next_mean;
246 let next_m2 = current.running_m2 + delta * delta2;
247
248 let min_sd = (next_m2 / T::from(n - 1).unwrap()).sqrt();
249 if min_sd <= sd_upper {
250 let mut new_values = current.values.clone();
251 new_values.push(next_value);
252 stack.push_back(Combination {
253 values: new_values,
254 running_sum: next_sum,
255 running_m2: next_m2,
256 });
257 }
258 }
259 }
260 }
261 }
262 results
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[test]
270 fn test_count_initial_combinations() {
271 assert_eq!(count_initial_combinations(1, 3), 6);
272 assert_eq!(count_initial_combinations(1, 4), 10);
273 }
274}
275