npc_engine_core/
edge.rs

1/*
2 *  SPDX-License-Identifier: Apache-2.0 OR MIT
3 *  © 2020-2022 ETH Zurich and other contributors, see AUTHORS.txt for details
4 */
5
6use std::{
7    collections::BTreeSet,
8    fmt, mem,
9    ops::Range,
10    sync::{Arc, Mutex},
11};
12
13use crate::{AgentId, AgentValue, Domain, Node, SeededHashMap, StateDiffRef, Task, WeakNode};
14
15use rand::distributions::WeightedIndex;
16
17/// The tasks left to expand in a given node.
18///
19/// None if all tasks are expanded.
20type UnexpandedTasks<D> = Option<(WeightedIndex<f32>, Vec<Box<dyn Task<D>>>)>;
21
22/// The outgoing edges from a node, possibly partially expanded.
23pub struct Edges<D: Domain> {
24    pub(crate) unexpanded_tasks: UnexpandedTasks<D>,
25    pub(crate) expanded_tasks: SeededHashMap<Box<dyn Task<D>>, Edge<D>>,
26}
27impl<D: Domain> fmt::Debug for Edges<D> {
28    fn fmt(&self, f: &'_ mut fmt::Formatter) -> fmt::Result {
29        f.debug_struct("Edges")
30            .field("unexpanded_tasks", &self.unexpanded_tasks)
31            .field("expanded_tasks", &self.expanded_tasks)
32            .finish()
33    }
34}
35
36impl<'a, D: Domain> IntoIterator for &'a Edges<D> {
37    type Item = (&'a Box<dyn Task<D>>, &'a Edge<D>);
38    type IntoIter = std::collections::hash_map::Iter<'a, Box<dyn Task<D>>, Edge<D>>;
39
40    fn into_iter(self) -> Self::IntoIter {
41        self.expanded_tasks.iter()
42    }
43}
44
45impl<D: Domain> Edges<D> {
46    /// Creates new edges, with optionally a forced task that will be the sole edge.
47    pub fn new(
48        node: &Node<D>,
49        initial_state: &D::State,
50        next_task: Option<Box<dyn Task<D>>>,
51    ) -> Self {
52        let unexpanded_tasks = match next_task {
53            Some(task)
54                if task.is_valid(
55                    node.tick,
56                    StateDiffRef::new(initial_state, &node.diff),
57                    node.active_agent,
58                ) =>
59            {
60                let weights = WeightedIndex::new((&[1.]).iter().map(Clone::clone)).unwrap();
61
62                // Set existing child weights, only option
63                Some((weights, vec![task.clone()]))
64            }
65            _ => {
66                // Get possible tasks
67                let tasks = D::get_tasks(
68                    node.tick,
69                    StateDiffRef::new(initial_state, &node.diff),
70                    node.active_agent,
71                );
72                if tasks.is_empty() {
73                    // no task, return empty edges
74                    return Edges {
75                        unexpanded_tasks: None,
76                        expanded_tasks: Default::default(),
77                    };
78                }
79
80                // Safety-check that all tasks are valid
81                let state_diff = StateDiffRef::new(initial_state, &node.diff);
82                for task in &tasks {
83                    debug_assert!(task.is_valid(node.tick, state_diff, node.active_agent));
84                }
85
86                // Get the weight for each task
87                let weights = WeightedIndex::new(
88                    tasks
89                        .iter()
90                        .map(|task| task.weight(node.tick, state_diff, node.active_agent)),
91                )
92                .unwrap();
93
94                Some((weights, tasks))
95            }
96        };
97
98        Edges {
99            unexpanded_tasks,
100            expanded_tasks: Default::default(),
101        }
102    }
103
104    /// Returns the sum of all visits to the edges of this nodes.
105    pub fn child_visits(&self) -> usize {
106        self.expanded_tasks
107            .values()
108            .map(|edge| edge.lock().unwrap().visits)
109            .sum()
110    }
111
112    /// Finds the best task with the given `exploration` factor and normalization `range`.
113    pub fn best_task(
114        &self,
115        agent: AgentId,
116        exploration: f32,
117        range: Range<AgentValue>,
118    ) -> Option<Box<dyn Task<D>>> {
119        let visits = self.child_visits();
120        self.expanded_tasks
121            .iter()
122            .max_by(|(_, a), (_, b)| {
123                let a = a.lock().unwrap();
124                let b = b.lock().unwrap();
125                a.uct(agent, visits, exploration, range.clone())
126                    .partial_cmp(&b.uct(agent, visits, exploration, range.clone()))
127                    .unwrap()
128            })
129            .map(|(k, _)| k.clone())
130    }
131
132    /// Returns the weighted average q value of all child edges.
133    ///
134    /// The `fallback` value is used for self-referential edges.
135    pub fn q_value(&self, fallback: (usize, f32), agent: AgentId) -> Option<f32> {
136        self.expanded_tasks
137            .values()
138            .map(|edge| {
139                edge.try_lock()
140                    .map(|edge| {
141                        (
142                            edge.visits,
143                            edge.q_values.get(&agent).copied().unwrap_or_default(),
144                        )
145                    })
146                    .unwrap_or(fallback)
147            })
148            .fold(None, |acc, (visits, value)| match acc {
149                Some((sum, count)) => Some((sum + visits as f32 * value, count + visits)),
150                None => Some((visits as f32 * value, visits)),
151            })
152            .map(|(sum, count)| sum / count as f32)
153    }
154
155    /// Returns the number of already-expanded edges.
156    pub fn expanded_count(&self) -> usize {
157        self.expanded_tasks.len()
158    }
159
160    /// Returns the number of not-yet-expanded edges.
161    pub fn unexpanded_count(&self) -> usize {
162        self.unexpanded_tasks
163            .as_ref()
164            .map_or(0, |(_, tasks)| tasks.len())
165    }
166
167    /// Returns how many edges there are, the sum of the expanded and not-yet expanded counts.
168    pub fn branching_factor(&self) -> usize {
169        self.expanded_count() + self.unexpanded_count()
170    }
171
172    /// Returns the expanded edge associated to a task, None if it does not exist.
173    #[allow(clippy::borrowed_box)]
174    pub fn get_edge(&self, task: &Box<dyn Task<D>>) -> Option<Edge<D>> {
175        self.expanded_tasks.get(task).cloned()
176    }
177
178    /// The memory footprint of this struct.
179    pub fn size(&self, task_size: fn(&dyn Task<D>) -> usize) -> usize {
180        let mut size = 0;
181
182        size += mem::size_of::<Self>();
183
184        if let Some((_, tasks)) = self.unexpanded_tasks.as_ref() {
185            for task in tasks {
186                size += task_size(&**task);
187            }
188        }
189
190        for (task, edge) in &self.expanded_tasks {
191            size += task_size(&**task);
192            size += edge.lock().unwrap().size();
193        }
194
195        size
196    }
197}
198
199/// Strong atomic reference counted edge.
200pub type Edge<D> = Arc<Mutex<EdgeInner<D>>>;
201
202/// The data associated with an edge.
203pub struct EdgeInner<D: Domain> {
204    pub(crate) parent: WeakNode<D>,
205    pub(crate) child: WeakNode<D>,
206    pub(crate) visits: usize,
207    pub(crate) q_values: SeededHashMap<AgentId, f32>,
208}
209
210impl<D: Domain> fmt::Debug for EdgeInner<D> {
211    fn fmt(&self, f: &'_ mut fmt::Formatter) -> fmt::Result {
212        f.debug_struct("EdgeInner")
213            .field("parent", &self.parent)
214            .field("child", &self.child)
215            .field("visits", &self.visits)
216            .field("q_values", &self.q_values)
217            .finish()
218    }
219}
220
221/// Creates a new edge between a parent and a child.
222pub(crate) fn new_edge<D: Domain>(
223    parent: &Node<D>,
224    child: &Node<D>,
225    agents: &BTreeSet<AgentId>,
226) -> Edge<D> {
227    Arc::new(Mutex::new(EdgeInner {
228        parent: Node::downgrade(parent),
229        child: Node::downgrade(child),
230        visits: Default::default(),
231        q_values: agents.iter().map(|agent| (*agent, 0.)).collect(),
232    }))
233}
234
235impl<D: Domain> EdgeInner<D> {
236    /// Calculates the current UCT value for the edge.
237    pub fn uct(
238        &self,
239        parent_agent: AgentId,
240        parent_child_visits: usize,
241        exploration: f32,
242        range: Range<AgentValue>,
243    ) -> f32 {
244        // If parent is not present, this node is being reused and the parent leaves the horizon. Score doesn't matter
245        if let Some(q_value) = self.q_values.get(&parent_agent) {
246            // Normalize the exploitation factor so it doesn't overshadow the exploration
247            let exploitation_value =
248                (q_value - *range.start) / (*(range.end - range.start)).max(f32::EPSILON);
249            let exploration_value =
250                ((parent_child_visits as f32).ln() / (self.visits as f32).max(f32::EPSILON)).sqrt();
251            exploitation_value + exploration * exploration_value
252        } else {
253            0.
254        }
255    }
256
257    /// Returns the number of visits to this edge
258    pub fn visits(&self) -> usize {
259        self.visits
260    }
261
262    /// Get the q-value of a given agent, 0 if not present
263    pub fn q_value(&self, agent: AgentId) -> f32 {
264        self.q_values.get(&agent).copied().unwrap_or(0.)
265    }
266
267    /// Returns the linked child node.
268    pub fn child(&self) -> Node<D> {
269        self.child.upgrade().unwrap()
270    }
271
272    /// Returns the linked parent node.
273    pub fn parent(&self) -> Node<D> {
274        self.parent.upgrade().unwrap()
275    }
276
277    /// The memory footprint of this struct.
278    pub fn size(&self) -> usize {
279        let mut size = 0;
280
281        size += mem::size_of::<Self>();
282        size += self.q_values.len() * mem::size_of::<(AgentId, f32)>();
283
284        size
285    }
286}