1use std::{
7 collections::BTreeSet,
8 fmt, mem,
9 ops::Range,
10 sync::{Arc, Mutex},
11};
12
13use crate::{AgentId, AgentValue, Domain, Node, SeededHashMap, StateDiffRef, Task, WeakNode};
14
15use rand::distributions::WeightedIndex;
16
17type UnexpandedTasks<D> = Option<(WeightedIndex<f32>, Vec<Box<dyn Task<D>>>)>;
21
22pub struct Edges<D: Domain> {
24 pub(crate) unexpanded_tasks: UnexpandedTasks<D>,
25 pub(crate) expanded_tasks: SeededHashMap<Box<dyn Task<D>>, Edge<D>>,
26}
27impl<D: Domain> fmt::Debug for Edges<D> {
28 fn fmt(&self, f: &'_ mut fmt::Formatter) -> fmt::Result {
29 f.debug_struct("Edges")
30 .field("unexpanded_tasks", &self.unexpanded_tasks)
31 .field("expanded_tasks", &self.expanded_tasks)
32 .finish()
33 }
34}
35
36impl<'a, D: Domain> IntoIterator for &'a Edges<D> {
37 type Item = (&'a Box<dyn Task<D>>, &'a Edge<D>);
38 type IntoIter = std::collections::hash_map::Iter<'a, Box<dyn Task<D>>, Edge<D>>;
39
40 fn into_iter(self) -> Self::IntoIter {
41 self.expanded_tasks.iter()
42 }
43}
44
45impl<D: Domain> Edges<D> {
46 pub fn new(
48 node: &Node<D>,
49 initial_state: &D::State,
50 next_task: Option<Box<dyn Task<D>>>,
51 ) -> Self {
52 let unexpanded_tasks = match next_task {
53 Some(task)
54 if task.is_valid(
55 node.tick,
56 StateDiffRef::new(initial_state, &node.diff),
57 node.active_agent,
58 ) =>
59 {
60 let weights = WeightedIndex::new((&[1.]).iter().map(Clone::clone)).unwrap();
61
62 Some((weights, vec![task.clone()]))
64 }
65 _ => {
66 let tasks = D::get_tasks(
68 node.tick,
69 StateDiffRef::new(initial_state, &node.diff),
70 node.active_agent,
71 );
72 if tasks.is_empty() {
73 return Edges {
75 unexpanded_tasks: None,
76 expanded_tasks: Default::default(),
77 };
78 }
79
80 let state_diff = StateDiffRef::new(initial_state, &node.diff);
82 for task in &tasks {
83 debug_assert!(task.is_valid(node.tick, state_diff, node.active_agent));
84 }
85
86 let weights = WeightedIndex::new(
88 tasks
89 .iter()
90 .map(|task| task.weight(node.tick, state_diff, node.active_agent)),
91 )
92 .unwrap();
93
94 Some((weights, tasks))
95 }
96 };
97
98 Edges {
99 unexpanded_tasks,
100 expanded_tasks: Default::default(),
101 }
102 }
103
104 pub fn child_visits(&self) -> usize {
106 self.expanded_tasks
107 .values()
108 .map(|edge| edge.lock().unwrap().visits)
109 .sum()
110 }
111
112 pub fn best_task(
114 &self,
115 agent: AgentId,
116 exploration: f32,
117 range: Range<AgentValue>,
118 ) -> Option<Box<dyn Task<D>>> {
119 let visits = self.child_visits();
120 self.expanded_tasks
121 .iter()
122 .max_by(|(_, a), (_, b)| {
123 let a = a.lock().unwrap();
124 let b = b.lock().unwrap();
125 a.uct(agent, visits, exploration, range.clone())
126 .partial_cmp(&b.uct(agent, visits, exploration, range.clone()))
127 .unwrap()
128 })
129 .map(|(k, _)| k.clone())
130 }
131
132 pub fn q_value(&self, fallback: (usize, f32), agent: AgentId) -> Option<f32> {
136 self.expanded_tasks
137 .values()
138 .map(|edge| {
139 edge.try_lock()
140 .map(|edge| {
141 (
142 edge.visits,
143 edge.q_values.get(&agent).copied().unwrap_or_default(),
144 )
145 })
146 .unwrap_or(fallback)
147 })
148 .fold(None, |acc, (visits, value)| match acc {
149 Some((sum, count)) => Some((sum + visits as f32 * value, count + visits)),
150 None => Some((visits as f32 * value, visits)),
151 })
152 .map(|(sum, count)| sum / count as f32)
153 }
154
155 pub fn expanded_count(&self) -> usize {
157 self.expanded_tasks.len()
158 }
159
160 pub fn unexpanded_count(&self) -> usize {
162 self.unexpanded_tasks
163 .as_ref()
164 .map_or(0, |(_, tasks)| tasks.len())
165 }
166
167 pub fn branching_factor(&self) -> usize {
169 self.expanded_count() + self.unexpanded_count()
170 }
171
172 #[allow(clippy::borrowed_box)]
174 pub fn get_edge(&self, task: &Box<dyn Task<D>>) -> Option<Edge<D>> {
175 self.expanded_tasks.get(task).cloned()
176 }
177
178 pub fn size(&self, task_size: fn(&dyn Task<D>) -> usize) -> usize {
180 let mut size = 0;
181
182 size += mem::size_of::<Self>();
183
184 if let Some((_, tasks)) = self.unexpanded_tasks.as_ref() {
185 for task in tasks {
186 size += task_size(&**task);
187 }
188 }
189
190 for (task, edge) in &self.expanded_tasks {
191 size += task_size(&**task);
192 size += edge.lock().unwrap().size();
193 }
194
195 size
196 }
197}
198
199pub type Edge<D> = Arc<Mutex<EdgeInner<D>>>;
201
202pub struct EdgeInner<D: Domain> {
204 pub(crate) parent: WeakNode<D>,
205 pub(crate) child: WeakNode<D>,
206 pub(crate) visits: usize,
207 pub(crate) q_values: SeededHashMap<AgentId, f32>,
208}
209
210impl<D: Domain> fmt::Debug for EdgeInner<D> {
211 fn fmt(&self, f: &'_ mut fmt::Formatter) -> fmt::Result {
212 f.debug_struct("EdgeInner")
213 .field("parent", &self.parent)
214 .field("child", &self.child)
215 .field("visits", &self.visits)
216 .field("q_values", &self.q_values)
217 .finish()
218 }
219}
220
221pub(crate) fn new_edge<D: Domain>(
223 parent: &Node<D>,
224 child: &Node<D>,
225 agents: &BTreeSet<AgentId>,
226) -> Edge<D> {
227 Arc::new(Mutex::new(EdgeInner {
228 parent: Node::downgrade(parent),
229 child: Node::downgrade(child),
230 visits: Default::default(),
231 q_values: agents.iter().map(|agent| (*agent, 0.)).collect(),
232 }))
233}
234
235impl<D: Domain> EdgeInner<D> {
236 pub fn uct(
238 &self,
239 parent_agent: AgentId,
240 parent_child_visits: usize,
241 exploration: f32,
242 range: Range<AgentValue>,
243 ) -> f32 {
244 if let Some(q_value) = self.q_values.get(&parent_agent) {
246 let exploitation_value =
248 (q_value - *range.start) / (*(range.end - range.start)).max(f32::EPSILON);
249 let exploration_value =
250 ((parent_child_visits as f32).ln() / (self.visits as f32).max(f32::EPSILON)).sqrt();
251 exploitation_value + exploration * exploration_value
252 } else {
253 0.
254 }
255 }
256
257 pub fn visits(&self) -> usize {
259 self.visits
260 }
261
262 pub fn q_value(&self, agent: AgentId) -> f32 {
264 self.q_values.get(&agent).copied().unwrap_or(0.)
265 }
266
267 pub fn child(&self) -> Node<D> {
269 self.child.upgrade().unwrap()
270 }
271
272 pub fn parent(&self) -> Node<D> {
274 self.parent.upgrade().unwrap()
275 }
276
277 pub fn size(&self) -> usize {
279 let mut size = 0;
280
281 size += mem::size_of::<Self>();
282 size += self.q_values.len() * mem::size_of::<(AgentId, f32)>();
283
284 size
285 }
286}