use crate::evaluate::evaluate;
use crate::generate_instructions::generate_instructions_from_move_pair;
use crate::instruction::StateInstructions;
use crate::state::{MoveChoice, State};
use rand::distributions::WeightedIndex;
use rand::prelude::*;
use rand::thread_rng;
use std::collections::HashMap;
use std::time::Duration;
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-0.0125 * x).exp())
}
#[derive(Debug)]
pub struct Node {
pub root: bool,
pub parent: *mut Node,
pub children: HashMap<(usize, usize), Vec<Node>>,
pub times_visited: i64,
pub instructions: StateInstructions,
pub s1_choice: usize,
pub s2_choice: usize,
pub s1_options: Vec<MoveNode>,
pub s2_options: Vec<MoveNode>,
}
impl Node {
fn new(s1_options: Vec<MoveChoice>, s2_options: Vec<MoveChoice>) -> Node {
let s1_options_vec = s1_options
.iter()
.map(|x| MoveNode {
move_choice: x.clone(),
total_score: 0.0,
visits: 0,
})
.collect();
let s2_options_vec = s2_options
.iter()
.map(|x| MoveNode {
move_choice: x.clone(),
total_score: 0.0,
visits: 0,
})
.collect();
Node {
root: false,
parent: std::ptr::null_mut(),
instructions: StateInstructions::default(),
times_visited: 0,
children: HashMap::new(),
s1_choice: 0,
s2_choice: 0,
s1_options: s1_options_vec,
s2_options: s2_options_vec,
}
}
pub fn maximize_ucb_for_side(&self, side_map: &[MoveNode]) -> usize {
let mut choice = 0;
let mut best_ucb1 = f32::MIN;
for (index, node) in side_map.iter().enumerate() {
let this_ucb1 = node.ucb1(self.times_visited);
if this_ucb1 > best_ucb1 {
best_ucb1 = this_ucb1;
choice = index;
}
}
choice
}
pub unsafe fn selection(&mut self, state: &mut State) -> (*mut Node, usize, usize) {
let return_node = self as *mut Node;
let s1_mc_index = self.maximize_ucb_for_side(&self.s1_options);
let s2_mc_index = self.maximize_ucb_for_side(&self.s2_options);
let child_vector = self.children.get_mut(&(s1_mc_index, s2_mc_index));
match child_vector {
Some(child_vector) => {
let child_vec_ptr = child_vector as *mut Vec<Node>;
let chosen_child = self.sample_node(child_vec_ptr);
state.apply_instructions(&(*chosen_child).instructions.instruction_list);
(*chosen_child).selection(state)
}
None => (return_node, s1_mc_index, s2_mc_index),
}
}
unsafe fn sample_node(&self, move_vector: *mut Vec<Node>) -> *mut Node {
let mut rng = thread_rng();
let weights: Vec<f64> = (*move_vector)
.iter()
.map(|x| x.instructions.percentage as f64)
.collect();
let dist = WeightedIndex::new(weights).unwrap();
let chosen_node = &mut (*move_vector)[dist.sample(&mut rng)];
let chosen_node_ptr = chosen_node as *mut Node;
chosen_node_ptr
}
pub unsafe fn expand(
&mut self,
state: &mut State,
s1_move_index: usize,
s2_move_index: usize,
) -> *mut Node {
let s1_move = &self.s1_options[s1_move_index].move_choice;
let s2_move = &self.s2_options[s2_move_index].move_choice;
if (state.battle_is_over() != 0.0 && !self.root)
|| (s1_move == &MoveChoice::None && s2_move == &MoveChoice::None)
{
return self as *mut Node;
}
let should_branch_on_damage = self.root || (*self.parent).root;
let mut new_instructions =
generate_instructions_from_move_pair(state, s1_move, s2_move, should_branch_on_damage);
let mut this_pair_vec = Vec::with_capacity(2);
for state_instructions in new_instructions.drain(..) {
state.apply_instructions(&state_instructions.instruction_list);
let (s1_options, s2_options) = state.get_all_options();
state.reverse_instructions(&state_instructions.instruction_list);
let mut new_node = Node::new(s1_options, s2_options);
new_node.parent = self;
new_node.instructions = state_instructions;
new_node.s1_choice = s1_move_index;
new_node.s2_choice = s2_move_index;
this_pair_vec.push(new_node);
}
let new_node_ptr = self.sample_node(&mut this_pair_vec);
state.apply_instructions(&(*new_node_ptr).instructions.instruction_list);
self.children
.insert((s1_move_index, s2_move_index), this_pair_vec);
new_node_ptr
}
pub unsafe fn backpropagate(&mut self, score: f32, state: &mut State) {
self.times_visited += 1;
if self.root {
return;
}
let parent_s1_movenode = &mut (*self.parent).s1_options[self.s1_choice];
parent_s1_movenode.total_score += score;
parent_s1_movenode.visits += 1;
let parent_s2_movenode = &mut (*self.parent).s2_options[self.s2_choice];
parent_s2_movenode.total_score += 1.0 - score;
parent_s2_movenode.visits += 1;
state.reverse_instructions(&self.instructions.instruction_list);
(*self.parent).backpropagate(score, state);
}
pub fn rollout(&mut self, state: &mut State, root_eval: &f32) -> f32 {
let battle_is_over = state.battle_is_over();
if battle_is_over == 0.0 {
let eval = evaluate(state);
sigmoid(eval - root_eval)
} else {
if battle_is_over == -1.0 {
0.0
} else {
battle_is_over
}
}
}
}
#[derive(Debug)]
pub struct MoveNode {
pub move_choice: MoveChoice,
pub total_score: f32,
pub visits: i64,
}
impl MoveNode {
pub fn ucb1(&self, parent_visits: i64) -> f32 {
if self.visits == 0 {
return f32::INFINITY;
}
let score = (self.total_score / self.visits as f32)
+ (2.0 * (parent_visits as f32).ln() / self.visits as f32).sqrt();
score
}
pub fn average_score(&self) -> f32 {
let score = self.total_score / self.visits as f32;
score
}
}
#[derive(Clone)]
pub struct MctsSideResult {
pub move_choice: MoveChoice,
pub total_score: f32,
pub visits: i64,
}
impl MctsSideResult {
pub fn average_score(&self) -> f32 {
if self.visits == 0 {
return 0.0;
}
let score = self.total_score / self.visits as f32;
score
}
}
pub struct MctsResult {
pub s1: Vec<MctsSideResult>,
pub s2: Vec<MctsSideResult>,
pub iteration_count: i64,
}
fn do_mcts(root_node: &mut Node, state: &mut State, root_eval: &f32) {
let (mut new_node, s1_move, s2_move) = unsafe { root_node.selection(state) };
new_node = unsafe { (*new_node).expand(state, s1_move, s2_move) };
let rollout_result = unsafe { (*new_node).rollout(state, root_eval) };
unsafe { (*new_node).backpropagate(rollout_result, state) }
}
pub fn perform_mcts(
state: &mut State,
side_one_options: Vec<MoveChoice>,
side_two_options: Vec<MoveChoice>,
max_time: Duration,
) -> MctsResult {
let mut root_node = Node::new(side_one_options, side_two_options);
root_node.root = true;
let root_eval = evaluate(state);
let start_time = std::time::Instant::now();
while start_time.elapsed() < max_time {
for _ in 0..1000 {
do_mcts(&mut root_node, state, &root_eval);
}
if root_node.times_visited == 10_000_000 {
break;
}
}
let result = MctsResult {
s1: root_node
.s1_options
.iter()
.map(|v| MctsSideResult {
move_choice: v.move_choice.clone(),
total_score: v.total_score,
visits: v.visits,
})
.collect(),
s2: root_node
.s2_options
.iter()
.map(|v| MctsSideResult {
move_choice: v.move_choice.clone(),
total_score: v.total_score,
visits: v.visits,
})
.collect(),
iteration_count: root_node.times_visited,
};
result
}