1use std::collections::HashMap;
7use std::collections::{btree_map::Entry, BTreeMap, BTreeSet, HashSet};
8use std::f32;
9use std::mem;
10use std::ops::Range;
11use std::time::{Duration, Instant};
12
13use rand::{
14 distributions::WeightedIndex,
15 prelude::{thread_rng, Distribution, RngCore, SeedableRng},
16 Rng,
17};
18use rand_chacha::ChaCha8Rng;
19
20use crate::*;
21
22const VALUE_ZERO: AgentValue = unsafe { AgentValue::new_unchecked(0.) };
28const VALUE_INFINITE: AgentValue = unsafe { AgentValue::new_unchecked(std::f32::INFINITY) };
30const VALUE_NEG_INFINITE: AgentValue = unsafe { AgentValue::new_unchecked(std::f32::NEG_INFINITY) };
32
33pub struct MCTS<D: Domain> {
35 time: Duration,
37
38 config: MCTSConfiguration,
40 state_value_estimator: Box<dyn StateValueEstimator<D> + Send>,
41 early_stop_condition: Option<Box<EarlyStopCondition>>,
42
43 root_agent: AgentId,
45 seed: u64,
46
47 root: Node<D>,
49 nodes: SeededHashMap<Node<D>, Edges<D>>,
50
51 q_value_ranges: BTreeMap<AgentId, Range<AgentValue>>,
53
54 initial_state: D::State,
56 start_tick: u64,
57
58 rng: ChaCha8Rng,
60}
61
62enum TreePolicyOutcome<D: Domain> {
64 NodeCreated(u32, Node<D>, Vec<Edge<D>>), NoValidTask(u32, Vec<Edge<D>>), NoChildNode(u32, Node<D>, Vec<Edge<D>>), DepthLimitReached(u32, Node<D>, Vec<Edge<D>>), }
69
70impl<D: Domain> MCTS<D> {
71 pub fn new(initial_state: D::State, root_agent: AgentId, config: MCTSConfiguration) -> Self {
73 let state_value_estimator = Box::new(DefaultPolicyEstimator {});
74 Self::new_with_tasks(
75 initial_state,
76 root_agent,
77 0,
78 Default::default(),
79 config,
80 state_value_estimator,
81 None,
82 )
83 }
84
85 pub fn new_with_tasks(
87 initial_state: D::State,
88 root_agent: AgentId,
89 start_tick: u64,
90 tasks: ActiveTasks<D>,
91 config: MCTSConfiguration,
92 state_value_estimator: Box<dyn StateValueEstimator<D> + Send>,
93 early_stop_condition: Option<Box<EarlyStopCondition>>,
94 ) -> Self {
95 let next_task =
97 get_task_for_agent(&tasks, root_agent).map(|active_task| active_task.task.clone());
98
99 let root = Node::new(NodeInner::new(
101 &initial_state,
102 start_tick,
103 Default::default(),
104 root_agent,
105 start_tick,
106 tasks,
107 ));
108
109 let mut nodes = SeededHashMap::with_capacity_and_hasher(
111 config.visits as usize + 1,
112 SeededRandomState::default(),
113 );
114
115 let root_edges = Edges::new(&root, &initial_state, next_task);
117 nodes.insert(root.clone(), root_edges);
118
119 let cur_seed = config.seed.unwrap_or_else(|| thread_rng().next_u64());
121
122 MCTS {
123 time: Duration::default(),
124 config,
125 state_value_estimator,
126 early_stop_condition,
127 seed: cur_seed,
128 root_agent,
129 root,
130 nodes,
131 q_value_ranges: Default::default(),
132 initial_state,
133 start_tick,
134 rng: ChaCha8Rng::seed_from_u64(cur_seed),
135 }
136 }
137
138 pub fn best_task_at_root(&mut self) -> Option<Box<dyn Task<D>>> {
140 let range = self.min_max_range(self.root_agent);
141 let edges = self.nodes.get(&self.root).unwrap();
142 edges
143 .best_task(self.root_agent, 0., range)
145 .or_else(|| {
147 edges.unexpanded_tasks.as_ref().and_then(|(_, tasks)| {
148 if tasks.is_empty() {
149 None
150 } else {
151 let index = self.rng.gen_range(0..tasks.len());
152 Some(tasks[index].clone())
153 }
154 })
155 })
156 }
157
158 pub fn best_task_with_history(
160 &self,
161 task_history: &HashMap<AgentId, ActiveTask<D>>,
162 ) -> Box<dyn Task<D>>
163 where
164 D: DomainWithPlanningTask,
165 {
166 log::debug!(
167 "Finding best task for {} using history {:?}",
168 self.root_agent,
169 task_history
170 );
171 let mut current_node = self.root.clone();
172 let mut edges = self.nodes.get(¤t_node).unwrap();
173 let mut depth = 0;
174 loop {
175 let node_agent = current_node.agent();
176 let node_tick = current_node.tick();
177 let edge = if edges.expanded_tasks.len() == 1 {
178 let (task, edge) = edges.expanded_tasks.iter().next().unwrap();
179 log::trace!("[{depth}] T{node_tick} {node_agent} skipping {task:?}");
180
181 edge
183 } else {
184 let executed_task = task_history.get(&node_agent);
185 let executed_task = executed_task
186 .unwrap_or_else(|| panic!("Found no task for {node_agent} is history"));
187 let task = &executed_task.task;
188 log::trace!("[{depth}] T{node_tick} {node_agent} executed {task:?}");
189
190 let edge = edges.expanded_tasks.get(task);
191
192 if edge.is_none() {
193 log::info!("{node_agent} executed unexpected {task:?} not present in search tree, returning fallback task");
194 return D::fallback_task(self.root_agent);
195 }
196
197 edge.unwrap()
198 };
199 let edge = edge.lock().unwrap();
200 current_node = edge.child();
201 edges = self.nodes.get(¤t_node).unwrap();
203
204 depth += 1;
205
206 if current_node.agent() == self.root_agent {
208 break;
209 }
210 }
211
212 let range = self.min_max_range(self.root_agent);
214 let best = edges.best_task(self.root_agent, 0., range);
215
216 if best.is_none() {
217 log::info!(
218 "No valid task for agent {}, returning fallback task",
219 self.root_agent
220 );
221 return D::fallback_task(self.root_agent);
222 }
223
224 best.unwrap().clone()
225 }
226
227 pub fn q_value_at_root(&self, agent: AgentId) -> f32 {
229 let edges = self.nodes.get(&self.root).unwrap();
230 edges.q_value((0, 0.), agent).unwrap()
231 }
232
233 pub fn run(&mut self) -> Option<Box<dyn Task<D>>> {
237 self.q_value_ranges.clear();
239
240 let start = Instant::now();
241 let max_visits = self.config.visits;
242 for i in 0..max_visits {
243 let tree_policy_outcome = self.tree_policy();
245
246 let (path, rollout_values) = match tree_policy_outcome {
249 TreePolicyOutcome::NodeCreated(depth, leaf, path) => {
250 let edges = self.nodes.get(&leaf).unwrap();
252 let rollout_values = self.state_value_estimator.estimate(
253 &mut self.rng,
254 &self.config,
255 &self.initial_state,
256 self.start_tick,
257 &leaf,
258 edges,
259 depth,
260 );
261 (path, rollout_values)
262 }
263 TreePolicyOutcome::NoValidTask(_, path) => (path, None),
264 TreePolicyOutcome::NoChildNode(_, _, path) => (path, None),
265 TreePolicyOutcome::DepthLimitReached(_, _, path) => (path, None),
266 };
267
268 self.backpropagation(path, rollout_values);
270
271 if let Some(early_stop_condition) = &self.early_stop_condition {
273 if early_stop_condition() {
274 log::info!("{:?} early stops planning after {} visits", self.agent(), i);
275 break;
276 }
277 }
278 }
279 self.time = start.elapsed();
280
281 self.best_task_at_root()
282 }
283
284 fn tree_policy(&mut self) -> TreePolicyOutcome<D> {
286 let agents = self.root.agents();
287
288 let mut node = self.root.clone();
289
290 let mut path = Vec::with_capacity(self.config.depth as usize * agents.len());
292
293 let mut depth = 0;
295 while depth < self.config.depth {
296 let mut edges = self.nodes.get_mut(&node).unwrap();
297
298 if let Some((weights, tasks)) = edges.unexpanded_tasks.as_mut() {
303 let mut diff = node.diff.clone();
305
306 let idx = weights.sample(&mut self.rng);
308 let task = tasks[idx].clone();
309 let state_diff = StateDiffRef::new(&self.initial_state, &diff);
310 debug_assert!(task.is_valid(node.tick, state_diff, node.active_agent));
311 log::debug!(
312 "T{}\t{:?} - Expand task: {:?}",
313 node.tick,
314 node.active_agent,
315 task
316 );
317
318 if weights.update_weights(&[(idx, &0.)]).is_err() {
322 edges.unexpanded_tasks = None;
324 }
325
326 let mut child_tasks = node
328 .tasks
329 .iter()
330 .filter(|task| task.agent != node.active_agent)
331 .cloned()
332 .collect::<BTreeSet<_>>();
333 let active_task =
335 ActiveTask::new(node.active_agent, task.clone(), node.tick, state_diff);
336 child_tasks.insert(active_task);
337 log::trace!("\tActive Tasks ({}):", child_tasks.len());
338 for active_task in &child_tasks {
339 log::trace!(
340 "\t {:?}: {:?} ends T{}",
341 active_task.agent,
342 active_task.task,
343 active_task.end
344 );
345 }
346
347 let next_active_task = child_tasks.iter().next().unwrap().clone();
349 log::trace!(
350 "\tNext Active Task: {:?}: {:?} ends T{}",
351 next_active_task.agent,
352 next_active_task.task,
353 next_active_task.end
354 );
355
356 let is_task_valid = next_active_task.task.is_valid(
358 next_active_task.end,
359 state_diff,
360 next_active_task.agent,
361 );
362 if !is_task_valid && !self.config.allow_invalid_tasks {
363 log::debug!("T{}\tNext active task {:?} is invalid and that is not allowed, aborting expansion", next_active_task.end, next_active_task.task);
364 return TreePolicyOutcome::NoValidTask(depth, path);
365 }
366 let after_next_task = if is_task_valid {
368 let state_diff_mut = StateDiffRefMut::new(&self.initial_state, &mut diff);
369 next_active_task.task.execute(
370 next_active_task.end,
371 state_diff_mut,
372 next_active_task.agent,
373 )
374 } else {
375 None
376 };
377
378 let after_next_task = if after_next_task.is_none() {
380 if let Some(planning_task_duration) = self.config.planning_task_duration {
382 if next_active_task
383 .task
384 .downcast_ref::<PlanningTask>()
385 .is_none()
386 {
387 let task: Box<dyn Task<D>> =
389 Box::new(PlanningTask(planning_task_duration));
390 Some(task)
391 } else {
392 None
393 }
394 } else {
395 None
396 }
397 } else {
398 after_next_task
399 };
400
401 let child_state = NodeInner::new(
404 &self.initial_state,
405 self.start_tick,
406 diff,
407 next_active_task.agent,
408 next_active_task.end,
409 child_tasks,
410 );
411
412 let child_node =
414 if let Some((existing_node, _)) = self.nodes.get_key_value(&child_state) {
415 log::trace!("\tLinking to existing node {:?}", existing_node);
417 existing_node.clone()
418 } else {
419 log::trace!("\tCreating new node {:?}", child_state);
421 let child_node = Node::new(child_state);
422 self.nodes.insert(
423 child_node.clone(),
424 Edges::new(&child_node, &self.initial_state, after_next_task),
425 );
426 child_node
427 };
428
429 let edge = new_edge(&node, &child_node, &agents);
431 let edges = self.nodes.get_mut(&node).unwrap();
432 edges.expanded_tasks.insert(task, edge.clone());
433
434 path.push(edge);
436
437 depth += (child_node.tick - node.tick) as u32;
438 log::debug!(
439 "T{}\tExpansion successful, node created with incoming task {:?}",
440 child_node.tick,
441 next_active_task.task
442 );
443 return TreePolicyOutcome::NodeCreated(depth, child_node, path);
444 }
445
446 if edges.child_visits() == 0 {
448 log::debug!("T{}\tNode has no children, aborting expansion", node.tick);
449 return TreePolicyOutcome::NoChildNode(depth, node, path);
450 }
451
452 let range = self.min_max_range(node.active_agent);
457 let edges = self.nodes.get_mut(&node).unwrap();
458 let task = edges
459 .best_task(node.active_agent, self.config.exploration, range)
460 .expect("No valid task!");
461 log::trace!(
462 "T{}\t{:?} - Select task: {:?}",
463 node.tick,
464 node.active_agent,
465 task
466 );
467 let edge = edges.expanded_tasks.get(&task).unwrap().clone();
468
469 let parent_tick = node.tick;
471 node = {
472 let edge = edge.lock().unwrap();
473 edge.child()
474 };
475 let child_tick = node.tick;
476 depth += (child_tick - parent_tick) as u32;
477
478 path.push(edge);
480 }
481
482 log::debug!(
484 "T{}\tReached maximum depth {}, aborting expansion",
485 node.tick,
486 depth
487 );
488 TreePolicyOutcome::DepthLimitReached(depth, node, path)
489 }
490
491 fn backpropagation(
493 &mut self,
494 mut path: Vec<Edge<D>>,
495 rollout_values: Option<BTreeMap<AgentId, f32>>,
496 ) {
497 path.drain(..).rev().for_each(|edge| {
499 let edge = &mut edge.lock().unwrap();
501 edge.visits += 1;
502 if let Some(rollout_values) = &rollout_values {
503 let parent_node = edge.parent();
504 let child_node = edge.child();
505 let visits = edge.visits;
506 let child_edges = self.nodes.get(&child_node).unwrap();
507
508 let discount_factor =
509 Self::discount_factor(child_node.tick - parent_node.tick, &self.config);
510
511 edge.q_values.iter_mut().for_each(|(&agent, q_value_ref)| {
513 let parent_current_value =
514 parent_node.current_value_or_compute(agent, &self.initial_state);
515 let child_current_value =
516 child_node.current_value_or_compute(agent, &self.initial_state);
517
518 let mut child_q_value =
520 if let Some(value) = child_edges.q_value((visits, *q_value_ref), agent) {
521 value
522 } else {
523 rollout_values.get(&agent).copied().unwrap_or_default()
524 };
525
526 child_q_value *= discount_factor;
529
530 let q_value = child_current_value - parent_current_value + child_q_value;
532
533 *q_value_ref = *q_value;
535
536 let q_value_range = self
538 .q_value_ranges
539 .entry(parent_node.active_agent)
540 .or_insert_with(|| Range {
541 start: VALUE_INFINITE,
542 end: VALUE_NEG_INFINITE,
543 });
544 q_value_range.start = q_value_range.start.min(q_value);
545 q_value_range.end = q_value_range.end.max(q_value);
546 });
547 }
548 });
549 }
550
551 fn discount_factor(duration: u64, config: &MCTSConfiguration) -> f32 {
556 2f64.powf((-(duration as f64)) / (config.discount_hl as f64)) as f32
557 }
558
559 pub fn initial_state(&self) -> &D::State {
561 &self.initial_state
562 }
563
564 pub fn start_tick(&self) -> u64 {
566 self.start_tick
567 }
568
569 pub fn agent(&self) -> AgentId {
571 self.root_agent
572 }
573
574 pub fn min_max_range(&self, agent: AgentId) -> Range<AgentValue> {
576 self.q_value_ranges.get(&agent).cloned().unwrap_or(Range {
577 start: VALUE_ZERO,
578 end: VALUE_ZERO,
579 })
580 }
581
582 pub fn nodes(&self) -> impl Iterator<Item = (&Node<D>, &Edges<D>)> {
584 self.nodes.iter()
585 }
586
587 pub fn root_node(&self) -> Node<D> {
589 self.root.clone()
590 }
591
592 pub fn get_edges(&self, node: &Node<D>) -> Option<&Edges<D>> {
594 self.nodes.get(node)
595 }
596
597 pub fn seed(&self) -> u64 {
599 self.seed
600 }
601
602 pub fn node_count(&self) -> usize {
604 self.nodes.len()
605 }
606
607 pub fn edge_count(&self) -> usize {
609 self.nodes
610 .values()
611 .map(|edges| edges.expanded_tasks.len())
612 .sum()
613 }
614
615 pub fn time(&self) -> Duration {
617 self.time
618 }
619
620 pub fn size(&self, task_size: fn(&dyn Task<D>) -> usize) -> usize {
622 let mut size = 0;
623
624 size += mem::size_of::<Self>();
625
626 for (node, edges) in &self.nodes {
627 size += node.size(task_size);
628 size += edges.size(task_size);
629 }
630
631 size += self.q_value_ranges.len() * mem::size_of::<(AgentId, Range<f32>)>();
632
633 size
634 }
635}
636
637pub struct DefaultPolicyEstimator {}
639impl<D: Domain> StateValueEstimator<D> for DefaultPolicyEstimator {
640 fn estimate(
641 &mut self,
642 rng: &mut ChaCha8Rng,
643 config: &MCTSConfiguration,
644 initial_state: &D::State,
645 start_tick: u64,
646 node: &Node<D>,
647 edges: &Edges<D>,
648 depth: u32,
649 ) -> Option<BTreeMap<AgentId, f32>> {
650 let mut diff = node.diff.clone();
651 log::debug!(
652 "T{}\tStarting rollout with cur. values: {:?}",
653 node.tick,
654 node.current_values()
655 );
656
657 let mut values: BTreeMap<AgentId, (AgentValue, f32)> = node
661 .current_values()
662 .iter()
663 .map(|(&agent, ¤t_value)| (agent, (current_value, 0f32)))
664 .collect::<BTreeMap<_, _>>();
665
666 let mut tasks = node
668 .tasks
669 .iter()
670 .filter(|task| task.agent != node.active_agent)
671 .cloned()
672 .collect::<BTreeSet<_>>();
673
674 let task = {
676 if let Some((weights, tasks)) = edges.unexpanded_tasks.as_ref() {
677 let idx = weights.sample(rng);
679 tasks[idx].clone()
680 } else {
681 log::debug!(
683 "T{}\tNo unexpanded edges in node passed to rollout",
684 node.tick
685 );
686 return None;
687 }
688 };
689 let new_active_task = ActiveTask::new(
690 node.active_agent,
691 task,
692 node.tick,
693 StateDiffRef::new(initial_state, &diff),
694 );
695 tasks.insert(new_active_task);
696 let mut agents_with_tasks = tasks
697 .iter()
698 .map(|task| task.agent)
699 .collect::<HashSet<AgentId>>();
700 let mut agents = agents_with_tasks.iter().copied().collect();
701
702 let rollout_start_tick = node.tick;
704 let mut tick = node.tick;
705 let mut depth = depth;
706 while depth < config.depth {
707 let state_diff = StateDiffRef::new(initial_state, &diff);
708
709 if tasks.is_empty() {
711 log::debug!(
712 "! T{} No more task to do in state\n{}",
713 tick,
714 D::get_state_description(state_diff)
715 );
716 break;
717 }
718
719 let active_task = tasks.iter().next().unwrap().clone();
721 tasks.remove(&active_task);
722 let active_agent = active_task.agent;
723 agents_with_tasks.remove(&active_agent);
724
725 let elapsed = active_task.end - tick;
727 tick = active_task.end;
728
729 let is_task_valid = active_task.task.is_valid(tick, state_diff, active_agent);
731 if !is_task_valid && !config.allow_invalid_tasks {
732 log::debug!(
733 "! T{} Not allowed invalid task {:?} by {:?} in state\n{}",
734 tick,
735 active_task.task,
736 active_agent,
737 D::get_state_description(state_diff)
738 );
739 break;
740 } else if is_task_valid {
741 log::trace!(
742 "✓ T{} Valid task {:?} by {:?} in state\n{}",
743 tick,
744 active_task.task,
745 active_agent,
746 D::get_state_description(state_diff)
747 );
748 } else {
749 log::trace!(
750 "✓ T{} Skipping invalid task {:?} by {:?} in state\n{}",
751 tick,
752 active_task.task,
753 active_agent,
754 D::get_state_description(state_diff)
755 );
756 }
757
758 let new_task = if is_task_valid {
760 let state_diff_mut = StateDiffRefMut::new(initial_state, &mut diff);
761 active_task.task.execute(tick, state_diff_mut, active_agent)
762 } else {
763 None
764 };
765 let new_state_diff = StateDiffRef::new(initial_state, &diff);
766
767 let new_task = if new_task.is_none() {
769 if let Some(planning_task_duration) = config.planning_task_duration {
771 if active_task.task.downcast_ref::<PlanningTask>().is_none() {
772 let task: Box<dyn Task<D>> = Box::new(PlanningTask(planning_task_duration));
774 Some(task)
775 } else {
776 None
777 }
778 } else {
779 None
780 }
781 } else {
782 new_task
783 };
784
785 if let Entry::Occupied(mut entry) = values.entry(active_agent) {
787 let (current_value, estimated_value) = entry.get_mut();
788
789 let discount =
791 MCTS::<D>::discount_factor(active_task.end - rollout_start_tick, config);
792
793 let new_current_value = D::get_current_value(tick, new_state_diff, active_agent);
795 *estimated_value += *(new_current_value - *current_value) * discount;
796 *current_value = new_current_value;
797 }
798
799 D::update_visible_agents(start_tick, tick, new_state_diff, active_agent, &mut agents);
802 for agent in agents.iter() {
803 if *agent != active_agent && !agents_with_tasks.contains(agent) {
804 tasks.insert(ActiveTask::new_idle(tick, *agent, active_agent));
805 agents_with_tasks.insert(*agent);
806 }
807 }
808
809 if agents.contains(&active_agent) {
811 let new_task = new_task.or_else(|| {
813 let tasks = D::get_tasks(tick, new_state_diff, active_agent);
815 if tasks.is_empty() {
816 return None;
817 }
818 for task in &tasks {
820 debug_assert!(task.is_valid(tick, new_state_diff, active_agent));
821 }
822 let weights = WeightedIndex::new(
824 tasks
825 .iter()
826 .map(|task| task.weight(tick, new_state_diff, active_agent)),
827 )
828 .unwrap();
829 let idx = weights.sample(rng);
831 Some(tasks[idx].clone())
832 });
833
834 if let Some(new_task) = new_task {
836 let new_active_task = ActiveTask::new(
838 active_agent,
839 new_task,
840 tick,
841 StateDiffRef::new(initial_state, &diff),
842 );
843 tasks.insert(new_active_task);
844 agents_with_tasks.insert(active_agent);
845 }
846 }
847
848 if agents_with_tasks.len() > agents.len() {
850 agents_with_tasks.retain(|id| agents.contains(id));
851 }
852
853 depth += elapsed as u32;
855 }
856
857 let q_values = values
858 .iter()
859 .map(|(agent, (_, q_value))| (*agent, *q_value))
860 .collect();
861
862 log::debug!(
863 "T{}\tRollout to T{}: q values: {:?}",
864 node.tick,
865 depth,
866 q_values
867 );
868
869 Some(q_values)
870 }
871}
872
873#[cfg(feature = "graphviz")]
875pub mod graphviz {
876 use super::*;
877 use std::hash::{Hash, Hasher};
878 use std::{
879 borrow::Cow,
880 io::{self, Write},
881 sync::{atomic::AtomicUsize, Arc},
882 };
883
884 use dot::{Arrow, Edges, GraphWalk, Id, Kind, LabelText, Labeller, Nodes, Style};
885
886 pub fn plot_mcts_tree<D: Domain, W: Write>(mcts: &MCTS<D>, w: &mut W) -> io::Result<()> {
888 dot::render(mcts, w)
889 }
890
891 fn agent_color_hsv(agent: AgentId) -> (f32, f32, f32) {
892 use palette::IntoColor;
893 let mut hasher = std::collections::hash_map::DefaultHasher::default();
894 agent.0.hash(&mut hasher);
895 let bytes: [u8; 8] = hasher.finish().to_ne_bytes();
896 let (h, s, v) = palette::Srgb::from_components((bytes[5], bytes[6], bytes[7]))
897 .into_format::<f32>()
898 .into_hsv::<palette::encoding::Srgb>()
899 .into_components();
900
901 ((h.to_degrees() + 180.) / 360., s, v)
902 }
903
904 struct Edge<D: Domain> {
905 parent: Node<D>,
906 child: Node<D>,
907 task: Box<dyn Task<D>>,
908 best: bool,
909 visits: usize,
910 score: f32,
911 uct: f32,
912 uct_0: f32,
913 reward: f32,
914 }
915
916 impl<D: Domain> Clone for Edge<D> {
917 fn clone(&self) -> Self {
918 Edge {
919 parent: self.parent.clone(),
920 child: self.child.clone(),
921 task: self.task.box_clone(),
922 best: self.best,
923 visits: self.visits,
924 score: self.score,
925 uct: self.uct,
926 uct_0: self.uct_0,
927 reward: self.reward,
928 }
929 }
930 }
931
932 static GRAPH_OUTPUT_DEPTH: AtomicUsize = AtomicUsize::new(4);
934
935 pub fn set_graph_output_depth(depth: usize) {
937 graphviz::GRAPH_OUTPUT_DEPTH.store(depth, std::sync::atomic::Ordering::Relaxed);
938 }
939 pub fn get_graph_output_depth() -> usize {
941 GRAPH_OUTPUT_DEPTH.load(std::sync::atomic::Ordering::Relaxed)
942 }
943
944 impl<D: Domain> MCTS<D> {
945 fn add_relevant_nodes(
946 &self,
947 nodes: &mut SeededHashSet<Node<D>>,
948 node: &Node<D>,
949 depth: usize,
950 ) {
951 if depth >= GRAPH_OUTPUT_DEPTH.load(std::sync::atomic::Ordering::Relaxed) {
952 return;
953 }
954
955 nodes.insert(node.clone());
956
957 let edges = self.nodes.get(node).unwrap();
958 for edge in edges.expanded_tasks.values() {
959 if let Ok(edge) = edge.try_lock() {
960 if let Some(child) = edge.child.upgrade() {
962 self.add_relevant_nodes(nodes, &child, depth + 1);
964 }
965 }
966 }
967 }
968 }
969
970 impl<'a, D: Domain> GraphWalk<'a, Node<D>, Edge<D>> for MCTS<D> {
971 fn nodes(&'a self) -> Nodes<'a, Node<D>> {
972 let mut nodes = SeededHashSet::default();
973 self.add_relevant_nodes(&mut nodes, &self.root, 0);
974
975 Nodes::Owned(nodes.iter().cloned().collect::<Vec<_>>())
976 }
977
978 fn edges(&'a self) -> Edges<'a, Edge<D>> {
979 let mut nodes = SeededHashSet::default();
980 self.add_relevant_nodes(&mut nodes, &self.root, 0);
981
982 let mut edge_vec = Vec::new();
983 nodes.iter().for_each(|node| {
984 let edges = self.nodes.get(node).unwrap();
985
986 if !edges.expanded_tasks.is_empty() {
987 let range = self.min_max_range(self.root_agent);
988 let best_task = edges
989 .best_task(node.active_agent, 0., range.clone())
990 .unwrap();
991 let visits = edges.child_visits();
992 edges.expanded_tasks.iter().for_each(|(obj, _edge)| {
993 let edge = _edge.lock().unwrap();
994
995 let parent = edge.parent();
996 let child = edge.child();
997
998 if nodes.contains(&child) {
999 let child_value = child.current_value(node.active_agent);
1000 let parent_value = parent.current_value(node.active_agent);
1001 let reward = child_value - parent_value;
1002 edge_vec.push(Edge {
1003 parent: edge.parent(),
1004 child,
1005 task: obj.clone(),
1006 best: obj == &best_task,
1007 visits: edge.visits,
1008 score: edge.q_values.get(&node.active_agent).copied().unwrap_or(0.),
1009 uct: edge.uct(
1010 node.active_agent,
1011 visits,
1012 self.config.exploration,
1013 range.clone(),
1014 ),
1015 uct_0: edge.uct(node.active_agent, visits, 0., range.clone()),
1016 reward: *reward,
1017 });
1018 }
1019 });
1020 }
1021 });
1022
1023 Edges::Owned(edge_vec)
1024 }
1025
1026 fn source(&'a self, edge: &Edge<D>) -> Node<D> {
1027 edge.parent.clone()
1028 }
1029
1030 fn target(&'a self, edge: &Edge<D>) -> Node<D> {
1031 edge.child.clone()
1032 }
1033 }
1034
1035 impl<'a, D: Domain> Labeller<'a, Node<D>, Edge<D>> for MCTS<D> {
1036 fn graph_id(&'a self) -> Id<'a> {
1037 Id::new(format!("agent_{}", self.root_agent.0)).unwrap()
1038 }
1039
1040 fn node_id(&'a self, n: &Node<D>) -> Id<'a> {
1041 Id::new(format!("_{:p}", Arc::as_ptr(n))).unwrap()
1042 }
1043
1044 fn node_label(&'a self, n: &Node<D>) -> LabelText<'a> {
1045 let edges = self.nodes.get(n).unwrap();
1046 let q_v = edges.q_value((0, 0.), n.active_agent);
1047 let state_diff = StateDiffRef::new(&self.initial_state, &n.diff);
1048 let mut state = D::get_state_description(state_diff);
1049 if !state.is_empty() {
1050 state = state.replace('\n', "<br/>");
1051 state = format!("<br/><font point-size='10'>{state}</font>");
1052 }
1053 LabelText::HtmlStr(Cow::Owned(format!(
1054 "Agent {}<br/>T: {}, Q: {}<br/>V: {:?}{state}",
1055 n.active_agent.0,
1056 n.tick,
1057 q_v.map(|q_v| format!("{:.2}", q_v))
1058 .unwrap_or_else(|| "None".to_owned()),
1059 n.current_values()
1060 .iter()
1061 .map(|(agent, value)| { (agent.0, **value) })
1062 .collect::<BTreeMap<_, _>>(),
1063 )))
1064 }
1065
1066 fn node_style(&'a self, node: &Node<D>) -> Style {
1067 if *node == self.root {
1068 Style::Bold
1069 } else {
1070 Style::Filled
1071 }
1072 }
1073
1074 fn node_color(&'a self, node: &Node<D>) -> Option<LabelText<'a>> {
1075 let root_visits = self.nodes.get(&self.root).unwrap().child_visits();
1076 let visits = self.nodes.get(node).unwrap().child_visits();
1077
1078 if *node == self.root {
1079 Some(LabelText::LabelStr(Cow::Borrowed("red")))
1080 } else {
1081 let (h, s, _v) = agent_color_hsv(node.active_agent);
1082 Some(LabelText::LabelStr(Cow::Owned(format!(
1084 "{:.3} {:.3} 1.000",
1085 h,
1086 s * (visits as f32 / root_visits as f32).min(1.0)
1087 ))))
1088 }
1089 }
1090
1091 fn edge_style(&'a self, edge: &Edge<D>) -> Style {
1092 if edge.best {
1093 Style::Bold
1094 } else {
1095 Style::Solid
1096 }
1097 }
1098
1099 fn edge_color(&'a self, edge: &Edge<D>) -> Option<LabelText<'a>> {
1100 if edge.best {
1101 Some(LabelText::LabelStr(Cow::Borrowed("red")))
1102 } else {
1103 None
1104 }
1105 }
1106
1107 fn edge_label(&'a self, edge: &Edge<D>) -> LabelText<'a> {
1108 LabelText::LabelStr(Cow::Owned(format!(
1109 "{:?}\nN: {}, R: {:.2}, Q: {:.2}\nU: {:.2} ({:.2} + {:.2})",
1110 edge.task.display_action(),
1111 edge.visits,
1112 edge.reward,
1113 edge.score,
1114 edge.uct,
1115 edge.uct_0,
1116 edge.uct - edge.uct_0
1117 )))
1118 }
1119
1120 fn edge_start_arrow(&'a self, _e: &Edge<D>) -> Arrow {
1121 Arrow::none()
1122 }
1123
1124 fn edge_end_arrow(&'a self, _e: &Edge<D>) -> Arrow {
1125 Arrow::normal()
1126 }
1127
1128 fn kind(&self) -> Kind {
1129 Kind::Digraph
1130 }
1131 }
1132}