manual_executor/
executor.rs1use crate::Key;
2use crate::TaskWake;
3use core::{
4 future::Future,
5 ops::DerefMut,
6 pin::Pin,
7 task::{Context, Waker},
8};
9use std::{
10 collections::BTreeMap,
11 sync::{Arc, Mutex},
12};
13
14type Task = Pin<Box<dyn Future<Output = ()> + Send>>;
15
16pub struct ManualExecutor {
17 tasks: Mutex<BTreeMap<Key, Task>>,
18}
19
20impl ManualExecutor {
21 pub fn new() -> Arc<Self> {
22 Arc::new(Self {
23 tasks: Default::default(),
24 })
25 }
26
27 pub fn task_count(&self) -> usize {
28 let tasks = self.tasks.lock().unwrap();
29 tasks.len()
30 }
31
32 pub fn spawn_wake(self: &Arc<Self>, task: impl Future<Output = ()> + Send + 'static) -> Key {
33 let key = self.spawn(task);
34 self.wake(key);
35 key
36 }
37
38 pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) -> Key {
39 let key = Key::new();
40 let mut tasks = self.tasks.lock();
41 let tasks = tasks.as_mut().unwrap();
42 assert!(tasks.insert(key, Box::pin(task)).is_none());
43 key
44 }
45
46 pub fn wake(self: &Arc<Self>, key: Key) {
47 let mut tasks = self.tasks.lock().unwrap();
48 let tasks = tasks.deref_mut();
49
50 let Some(mut task) = tasks.remove(&key) else {
51 return;
52 };
53
54 let pending = self.poll_task(key, &mut task);
55 if pending {
56 assert!(tasks.insert(key, task).is_none());
57 }
58 }
59
60 pub fn wake_all(self: &Arc<Self>) {
61 let mut tasks = self.tasks.lock().unwrap();
62 let tasks = tasks.deref_mut();
63 let mut tasks_new = BTreeMap::new();
64
65 while let Some((key, mut task)) = tasks.pop_first() {
66 let pending = self.poll_task(key, &mut task);
67 if pending {
68 assert!(tasks_new.insert(key, task).is_none());
69 }
70 }
71
72 *tasks = tasks_new;
73 }
74
75 fn poll_task(self: &Arc<Self>, key: Key, task: &mut Task) -> bool {
76 let wake = TaskWake::new(self.clone(), key);
77 let waker = Waker::from(wake.clone());
78 let mut context = Context::from_waker(&waker);
79
80 let poll_result = task.as_mut().poll(&mut context);
81 poll_result.is_pending()
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use once_cell::sync::Lazy;
89 use std::{collections::HashSet, sync::Mutex};
90
91 #[test]
92 fn test_wake_all() {
93 pub static MANUAL_EXECUTOR: Lazy<Arc<ManualExecutor>> = Lazy::new(ManualExecutor::new);
94
95 let set = Arc::new(Mutex::new(HashSet::new()));
96
97 {
98 let set = set.clone();
99 MANUAL_EXECUTOR.spawn(async move {
100 let mut set = set.lock();
101 set.as_mut().unwrap().insert("a");
102 });
103 }
104 assert_eq!(MANUAL_EXECUTOR.task_count(), 1);
105
106 {
107 let set = set.clone();
108 MANUAL_EXECUTOR.spawn(async move {
109 let mut set = set.lock();
110 set.as_mut().unwrap().insert("b");
111 });
112 }
113 assert_eq!(MANUAL_EXECUTOR.task_count(), 2);
114
115 MANUAL_EXECUTOR.wake_all();
116 assert_eq!(MANUAL_EXECUTOR.task_count(), 0);
117
118 let actual = set.lock().unwrap();
119 let expected: HashSet<_> = ["a", "b"].into();
120
121 assert_eq!(*actual, expected);
122 }
123
124 #[test]
125 fn test_wake_spawn() {
126 pub static MANUAL_EXECUTOR: Lazy<Arc<ManualExecutor>> = Lazy::new(ManualExecutor::new);
127
128 let set = Arc::new(Mutex::new(HashSet::new()));
129
130 {
131 let set = set.clone();
132 MANUAL_EXECUTOR.spawn_wake(async move {
133 let mut set = set.lock();
134 set.as_mut().unwrap().insert("a");
135 });
136 }
137 assert_eq!(MANUAL_EXECUTOR.task_count(), 0);
138
139 {
140 let set = set.clone();
141 MANUAL_EXECUTOR.spawn_wake(async move {
142 let mut set = set.lock();
143 set.as_mut().unwrap().insert("b");
144 });
145 }
146 assert_eq!(MANUAL_EXECUTOR.task_count(), 0);
147
148 let actual = set.lock().unwrap();
149 let expected: HashSet<_> = ["a", "b"].into();
150
151 assert_eq!(*actual, expected);
152 }
153}