1use std::thread::{Thread, Result, JoinHandle, current, sleep, spawn, park};
4use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
5use std::task::{Poll, Wake, Waker, Context};
6use std::sync::atomic::Ordering::SeqCst;
7use std::time::{Instant, Duration};
8use std::pin::{pin, Pin};
9use std::future::Future;
10use std::sync::Arc;
11
12#[cfg(feature = "monitor")]
13use serde::Serialize;
14
15use async_fifo::{Sender, Receiver};
16use async_fifo::non_blocking::Producer;
17
18pub mod utils;
19
20#[cfg(feature = "monitor")]
21mod monitor;
22
23#[allow(dead_code)]
24enum TaskExec {
25 Polling(TaskId),
26 PollReady(TaskId),
27 PollPending(TaskId),
28}
29
30#[cfg_attr(feature = "monitor", derive(Serialize))]
31#[allow(dead_code)]
32struct TaskDecl {
33 id: TaskId,
34 name: String,
35 runner: usize,
36}
37
38type TaskId = usize;
39
40pub struct Task {
44 inner: Pin<Box<dyn Future<Output = ()> + Send>>,
45}
46
47impl<F: Future<Output = ()> + Send + 'static> From<F> for Task {
48 fn from(fut: F) -> Self {
49 Self { inner: Box::pin(fut) }
50 }
51}
52
53const FLAGS: usize = 64;
54
55struct WakerCommon {
56 thread: Thread,
57 ready_flags: AtomicU64,
58 task_rx_flag: AtomicBool,
59}
60
61struct ThreadWaker {
62 common: Arc<WakerCommon>,
63 index: usize,
64}
65
66impl Wake for ThreadWaker {
67 fn wake(self: Arc<Self>) {
68 Self::wake_by_ref(&self)
69 }
70
71 fn wake_by_ref(self: &Arc<Self>) {
72 if self.index < FLAGS {
73 let mask = 1 << self.index;
74 self.common.ready_flags.fetch_or(mask, SeqCst);
75 } else {
76 self.common.task_rx_flag.store(true, SeqCst);
77 }
78
79 self.common.thread.unpark();
80 }
81}
82
83fn init_waker(i: usize, common: &Arc<WakerCommon>) -> Waker {
84 let waker = ThreadWaker {
85 common: common.clone(),
86 index: i,
87 };
88
89 Waker::from(Arc::new(waker))
90}
91
92struct TaskData {
93 task: Task,
94 id: TaskId,
95}
96
97fn try_tx_exec(tx: &Option<Producer<(TaskExec, Instant)>>, exec: TaskExec) {
98 if let Some(tx) = tx {
99 tx.send((exec, Instant::now()));
100 }
101}
102
103fn runner(
104 mut rx_tasks: Receiver<TaskData>,
105 tx_mon_exec: Option<Producer<(TaskExec, Instant)>>,
106) {
107 let common = Arc::new(WakerCommon {
108 thread: current(),
109 ready_flags: AtomicU64::new(0),
110 task_rx_flag: AtomicBool::new(true),
111 });
112
113 let new_task_waker = init_waker(FLAGS, &common);
114
115 let mut receiver = pin!(rx_tasks.recv());
116 let mut tasks: Vec<Option<TaskData>> = Vec::new();
117 let mut can_receive = true;
118 let mut num_started = 0;
119 let mut num_ended = 0;
120
121 while num_started != num_ended || can_receive {
122 let mut ready_flags = common.ready_flags.swap(0, SeqCst);
124 let task_rx_flag = common.task_rx_flag.swap(false, SeqCst);
125
126 if ready_flags == 0 && !task_rx_flag {
128 park();
130 continue;
131 }
132
133 if task_rx_flag {
134 let mut context = Context::from_waker(&new_task_waker);
135 can_receive = loop {
136 let task = match Future::poll(receiver.as_mut(), &mut context) {
137 Poll::Ready(Ok(new_task)) => new_task,
138 Poll::Ready(Err(_)) => break false,
139 Poll::Pending => break true,
140 };
141
142 let len = tasks.len();
143 let i = tasks.iter().position(Option::is_none).unwrap_or(len);
144
145 match i < len {
146 true => tasks[i] = Some(task),
147 false => tasks.push(Some(task)),
148 }
149
150 ready_flags |= 1 << (i % FLAGS);
151 num_started += 1;
152 };
153 }
154
155
156 for (i, maybe_task) in tasks.iter_mut().enumerate() {
157 let slot = i % FLAGS;
158 let ready = (ready_flags & (1 << slot)) != 0;
159
160 if !ready { continue; }
161 let Some(data) = maybe_task else { continue };
162
163 let waker = init_waker(slot, &common);
164 let fut = data.task.inner.as_mut();
165 let mut context = Context::from_waker(&waker);
166
167 try_tx_exec(&tx_mon_exec, TaskExec::Polling(data.id));
168
169 if let Poll::Ready(()) = Future::poll(fut, &mut context) {
170 try_tx_exec(&tx_mon_exec, TaskExec::PollReady(data.id));
171 num_ended += 1;
172 *maybe_task = None;
173 } else {
174 try_tx_exec(&tx_mon_exec, TaskExec::PollPending(data.id));
175 }
176 }
177 }
178
179 println!("thread down, ran {num_ended} task(s)");
180}
181
182pub struct Executor {
184 tx_tasks: Vec<Sender<TaskData>>,
185 handles: Vec<JoinHandle<()>>,
186 next_id: AtomicUsize,
187 #[cfg(feature = "monitor")]
188 tx_mon_decl: Option<Producer<TaskDecl>>,
189}
190
191impl Executor {
192 #[allow(unused_mut)]
194 #[allow(unused_variables)]
195 pub fn new(threads: usize, monitor_port: Option<u16>) -> Self {
196 let mut tx_tasks = Vec::new();
197 let mut handles = Vec::new();
198 let mut tx_mon_exec = None;
199 let mut monitor_task: Option<Task> = None;
200
201 #[cfg(feature = "monitor")]
202 let tx_mon_decl = if let Some(port) = monitor_port {
203 type Fifo = async_fifo::fifo::DefaultBlockSize;
204 let (tx_exec, rx_exec) = Fifo::non_blocking();
205 let (tx_name, rx_name) = Fifo::non_blocking();
206 let task = monitor::server(port, rx_exec, rx_name);
207 monitor_task = Some(task.into());
208 tx_mon_exec = Some(tx_exec);
209 Some(tx_name)
210 } else {
211 None
212 };
213
214 for _ in 0..threads {
215 let (tx, rx) = async_fifo::new();
216 let tx_mon_exec = tx_mon_exec.clone();
217 handles.push(spawn(|| runner(rx, tx_mon_exec)));
218 tx_tasks.push(tx);
219 }
220
221 let this = Self {
222 tx_tasks,
223 handles,
224 next_id: AtomicUsize::new(0),
225 #[cfg(feature = "monitor")]
226 tx_mon_decl,
227 };
228
229 if let Some(monitor_task) = monitor_task {
230 this.spawn_with_name(monitor_task, "monitor-server");
231 }
232
233 this
234 }
235
236 pub fn spawn<T: Into<Task>>(&self, task: T) {
238 self.spawn_with_name(task, "[unnamed]")
239 }
240
241 pub fn spawn_with_name<T: Into<Task>, S: Into<String>>(&self, task: T, _name: S) {
243 let task = task.into();
244 let id = self.next_id.fetch_add(1, SeqCst);
245
246 let data = TaskData {
247 task,
248 id: id,
249 };
250
251 let i = id % self.tx_tasks.len();
252
253 #[cfg(feature = "monitor")]
254 if let Some(tx_mon_decl) = &self.tx_mon_decl {
255 let decl = TaskDecl {
256 id,
257 name: _name.into(),
258 runner: i,
259 };
260
261 tx_mon_decl.send(decl);
262 }
263
264 self.tx_tasks[i].send(data);
265 }
266
267 pub fn join(mut self) -> Result<()> {
269 self.tx_tasks.drain(..);
270
271 for handle in self.handles {
272 handle.join()?;
273 }
274
275 Ok(())
276 }
277
278 pub fn join_arc(this: Arc<Self>) -> Result<()> {
279 let exec = loop {
280 sleep(Duration::from_millis(100));
281
282 if Arc::strong_count(&this) == 1 {
283 let exec = Arc::into_inner(this);
284 break exec.expect("Failed to recover Executor");
285 }
286 };
287
288 exec.join()
289 }
290}
291
292#[test]
293fn test_bad_timer() {
294 use std::time::Instant;
295
296 struct Timer {
297 expiration: Instant,
298 }
299
300 impl Future for Timer {
301 type Output = ();
302 fn poll(self: Pin<&mut Self>, _ctx: &mut Context) -> Poll<()> {
303 while Instant::now() < self.expiration {}
304 Poll::Ready(())
305 }
306 }
307
308 const ONE_SEC: Duration = Duration::from_secs(1);
309
310 fn sleep(duration: Duration) -> Timer {
311 Timer {
312 expiration: Instant::now() + duration,
313 }
314 }
315
316 let exec = Executor::new(4, Some(9090));
317
318 for _ in 0.. {
319 let task = async {
320 sleep(ONE_SEC).await;
321 println!("Done");
322 };
323
324 exec.spawn_with_name(task, "test");
325 std::thread::sleep(std::time::Duration::from_secs(3));
326 }
327}
328
329#[test]
330fn test_monitor() {
331 let exec = Executor::new(4, Some(9090));
332
333 for _ in 0.. {
334 let task = async {
335 println!("Done");
336 };
337
338 exec.spawn_with_name(task, "test");
339 std::thread::sleep(std::time::Duration::from_secs(3));
340 }
341}
342
343#[test]
344fn test_runners() {
345 let (tx_data, rx_data) = async_fifo::fifo::LargeBlockSize::channel();
346
347 let executor = Executor::new(8, None);
348 let counter = Arc::new(AtomicUsize::new(0));
349
350 for _ in 0..256 {
351 let counter = counter.clone();
352 let mut rx_data = rx_data.clone();
353
354 let task = async move {
355 rx_data.recv_array::<8192>().await.unwrap();
356 counter.fetch_add(1, SeqCst);
357 };
358
359 executor.spawn(task);
360 }
361
362 let data = [(); 256 * 8192];
363 tx_data.send_iter(data.iter().cloned());
364
365 executor.join().unwrap();
366
367 assert_eq!(counter.load(SeqCst), 256);
368}