1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
extern crate rand;
use self::rand::{thread_rng, ThreadRng, Rng};
use self::rand::distributions::{IndependentSample, Range};
pub struct AliasMethod<RNG: Rng> {
rng: RNG
}
pub fn alias_method() -> AliasMethod<ThreadRng> {
AliasMethod::new(thread_rng())
}
#[derive(Debug)]
pub struct AliasTable {
len: i32,
prob: Vec<f64>,
alias: Vec<usize>,
}
impl<RNG: Rng> AliasMethod<RNG> {
pub fn new(rng: RNG) -> Self {
AliasMethod { rng: rng }
}
pub fn random(&mut self, alias_table: &AliasTable) -> usize {
let u = self.rng.next_f64();
let range = Range::new(0, alias_table.len);
let n = range.ind_sample(&mut self.rng) as usize;
if u <= alias_table.prob[n] {
n
} else {
alias_table.alias[n]
}
}
}
pub fn new_alias_table(weights: &Vec<f64>) -> Result<AliasTable, &'static str> {
let n = weights.len() as i32;
let sum = weights.iter().fold(0.0, |acc, x| acc + x);
if sum == 0.0 {
return Err("sum of weights is 0.");
}
let mut prob = weights.iter().map(|w| w * (n as f64) / sum).collect::<Vec<f64>>();
let mut h = 0;
let mut l = n - 1;
let mut hl: Vec<usize> = vec![0; n as usize];
for (i, p) in prob.iter().enumerate() {
if *p < 1.0 {
hl[l as usize] = i;
l -= 1;
}
if 1.0 < *p {
hl[h as usize] = i;
h += 1;
}
}
let mut a: Vec<usize> = vec![0; n as usize];
while h != 0 && l != n - 1 {
let j = hl[(l + 1) as usize];
let k = hl[(h - 1) as usize];
if 1.0 < prob[j] {
panic!("MUST: {} <= 1", prob[j]);
}
if prob[k] < 1.0 {
panic!("MUST: 1 <= {}", prob[k]);
}
a[j] = k;
prob[k] -= 1.0 - prob[j];
l += 1;
if prob[k] < 1.0 {
hl[l as usize] = k;
l -= 1;
h -= 1;
}
}
Ok(AliasTable {
len: n,
prob: prob,
alias: a,
})
}
#[test]
fn test_new_alias_table() {
let params = [
vec![1.0, 1.0],
vec![1.0, 1.0, 8.0],
];
for sample_weights in params.into_iter() {
let alias_table = new_alias_table(&sample_weights);
match alias_table {
Ok(AliasTable {prob, ..}) => {
assert_eq!(prob.len(), sample_weights.len());
}
Err(e) => {
assert!(false, "error : {}", e);
}
}
}
}