use std::{
collections::BTreeSet,
fmt, mem,
ops::Range,
sync::{Arc, Mutex},
};
use crate::{AgentId, AgentValue, Domain, Node, SeededHashMap, StateDiffRef, Task, WeakNode};
use rand::distributions::WeightedIndex;
type UnexpandedTasks<D> = Option<(WeightedIndex<f32>, Vec<Box<dyn Task<D>>>)>;
pub struct Edges<D: Domain> {
pub(crate) unexpanded_tasks: UnexpandedTasks<D>,
pub(crate) expanded_tasks: SeededHashMap<Box<dyn Task<D>>, Edge<D>>,
}
impl<D: Domain> fmt::Debug for Edges<D> {
fn fmt(&self, f: &'_ mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Edges")
.field("unexpanded_tasks", &self.unexpanded_tasks)
.field("expanded_tasks", &self.expanded_tasks)
.finish()
}
}
impl<'a, D: Domain> IntoIterator for &'a Edges<D> {
type Item = (&'a Box<dyn Task<D>>, &'a Edge<D>);
type IntoIter = std::collections::hash_map::Iter<'a, Box<dyn Task<D>>, Edge<D>>;
fn into_iter(self) -> Self::IntoIter {
self.expanded_tasks.iter()
}
}
impl<D: Domain> Edges<D> {
pub fn new(
node: &Node<D>,
initial_state: &D::State,
next_task: Option<Box<dyn Task<D>>>,
) -> Self {
let unexpanded_tasks = match next_task {
Some(task)
if task.is_valid(
node.tick,
StateDiffRef::new(initial_state, &node.diff),
node.active_agent,
) =>
{
let weights = WeightedIndex::new((&[1.]).iter().map(Clone::clone)).unwrap();
Some((weights, vec![task.clone()]))
}
_ => {
let tasks = D::get_tasks(
node.tick,
StateDiffRef::new(initial_state, &node.diff),
node.active_agent,
);
if tasks.is_empty() {
return Edges {
unexpanded_tasks: None,
expanded_tasks: Default::default(),
};
}
let state_diff = StateDiffRef::new(initial_state, &node.diff);
for task in &tasks {
debug_assert!(task.is_valid(node.tick, state_diff, node.active_agent));
}
let weights = WeightedIndex::new(
tasks
.iter()
.map(|task| task.weight(node.tick, state_diff, node.active_agent)),
)
.unwrap();
Some((weights, tasks))
}
};
Edges {
unexpanded_tasks,
expanded_tasks: Default::default(),
}
}
pub fn child_visits(&self) -> usize {
self.expanded_tasks
.values()
.map(|edge| edge.lock().unwrap().visits)
.sum()
}
pub fn best_task(
&self,
agent: AgentId,
exploration: f32,
range: Range<AgentValue>,
) -> Option<Box<dyn Task<D>>> {
let visits = self.child_visits();
self.expanded_tasks
.iter()
.max_by(|(_, a), (_, b)| {
let a = a.lock().unwrap();
let b = b.lock().unwrap();
a.uct(agent, visits, exploration, range.clone())
.partial_cmp(&b.uct(agent, visits, exploration, range.clone()))
.unwrap()
})
.map(|(k, _)| k.clone())
}
pub fn q_value(&self, fallback: (usize, f32), agent: AgentId) -> Option<f32> {
self.expanded_tasks
.values()
.map(|edge| {
edge.try_lock()
.map(|edge| {
(
edge.visits,
edge.q_values.get(&agent).copied().unwrap_or_default(),
)
})
.unwrap_or(fallback)
})
.fold(None, |acc, (visits, value)| match acc {
Some((sum, count)) => Some((sum + visits as f32 * value, count + visits)),
None => Some((visits as f32 * value, visits)),
})
.map(|(sum, count)| sum / count as f32)
}
pub fn expanded_count(&self) -> usize {
self.expanded_tasks.len()
}
pub fn unexpanded_count(&self) -> usize {
self.unexpanded_tasks
.as_ref()
.map_or(0, |(_, tasks)| tasks.len())
}
pub fn branching_factor(&self) -> usize {
self.expanded_count() + self.unexpanded_count()
}
#[allow(clippy::borrowed_box)]
pub fn get_edge(&self, task: &Box<dyn Task<D>>) -> Option<Edge<D>> {
self.expanded_tasks.get(task).cloned()
}
pub fn size(&self, task_size: fn(&dyn Task<D>) -> usize) -> usize {
let mut size = 0;
size += mem::size_of::<Self>();
if let Some((_, tasks)) = self.unexpanded_tasks.as_ref() {
for task in tasks {
size += task_size(&**task);
}
}
for (task, edge) in &self.expanded_tasks {
size += task_size(&**task);
size += edge.lock().unwrap().size();
}
size
}
}
pub type Edge<D> = Arc<Mutex<EdgeInner<D>>>;
pub struct EdgeInner<D: Domain> {
pub(crate) parent: WeakNode<D>,
pub(crate) child: WeakNode<D>,
pub(crate) visits: usize,
pub(crate) q_values: SeededHashMap<AgentId, f32>,
}
impl<D: Domain> fmt::Debug for EdgeInner<D> {
fn fmt(&self, f: &'_ mut fmt::Formatter) -> fmt::Result {
f.debug_struct("EdgeInner")
.field("parent", &self.parent)
.field("child", &self.child)
.field("visits", &self.visits)
.field("q_values", &self.q_values)
.finish()
}
}
pub(crate) fn new_edge<D: Domain>(
parent: &Node<D>,
child: &Node<D>,
agents: &BTreeSet<AgentId>,
) -> Edge<D> {
Arc::new(Mutex::new(EdgeInner {
parent: Node::downgrade(parent),
child: Node::downgrade(child),
visits: Default::default(),
q_values: agents.iter().map(|agent| (*agent, 0.)).collect(),
}))
}
impl<D: Domain> EdgeInner<D> {
pub fn uct(
&self,
parent_agent: AgentId,
parent_child_visits: usize,
exploration: f32,
range: Range<AgentValue>,
) -> f32 {
if let Some(q_value) = self.q_values.get(&parent_agent) {
let exploitation_value =
(q_value - *range.start) / (*(range.end - range.start)).max(f32::EPSILON);
let exploration_value =
((parent_child_visits as f32).ln() / (self.visits as f32).max(f32::EPSILON)).sqrt();
exploitation_value + exploration * exploration_value
} else {
0.
}
}
pub fn visits(&self) -> usize {
self.visits
}
pub fn q_value(&self, agent: AgentId) -> f32 {
self.q_values.get(&agent).copied().unwrap_or(0.)
}
pub fn child(&self) -> Node<D> {
self.child.upgrade().unwrap()
}
pub fn parent(&self) -> Node<D> {
self.parent.upgrade().unwrap()
}
pub fn size(&self) -> usize {
let mut size = 0;
size += mem::size_of::<Self>();
size += self.q_values.len() * mem::size_of::<(AgentId, f32)>();
size
}
}