random_picker/
calc.rs

1use crate::*;
2use std::{hash::Hash, thread};
3
4impl<T: Clone + Eq + Hash> Config<T> {
5    /// Calculates probabilities of existences of table items in each picking result
6    /// of length `pick_amount`. In non-repetitive mode, the multi-thread tree algorithm
7    /// may be used.
8    ///
9    /// Preorder traversal is performed in each thread. Something like depth-first or
10    /// postorder algorithm may achieve higher precision (consider the error produced
11    /// while adding a small floating-point number to a much larger number, which is
12    /// the current way), but it will probably increase complexity and memory usage.
13    ///
14    /// TODO: figure out why its single-thread performance is about 7% slower than
15    /// the single-thread C++ version compiled with `clang++` without `-march=native`,
16    /// and unsafe operations can't make this Rust program faster.
17    /// It is faster than the C++ version compiled with GCC, though.
18    pub fn calc_probabilities(&self, pick_amount: usize) -> Result<Table<T>, Error> {
19        if pick_amount == 0 {
20            return Ok(self.table.keys().map(|k| (k.clone(), 0.)).collect());
21        }
22
23        if !self.repetitive {
24            if pick_amount > self.table.len() {
25                return Err(Error::InvalidAmount);
26            }
27            if pick_amount == self.table.len() {
28                return Ok(self.table.keys().map(|k| (k.clone(), 1.)).collect());
29            }
30            if self.is_fair() {
31                let prob = (pick_amount as f64) / (self.table.len() as f64);
32                return Ok(self.table.keys().map(|k| (k.clone(), prob)).collect());
33            }
34        }
35
36        // map to values within range 0. ~ 1.
37        let table: Vec<_> = {
38            let raw_table = self.vec_table()?;
39            let grid_width: f64 = raw_table.iter().map(|(_, v)| v).sum();
40            raw_table
41                .into_iter()
42                .map(|(k, v)| (k, v / grid_width))
43                .collect()
44        };
45
46        if pick_amount == 1 {
47            return Ok(table.into_iter().collect());
48        }
49        if self.repetitive {
50            return Ok(table
51                .into_iter()
52                .map(|(k, v)| (k, 1. - (1. - v).powi(pick_amount as i32)))
53                .collect());
54        }
55
56        // -------- calc for general non-repetitive cases --------
57
58        let table_val: Vec<_> = table.iter().map(|(_, v)| *v).collect();
59        let mut calc_result = table.clone();
60
61        let cnt_threads = thread::available_parallelism()
62            .map(|n| n.get())
63            .unwrap_or(4)
64            .max(table.len());
65        let cnt_calc_groups = table.len().div_ceil(cnt_threads);
66        let mut calc_groups = Vec::with_capacity(cnt_calc_groups);
67        let mut table_picked = vec![false; table.len()];
68        for i in 0..cnt_calc_groups {
69            let mut calcs = Vec::with_capacity(cnt_threads);
70            for j in 0..cnt_threads {
71                let i_th = i * cnt_threads + j;
72                if i_th >= table.len() {
73                    break;
74                }
75                table_picked[i_th] = true;
76                let calc_stack =
77                    CalcStack::new(table_val.clone(), pick_amount, table_picked.clone());
78                calcs.push((i_th, Some(calc_stack)));
79                table_picked[i_th] = false;
80            }
81            calc_groups.push(calcs);
82        }
83
84        for group in calc_groups.into_iter() {
85            let mut thread_hdls = Vec::with_capacity(cnt_threads);
86            for (i, mut calc) in group.into_iter() {
87                let calc = calc.take().unwrap();
88                thread_hdls.push(thread::spawn(move || (i, calc.calc())));
89            }
90            for hdl in thread_hdls {
91                let (i_th, sub_result) = hdl.join().map_err(|_| Error::ThreadError)?;
92                for (i, &sub_prob) in sub_result.iter().enumerate() {
93                    calc_result[i].1 += table_val[i_th] * sub_prob;
94                }
95            }
96        }
97
98        Ok(calc_result.into_iter().collect())
99    }
100}
101
102#[derive(Clone, Debug)]
103struct CalcStack {
104    // Do not modify table and stack_size
105    table: Vec<f64>,   // size: table.len()
106    stack_size: usize, // = pick_amount - initial amount of picked items
107
108    stack: Vec<(usize, f64)>, // maximum size: stack_size
109    table_picked: Vec<bool>,  // size: table.len()
110    rem_width: f64,           // current sum of grid cell widths of unpicked items
111
112    result: Vec<f64>, // size: table.len()
113}
114
115impl CalcStack {
116    // Do not construct it by other means
117    // pick_amount includes items that were already picked
118    fn new(table: Vec<f64>, pick_amount: usize, table_picked: Vec<bool>) -> Self {
119        assert!(table.len() == table_picked.len());
120        let table_len = table.len();
121
122        let mut stack_size = pick_amount;
123        let mut rem_width = 0.;
124        for (i, &picked) in table_picked.iter().enumerate() {
125            if !picked {
126                rem_width += table[i];
127            } else {
128                if stack_size == 0 {
129                    break; // something is wrong
130                }
131                stack_size -= 1;
132            }
133        }
134
135        Self {
136            table,
137            stack: Vec::with_capacity(stack_size),
138            stack_size,
139            table_picked,
140            rem_width,
141            result: vec![0.; table_len],
142        }
143    }
144
145    fn calc(mut self) -> Vec<f64> {
146        loop {
147            if self.go_down() {
148                continue;
149            }
150            if self.go_right() {
151                continue;
152            }
153            if self.go_up_right() {
154                continue;
155            }
156            return self.result;
157        }
158    }
159
160    #[inline(always)]
161    fn go_down(&mut self) -> bool {
162        if self.stack.len() >= self.stack_size {
163            return false;
164        }
165
166        let i_next;
167        if let Some(i) = self.next_unpicked(0) {
168            i_next = i;
169        } else {
170            return false;
171        };
172
173        let parent_prob = self.stack.last().map(|t| t.1).unwrap_or(1.);
174        let prob = parent_prob * self.table[i_next] / self.rem_width;
175
176        self.stack.push((i_next, prob));
177        self.table_picked[i_next] = true;
178        self.rem_width -= self.table[i_next];
179        self.result[i_next] += prob;
180        true
181    }
182
183    #[inline(always)]
184    fn go_right(&mut self) -> bool {
185        let i_prev;
186        if let Some(&(i, _)) = self.stack.last() {
187            i_prev = i;
188        } else {
189            return false;
190        }
191
192        let i_next;
193        if let Some(i) = self.next_unpicked(i_prev + 1) {
194            i_next = i;
195        } else {
196            return false;
197        };
198
199        let stack_level = self.stack.len();
200        let parent_prob = if stack_level >= 2 {
201            self.stack[stack_level - 2].1
202        } else {
203            1.
204        };
205        let parent_rem_width = self.rem_width + self.table[i_prev];
206        let prob = parent_prob * self.table[i_next] / parent_rem_width;
207
208        *self.stack.last_mut().unwrap() = (i_next, prob);
209        self.table_picked[i_prev] = false;
210        self.table_picked[i_next] = true;
211        self.rem_width = parent_rem_width - self.table[i_next];
212        self.result[i_next] += prob;
213        true
214    }
215
216    fn go_up_right(&mut self) -> bool {
217        while let Some((i_prev, _)) = self.stack.pop() {
218            self.table_picked[i_prev] = false;
219            self.rem_width += self.table[i_prev];
220            if self.go_right() {
221                return true;
222            }
223        }
224        false
225    }
226
227    fn next_unpicked(&self, least_index: usize) -> Option<usize> {
228        self.table_picked
229            .iter()
230            .enumerate()
231            .skip(least_index)
232            .find(|(_, &picked)| !picked)
233            .map(|(i, _)| i)
234    }
235}