1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
extern crate rand;
use self::rand::{XorShiftRng, Rng, SeedableRng};

use std;
use super::*;
use search_tree::*;

pub trait TreePolicy<Spec: MCTS<TreePolicy=Self>>: Sync + Sized {
    type MoveEvaluation: Sync + Send;
    type ThreadLocalData: Default;

    fn choose_child<'a, MoveIter>(&self, moves: MoveIter, handle: SearchHandle<Spec>) -> &'a MoveInfo<Spec>
        where MoveIter: Iterator<Item=&'a MoveInfo<Spec>> + Clone;
    fn validate_evaluations(&self, _evalns: &[Self::MoveEvaluation]) {}
}

#[derive(Clone, Debug)]
pub struct UCTPolicy {
    exploration_constant: f64,
}

impl UCTPolicy {
    pub fn new(exploration_constant: f64) -> Self {
        assert!(exploration_constant > 0.0,
            "exploration constant is {} (must be positive)",
            exploration_constant);
        Self {exploration_constant}
    }

    pub fn exploration_constant(&self) -> f64 {
        self.exploration_constant
    }
}

const RECIPROCAL_TABLE_LEN: usize = 128;

#[derive(Clone, Debug)]
pub struct AlphaGoPolicy {
    exploration_constant: f64,
    reciprocals: Vec<f64>,
}

impl AlphaGoPolicy {
    pub fn new(exploration_constant: f64) -> Self {
        assert!(exploration_constant > 0.0,
            "exploration constant is {} (must be positive)",
            exploration_constant);
        let reciprocals = (0..RECIPROCAL_TABLE_LEN)
            .map(|x| if x == 0 {
                2.0
            } else {
                1.0 / x as f64
            })
            .collect();
        Self {exploration_constant, reciprocals}
    }

    pub fn exploration_constant(&self) -> f64 {
        self.exploration_constant
    }

    fn reciprocal(&self, x: usize) -> f64 {
        if x < RECIPROCAL_TABLE_LEN {
            unsafe {
                *self.reciprocals.get_unchecked(x)
            }
        } else {
            1.0 / x as f64
        }
    }
}

impl<Spec: MCTS<TreePolicy=Self>> TreePolicy<Spec> for UCTPolicy
{
    type ThreadLocalData = PolicyRng;
    type MoveEvaluation = ();

    fn choose_child<'a, MoveIter>(&self, moves: MoveIter, mut handle: SearchHandle<Spec>) -> &'a MoveInfo<Spec>
        where MoveIter: Iterator<Item=&'a MoveInfo<Spec>> + Clone
    {
        let total_visits = moves.clone().map(|x| x.visits()).sum::<u64>();
        let adjusted_total = (total_visits + 1) as f64;
        let ln_adjusted_total = adjusted_total.ln();
        handle.thread_data().policy_data.select_by_key(moves, |mov| {
            let sum_rewards = mov.sum_rewards();
            let child_visits = mov.visits();
            // http://mcts.ai/pubs/mcts-survey-master.pdf
            if child_visits == 0 {
                std::f64::INFINITY
            } else {
                let explore_term = 2.0 * (ln_adjusted_total / child_visits as f64).sqrt();
                let mean_action_value = sum_rewards as f64 / child_visits as f64;
                self.exploration_constant * explore_term + mean_action_value
            }
        }).unwrap()
    }
}

impl<Spec: MCTS<TreePolicy=Self>> TreePolicy<Spec> for AlphaGoPolicy
{
    type ThreadLocalData = PolicyRng;
    type MoveEvaluation = f64;

    fn choose_child<'a, MoveIter>(&self, moves: MoveIter, mut handle: SearchHandle<Spec>) -> &'a MoveInfo<Spec>
        where MoveIter: Iterator<Item=&'a MoveInfo<Spec>> + Clone
    {
        let total_visits = moves.clone().map(|x| x.visits()).sum::<u64>() + 1;
        let sqrt_total_visits = (total_visits as f64).sqrt();
        let explore_coef = self.exploration_constant * sqrt_total_visits;
        handle.thread_data().policy_data.select_by_key(moves, |mov| {
            let sum_rewards = mov.sum_rewards() as f64;
            let child_visits = mov.visits();
            let policy_evaln = *mov.move_evaluation() as f64;
            (sum_rewards + explore_coef * policy_evaln) * self.reciprocal(child_visits as usize)
        }).unwrap()
    }

    fn validate_evaluations(&self, evalns: &[f64]) {
        for &x in evalns {
            assert!(x >= -1e-6,
                "Move evaluation is {} (must be non-negative)",
                x);
        }
        if evalns.len() >= 1 {
            let evaln_sum: f64 = evalns.iter().sum();
            assert!((evaln_sum - 1.0).abs() < 0.1,
                "Sum of evaluations is {} (should sum to 1)",
                evaln_sum);
        }
    }
}

#[derive(Clone)]
pub struct PolicyRng {
    rng: XorShiftRng
}

impl PolicyRng {
    pub fn new() -> Self {
        let rng = SeedableRng::from_seed([1, 2, 3, 4]);
        Self {rng}
    }

    pub fn select_by_key<T, Iter, KeyFn>(&mut self, elts: Iter, mut key_fn: KeyFn) -> Option<T>
        where Iter: Iterator<Item=T>, KeyFn: FnMut(&T) -> f64
    {
        let mut choice = None;
        let mut num_optimal: u32 = 0;
        let mut best_so_far: f64 = std::f64::NEG_INFINITY;
        for elt in elts {
            let score = key_fn(&elt);
            if score > best_so_far {
                choice = Some(elt);
                num_optimal = 1;
                best_so_far = score;
            } else if score == best_so_far {
                num_optimal += 1;
                if self.rng.gen_weighted_bool(num_optimal) {
                    choice = Some(elt);
                }
            }
        }
        choice
    }
}

impl Default for PolicyRng {
    fn default() -> Self {
        Self::new()
    }
}