npc_engine_core/
mcts.rs

1/*
2 *  SPDX-License-Identifier: Apache-2.0 OR MIT
3 *  © 2020-2022 ETH Zurich and other contributors, see AUTHORS.txt for details
4 */
5
6use std::collections::HashMap;
7use std::collections::{btree_map::Entry, BTreeMap, BTreeSet, HashSet};
8use std::f32;
9use std::mem;
10use std::ops::Range;
11use std::time::{Duration, Instant};
12
13use rand::{
14    distributions::WeightedIndex,
15    prelude::{thread_rng, Distribution, RngCore, SeedableRng},
16    Rng,
17};
18use rand_chacha::ChaCha8Rng;
19
20use crate::*;
21
22// TODO: Consider replacing Seeded hashmaps with btreemaps
23
24// TODO: Once is_nan() and unwrap() are const, remove unsafe
25
26// SAFETY: 0 is not NaN
27const VALUE_ZERO: AgentValue = unsafe { AgentValue::new_unchecked(0.) };
28// SAFETY: INFINITY is not NaN
29const VALUE_INFINITE: AgentValue = unsafe { AgentValue::new_unchecked(std::f32::INFINITY) };
30// SAFETY: NEG_INFINITY is not NaN
31const VALUE_NEG_INFINITE: AgentValue = unsafe { AgentValue::new_unchecked(std::f32::NEG_INFINITY) };
32
33/// The state of a running planner instance.
34pub struct MCTS<D: Domain> {
35    // Statistics
36    time: Duration,
37
38    // Config
39    config: MCTSConfiguration,
40    state_value_estimator: Box<dyn StateValueEstimator<D> + Send>,
41    early_stop_condition: Option<Box<EarlyStopCondition>>,
42
43    // Run-specific parameters
44    root_agent: AgentId,
45    seed: u64,
46
47    // Nodes
48    root: Node<D>,
49    nodes: SeededHashMap<Node<D>, Edges<D>>,
50
51    // Globals
52    q_value_ranges: BTreeMap<AgentId, Range<AgentValue>>,
53
54    // State before planning
55    initial_state: D::State,
56    start_tick: u64,
57
58    // Rng
59    rng: ChaCha8Rng,
60}
61
62/// The possible outcomes from a tree policy pass.
63enum TreePolicyOutcome<D: Domain> {
64    NodeCreated(u32, Node<D>, Vec<Edge<D>>), // depth, new node, path
65    NoValidTask(u32, Vec<Edge<D>>),          // depth, path
66    NoChildNode(u32, Node<D>, Vec<Edge<D>>), // depth, node, path
67    DepthLimitReached(u32, Node<D>, Vec<Edge<D>>), // depth, new node, path
68}
69
70impl<D: Domain> MCTS<D> {
71    /// Instantiates a new search tree for the given state, with idle tasks for all agents and starting at tick 0.
72    pub fn new(initial_state: D::State, root_agent: AgentId, config: MCTSConfiguration) -> Self {
73        let state_value_estimator = Box::new(DefaultPolicyEstimator {});
74        Self::new_with_tasks(
75            initial_state,
76            root_agent,
77            0,
78            Default::default(),
79            config,
80            state_value_estimator,
81            None,
82        )
83    }
84
85    /// Instantiates a new search tree for the given state, with active tasks for all agents and starting at a given tick.
86    pub fn new_with_tasks(
87        initial_state: D::State,
88        root_agent: AgentId,
89        start_tick: u64,
90        tasks: ActiveTasks<D>,
91        config: MCTSConfiguration,
92        state_value_estimator: Box<dyn StateValueEstimator<D> + Send>,
93        early_stop_condition: Option<Box<EarlyStopCondition>>,
94    ) -> Self {
95        // Check whether there is a task for this agent already
96        let next_task =
97            get_task_for_agent(&tasks, root_agent).map(|active_task| active_task.task.clone());
98
99        // Create new root node
100        let root = Node::new(NodeInner::new(
101            &initial_state,
102            start_tick,
103            Default::default(),
104            root_agent,
105            start_tick,
106            tasks,
107        ));
108
109        // Prepare nodes, reserve the maximum amount we could need
110        let mut nodes = SeededHashMap::with_capacity_and_hasher(
111            config.visits as usize + 1,
112            SeededRandomState::default(),
113        );
114
115        // Insert new root node
116        let root_edges = Edges::new(&root, &initial_state, next_task);
117        nodes.insert(root.clone(), root_edges);
118
119        // Compute seed
120        let cur_seed = config.seed.unwrap_or_else(|| thread_rng().next_u64());
121
122        MCTS {
123            time: Duration::default(),
124            config,
125            state_value_estimator,
126            early_stop_condition,
127            seed: cur_seed,
128            root_agent,
129            root,
130            nodes,
131            q_value_ranges: Default::default(),
132            initial_state,
133            start_tick,
134            rng: ChaCha8Rng::seed_from_u64(cur_seed),
135        }
136    }
137
138    /// Returns the best task, using exploration value of 0.
139    pub fn best_task_at_root(&mut self) -> Option<Box<dyn Task<D>>> {
140        let range = self.min_max_range(self.root_agent);
141        let edges = self.nodes.get(&self.root).unwrap();
142        edges
143            // Get best expanded tasks.
144            .best_task(self.root_agent, 0., range)
145            // If none, sample unexpanded tasks.
146            .or_else(|| {
147                edges.unexpanded_tasks.as_ref().and_then(|(_, tasks)| {
148                    if tasks.is_empty() {
149                        None
150                    } else {
151                        let index = self.rng.gen_range(0..tasks.len());
152                        Some(tasks[index].clone())
153                    }
154                })
155            })
156    }
157
158    /// Returns the best task, following a given recent task history, in case planning tasks are used.
159    pub fn best_task_with_history(
160        &self,
161        task_history: &HashMap<AgentId, ActiveTask<D>>,
162    ) -> Box<dyn Task<D>>
163    where
164        D: DomainWithPlanningTask,
165    {
166        log::debug!(
167            "Finding best task for {} using history {:?}",
168            self.root_agent,
169            task_history
170        );
171        let mut current_node = self.root.clone();
172        let mut edges = self.nodes.get(&current_node).unwrap();
173        let mut depth = 0;
174        loop {
175            let node_agent = current_node.agent();
176            let node_tick = current_node.tick();
177            let edge = if edges.expanded_tasks.len() == 1 {
178                let (task, edge) = edges.expanded_tasks.iter().next().unwrap();
179                log::trace!("[{depth}] T{node_tick} {node_agent} skipping {task:?}");
180
181                // Skip non-branching nodes
182                edge
183            } else {
184                let executed_task = task_history.get(&node_agent);
185                let executed_task = executed_task
186                    .unwrap_or_else(|| panic!("Found no task for {node_agent} is history"));
187                let task = &executed_task.task;
188                log::trace!("[{depth}] T{node_tick} {node_agent} executed {task:?}");
189
190                let edge = edges.expanded_tasks.get(task);
191
192                if edge.is_none() {
193                    log::info!("{node_agent} executed unexpected {task:?} not present in search tree, returning fallback task");
194                    return D::fallback_task(self.root_agent);
195                }
196
197                edge.unwrap()
198            };
199            let edge = edge.lock().unwrap();
200            current_node = edge.child();
201            // log::debug!("NEW_CUR_NODE: {current_node:?} {:p}", Arc::as_ptr(current_node));
202            edges = self.nodes.get(&current_node).unwrap();
203
204            depth += 1;
205
206            // Stop if we reach our own node again
207            if current_node.agent() == self.root_agent {
208                break;
209            }
210        }
211
212        // Return best task, using exploration value of 0
213        let range = self.min_max_range(self.root_agent);
214        let best = edges.best_task(self.root_agent, 0., range);
215
216        if best.is_none() {
217            log::info!(
218                "No valid task for agent {}, returning fallback task",
219                self.root_agent
220            );
221            return D::fallback_task(self.root_agent);
222        }
223
224        best.unwrap().clone()
225    }
226
227    /// Returns the q-value at root.
228    pub fn q_value_at_root(&self, agent: AgentId) -> f32 {
229        let edges = self.nodes.get(&self.root).unwrap();
230        edges.q_value((0, 0.), agent).unwrap()
231    }
232
233    /// Executes the MCTS search.
234    ///
235    /// Returns the current best task, if there is at least one task for the root node.
236    pub fn run(&mut self) -> Option<Box<dyn Task<D>>> {
237        // Reset globals
238        self.q_value_ranges.clear();
239
240        let start = Instant::now();
241        let max_visits = self.config.visits;
242        for i in 0..max_visits {
243            // Execute tree policy, if expansion resulted in no node, do nothing
244            let tree_policy_outcome = self.tree_policy();
245
246            // Only if the tree policy resulted in a node expansion, we execute the default policy,
247            // but in any case we update the visit count.
248            let (path, rollout_values) = match tree_policy_outcome {
249                TreePolicyOutcome::NodeCreated(depth, leaf, path) => {
250                    // Execute default policy
251                    let edges = self.nodes.get(&leaf).unwrap();
252                    let rollout_values = self.state_value_estimator.estimate(
253                        &mut self.rng,
254                        &self.config,
255                        &self.initial_state,
256                        self.start_tick,
257                        &leaf,
258                        edges,
259                        depth,
260                    );
261                    (path, rollout_values)
262                }
263                TreePolicyOutcome::NoValidTask(_, path) => (path, None),
264                TreePolicyOutcome::NoChildNode(_, _, path) => (path, None),
265                TreePolicyOutcome::DepthLimitReached(_, _, path) => (path, None),
266            };
267
268            // Backpropagate results
269            self.backpropagation(path, rollout_values);
270
271            // Early stopping if told so by some user-defined condition
272            if let Some(early_stop_condition) = &self.early_stop_condition {
273                if early_stop_condition() {
274                    log::info!("{:?} early stops planning after {} visits", self.agent(), i);
275                    break;
276                }
277            }
278        }
279        self.time = start.elapsed();
280
281        self.best_task_at_root()
282    }
283
284    /// MCTS tree policy. Executes the `selection` and `expansion` phases.
285    fn tree_policy(&mut self) -> TreePolicyOutcome<D> {
286        let agents = self.root.agents();
287
288        let mut node = self.root.clone();
289
290        // Path through the tree, including root and leaf
291        let mut path = Vec::with_capacity(self.config.depth as usize * agents.len());
292
293        // Execute selection until at most `depth`, expressed as number of ticks
294        let mut depth = 0;
295        while depth < self.config.depth {
296            let mut edges = self.nodes.get_mut(&node).unwrap();
297
298            // -------------------------
299            // Expansion
300            // -------------------------
301            // If weights are non-empty, the node has not been fully expanded
302            if let Some((weights, tasks)) = edges.unexpanded_tasks.as_mut() {
303                // Clone a new diff from the current one to be used for the newly expanded node
304                let mut diff = node.diff.clone();
305
306                // Select expansion task randomly
307                let idx = weights.sample(&mut self.rng);
308                let task = tasks[idx].clone();
309                let state_diff = StateDiffRef::new(&self.initial_state, &diff);
310                debug_assert!(task.is_valid(node.tick, state_diff, node.active_agent));
311                log::debug!(
312                    "T{}\t{:?} - Expand task: {:?}",
313                    node.tick,
314                    node.active_agent,
315                    task
316                );
317
318                // Set weight of chosen task to zero to mark it as expanded.
319                // As updating weights returns an error if all weights are zero,
320                // we have to handle this by setting unexpanded_tasks to None if we get an error.
321                if weights.update_weights(&[(idx, &0.)]).is_err() {
322                    // All weights being zero implies the node is fully expanded
323                    edges.unexpanded_tasks = None;
324                }
325
326                // Clone active tasks for child node, removing task of active agent
327                let mut child_tasks = node
328                    .tasks
329                    .iter()
330                    .filter(|task| task.agent != node.active_agent)
331                    .cloned()
332                    .collect::<BTreeSet<_>>();
333                // Create and insert new active task for the active agent and the selected task
334                let active_task =
335                    ActiveTask::new(node.active_agent, task.clone(), node.tick, state_diff);
336                child_tasks.insert(active_task);
337                log::trace!("\tActive Tasks ({}):", child_tasks.len());
338                for active_task in &child_tasks {
339                    log::trace!(
340                        "\t  {:?}: {:?} ends T{}",
341                        active_task.agent,
342                        active_task.task,
343                        active_task.end
344                    );
345                }
346
347                // Get task that finishes in the next node
348                let next_active_task = child_tasks.iter().next().unwrap().clone();
349                log::trace!(
350                    "\tNext Active Task: {:?}: {:?} ends T{}",
351                    next_active_task.agent,
352                    next_active_task.task,
353                    next_active_task.end
354                );
355
356                // If it is not valid, abort this expansion
357                let is_task_valid = next_active_task.task.is_valid(
358                    next_active_task.end,
359                    state_diff,
360                    next_active_task.agent,
361                );
362                if !is_task_valid && !self.config.allow_invalid_tasks {
363                    log::debug!("T{}\tNext active task {:?} is invalid and that is not allowed, aborting expansion", next_active_task.end, next_active_task.task);
364                    return TreePolicyOutcome::NoValidTask(depth, path);
365                }
366                // Execute the task which finishes in the next node
367                let after_next_task = if is_task_valid {
368                    let state_diff_mut = StateDiffRefMut::new(&self.initial_state, &mut diff);
369                    next_active_task.task.execute(
370                        next_active_task.end,
371                        state_diff_mut,
372                        next_active_task.agent,
373                    )
374                } else {
375                    None
376                };
377
378                // If we do not have a forced follow-up task...
379                let after_next_task = if after_next_task.is_none() {
380                    // And we have a forced planning task, handle it
381                    if let Some(planning_task_duration) = self.config.planning_task_duration {
382                        if next_active_task
383                            .task
384                            .downcast_ref::<PlanningTask>()
385                            .is_none()
386                        {
387                            // the incoming task was not planning, so the next one should be
388                            let task: Box<dyn Task<D>> =
389                                Box::new(PlanningTask(planning_task_duration));
390                            Some(task)
391                        } else {
392                            None
393                        }
394                    } else {
395                        None
396                    }
397                } else {
398                    after_next_task
399                };
400
401                // Create expanded node state
402                // let was_planning = task.downcast_ref::<Plan>().is_some();
403                let child_state = NodeInner::new(
404                    &self.initial_state,
405                    self.start_tick,
406                    diff,
407                    next_active_task.agent,
408                    next_active_task.end,
409                    child_tasks,
410                );
411
412                // Check if child node exists already
413                let child_node =
414                    if let Some((existing_node, _)) = self.nodes.get_key_value(&child_state) {
415                        // Link existing child node
416                        log::trace!("\tLinking to existing node {:?}", existing_node);
417                        existing_node.clone()
418                    } else {
419                        // Create and insert new child node
420                        log::trace!("\tCreating new node {:?}", child_state);
421                        let child_node = Node::new(child_state);
422                        self.nodes.insert(
423                            child_node.clone(),
424                            Edges::new(&child_node, &self.initial_state, after_next_task),
425                        );
426                        child_node
427                    };
428
429                // Create edge from parent to child
430                let edge = new_edge(&node, &child_node, &agents);
431                let edges = self.nodes.get_mut(&node).unwrap();
432                edges.expanded_tasks.insert(task, edge.clone());
433
434                // Push edge to path
435                path.push(edge);
436
437                depth += (child_node.tick - node.tick) as u32;
438                log::debug!(
439                    "T{}\tExpansion successful, node created with incoming task {:?}",
440                    child_node.tick,
441                    next_active_task.task
442                );
443                return TreePolicyOutcome::NodeCreated(depth, child_node, path);
444            }
445
446            // There is no child to this node, still return last node to ensure increase of visit count for this path
447            if edges.child_visits() == 0 {
448                log::debug!("T{}\tNode has no children, aborting expansion", node.tick);
449                return TreePolicyOutcome::NoChildNode(depth, node, path);
450            }
451
452            // -------------------------
453            // Selection
454            // -------------------------
455            // Node is fully expanded, perform selection
456            let range = self.min_max_range(node.active_agent);
457            let edges = self.nodes.get_mut(&node).unwrap();
458            let task = edges
459                .best_task(node.active_agent, self.config.exploration, range)
460                .expect("No valid task!");
461            log::trace!(
462                "T{}\t{:?} - Select task: {:?}",
463                node.tick,
464                node.active_agent,
465                task
466            );
467            let edge = edges.expanded_tasks.get(&task).unwrap().clone();
468
469            // New node is the current child node
470            let parent_tick = node.tick;
471            node = {
472                let edge = edge.lock().unwrap();
473                edge.child()
474            };
475            let child_tick = node.tick;
476            depth += (child_tick - parent_tick) as u32;
477
478            // Push edge to path
479            path.push(edge);
480        }
481
482        // We reached maximum depth, still return last node to ensure increase of visit count for this path
483        log::debug!(
484            "T{}\tReached maximum depth {}, aborting expansion",
485            node.tick,
486            depth
487        );
488        TreePolicyOutcome::DepthLimitReached(depth, node, path)
489    }
490
491    /// MCTS backpropagation phase. If rollout values are None, just increment the visits.
492    fn backpropagation(
493        &mut self,
494        mut path: Vec<Edge<D>>,
495        rollout_values: Option<BTreeMap<AgentId, f32>>,
496    ) {
497        // Backtracking
498        path.drain(..).rev().for_each(|edge| {
499            // Increment child node visit count
500            let edge = &mut edge.lock().unwrap();
501            edge.visits += 1;
502            if let Some(rollout_values) = &rollout_values {
503                let parent_node = edge.parent();
504                let child_node = edge.child();
505                let visits = edge.visits;
506                let child_edges = self.nodes.get(&child_node).unwrap();
507
508                let discount_factor =
509                    Self::discount_factor(child_node.tick - parent_node.tick, &self.config);
510
511                // Iterate all agents on edge
512                edge.q_values.iter_mut().for_each(|(&agent, q_value_ref)| {
513                    let parent_current_value =
514                        parent_node.current_value_or_compute(agent, &self.initial_state);
515                    let child_current_value =
516                        child_node.current_value_or_compute(agent, &self.initial_state);
517
518                    // Get q value from child, or rollout value if leaf node, or 0 if not in rollout
519                    let mut child_q_value =
520                        if let Some(value) = child_edges.q_value((visits, *q_value_ref), agent) {
521                            value
522                        } else {
523                            rollout_values.get(&agent).copied().unwrap_or_default()
524                        };
525
526                    // Apply discount, there is no risk of double-discounting as if the parent and the child node
527                    // have the same tick, the discount value will be 1.0
528                    child_q_value *= discount_factor;
529
530                    // Use Bellman Equation
531                    let q_value = child_current_value - parent_current_value + child_q_value;
532
533                    // Update q value for edge
534                    *q_value_ref = *q_value;
535
536                    // Update global q value range for agent
537                    let q_value_range = self
538                        .q_value_ranges
539                        .entry(parent_node.active_agent)
540                        .or_insert_with(|| Range {
541                            start: VALUE_INFINITE,
542                            end: VALUE_NEG_INFINITE,
543                        });
544                    q_value_range.start = q_value_range.start.min(q_value);
545                    q_value_range.end = q_value_range.end.max(q_value);
546                });
547            }
548        });
549    }
550
551    /// Calculates the discount factor for the tick duration.
552    ///
553    /// This basically calculates a half-life decay factor for the given duration.
554    /// This means the discount factor will be 0.5 if the given ticks are equal to the configured half-life in the MCTS.
555    fn discount_factor(duration: u64, config: &MCTSConfiguration) -> f32 {
556        2f64.powf((-(duration as f64)) / (config.discount_hl as f64)) as f32
557    }
558
559    /// Returns the initial state at the root of the planning tree.
560    pub fn initial_state(&self) -> &D::State {
561        &self.initial_state
562    }
563
564    /// Returns the tick at the root of the planning tree.
565    pub fn start_tick(&self) -> u64 {
566        self.start_tick
567    }
568
569    /// Returns the agent the tree searches for.
570    pub fn agent(&self) -> AgentId {
571        self.root_agent
572    }
573
574    /// Returns the range of minimum and maximum global values.
575    pub fn min_max_range(&self, agent: AgentId) -> Range<AgentValue> {
576        self.q_value_ranges.get(&agent).cloned().unwrap_or(Range {
577            start: VALUE_ZERO,
578            end: VALUE_ZERO,
579        })
580    }
581
582    /// Returns an iterator over all nodes and edges in the tree.
583    pub fn nodes(&self) -> impl Iterator<Item = (&Node<D>, &Edges<D>)> {
584        self.nodes.iter()
585    }
586
587    /// Returns the root node of the search tree.
588    pub fn root_node(&self) -> Node<D> {
589        self.root.clone()
590    }
591
592    /// Returns the edges associated to a given node.
593    pub fn get_edges(&self, node: &Node<D>) -> Option<&Edges<D>> {
594        self.nodes.get(node)
595    }
596
597    /// Returns the seed of the tree.
598    pub fn seed(&self) -> u64 {
599        self.seed
600    }
601
602    /// Returns the number of nodes.
603    pub fn node_count(&self) -> usize {
604        self.nodes.len()
605    }
606
607    /// Returns the number of nodes.
608    pub fn edge_count(&self) -> usize {
609        self.nodes
610            .values()
611            .map(|edges| edges.expanded_tasks.len())
612            .sum()
613    }
614
615    /// Returns the duration of the last run.
616    pub fn time(&self) -> Duration {
617        self.time
618    }
619
620    /// Returns an estimation of the memory footprint of the MCTS struct.
621    pub fn size(&self, task_size: fn(&dyn Task<D>) -> usize) -> usize {
622        let mut size = 0;
623
624        size += mem::size_of::<Self>();
625
626        for (node, edges) in &self.nodes {
627            size += node.size(task_size);
628            size += edges.size(task_size);
629        }
630
631        size += self.q_value_ranges.len() * mem::size_of::<(AgentId, Range<f32>)>();
632
633        size
634    }
635}
636
637/// MCTS default policy using simulation-based rollout.
638pub struct DefaultPolicyEstimator {}
639impl<D: Domain> StateValueEstimator<D> for DefaultPolicyEstimator {
640    fn estimate(
641        &mut self,
642        rng: &mut ChaCha8Rng,
643        config: &MCTSConfiguration,
644        initial_state: &D::State,
645        start_tick: u64,
646        node: &Node<D>,
647        edges: &Edges<D>,
648        depth: u32,
649    ) -> Option<BTreeMap<AgentId, f32>> {
650        let mut diff = node.diff.clone();
651        log::debug!(
652            "T{}\tStarting rollout with cur. values: {:?}",
653            node.tick,
654            node.current_values()
655        );
656
657        // In this map we collect at the same time both:
658        // - the current value (measured from state and replaced in the course of simulation)
659        // - the Q value (initially 0, updated in the course of simulation)
660        let mut values: BTreeMap<AgentId, (AgentValue, f32)> = node
661            .current_values()
662            .iter()
663            .map(|(&agent, &current_value)| (agent, (current_value, 0f32)))
664            .collect::<BTreeMap<_, _>>();
665
666        // Clone active tasks for child node, removing task of active agent
667        let mut tasks = node
668            .tasks
669            .iter()
670            .filter(|task| task.agent != node.active_agent)
671            .cloned()
672            .collect::<BTreeSet<_>>();
673
674        // Sample a task for the node's unexpanded list, and put it in the queue
675        let task = {
676            if let Some((weights, tasks)) = edges.unexpanded_tasks.as_ref() {
677                // Select task randomly
678                let idx = weights.sample(rng);
679                tasks[idx].clone()
680            } else {
681                // No unexpanded edges, q values are 0
682                log::debug!(
683                    "T{}\tNo unexpanded edges in node passed to rollout",
684                    node.tick
685                );
686                return None;
687            }
688        };
689        let new_active_task = ActiveTask::new(
690            node.active_agent,
691            task,
692            node.tick,
693            StateDiffRef::new(initial_state, &diff),
694        );
695        tasks.insert(new_active_task);
696        let mut agents_with_tasks = tasks
697            .iter()
698            .map(|task| task.agent)
699            .collect::<HashSet<AgentId>>();
700        let mut agents = agents_with_tasks.iter().copied().collect();
701
702        // Create the state we need to perform the simulation
703        let rollout_start_tick = node.tick;
704        let mut tick = node.tick;
705        let mut depth = depth;
706        while depth < config.depth {
707            let state_diff = StateDiffRef::new(initial_state, &diff);
708
709            // If there is no more task to do, return what we have so far
710            if tasks.is_empty() {
711                log::debug!(
712                    "! T{} No more task to do in state\n{}",
713                    tick,
714                    D::get_state_description(state_diff)
715                );
716                break;
717            }
718
719            // Pop first task that is completed
720            let active_task = tasks.iter().next().unwrap().clone();
721            tasks.remove(&active_task);
722            let active_agent = active_task.agent;
723            agents_with_tasks.remove(&active_agent);
724
725            // Compute elapsed time and update tick
726            let elapsed = active_task.end - tick;
727            tick = active_task.end;
728
729            // If task is invalid, stop rollout
730            let is_task_valid = active_task.task.is_valid(tick, state_diff, active_agent);
731            if !is_task_valid && !config.allow_invalid_tasks {
732                log::debug!(
733                    "! T{} Not allowed invalid task {:?} by {:?} in state\n{}",
734                    tick,
735                    active_task.task,
736                    active_agent,
737                    D::get_state_description(state_diff)
738                );
739                break;
740            } else if is_task_valid {
741                log::trace!(
742                    "✓ T{} Valid task {:?} by {:?} in state\n{}",
743                    tick,
744                    active_task.task,
745                    active_agent,
746                    D::get_state_description(state_diff)
747                );
748            } else {
749                log::trace!(
750                    "✓ T{} Skipping invalid task {:?} by {:?} in state\n{}",
751                    tick,
752                    active_task.task,
753                    active_agent,
754                    D::get_state_description(state_diff)
755                );
756            }
757
758            // Execute the task
759            let new_task = if is_task_valid {
760                let state_diff_mut = StateDiffRefMut::new(initial_state, &mut diff);
761                active_task.task.execute(tick, state_diff_mut, active_agent)
762            } else {
763                None
764            };
765            let new_state_diff = StateDiffRef::new(initial_state, &diff);
766
767            // If we do not have a forced follow-up task...
768            let new_task = if new_task.is_none() {
769                // And we have a forced planning task, handle it
770                if let Some(planning_task_duration) = config.planning_task_duration {
771                    if active_task.task.downcast_ref::<PlanningTask>().is_none() {
772                        // the incoming task was not planning, so the next one should be
773                        let task: Box<dyn Task<D>> = Box::new(PlanningTask(planning_task_duration));
774                        Some(task)
775                    } else {
776                        None
777                    }
778                } else {
779                    None
780                }
781            } else {
782                new_task
783            };
784
785            // if the values for the agent executing the task are being tracked, update them
786            if let Entry::Occupied(mut entry) = values.entry(active_agent) {
787                let (current_value, estimated_value) = entry.get_mut();
788
789                // Compute discount
790                let discount =
791                    MCTS::<D>::discount_factor(active_task.end - rollout_start_tick, config);
792
793                // Update estimated value with discounted difference in current values
794                let new_current_value = D::get_current_value(tick, new_state_diff, active_agent);
795                *estimated_value += *(new_current_value - *current_value) * discount;
796                *current_value = new_current_value;
797            }
798
799            // Update the list of tasks, only considering visible agents,
800            // excluding the active agent (a new task for it will be added later)
801            D::update_visible_agents(start_tick, tick, new_state_diff, active_agent, &mut agents);
802            for agent in agents.iter() {
803                if *agent != active_agent && !agents_with_tasks.contains(agent) {
804                    tasks.insert(ActiveTask::new_idle(tick, *agent, active_agent));
805                    agents_with_tasks.insert(*agent);
806                }
807            }
808
809            // If active agent is visible, insert its next task, otherwise we forget about it
810            if agents.contains(&active_agent) {
811                // If no new task is available, select one randomly
812                let new_task = new_task.or_else(|| {
813                    // Get possible tasks
814                    let tasks = D::get_tasks(tick, new_state_diff, active_agent);
815                    if tasks.is_empty() {
816                        return None;
817                    }
818                    // Safety-check that all tasks are valid
819                    for task in &tasks {
820                        debug_assert!(task.is_valid(tick, new_state_diff, active_agent));
821                    }
822                    // Get the weight for each task
823                    let weights = WeightedIndex::new(
824                        tasks
825                            .iter()
826                            .map(|task| task.weight(tick, new_state_diff, active_agent)),
827                    )
828                    .unwrap();
829                    // Select task randomly
830                    let idx = weights.sample(rng);
831                    Some(tasks[idx].clone())
832                });
833
834                // If still none is available, stop caring about this agent
835                if let Some(new_task) = new_task {
836                    // Insert new task
837                    let new_active_task = ActiveTask::new(
838                        active_agent,
839                        new_task,
840                        tick,
841                        StateDiffRef::new(initial_state, &diff),
842                    );
843                    tasks.insert(new_active_task);
844                    agents_with_tasks.insert(active_agent);
845                }
846            }
847
848            // Make sure we do not keep track of the agents outside of the horizon
849            if agents_with_tasks.len() > agents.len() {
850                agents_with_tasks.retain(|id| agents.contains(id));
851            }
852
853            // Update depth
854            depth += elapsed as u32;
855        }
856
857        let q_values = values
858            .iter()
859            .map(|(agent, (_, q_value))| (*agent, *q_value))
860            .collect();
861
862        log::debug!(
863            "T{}\tRollout to T{}: q values: {:?}",
864            node.tick,
865            depth,
866            q_values
867        );
868
869        Some(q_values)
870    }
871}
872
873/// When `graphviz` feature is enabled, provides plotting of the search tree.
874#[cfg(feature = "graphviz")]
875pub mod graphviz {
876    use super::*;
877    use std::hash::{Hash, Hasher};
878    use std::{
879        borrow::Cow,
880        io::{self, Write},
881        sync::{atomic::AtomicUsize, Arc},
882    };
883
884    use dot::{Arrow, Edges, GraphWalk, Id, Kind, LabelText, Labeller, Nodes, Style};
885
886    /// Renders the search tree as graphviz's dot format.
887    pub fn plot_mcts_tree<D: Domain, W: Write>(mcts: &MCTS<D>, w: &mut W) -> io::Result<()> {
888        dot::render(mcts, w)
889    }
890
891    fn agent_color_hsv(agent: AgentId) -> (f32, f32, f32) {
892        use palette::IntoColor;
893        let mut hasher = std::collections::hash_map::DefaultHasher::default();
894        agent.0.hash(&mut hasher);
895        let bytes: [u8; 8] = hasher.finish().to_ne_bytes();
896        let (h, s, v) = palette::Srgb::from_components((bytes[5], bytes[6], bytes[7]))
897            .into_format::<f32>()
898            .into_hsv::<palette::encoding::Srgb>()
899            .into_components();
900
901        ((h.to_degrees() + 180.) / 360., s, v)
902    }
903
904    struct Edge<D: Domain> {
905        parent: Node<D>,
906        child: Node<D>,
907        task: Box<dyn Task<D>>,
908        best: bool,
909        visits: usize,
910        score: f32,
911        uct: f32,
912        uct_0: f32,
913        reward: f32,
914    }
915
916    impl<D: Domain> Clone for Edge<D> {
917        fn clone(&self) -> Self {
918            Edge {
919                parent: self.parent.clone(),
920                child: self.child.clone(),
921                task: self.task.box_clone(),
922                best: self.best,
923                visits: self.visits,
924                score: self.score,
925                uct: self.uct,
926                uct_0: self.uct_0,
927                reward: self.reward,
928            }
929        }
930    }
931
932    /// The depth of the graph to plot, in number of nodes.
933    static GRAPH_OUTPUT_DEPTH: AtomicUsize = AtomicUsize::new(4);
934
935    /// Sets the depth of the graph to plot, in number of nodes.
936    pub fn set_graph_output_depth(depth: usize) {
937        graphviz::GRAPH_OUTPUT_DEPTH.store(depth, std::sync::atomic::Ordering::Relaxed);
938    }
939    /// Gets the depth of the graph to plot, in number of nodes.
940    pub fn get_graph_output_depth() -> usize {
941        GRAPH_OUTPUT_DEPTH.load(std::sync::atomic::Ordering::Relaxed)
942    }
943
944    impl<D: Domain> MCTS<D> {
945        fn add_relevant_nodes(
946            &self,
947            nodes: &mut SeededHashSet<Node<D>>,
948            node: &Node<D>,
949            depth: usize,
950        ) {
951            if depth >= GRAPH_OUTPUT_DEPTH.load(std::sync::atomic::Ordering::Relaxed) {
952                return;
953            }
954
955            nodes.insert(node.clone());
956
957            let edges = self.nodes.get(node).unwrap();
958            for edge in edges.expanded_tasks.values() {
959                if let Ok(edge) = edge.try_lock() {
960                    // Prevent recursion
961                    if let Some(child) = edge.child.upgrade() {
962                        // TODO: Priority queue
963                        self.add_relevant_nodes(nodes, &child, depth + 1);
964                    }
965                }
966            }
967        }
968    }
969
970    impl<'a, D: Domain> GraphWalk<'a, Node<D>, Edge<D>> for MCTS<D> {
971        fn nodes(&'a self) -> Nodes<'a, Node<D>> {
972            let mut nodes = SeededHashSet::default();
973            self.add_relevant_nodes(&mut nodes, &self.root, 0);
974
975            Nodes::Owned(nodes.iter().cloned().collect::<Vec<_>>())
976        }
977
978        fn edges(&'a self) -> Edges<'a, Edge<D>> {
979            let mut nodes = SeededHashSet::default();
980            self.add_relevant_nodes(&mut nodes, &self.root, 0);
981
982            let mut edge_vec = Vec::new();
983            nodes.iter().for_each(|node| {
984                let edges = self.nodes.get(node).unwrap();
985
986                if !edges.expanded_tasks.is_empty() {
987                    let range = self.min_max_range(self.root_agent);
988                    let best_task = edges
989                        .best_task(node.active_agent, 0., range.clone())
990                        .unwrap();
991                    let visits = edges.child_visits();
992                    edges.expanded_tasks.iter().for_each(|(obj, _edge)| {
993                        let edge = _edge.lock().unwrap();
994
995                        let parent = edge.parent();
996                        let child = edge.child();
997
998                        if nodes.contains(&child) {
999                            let child_value = child.current_value(node.active_agent);
1000                            let parent_value = parent.current_value(node.active_agent);
1001                            let reward = child_value - parent_value;
1002                            edge_vec.push(Edge {
1003                                parent: edge.parent(),
1004                                child,
1005                                task: obj.clone(),
1006                                best: obj == &best_task,
1007                                visits: edge.visits,
1008                                score: edge.q_values.get(&node.active_agent).copied().unwrap_or(0.),
1009                                uct: edge.uct(
1010                                    node.active_agent,
1011                                    visits,
1012                                    self.config.exploration,
1013                                    range.clone(),
1014                                ),
1015                                uct_0: edge.uct(node.active_agent, visits, 0., range.clone()),
1016                                reward: *reward,
1017                            });
1018                        }
1019                    });
1020                }
1021            });
1022
1023            Edges::Owned(edge_vec)
1024        }
1025
1026        fn source(&'a self, edge: &Edge<D>) -> Node<D> {
1027            edge.parent.clone()
1028        }
1029
1030        fn target(&'a self, edge: &Edge<D>) -> Node<D> {
1031            edge.child.clone()
1032        }
1033    }
1034
1035    impl<'a, D: Domain> Labeller<'a, Node<D>, Edge<D>> for MCTS<D> {
1036        fn graph_id(&'a self) -> Id<'a> {
1037            Id::new(format!("agent_{}", self.root_agent.0)).unwrap()
1038        }
1039
1040        fn node_id(&'a self, n: &Node<D>) -> Id<'a> {
1041            Id::new(format!("_{:p}", Arc::as_ptr(n))).unwrap()
1042        }
1043
1044        fn node_label(&'a self, n: &Node<D>) -> LabelText<'a> {
1045            let edges = self.nodes.get(n).unwrap();
1046            let q_v = edges.q_value((0, 0.), n.active_agent);
1047            let state_diff = StateDiffRef::new(&self.initial_state, &n.diff);
1048            let mut state = D::get_state_description(state_diff);
1049            if !state.is_empty() {
1050                state = state.replace('\n', "<br/>");
1051                state = format!("<br/><font point-size='10'>{state}</font>");
1052            }
1053            LabelText::HtmlStr(Cow::Owned(format!(
1054                "Agent {}<br/>T: {}, Q: {}<br/>V: {:?}{state}",
1055                n.active_agent.0,
1056                n.tick,
1057                q_v.map(|q_v| format!("{:.2}", q_v))
1058                    .unwrap_or_else(|| "None".to_owned()),
1059                n.current_values()
1060                    .iter()
1061                    .map(|(agent, value)| { (agent.0, **value) })
1062                    .collect::<BTreeMap<_, _>>(),
1063            )))
1064        }
1065
1066        fn node_style(&'a self, node: &Node<D>) -> Style {
1067            if *node == self.root {
1068                Style::Bold
1069            } else {
1070                Style::Filled
1071            }
1072        }
1073
1074        fn node_color(&'a self, node: &Node<D>) -> Option<LabelText<'a>> {
1075            let root_visits = self.nodes.get(&self.root).unwrap().child_visits();
1076            let visits = self.nodes.get(node).unwrap().child_visits();
1077
1078            if *node == self.root {
1079                Some(LabelText::LabelStr(Cow::Borrowed("red")))
1080            } else {
1081                let (h, s, _v) = agent_color_hsv(node.active_agent);
1082                // let saturation = 95 -  * 50.) as usize;
1083                Some(LabelText::LabelStr(Cow::Owned(format!(
1084                    "{:.3} {:.3} 1.000",
1085                    h,
1086                    s * (visits as f32 / root_visits as f32).min(1.0)
1087                ))))
1088            }
1089        }
1090
1091        fn edge_style(&'a self, edge: &Edge<D>) -> Style {
1092            if edge.best {
1093                Style::Bold
1094            } else {
1095                Style::Solid
1096            }
1097        }
1098
1099        fn edge_color(&'a self, edge: &Edge<D>) -> Option<LabelText<'a>> {
1100            if edge.best {
1101                Some(LabelText::LabelStr(Cow::Borrowed("red")))
1102            } else {
1103                None
1104            }
1105        }
1106
1107        fn edge_label(&'a self, edge: &Edge<D>) -> LabelText<'a> {
1108            LabelText::LabelStr(Cow::Owned(format!(
1109                "{:?}\nN: {}, R: {:.2}, Q: {:.2}\nU: {:.2} ({:.2} + {:.2})",
1110                edge.task.display_action(),
1111                edge.visits,
1112                edge.reward,
1113                edge.score,
1114                edge.uct,
1115                edge.uct_0,
1116                edge.uct - edge.uct_0
1117            )))
1118        }
1119
1120        fn edge_start_arrow(&'a self, _e: &Edge<D>) -> Arrow {
1121            Arrow::none()
1122        }
1123
1124        fn edge_end_arrow(&'a self, _e: &Edge<D>) -> Arrow {
1125            Arrow::normal()
1126        }
1127
1128        fn kind(&self) -> Kind {
1129            Kind::Digraph
1130        }
1131    }
1132}