1use std::thread::{Thread, Result, JoinHandle, current, sleep, spawn, park};
4use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
5use std::task::{Poll, Wake, Waker, Context};
6use std::sync::atomic::Ordering::SeqCst;
7use std::time::{Instant, Duration};
8use std::pin::{pin, Pin};
9use std::future::Future;
10use std::sync::Arc;
11
12#[cfg(feature = "monitor")]
13use serde::Serialize;
14
15use async_fifo::{Sender, Receiver};
16use async_fifo::non_blocking::Producer;
17
18pub mod utils;
19
20#[cfg(feature = "time")]
21pub mod time;
22
23#[cfg(feature = "monitor")]
24mod monitor;
25
26#[allow(dead_code)]
27enum TaskExec {
28 Polling(TaskId),
29 PollReady(TaskId),
30 PollPending(TaskId),
31}
32
33#[cfg_attr(feature = "monitor", derive(Serialize))]
34#[allow(dead_code)]
35struct TaskDecl {
36 id: TaskId,
37 name: String,
38 runner: usize,
39}
40
41type TaskId = usize;
42
43pub struct Task {
47 inner: Pin<Box<dyn Future<Output = ()> + Send>>,
48}
49
50impl<F: Future<Output = ()> + Send + 'static> From<F> for Task {
51 fn from(fut: F) -> Self {
52 Self { inner: Box::pin(fut) }
53 }
54}
55
56const FLAGS: usize = 64;
57
58struct WakerCommon {
59 thread: Thread,
60 ready_flags: AtomicU64,
61 task_rx_flag: AtomicBool,
62}
63
64struct ThreadWaker {
65 common: Arc<WakerCommon>,
66 index: usize,
67}
68
69impl Wake for ThreadWaker {
70 fn wake(self: Arc<Self>) {
71 Self::wake_by_ref(&self)
72 }
73
74 fn wake_by_ref(self: &Arc<Self>) {
75 if self.index < FLAGS {
76 let mask = 1 << self.index;
77 self.common.ready_flags.fetch_or(mask, SeqCst);
78 } else {
79 self.common.task_rx_flag.store(true, SeqCst);
80 }
81
82 self.common.thread.unpark();
83 }
84}
85
86fn init_waker(i: usize, common: &Arc<WakerCommon>) -> Waker {
87 let waker = ThreadWaker {
88 common: common.clone(),
89 index: i,
90 };
91
92 Waker::from(Arc::new(waker))
93}
94
95struct TaskData {
96 task: Task,
97 id: TaskId,
98}
99
100fn try_tx_exec(tx: &Option<Producer<(TaskExec, Instant)>>, exec: TaskExec) {
101 if let Some(tx) = tx {
102 tx.send((exec, Instant::now()));
103 }
104}
105
106fn runner(
107 mut rx_tasks: Receiver<TaskData>,
108 tx_mon_exec: Option<Producer<(TaskExec, Instant)>>,
109) {
110 let common = Arc::new(WakerCommon {
111 thread: current(),
112 ready_flags: AtomicU64::new(0),
113 task_rx_flag: AtomicBool::new(true),
114 });
115
116 let new_task_waker = init_waker(FLAGS, &common);
117
118 let mut receiver = pin!(rx_tasks.recv());
119 let mut tasks: Vec<Option<TaskData>> = Vec::new();
120 let mut can_receive = true;
121 let mut num_started = 0;
122 let mut num_ended = 0;
123
124 while num_started != num_ended || can_receive {
125 let mut ready_flags = common.ready_flags.swap(0, SeqCst);
127 let task_rx_flag = common.task_rx_flag.swap(false, SeqCst);
128
129 if ready_flags == 0 && !task_rx_flag {
131 park();
133 continue;
134 }
135
136 if task_rx_flag {
137 let mut context = Context::from_waker(&new_task_waker);
138 can_receive = loop {
139 let task = match Future::poll(receiver.as_mut(), &mut context) {
140 Poll::Ready(Ok(new_task)) => new_task,
141 Poll::Ready(Err(_)) => break false,
142 Poll::Pending => break true,
143 };
144
145 let len = tasks.len();
146 let i = tasks.iter().position(Option::is_none).unwrap_or(len);
147
148 match i < len {
149 true => tasks[i] = Some(task),
150 false => tasks.push(Some(task)),
151 }
152
153 ready_flags |= 1 << (i % FLAGS);
154 num_started += 1;
155 };
156 }
157
158
159 for (i, maybe_task) in tasks.iter_mut().enumerate() {
160 let slot = i % FLAGS;
161 let ready = (ready_flags & (1 << slot)) != 0;
162
163 if !ready { continue; }
164 let Some(data) = maybe_task else { continue };
165
166 let waker = init_waker(slot, &common);
167 let fut = data.task.inner.as_mut();
168 let mut context = Context::from_waker(&waker);
169
170 try_tx_exec(&tx_mon_exec, TaskExec::Polling(data.id));
171
172 if let Poll::Ready(()) = Future::poll(fut, &mut context) {
173 try_tx_exec(&tx_mon_exec, TaskExec::PollReady(data.id));
174 num_ended += 1;
175 *maybe_task = None;
176 } else {
177 try_tx_exec(&tx_mon_exec, TaskExec::PollPending(data.id));
178 }
179 }
180 }
181
182 println!("thread down, ran {num_ended} task(s)");
183}
184
185pub struct Executor {
187 tx_tasks: Vec<Sender<TaskData>>,
188 handles: Vec<JoinHandle<()>>,
189 next_id: AtomicUsize,
190 #[cfg(feature = "monitor")]
191 tx_mon_decl: Option<Producer<TaskDecl>>,
192}
193
194impl Executor {
195 #[allow(unused_mut)]
197 #[allow(unused_variables)]
198 pub fn new(threads: usize, monitor_port: Option<u16>) -> Self {
199 let mut tx_tasks = Vec::new();
200 let mut handles = Vec::new();
201 let mut tx_mon_exec = None;
202 let mut monitor_task: Option<Task> = None;
203
204 #[cfg(feature = "monitor")]
205 let tx_mon_decl = if let Some(port) = monitor_port {
206 type Fifo = async_fifo::fifo::DefaultBlockSize;
207 let (tx_exec, rx_exec) = Fifo::non_blocking();
208 let (tx_name, rx_name) = Fifo::non_blocking();
209 let task = monitor::server(port, rx_exec, rx_name);
210 monitor_task = Some(task.into());
211 tx_mon_exec = Some(tx_exec);
212 Some(tx_name)
213 } else {
214 None
215 };
216
217 for _ in 0..threads {
218 let (tx, rx) = async_fifo::new();
219 let tx_mon_exec = tx_mon_exec.clone();
220 handles.push(spawn(|| runner(rx, tx_mon_exec)));
221 tx_tasks.push(tx);
222 }
223
224 let this = Self {
225 tx_tasks,
226 handles,
227 next_id: AtomicUsize::new(0),
228 #[cfg(feature = "monitor")]
229 tx_mon_decl,
230 };
231
232 if let Some(monitor_task) = monitor_task {
233 this.spawn_with_name(monitor_task, "monitor-server");
234 }
235
236 this
237 }
238
239 pub fn spawn<T: Into<Task>>(&self, task: T) {
241 self.spawn_with_name(task, "[unnamed]")
242 }
243
244 pub fn spawn_with_name<T: Into<Task>, S: Into<String>>(&self, task: T, _name: S) {
246 let task = task.into();
247 let id = self.next_id.fetch_add(1, SeqCst);
248
249 let data = TaskData {
250 task,
251 id: id,
252 };
253
254 let i = id % self.tx_tasks.len();
255
256 #[cfg(feature = "monitor")]
257 if let Some(tx_mon_decl) = &self.tx_mon_decl {
258 let decl = TaskDecl {
259 id,
260 name: _name.into(),
261 runner: i,
262 };
263
264 tx_mon_decl.send(decl);
265 }
266
267 self.tx_tasks[i].send(data);
268 }
269
270 pub fn join(mut self) -> Result<()> {
272 self.tx_tasks.drain(..);
273
274 for handle in self.handles {
275 handle.join()?;
276 }
277
278 Ok(())
279 }
280
281 pub fn join_arc(this: Arc<Self>) -> Result<()> {
282 let exec = loop {
283 sleep(Duration::from_millis(100));
284
285 if Arc::strong_count(&this) == 1 {
286 let exec = Arc::into_inner(this);
287 break exec.expect("Failed to recover Executor");
288 }
289 };
290
291 exec.join()
292 }
293}
294
295#[test]
296fn test_bad_timer() {
297 use std::time::Instant;
298
299 struct Timer {
300 expiration: Instant,
301 }
302
303 impl Future for Timer {
304 type Output = ();
305 fn poll(self: Pin<&mut Self>, _ctx: &mut Context) -> Poll<()> {
306 while Instant::now() < self.expiration {}
307 Poll::Ready(())
308 }
309 }
310
311 const ONE_SEC: Duration = Duration::from_secs(1);
312
313 fn sleep(duration: Duration) -> Timer {
314 Timer {
315 expiration: Instant::now() + duration,
316 }
317 }
318
319 let exec = Executor::new(4, Some(9090));
320
321 for _ in 0.. {
322 let task = async {
323 sleep(ONE_SEC).await;
324 println!("Done");
325 };
326
327 exec.spawn_with_name(task, "test");
328 std::thread::sleep(std::time::Duration::from_secs(3));
329 }
330}
331
332#[test]
333fn test_monitor() {
334 let exec = Executor::new(4, Some(9090));
335
336 for _ in 0.. {
337 let task = async {
338 println!("Done");
339 };
340
341 exec.spawn_with_name(task, "test");
342 std::thread::sleep(std::time::Duration::from_secs(3));
343 }
344}
345
346#[test]
347fn test_runners() {
348 let (tx_data, rx_data) = async_fifo::fifo::LargeBlockSize::channel();
349
350 let executor = Executor::new(8, None);
351 let counter = Arc::new(AtomicUsize::new(0));
352
353 for _ in 0..256 {
354 let counter = counter.clone();
355 let mut rx_data = rx_data.clone();
356
357 let task = async move {
358 rx_data.recv_array::<8192>().await.unwrap();
359 counter.fetch_add(1, SeqCst);
360 };
361
362 executor.spawn(task);
363 }
364
365 let data = [(); 256 * 8192];
366 tx_data.send_iter(data.iter().cloned());
367
368 executor.join().unwrap();
369
370 assert_eq!(counter.load(SeqCst), 256);
371}