1use ansi_term::Style;
7use npc_engine_core::{
8 ActiveTask, ActiveTasks, AgentId, DefaultPolicyEstimator, Domain, DomainWithPlanningTask,
9 EarlyStopCondition, IdleTask, MCTSConfiguration, PlanningTask, StateDiffRef, StateDiffRefMut,
10 StateValueEstimator, Task, MCTS,
11};
12use std::{
13 collections::HashMap,
14 hash::Hash,
15 sync::{
16 atomic::{AtomicU64, Ordering},
17 Arc,
18 },
19 thread::{self, JoinHandle},
20 time::Duration,
21};
22
23use crate::GlobalDomain;
24
25fn highlight_style() -> Style {
26 ansi_term::Style::new().bold().fg(ansi_term::Colour::Green)
27}
28fn highlight_tick(tick: u64) -> String {
29 let tick_text = format!("T{}", tick);
30 highlight_style().paint(&tick_text).to_string()
31}
32
33fn highlight_agent(agent_id: AgentId) -> String {
34 let tick_text = format!("{}", agent_id);
35 highlight_style().paint(&tick_text).to_string()
36}
37
38pub trait ExecutableDomain: Domain {
40 fn apply_diff(diff: Self::Diff, state: &mut Self::State);
42}
43impl<
45 S: std::fmt::Debug + Sized + Clone + Hash + Eq,
46 DA: std::fmt::Debug + Default,
47 D: Domain<State = S, Diff = Option<S>, DisplayAction = DA>,
48 > ExecutableDomain for D
49{
50 fn apply_diff(diff: Self::Diff, state: &mut Self::State) {
51 if let Some(diff) = diff {
52 *state = diff;
53 }
54 }
55}
56
57pub trait ExecutorState<D: Domain> {
65 fn create_state_value_estimator(&self) -> Box<dyn StateValueEstimator<D> + Send> {
67 Box::new(DefaultPolicyEstimator {})
68 }
69 fn post_action_execute_hook(
71 &mut self,
72 _state: &D::State,
73 _diff: &D::Diff,
74 _active_task: &ActiveTask<D>,
75 _queue: &mut ActiveTasks<D>,
76 ) {
77 }
78 fn post_mcts_run_hook(&mut self, _mcts: &MCTS<D>, _last_active_task: &ActiveTask<D>) {}
80}
81
82pub trait ExecutorStateLocal<D: Domain> {
90 fn create_initial_state(&self) -> D::State;
92 fn init_task_queue(&self, state: &D::State) -> ActiveTasks<D>;
94 fn keep_agent(&self, _tick: u64, _state: &D::State, _agent: AgentId) -> bool {
96 true
97 }
98}
99
100pub trait ExecutorStateGlobal<D: GlobalDomain> {
108 fn create_initial_state(&self) -> D::GlobalState;
110 fn init_task_queue(&self, state: &D::GlobalState) -> ActiveTasks<D>;
112 fn keep_agent(&self, _tick: u64, _state: &D::GlobalState, _agent: AgentId) -> bool {
114 true
115 }
116 fn keep_execution(&self, _tick: u64, _queue: &ActiveTasks<D>, _state: &D::GlobalState) -> bool {
118 true
119 }
120 fn post_step_hook(&self, _tick: u64, _state: &D::GlobalState) {}
122}
123
124struct ExecutionQueue<D>
129where
130 D: Domain,
131{
132 task_queue: ActiveTasks<D>,
134}
135impl<D> ExecutionQueue<D>
136where
137 D: Domain,
138{
139 pub fn new(task_queue: ActiveTasks<D>) -> Self {
140 Self { task_queue }
141 }
142
143 pub fn is_empty(&self) -> bool {
144 self.task_queue.is_empty()
145 }
146
147 pub fn size(&self) -> usize {
148 self.task_queue.len()
149 }
150
151 pub fn pop_first_task(&mut self) -> ActiveTask<D> {
152 let active_task = self.task_queue.iter().next().unwrap().clone();
154 self.task_queue.remove(&active_task);
155 active_task
156 }
157
158 pub fn execute_task<S, C>(
159 &mut self,
160 active_task: &ActiveTask<D>,
161 state: &D::State,
162 executor_state: &mut S,
163 mut new_agents_cb: C,
164 ) -> (D::Diff, Option<Box<dyn Task<D>>>)
165 where
166 S: ExecutorState<D>,
167 C: FnMut(&Vec<ActiveTask<D>>),
168 {
169 let active_agent = active_task.agent;
170 let tick = active_task.end;
171
172 let mut diff = D::Diff::default();
174 let state_diff = StateDiffRef::<D>::new(state, &diff);
175 if log::log_enabled!(log::Level::Info) {
176 let highlight_style = highlight_style();
177 let tick_text = format!("T{}", tick);
178 let task_name = format!("{:?}", active_task.task);
179 let agent_id_text = format!("A{}", active_agent.0);
180 log::info!(
181 "\n{}, State:\n{}\n{} task to be executed: {}",
182 highlight_style.paint(&tick_text),
183 D::get_state_description(state_diff),
184 highlight_style.paint(&agent_id_text),
185 highlight_style.paint(&task_name)
186 );
187 }
188
189 let is_task_valid = active_task.task.is_valid(tick, state_diff, active_agent);
191 if is_task_valid {
192 log::info!("Valid task, executing...");
193 let state_diff_mut = StateDiffRefMut::new(state, &mut diff);
194 let new_task = active_task.task.execute(tick, state_diff_mut, active_agent);
195 let mut new_agents_tasks = D::get_new_agents(StateDiffRef::new(state, &diff))
197 .into_iter()
198 .map(|new_agent| ActiveTask::new_idle(tick, new_agent, active_agent))
199 .collect();
200 new_agents_cb(&new_agents_tasks);
202 self.task_queue.extend(new_agents_tasks.drain(..));
204 executor_state.post_action_execute_hook(
206 state,
207 &diff,
208 active_task,
209 &mut self.task_queue,
210 );
211 (diff, new_task)
212 } else {
213 log::info!("Invalid task!");
214 (diff, None)
215 }
216 }
217
218 pub fn queue_task(
219 &mut self,
220 tick: u64,
221 active_agent: AgentId,
222 new_task: Box<dyn Task<D>>,
223 state: &D::State,
224 ) -> ActiveTask<D> {
225 let diff = D::Diff::default();
226 let state_diff = StateDiffRef::new(state, &diff);
227 if log::log_enabled!(log::Level::Info) {
228 let new_active_task = ActiveTask::new(active_agent, new_task.clone(), tick, state_diff);
229 log::info!(
230 "Queuing new task for {} until {}: {:?}",
231 highlight_agent(active_agent),
232 highlight_tick(new_active_task.end),
233 new_task
234 );
235 }
236 let new_active_task = ActiveTask::new(active_agent, new_task, tick, state_diff);
237 self.task_queue.insert(new_active_task.clone());
238 new_active_task
239 }
240}
241
242pub struct SimpleExecutor<'a, D, S>
249where
250 D: ExecutableDomain,
251 D::State: Clone,
252 S: ExecutorState<D> + ExecutorStateLocal<D>,
253{
254 mcts_config: MCTSConfiguration,
256 executor_state: &'a mut S,
258 state: D::State,
260 queue: ExecutionQueue<D>,
262}
263impl<'a, D, S> SimpleExecutor<'a, D, S>
264where
265 D: ExecutableDomain,
266 D::State: Clone,
267 S: ExecutorState<D> + ExecutorStateLocal<D>,
268{
269 pub fn new(mcts_config: MCTSConfiguration, executor_state: &'a mut S) -> Self {
271 let state = executor_state.create_initial_state();
272 let task_queue = executor_state.init_task_queue(&state);
273 let queue = ExecutionQueue::new(task_queue);
274 Self {
275 mcts_config,
276 state,
277 queue,
278 executor_state,
279 }
280 }
281
282 pub fn step(&mut self) -> bool {
284 if self.queue.is_empty() {
285 return false;
286 }
287
288 let active_task = self.queue.pop_first_task();
290 let active_agent = active_task.agent;
291 let tick = active_task.end;
292
293 if !self
295 .executor_state
296 .keep_agent(tick, &self.state, active_agent)
297 {
298 return true;
299 }
300
301 let (diff, new_task) =
303 self.queue
304 .execute_task(&active_task, &self.state, self.executor_state, |_| {});
305 D::apply_diff(diff, &mut self.state);
306
307 let new_task = new_task.unwrap_or_else(|| {
309 log::info!("No subsequent task, planning!");
310 let mut mcts = self.new_mcts(tick, active_agent);
311 let new_task = mcts.run().unwrap_or_else(|| Box::new(IdleTask));
312 self.executor_state.post_mcts_run_hook(&mcts, &active_task);
313 new_task
314 });
315
316 self.queue
318 .queue_task(tick, active_agent, new_task, &self.state);
319
320 true
321 }
322
323 fn new_mcts(&self, tick: u64, active_agent: AgentId) -> MCTS<D> {
324 MCTS::<D>::new_with_tasks(
325 self.state.clone(),
326 active_agent,
327 tick,
328 self.queue.task_queue.clone(),
329 self.mcts_config.clone(),
330 self.executor_state.create_state_value_estimator(),
331 None,
332 )
333 }
334}
335
336pub fn run_simple_executor<D, S>(mcts_config: &MCTSConfiguration, executor_state: &mut S)
338where
339 D: ExecutableDomain,
340 D::State: Clone,
341 S: ExecutorState<D> + ExecutorStateLocal<D>,
342{
343 let mut executor = SimpleExecutor::<D, S>::new(mcts_config.clone(), executor_state);
345 loop {
346 if !executor.step() {
347 break;
348 }
349 }
350}
351
352pub struct ThreadedExecutor<'a, D, S>
360where
361 D: DomainWithPlanningTask + GlobalDomain,
362 D::State: Clone + Send,
363 D::Diff: Send + Sync,
364 S: ExecutorState<D> + ExecutorStateGlobal<D>,
365{
366 mcts_config: MCTSConfiguration,
368 executor_state: &'a mut S,
370 state: D::GlobalState,
372 queue: ExecutionQueue<D>,
374 task_history: HashMap<AgentId, ActiveTask<D>>,
379 threads: HashMap<AgentId, JoinHandle<MCTS<D>>>,
381 tick: Arc<AtomicU64>,
383}
384impl<'a, D, S> ThreadedExecutor<'a, D, S>
385where
386 D: DomainWithPlanningTask + GlobalDomain,
387 D::State: Clone + Send,
388 D::Diff: Send + Sync,
389 S: ExecutorState<D> + ExecutorStateGlobal<D>,
390{
391 pub fn new(mcts_config: MCTSConfiguration, executor_state: &'a mut S) -> Self {
393 let state = executor_state.create_initial_state();
394 let task_queue = executor_state.init_task_queue(&state);
395 let task_history = task_queue
396 .iter()
397 .map(|active_task| (active_task.agent, active_task.clone()))
398 .collect();
399 let queue = ExecutionQueue::new(task_queue);
400 Self {
401 mcts_config,
402 state,
403 queue,
404 task_history,
405 threads: Default::default(),
406 tick: Arc::new(AtomicU64::new(0)),
407 executor_state,
408 }
409 }
410
411 fn new_mcts(&self, tick: u64, active_agent: AgentId) -> MCTS<D> {
412 let planning_task_duration = self
413 .mcts_config
414 .planning_task_duration
415 .expect("Planning task must have non-zero duration for threaded executor");
416 let tick_atomic = self.tick.clone();
417 let early_stop_condition: Option<Box<EarlyStopCondition>> = Some(Box::new(move || {
418 tick_atomic.load(Ordering::Relaxed) >= tick + planning_task_duration.get() - 1
419 }));
420 MCTS::<D>::new_with_tasks(
421 D::derive_local_state(&self.state, active_agent),
422 active_agent,
423 tick,
424 self.queue.task_queue.clone(),
425 self.mcts_config.clone(),
426 self.executor_state.create_state_value_estimator(),
427 early_stop_condition,
428 )
429 }
430
431 fn block_on_planning(&mut self, tick: u64) {
434 let active_tasks = self.queue.task_queue.clone();
436 for active_task in active_tasks
437 .iter()
438 .filter(|task| task.end <= tick && task.task.downcast_ref::<PlanningTask>().is_some())
439 {
440 let active_agent = active_task.agent;
441 debug_assert!(active_task.end == tick,
442 "Processing an active planning task at tick {tick} but it should have been processed at tick {}.", active_task.end
443 );
444
445 let thread = self.threads.remove(&active_agent);
447 assert!(thread.is_some(),
448 "There is no planning thread for {active_agent} even though there is an active_task for it."
449 );
450 let thread = thread.unwrap();
451
452 let mcts = thread.join();
454 assert!(
455 mcts.is_ok(),
456 "Could not join planning thread of {active_agent}! Probably it panicked!"
457 );
458 let mcts = mcts.unwrap();
459 self.executor_state.post_mcts_run_hook(&mcts, active_task);
460
461 if log::log_enabled!(log::Level::Info) {
463 log::info!(
464 "{} - {} finished planning. Looking for best task...",
465 highlight_tick(tick),
466 highlight_agent(active_agent)
467 );
468 }
469 let best_task = mcts.best_task_with_history(&self.task_history);
470 log::info!("Best Task: {best_task:?}");
471
472 self.queue.task_queue.remove(active_task);
473 let local_state = D::derive_local_state(&self.state, active_agent);
474 let new_active_task =
475 self.queue
476 .queue_task(tick, active_agent, best_task.clone(), &local_state);
477 self.task_history.insert(active_agent, new_active_task);
478 }
479 }
480
481 fn execute_finished_tasks(&mut self, tick: u64) {
483 let active_tasks = self.queue.task_queue.clone();
484 for active_task in active_tasks.iter().filter(|task| task.end <= tick) {
485 self.queue.task_queue.remove(active_task);
487 let active_agent = active_task.agent;
488 debug_assert!(
489 active_task.end == tick,
490 "Processing an active task at tick {tick} but it ended at tick {}.",
491 active_task.end
492 );
493
494 if !self
496 .executor_state
497 .keep_agent(tick, &self.state, active_agent)
498 {
499 continue;
500 }
501
502 let local_state = D::derive_local_state(&self.state, active_agent);
504 let (diff, new_task) = self.queue.execute_task(
505 active_task,
506 &local_state,
507 self.executor_state,
508 |new_agents_tasks| {
509 for new_task in new_agents_tasks.iter() {
510 self.task_history.insert(new_task.agent, new_task.clone());
511 }
512 },
513 );
514 D::apply(&mut self.state, &local_state, &diff);
515 let local_state = D::derive_local_state(&self.state, active_agent);
516
517 let new_task = new_task.unwrap_or_else(|| {
519 Box::new(PlanningTask(
520 self.mcts_config.planning_task_duration.unwrap(),
521 ))
522 });
523
524 let end_tick = self
526 .queue
527 .queue_task(tick, active_agent, new_task.clone(), &local_state)
528 .end;
529
530 if new_task.downcast_ref::<PlanningTask>().is_some() {
532 let mut mcts = self.new_mcts(tick, active_agent);
533 if log::log_enabled!(log::Level::Info) {
534 log::info!(
535 "{} - {} starts planning until {}.",
536 highlight_tick(tick),
537 active_agent,
538 highlight_tick(end_tick)
539 );
540 log::trace!("Active Tasks:");
541 for active_task in &self.queue.task_queue {
542 log::trace!(
543 "{}: {} {:?}",
544 active_task.agent,
545 highlight_tick(active_task.end),
546 active_task.task
547 );
548 }
549 }
550 let handle = thread::Builder::new()
551 .name(format!("plan-{}", active_agent.0))
552 .spawn(move || {
553 mcts.run();
556 mcts
557 })
558 .unwrap();
559 self.threads.insert(active_task.agent, handle);
560 }
561 }
562 }
563
564 pub fn step(&mut self) -> bool {
568 if self.queue.is_empty() {
569 return false;
570 }
571
572 let tick = self.tick.load(Ordering::Relaxed);
573 if !self
574 .executor_state
575 .keep_execution(tick, &self.queue.task_queue, &self.state)
576 {
577 return false;
578 }
579 self.block_on_planning(tick);
580 self.execute_finished_tasks(tick);
581 self.executor_state.post_step_hook(tick, &self.state);
582
583 self.tick.fetch_add(1, Ordering::Relaxed);
584 true
585 }
586
587 pub fn stop(&mut self) {
589 self.tick.store(u64::MAX, Ordering::Relaxed);
591 self.threads.drain().for_each(|(_, thread)| {
593 let _ = thread.join();
594 });
595 }
596
597 pub fn state(&self) -> &D::GlobalState {
599 &self.state
600 }
601
602 pub fn agents_count(&self) -> usize {
604 self.queue.size()
605 }
606}
607
608pub fn run_threaded_executor<D, S>(
612 mcts_config: &MCTSConfiguration,
613 executor_state: &mut S,
614 step_duration: Duration,
615) where
616 D: DomainWithPlanningTask + GlobalDomain,
617 D::State: Clone + Send,
618 D::Diff: Send + Sync,
619 S: ExecutorState<D> + ExecutorStateGlobal<D>,
620{
621 let mut executor = ThreadedExecutor::<D, S>::new(mcts_config.clone(), executor_state);
623 loop {
624 if !executor.step() {
625 break;
626 }
627 thread::sleep(step_duration);
628 }
629}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634 use crate::*;
635 use core::time;
636 use npc_engine_core::{
637 ActiveTask, ActiveTasks, AgentId, AgentValue, Behavior, Domain, IdleTask,
638 MCTSConfiguration, StateDiffRef, Task,
639 };
640 use std::{collections::BTreeSet, num::NonZeroU64, thread};
641
642 #[test]
643 fn threaded_executor_trivial_domain() {
644 #[derive(Debug)]
645 enum DisplayAction {
646 Idle,
648 Plan,
649 }
650 impl Default for DisplayAction {
651 fn default() -> Self {
652 Self::Idle
653 }
654 }
655
656 struct TrivialDomain;
657 impl Domain for TrivialDomain {
658 type State = ();
659 type Diff = ();
660 type DisplayAction = DisplayAction;
661
662 fn list_behaviors() -> &'static [&'static dyn Behavior<Self>] {
663 &[&TrivialBehavior]
664 }
665
666 fn get_current_value(
667 _tick: u64,
668 _state_diff: StateDiffRef<Self>,
669 _agent: AgentId,
670 ) -> AgentValue {
671 AgentValue::new(0.).unwrap()
672 }
673
674 fn update_visible_agents(
675 _start_tick: u64,
676 _tick: u64,
677 _state_diff: StateDiffRef<Self>,
678 agent: AgentId,
679 agents: &mut BTreeSet<AgentId>,
680 ) {
681 agents.insert(agent);
682 }
683
684 fn display_action_task_planning() -> Self::DisplayAction {
685 DisplayAction::Plan
686 }
687 }
688 impl GlobalDomain for TrivialDomain {
689 type GlobalState = ();
690 fn derive_local_state(
691 _global_state: &Self::GlobalState,
692 _agent: AgentId,
693 ) -> Self::State {
694 }
695 fn apply(
696 _global_state: &mut Self::GlobalState,
697 _local_state: &Self::State,
698 _diff: &Self::Diff,
699 ) {
700 }
701 }
702 impl DomainWithPlanningTask for TrivialDomain {}
703
704 #[derive(Copy, Clone, Debug)]
705 struct TrivialBehavior;
706 impl Behavior<TrivialDomain> for TrivialBehavior {
707 fn add_own_tasks(
708 &self,
709 _tick: u64,
710 _state_diff: StateDiffRef<TrivialDomain>,
711 _agent: AgentId,
712 tasks: &mut Vec<Box<dyn Task<TrivialDomain>>>,
713 ) {
714 tasks.push(Box::new(IdleTask));
715 }
716
717 fn is_valid(
718 &self,
719 _tick: u64,
720 _state_diff: StateDiffRef<TrivialDomain>,
721 _agent: AgentId,
722 ) -> bool {
723 true
724 }
725 }
726
727 struct TrivialExecutorState;
728 impl ExecutorStateGlobal<TrivialDomain> for TrivialExecutorState {
729 fn create_initial_state(&self) {}
730 fn init_task_queue(&self, _: &()) -> ActiveTasks<TrivialDomain> {
731 vec![ActiveTask::new_with_end(0, AgentId(0), Box::new(IdleTask))]
732 .into_iter()
733 .collect()
734 }
735 }
736 impl ExecutorState<TrivialDomain> for TrivialExecutorState {}
737
738 env_logger::init();
739 let mcts_config = MCTSConfiguration {
740 allow_invalid_tasks: false,
741 visits: 5,
742 depth: 100,
743 exploration: 1.414,
744 discount_hl: 30.,
745 seed: None,
746 planning_task_duration: Some(NonZeroU64::new(10).unwrap()),
747 };
748 let mut executor_state = TrivialExecutorState;
749 let mut executor = ThreadedExecutor::new(mcts_config, &mut executor_state);
750 let one_millis = time::Duration::from_millis(1);
751 for _ in 0..5 {
752 executor.step();
753 thread::sleep(one_millis);
754 }
755 }
756}