1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
extern crate rand;

use self::rand::{thread_rng, ThreadRng, Rng};
use self::rand::distributions::{IndependentSample, Range};


pub struct AliasMethod<RNG: Rng> {
    rng: RNG
}


/// Creates a new AliasMethod using the ThreadRng
pub fn alias_method() -> AliasMethod<ThreadRng> {
    AliasMethod::new(thread_rng())
}


#[derive(Debug)]
pub struct AliasTable {
    len: i32,
    prob: Vec<f64>,
    alias: Vec<usize>,
}


impl<RNG: Rng> AliasMethod<RNG> {

    /// Creates a new AliasMethod struct.
    pub fn new(rng: RNG) -> Self {
        AliasMethod { rng: rng }
    }

    /// Chooses a index.
    pub fn random(&mut self, alias_table: &AliasTable) -> usize {
        let u = self.rng.next_f64();
        let range = Range::new(0, alias_table.len);
        let n = range.ind_sample(&mut self.rng) as usize;

        if u <= alias_table.prob[n] {
            n
        } else {
            alias_table.alias[n]
        }
    }
}


/// Creates a new AliasTable struct.
pub fn new_alias_table(weights: &Vec<f64>) -> Result<AliasTable, &'static str> {
    let n = weights.len() as i32;

    let sum = weights.iter().fold(0.0, |acc, x| acc + x);
    if sum == 0.0 {
        return Err("sum of weights is 0.");
    }

    let mut prob = weights.iter().map(|w| w * (n as f64) / sum).collect::<Vec<f64>>();
    let mut h = 0;
    let mut l = n - 1;
    let mut hl: Vec<usize> = vec![0; n as usize];

    for (i, p) in prob.iter().enumerate() {
        if *p < 1.0 {
            hl[l as usize] = i;
            l -= 1;
        }
        if 1.0 < *p {
            hl[h as usize] = i;
            h += 1;
        }
    }

    let mut a: Vec<usize> = vec![0; n as usize];

    while h != 0 && l != n - 1 {
        let j = hl[(l + 1) as usize];
        let k = hl[(h - 1) as usize];

        if 1.0 < prob[j] {
            panic!("MUST: {} <= 1", prob[j]);
        }
        if prob[k] < 1.0 {
            panic!("MUST: 1 <= {}", prob[k]);
        }

        a[j] = k;
        prob[k] -= 1.0 - prob[j];   // - residual weight
        l += 1;
        if prob[k] < 1.0 {
            hl[l as usize] = k;
            l -= 1;
            h -= 1;
        }
    }

    Ok(AliasTable {
        len: n,
        prob: prob,
        alias: a,
    })
}


#[test]
fn test_new_alias_table() {
    let params = [
        vec![1.0, 1.0],
        vec![1.0, 1.0, 8.0],
    ];
    for sample_weights in params.into_iter() {
        let alias_table = new_alias_table(&sample_weights);
        match alias_table {
            Ok(AliasTable {prob, ..}) => {
                assert_eq!(prob.len(), sample_weights.len());
            }
            Err(e) => {
                assert!(false, "error : {}", e);
            }
        }
    }
}