use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::marker::PhantomData;
use std::path::PathBuf;
use std::time::{Duration, Instant};
use std::{fs, io};
use backpropagation::{Backpropagation, WinRate};
use expansion::{Expansion, Greedy};
use final_scorer::{FinalScorer, NumVisits};
use petgraph::dot::Dot;
use petgraph::graph::NodeIndex;
use petgraph::{Direction, Graph};
use playout::Playout;
use selection::{Selection, UCT};
use super::random::Random;
use super::{InternalSearch, SearchError, SearchScore};
use crate::prelude::{Game, GameError};
pub mod backpropagation;
pub mod expansion;
pub mod final_scorer;
pub mod playout;
pub mod selection;
#[derive(Debug, thiserror::Error)]
pub enum MonteCarloTreeSearchError<const N: usize, G, S, E, P, B, F, T>
where
G: Game<N>,
G::State: Clone,
G::Player: Clone,
S: Selection<N, G, T>,
E: Expansion<N, G, T>,
P: Playout<N, G, T>,
B: Backpropagation<N, G, T>,
F: FinalScorer<N, G, T>,
T: Debug,
{
#[error(transparent)]
GameError(G::Error),
#[error("No payout for player at index {0}")]
NoPayoutForPlayer(usize),
#[error("No actions available to choose from")]
NoActions,
#[error("Non-root node has no parent: {0:?}")]
ParentlessNode(NodeIndex),
#[error("Node has multiple parents: {0:?}")]
MultipleParents(NodeIndex),
#[error("No edge exists between nodes ({0:?}, {1:?})")]
NoEdge(NodeIndex, NodeIndex),
#[error(transparent)]
IoError(#[from] io::Error),
#[error(transparent)]
SelectionError(S::Error),
#[error(transparent)]
ExpansionError(E::Error),
#[error(transparent)]
PlayoutError(P::Error),
#[error(transparent)]
BackpropagationError(B::Error),
#[error(transparent)]
FinalScorerError(F::Error),
}
impl<const N: usize, G, S, E, P, B, F, T> SearchError
for MonteCarloTreeSearchError<N, G, S, E, P, B, F, T>
where
G: Game<N>,
G::State: Clone,
G::Player: Clone,
S: Debug + Selection<N, G, T>,
E: Debug + Expansion<N, G, T>,
P: Debug + Playout<N, G, T>,
B: Debug + Backpropagation<N, G, T>,
F: Debug + FinalScorer<N, G, T>,
T: Debug,
{
}
impl<const N: usize, G, X, S, E, P, B, F, T> From<X>
for MonteCarloTreeSearchError<N, G, S, E, P, B, F, T>
where
G: Game<N, Error = X>,
G::State: Clone,
G::Player: Clone,
X: GameError,
S: Debug + Selection<N, G, T>,
E: Debug + Expansion<N, G, T>,
P: Debug + Playout<N, G, T>,
B: Debug + Backpropagation<N, G, T>,
F: Debug + FinalScorer<N, G, T>,
T: Debug,
{
fn from(value: G::Error) -> Self {
Self::GameError(value)
}
}
#[derive(Debug)]
enum Graphviz {
Final(PathBuf),
EveryIter(PathBuf),
}
impl Graphviz {
fn create_new_internal_dir(&self) -> io::Result<PathBuf> {
let parent_dir = match self {
Graphviz::Final(inner) => inner,
Graphviz::EveryIter(inner) => inner,
};
fs::create_dir_all(parent_dir)?;
let existing_subdir_names = fs::read_dir(parent_dir)?
.filter_map(|p| {
let path = match p {
Ok(entry) => Some(entry.path()),
Err(err) => return Some(Err(err)),
}?;
if path.is_dir() {
let filename = path.file_name()?;
Some(Ok(filename.to_os_string()))
} else {
None
}
})
.collect::<io::Result<Vec<_>>>()?;
let maxdirnum = existing_subdir_names
.into_iter()
.filter_map(|dirname| dirname.into_string().ok()?.parse::<u64>().ok())
.max();
let newdir = parent_dir.join(format!("{:05}", maxdirnum.map_or(0, |m| m + 1)));
fs::create_dir(&newdir)?;
Ok(newdir)
}
}
#[derive(Debug)]
pub struct MonteCarloTreeSearch<
S = UCT,
E = Greedy,
P = Random,
B = WinRate,
F = NumVisits,
T = f64,
> {
max_time: Duration,
max_iters: Option<u32>,
selection: S,
expansion: E,
playout: P,
backpropagation: B,
final_scorer: F,
graphviz: Option<Graphviz>,
_phantom_payout: PhantomData<T>,
}
impl MonteCarloTreeSearch {
pub fn builder() -> MonteCarloTreeSearchBuilder {
MonteCarloTreeSearchBuilder::default()
}
}
impl<S, E, P, B, F, T> InternalSearch for MonteCarloTreeSearch<S, E, P, B, F, T> {}
type MCTSGraph<const N: usize, G, T> =
Graph<Node<N, <G as Game<N>>::State, T>, <G as Game<N>>::Action>;
impl<S, E, P, B, F, T> MonteCarloTreeSearch<S, E, P, B, F, T>
where
E: Debug,
P: Debug,
B: Debug,
F: Debug,
{
#[allow(clippy::type_complexity)]
pub fn playout_graph<const N: usize, G>(
&self,
game: &G,
state: &G::State,
) -> Result<(MCTSGraph<N, G, T>, NodeIndex), MonteCarloTreeSearchError<N, G, S, E, P, B, F, T>>
where
G: Game<N>,
G::State: Clone,
G::Player: Clone,
G::Action: Debug,
S: Selection<N, G, T>,
E: Expansion<N, G, T>,
P: Playout<N, G, T>,
B: Backpropagation<N, G, T>,
F: FinalScorer<N, G, T>,
T: Debug + Copy + Default,
{
let end = Instant::now() + self.max_time;
let mut tree = Graph::<Node<N, G::State, T>, G::Action>::new();
let root = tree.add_node(Node::new(state.clone()));
log::info!(
"running MCTS search for {} seconds",
self.max_time.as_secs_f32()
);
let graphvizdir = if let Some(graphviz) = &self.graphviz {
Some(graphviz.create_new_internal_dir()?)
} else {
None
};
while Instant::now() < end && self.max_iters.map_or(true, |max| tree[root].visits < max) {
log::debug!("running the selection strategy");
let new_node_idx = self
.selection
.select(game, &tree, root)
.map_err(MonteCarloTreeSearchError::SelectionError)?;
log::debug!("running the expansion strategy");
self.expansion
.expand(game, &mut tree, new_node_idx)
.map_err(MonteCarloTreeSearchError::ExpansionError)?;
log::debug!("running the playout strategy");
let res = self
.playout
.until_end(game, &tree[new_node_idx])
.map_err(MonteCarloTreeSearchError::PlayoutError)?;
log::debug!("running the backpropagation strategy");
let mut bp_node_idx = new_node_idx;
loop {
let parents = tree
.neighbors_directed(bp_node_idx, Direction::Incoming)
.collect::<Vec<_>>();
let bp_node = &mut tree[bp_node_idx];
self.backpropagation
.backpropagate(bp_node, &res)
.map_err(MonteCarloTreeSearchError::BackpropagationError)?;
bp_node.visits = bp_node
.visits
.checked_add(1)
.expect("Overflow while incrementing node visits");
match parents.len() {
0 if bp_node_idx == root => break,
0 => return Err(MonteCarloTreeSearchError::ParentlessNode(bp_node_idx)),
1 => bp_node_idx = parents[0],
_ => return Err(MonteCarloTreeSearchError::MultipleParents(bp_node_idx)),
}
}
log::trace!("completed one search from root");
if let (Some(Graphviz::EveryIter(_)), Some(dir)) = (&self.graphviz, &graphvizdir) {
let file = dir
.as_path()
.join(format!("iteration_{:07}.dot", tree[root].visits));
fs::write(file, format!("{:?}", Dot::new(&tree)))?;
}
}
if let (Some(Graphviz::Final(_)), Some(dir)) = (&self.graphviz, &graphvizdir) {
let file = dir.as_path().join("final.dot");
fs::write(file, format!("{:?}", Dot::new(&tree)))?;
}
log::info!("Completed building the playout graph");
Ok((tree, root))
}
}
impl<const N: usize, G, S, E, P, B, F, T> SearchScore<N, G>
for MonteCarloTreeSearch<S, E, P, B, F, T>
where
G: Game<N>,
G::State: Clone,
G::Player: Clone,
G::Action: Clone + Debug + Hash + Eq,
S: Selection<N, G, T> + Debug,
E: Expansion<N, G, T> + Debug,
P: Playout<N, G, T> + Debug,
B: Backpropagation<N, G, T> + Debug,
F: FinalScorer<N, G, T> + Debug,
T: Default + Copy + Debug + PartialOrd,
{
type Error = MonteCarloTreeSearchError<N, G, S, E, P, B, F, T>;
type Score = F::Score;
fn score_actions(
&self,
game: &G,
state: &G::State,
) -> Result<HashMap<G::Action, Self::Score>, Self::Error> {
let (tree, root) = self.playout_graph(game, state)?;
log::debug!("Obtaining action scores from playout graph");
let player_index = game.active_player_index(state)?;
let node_scores = self
.final_scorer
.node_scores(
player_index,
tree.neighbors_directed(root, Direction::Outgoing),
&tree,
)
.map_err(MonteCarloTreeSearchError::FinalScorerError)?;
let edge_scores = node_scores
.into_iter()
.map(|(node, score)| {
let edge = tree
.find_edge(root, node)
.ok_or_else(|| MonteCarloTreeSearchError::NoEdge(root, node))?;
Ok((tree[edge].clone(), score))
})
.collect::<Result<_, Self::Error>>()?;
Ok(edge_scores)
}
}
#[derive(Debug, Clone)]
pub struct Node<const N: usize, S, T> {
state: S,
visits: u32,
payouts: [T; N],
}
fn is_fully_expanded<const N: usize, G: Game<N>, T>(
nodeidx: NodeIndex,
game: &G,
tree: &Graph<Node<N, G::State, T>, G::Action>,
) -> Result<bool, G::Error> {
let node = &tree[nodeidx];
Ok(game.available_actions(&node.state)?.len()
== tree
.neighbors_directed(nodeidx, Direction::Outgoing)
.count())
}
impl<const N: usize, S, T> Node<N, S, T>
where
T: Default + Copy,
{
pub fn new(state: S) -> Self {
Self {
state,
visits: 0,
payouts: [T::default(); N],
}
}
pub fn state(&self) -> &S {
&self.state
}
pub fn visits(&self) -> u32 {
self.visits
}
pub fn payouts(&self) -> &[T; N] {
&self.payouts
}
}
#[derive(Debug)]
pub struct MonteCarloTreeSearchBuilder<S = UCT, E = Greedy, P = Random, B = WinRate, F = NumVisits>
{
max_time: Option<Duration>,
max_iters: Option<u32>,
selection: Option<S>,
expansion: Option<E>,
playout: Option<P>,
backpropagation: Option<B>,
final_scorer: Option<F>,
graphviz: Option<Graphviz>,
}
impl<S, E, P, B, F> Default for MonteCarloTreeSearchBuilder<S, E, P, B, F> {
fn default() -> Self {
Self {
max_time: None,
max_iters: None,
selection: None,
expansion: None,
playout: None,
backpropagation: None,
final_scorer: None,
graphviz: None,
}
}
}
impl<S, E, P, B, F> MonteCarloTreeSearchBuilder<S, E, P, B, F> {
pub fn max_time(mut self, max_time: Duration) -> Self {
self.max_time = Some(max_time);
self
}
pub fn max_iters(mut self, max_iters: u32) -> Self {
self.max_iters = Some(max_iters);
self
}
pub fn selection<T>(self, selection: T) -> MonteCarloTreeSearchBuilder<T, E, P, B, F> {
MonteCarloTreeSearchBuilder {
max_time: self.max_time,
max_iters: self.max_iters,
selection: Some(selection),
expansion: self.expansion,
playout: self.playout,
backpropagation: self.backpropagation,
final_scorer: self.final_scorer,
graphviz: self.graphviz,
}
}
pub fn expansion<T>(self, expansion: T) -> MonteCarloTreeSearchBuilder<S, T, P, B, F> {
MonteCarloTreeSearchBuilder {
max_time: self.max_time,
max_iters: self.max_iters,
selection: self.selection,
expansion: Some(expansion),
playout: self.playout,
backpropagation: self.backpropagation,
final_scorer: self.final_scorer,
graphviz: self.graphviz,
}
}
pub fn playout<T>(self, playout: T) -> MonteCarloTreeSearchBuilder<S, E, T, B, F> {
MonteCarloTreeSearchBuilder {
max_time: self.max_time,
max_iters: self.max_iters,
selection: self.selection,
expansion: self.expansion,
playout: Some(playout),
backpropagation: self.backpropagation,
final_scorer: self.final_scorer,
graphviz: self.graphviz,
}
}
pub fn backpropagation<T>(
self,
backpropagation: T,
) -> MonteCarloTreeSearchBuilder<S, E, P, T, F> {
MonteCarloTreeSearchBuilder {
max_time: self.max_time,
max_iters: self.max_iters,
selection: self.selection,
expansion: self.expansion,
playout: self.playout,
backpropagation: Some(backpropagation),
final_scorer: self.final_scorer,
graphviz: self.graphviz,
}
}
pub fn final_scorer<T>(self, final_scorer: T) -> MonteCarloTreeSearchBuilder<S, E, P, B, T> {
MonteCarloTreeSearchBuilder {
max_time: self.max_time,
max_iters: self.max_iters,
selection: self.selection,
expansion: self.expansion,
playout: self.playout,
backpropagation: self.backpropagation,
final_scorer: Some(final_scorer),
graphviz: self.graphviz,
}
}
pub fn graphviz_final<T: Into<PathBuf>>(mut self, p: T) -> Self {
self.graphviz = Some(Graphviz::Final(p.into()));
self
}
pub fn graphviz_every<T: Into<PathBuf>>(mut self, p: T) -> Self {
self.graphviz = Some(Graphviz::EveryIter(p.into()));
self
}
}
impl<S, E, P, B, F> MonteCarloTreeSearchBuilder<S, E, P, B, F>
where
S: Default,
E: Default,
P: Default,
B: Default,
F: Default,
{
pub fn build<T>(self) -> MonteCarloTreeSearch<S, E, P, B, F, T> {
MonteCarloTreeSearch {
max_time: self.max_time.unwrap_or(Duration::from_secs(5)),
max_iters: self.max_iters,
selection: self.selection.unwrap_or_else(|| S::default()),
expansion: self.expansion.unwrap_or_else(|| E::default()),
playout: self.playout.unwrap_or_else(|| P::default()),
backpropagation: self.backpropagation.unwrap_or_else(|| B::default()),
final_scorer: self.final_scorer.unwrap_or_else(|| F::default()),
graphviz: self.graphviz,
_phantom_payout: Default::default(),
}
}
}