use crate::node::NodeStats;
use super::{action_list::ActionList, edge::Edge, node::Node, state::State};
#[derive(Clone)]
pub struct Tree<T: State> {
nodes: Vec<Node<T>>,
index: usize,
c: f32,
default_capacity: usize,
}
impl<T> Tree<T>
where
T: State + Clone,
{
pub fn new(c: f32, default_capacity: usize) -> Self {
Tree {
nodes: Vec::with_capacity(default_capacity),
index: 0,
c,
default_capacity,
}
}
pub fn reset(&mut self) {
self.index = 0;
self.nodes = Vec::with_capacity(self.default_capacity);
}
pub fn add_node(
&mut self,
edge: Option<Edge<T::Action, usize>>,
parent_id: Option<usize>,
) -> usize {
let node_id = self.index;
if let Some(parent_id) = parent_id {
self.nodes[parent_id].add_child(node_id);
}
let node = Node::new(edge, parent_id);
self.nodes.push(node);
self.index += 1;
node_id
}
pub fn select(&mut self, mut node_id: usize, state: &mut T) -> usize {
let mut legal_actions = state.possible_actions();
while !state.is_terminal() && self.is_fully_expanded(node_id, &legal_actions) {
node_id = self.uct_select_child(node_id, &legal_actions).unwrap();
let action = self.get_edge(node_id).unwrap().action();
state.apply_action(action);
if state.is_terminal() {
break;
}
legal_actions = state.possible_actions();
}
node_id
}
fn uct_select_child(&mut self, node_id: usize, legal_actions: &T::ActionList) -> Option<usize> {
let mut best_child = None;
let mut best_score = f32::MIN;
let split_pos = node_id + 1;
let (left, right) = self.nodes.split_at_mut(split_pos);
let parent = left.last().unwrap();
for &child_id in parent.child_ids_ref().iter() {
let action = right[child_id - split_pos].edge().unwrap().action();
if legal_actions.has(&action) {
let uct_score = right[child_id - split_pos].uct_score(self.c);
if uct_score > best_score {
best_score = uct_score;
best_child = Some(child_id);
}
right[child_id - split_pos].increase_availability();
}
}
best_child
}
pub fn expand(&mut self, node_id: usize, state: &mut T) -> usize {
if state.is_terminal() {
return node_id;
}
let legal_actions = state.possible_actions();
match self.nodes[node_id].pop_action(&legal_actions) {
None => node_id,
Some(action) => {
let actor = state.turn();
let edge = Edge::new(action.clone(), actor);
state.apply_action(action);
self.add_node(Some(edge), Some(node_id))
}
}
}
pub fn best_action(&self, node_id: usize, state: &T) -> Option<T::Action> {
let legal_actions = state.possible_actions();
let child_id = self.nodes[node_id]
.child_ids_ref()
.iter()
.filter(|&&child_id| legal_actions.has(&self.get_edge(child_id).unwrap().action()))
.max_by_key(|&&child_id| self.nodes[child_id].num_sims())
.unwrap();
self.get_edge(*child_id).map(|e| e.action())
}
pub fn child_stats(&self, node_id: usize, state: &T) -> Vec<(NodeStats, T::Action)> {
let mut results = vec![];
let legal_actions = state.possible_actions();
for &child_id in self.nodes[node_id].child_ids_ref() {
if let Some(edge) = self.get_edge(child_id) {
let action = edge.action();
if legal_actions.has(&action) {
let stats = self.nodes[child_id].stats();
results.push((stats, action))
}
}
}
results
}
pub fn update_node(&mut self, node_id: usize, reward: f32) {
self.nodes[node_id].update(reward);
}
pub fn is_fully_expanded(&self, node_id: usize, legal_actions: &T::ActionList) -> bool {
!self.nodes[node_id].has_untried_actions(legal_actions)
}
pub fn get_parent_id(&self, node_id: usize) -> Option<usize> {
self.nodes[node_id].parent_id()
}
pub fn get_edge(&self, node_id: usize) -> Option<Edge<T::Action, usize>> {
self.nodes[node_id].edge()
}
pub fn get_size(&self) -> usize {
self.nodes.capacity()
}
}