manual_executor/
executor.rs

1use crate::Key;
2use crate::TaskWake;
3use core::{
4  future::Future,
5  ops::DerefMut,
6  pin::Pin,
7  task::{Context, Waker},
8};
9use std::{
10  collections::BTreeMap,
11  sync::{Arc, Mutex},
12};
13
14type Task = Pin<Box<dyn Future<Output = ()> + Send>>;
15
16pub struct ManualExecutor {
17  tasks: Mutex<BTreeMap<Key, Task>>,
18}
19
20impl ManualExecutor {
21  pub fn new() -> Arc<Self> {
22    Arc::new(Self {
23      tasks: Default::default(),
24    })
25  }
26
27  pub fn task_count(&self) -> usize {
28    let tasks = self.tasks.lock().unwrap();
29    tasks.len()
30  }
31
32  pub fn spawn_wake(self: &Arc<Self>, task: impl Future<Output = ()> + Send + 'static) -> Key {
33    let key = self.spawn(task);
34    self.wake(key);
35    key
36  }
37
38  pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) -> Key {
39    let key = Key::new();
40    let mut tasks = self.tasks.lock();
41    let tasks = tasks.as_mut().unwrap();
42    assert!(tasks.insert(key, Box::pin(task)).is_none());
43    key
44  }
45
46  pub fn wake(self: &Arc<Self>, key: Key) {
47    let mut tasks = self.tasks.lock().unwrap();
48    let tasks = tasks.deref_mut();
49
50    let Some(mut task) = tasks.remove(&key) else {
51      return;
52    };
53
54    let pending = self.poll_task(key, &mut task);
55    if pending {
56      assert!(tasks.insert(key, task).is_none());
57    }
58  }
59
60  pub fn wake_all(self: &Arc<Self>) {
61    let mut tasks = self.tasks.lock().unwrap();
62    let tasks = tasks.deref_mut();
63    let mut tasks_new = BTreeMap::new();
64
65    while let Some((key, mut task)) = tasks.pop_first() {
66      let pending = self.poll_task(key, &mut task);
67      if pending {
68        assert!(tasks_new.insert(key, task).is_none());
69      }
70    }
71
72    *tasks = tasks_new;
73  }
74
75  fn poll_task(self: &Arc<Self>, key: Key, task: &mut Task) -> bool {
76    let wake = TaskWake::new(self.clone(), key);
77    let waker = Waker::from(wake.clone());
78    let mut context = Context::from_waker(&waker);
79
80    let poll_result = task.as_mut().poll(&mut context);
81    poll_result.is_pending()
82  }
83}
84
85#[cfg(test)]
86mod tests {
87  use super::*;
88  use once_cell::sync::Lazy;
89  use std::{collections::HashSet, sync::Mutex};
90
91  #[test]
92  fn test_wake_all() {
93    pub static MANUAL_EXECUTOR: Lazy<Arc<ManualExecutor>> = Lazy::new(ManualExecutor::new);
94
95    let set = Arc::new(Mutex::new(HashSet::new()));
96
97    {
98      let set = set.clone();
99      MANUAL_EXECUTOR.spawn(async move {
100        let mut set = set.lock();
101        set.as_mut().unwrap().insert("a");
102      });
103    }
104    assert_eq!(MANUAL_EXECUTOR.task_count(), 1);
105
106    {
107      let set = set.clone();
108      MANUAL_EXECUTOR.spawn(async move {
109        let mut set = set.lock();
110        set.as_mut().unwrap().insert("b");
111      });
112    }
113    assert_eq!(MANUAL_EXECUTOR.task_count(), 2);
114
115    MANUAL_EXECUTOR.wake_all();
116    assert_eq!(MANUAL_EXECUTOR.task_count(), 0);
117
118    let actual = set.lock().unwrap();
119    let expected: HashSet<_> = ["a", "b"].into();
120
121    assert_eq!(*actual, expected);
122  }
123
124  #[test]
125  fn test_wake_spawn() {
126    pub static MANUAL_EXECUTOR: Lazy<Arc<ManualExecutor>> = Lazy::new(ManualExecutor::new);
127
128    let set = Arc::new(Mutex::new(HashSet::new()));
129
130    {
131      let set = set.clone();
132      MANUAL_EXECUTOR.spawn_wake(async move {
133        let mut set = set.lock();
134        set.as_mut().unwrap().insert("a");
135      });
136    }
137    assert_eq!(MANUAL_EXECUTOR.task_count(), 0);
138
139    {
140      let set = set.clone();
141      MANUAL_EXECUTOR.spawn_wake(async move {
142        let mut set = set.lock();
143        set.as_mut().unwrap().insert("b");
144      });
145    }
146    assert_eq!(MANUAL_EXECUTOR.task_count(), 0);
147
148    let actual = set.lock().unwrap();
149    let expected: HashSet<_> = ["a", "b"].into();
150
151    assert_eq!(*actual, expected);
152  }
153}