1use crate::*;
2use std::{hash::Hash, thread};
3
4impl<T: Clone + Eq + Hash> Config<T> {
5 pub fn calc_probabilities(&self, pick_amount: usize) -> Result<Table<T>, Error> {
19 if pick_amount == 0 {
20 return Ok(self.table.keys().map(|k| (k.clone(), 0.)).collect());
21 }
22
23 if !self.repetitive {
24 if pick_amount > self.table.len() {
25 return Err(Error::InvalidAmount);
26 }
27 if pick_amount == self.table.len() {
28 return Ok(self.table.keys().map(|k| (k.clone(), 1.)).collect());
29 }
30 if self.is_fair() {
31 let prob = (pick_amount as f64) / (self.table.len() as f64);
32 return Ok(self.table.keys().map(|k| (k.clone(), prob)).collect());
33 }
34 }
35
36 let table: Vec<_> = {
38 let raw_table = self.vec_table()?;
39 let grid_width: f64 = raw_table.iter().map(|(_, v)| v).sum();
40 raw_table
41 .into_iter()
42 .map(|(k, v)| (k, v / grid_width))
43 .collect()
44 };
45
46 if pick_amount == 1 {
47 return Ok(table.into_iter().collect());
48 }
49 if self.repetitive {
50 return Ok(table
51 .into_iter()
52 .map(|(k, v)| (k, 1. - (1. - v).powi(pick_amount as i32)))
53 .collect());
54 }
55
56 let table_val: Vec<_> = table.iter().map(|(_, v)| *v).collect();
59 let mut calc_result = table.clone();
60
61 let cnt_threads = thread::available_parallelism()
62 .map(|n| n.get())
63 .unwrap_or(4)
64 .max(table.len());
65 let cnt_calc_groups = table.len().div_ceil(cnt_threads);
66 let mut calc_groups = Vec::with_capacity(cnt_calc_groups);
67 let mut table_picked = vec![false; table.len()];
68 for i in 0..cnt_calc_groups {
69 let mut calcs = Vec::with_capacity(cnt_threads);
70 for j in 0..cnt_threads {
71 let i_th = i * cnt_threads + j;
72 if i_th >= table.len() {
73 break;
74 }
75 table_picked[i_th] = true;
76 let calc_stack =
77 CalcStack::new(table_val.clone(), pick_amount, table_picked.clone());
78 calcs.push((i_th, Some(calc_stack)));
79 table_picked[i_th] = false;
80 }
81 calc_groups.push(calcs);
82 }
83
84 for group in calc_groups.into_iter() {
85 let mut thread_hdls = Vec::with_capacity(cnt_threads);
86 for (i, mut calc) in group.into_iter() {
87 let calc = calc.take().unwrap();
88 thread_hdls.push(thread::spawn(move || (i, calc.calc())));
89 }
90 for hdl in thread_hdls {
91 let (i_th, sub_result) = hdl.join().map_err(|_| Error::ThreadError)?;
92 for (i, &sub_prob) in sub_result.iter().enumerate() {
93 calc_result[i].1 += table_val[i_th] * sub_prob;
94 }
95 }
96 }
97
98 Ok(calc_result.into_iter().collect())
99 }
100}
101
102#[derive(Clone, Debug)]
103struct CalcStack {
104 table: Vec<f64>, stack_size: usize, stack: Vec<(usize, f64)>, table_picked: Vec<bool>, rem_width: f64, result: Vec<f64>, }
114
115impl CalcStack {
116 fn new(table: Vec<f64>, pick_amount: usize, table_picked: Vec<bool>) -> Self {
119 assert!(table.len() == table_picked.len());
120 let table_len = table.len();
121
122 let mut stack_size = pick_amount;
123 let mut rem_width = 0.;
124 for (i, &picked) in table_picked.iter().enumerate() {
125 if !picked {
126 rem_width += table[i];
127 } else {
128 if stack_size == 0 {
129 break; }
131 stack_size -= 1;
132 }
133 }
134
135 Self {
136 table,
137 stack: Vec::with_capacity(stack_size),
138 stack_size,
139 table_picked,
140 rem_width,
141 result: vec![0.; table_len],
142 }
143 }
144
145 fn calc(mut self) -> Vec<f64> {
146 loop {
147 if self.go_down() {
148 continue;
149 }
150 if self.go_right() {
151 continue;
152 }
153 if self.go_up_right() {
154 continue;
155 }
156 return self.result;
157 }
158 }
159
160 #[inline(always)]
161 fn go_down(&mut self) -> bool {
162 if self.stack.len() >= self.stack_size {
163 return false;
164 }
165
166 let i_next;
167 if let Some(i) = self.next_unpicked(0) {
168 i_next = i;
169 } else {
170 return false;
171 };
172
173 let parent_prob = self.stack.last().map(|t| t.1).unwrap_or(1.);
174 let prob = parent_prob * self.table[i_next] / self.rem_width;
175
176 self.stack.push((i_next, prob));
177 self.table_picked[i_next] = true;
178 self.rem_width -= self.table[i_next];
179 self.result[i_next] += prob;
180 true
181 }
182
183 #[inline(always)]
184 fn go_right(&mut self) -> bool {
185 let i_prev;
186 if let Some(&(i, _)) = self.stack.last() {
187 i_prev = i;
188 } else {
189 return false;
190 }
191
192 let i_next;
193 if let Some(i) = self.next_unpicked(i_prev + 1) {
194 i_next = i;
195 } else {
196 return false;
197 };
198
199 let stack_level = self.stack.len();
200 let parent_prob = if stack_level >= 2 {
201 self.stack[stack_level - 2].1
202 } else {
203 1.
204 };
205 let parent_rem_width = self.rem_width + self.table[i_prev];
206 let prob = parent_prob * self.table[i_next] / parent_rem_width;
207
208 *self.stack.last_mut().unwrap() = (i_next, prob);
209 self.table_picked[i_prev] = false;
210 self.table_picked[i_next] = true;
211 self.rem_width = parent_rem_width - self.table[i_next];
212 self.result[i_next] += prob;
213 true
214 }
215
216 fn go_up_right(&mut self) -> bool {
217 while let Some((i_prev, _)) = self.stack.pop() {
218 self.table_picked[i_prev] = false;
219 self.rem_width += self.table[i_prev];
220 if self.go_right() {
221 return true;
222 }
223 }
224 false
225 }
226
227 fn next_unpicked(&self, least_index: usize) -> Option<usize> {
228 self.table_picked
229 .iter()
230 .enumerate()
231 .skip(least_index)
232 .find(|(_, &picked)| !picked)
233 .map(|(i, _)| i)
234 }
235}