multitask/lib.rs
1//! An executor for running async tasks.
2
3#![forbid(unsafe_code)]
4#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
5
6use std::cell::Cell;
7use std::future::Future;
8use std::marker::PhantomData;
9use std::panic::{RefUnwindSafe, UnwindSafe};
10use std::fmt;
11use std::pin::Pin;
12use std::rc::Rc;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::{Arc, Mutex, RwLock};
15use std::task::{Context, Poll};
16
17use concurrent_queue::ConcurrentQueue;
18
19/// A runnable future, ready for execution.
20///
21/// When a future is internally spawned using `async_task::spawn()` or `async_task::spawn_local()`,
22/// we get back two values:
23///
24/// 1. an `async_task::Task<()>`, which we refer to as a `Runnable`
25/// 2. an `async_task::JoinHandle<T, ()>`, which is wrapped inside a `Task<T>`
26///
27/// Once a `Runnable` is run, it "vanishes" and only reappears when its future is woken. When it's
28/// woken up, its schedule function is called, which means the `Runnable` gets pushed into a task
29/// queue in an executor.
30type Runnable = async_task::Task<()>;
31
32/// A spawned future.
33///
34/// Tasks are also futures themselves and yield the output of the spawned future.
35///
36/// When a task is dropped, its gets canceled and won't be polled again. To cancel a task a bit
37/// more gracefully and wait until it stops running, use the [`cancel()`][Task::cancel()] method.
38///
39/// Tasks that panic get immediately canceled. Awaiting a canceled task also causes a panic.
40///
41/// If a task panics, the panic will be thrown by the [`Ticker::tick()`] invocation that polled it.
42///
43/// # Examples
44///
45/// ```
46/// use blocking::block_on;
47/// use multitask::Executor;
48/// use std::thread;
49///
50/// let ex = Executor::new();
51///
52/// // Spawn a future onto the executor.
53/// let task = ex.spawn(async {
54/// println!("Hello from a task!");
55/// 1 + 2
56/// });
57///
58/// // Run an executor thread.
59/// thread::spawn(move || {
60/// let (p, u) = parking::pair();
61/// let ticker = ex.ticker(move || u.unpark());
62/// loop {
63/// if !ticker.tick() {
64/// p.park();
65/// }
66/// }
67/// });
68///
69/// // Wait for the result.
70/// assert_eq!(block_on(task), 3);
71/// ```
72#[must_use = "tasks get canceled when dropped, use `.detach()` to run them in the background"]
73#[derive(Debug)]
74pub struct Task<T>(Option<async_task::JoinHandle<T, ()>>);
75
76impl<T> Task<T> {
77 /// Detaches the task to let it keep running in the background.
78 ///
79 /// # Examples
80 ///
81 /// ```
82 /// use async_io::Timer;
83 /// use multitask::Executor;
84 /// use std::time::Duration;
85 ///
86 /// let ex = Executor::new();
87 ///
88 /// // Spawn a deamon future.
89 /// ex.spawn(async {
90 /// loop {
91 /// println!("I'm a daemon task looping forever.");
92 /// Timer::new(Duration::from_secs(1)).await;
93 /// }
94 /// })
95 /// .detach();
96 /// ```
97 pub fn detach(mut self) {
98 self.0.take().unwrap();
99 }
100
101 /// Cancels the task and waits for it to stop running.
102 ///
103 /// Returns the task's output if it was completed just before it got canceled, or [`None`] if
104 /// it didn't complete.
105 ///
106 /// While it's possible to simply drop the [`Task`] to cancel it, this is a cleaner way of
107 /// canceling because it also waits for the task to stop running.
108 ///
109 /// # Examples
110 ///
111 /// ```
112 /// use async_io::Timer;
113 /// use blocking::block_on;
114 /// use multitask::Executor;
115 /// use std::thread;
116 /// use std::time::Duration;
117 ///
118 /// let ex = Executor::new();
119 ///
120 /// // Spawn a deamon future.
121 /// let task = ex.spawn(async {
122 /// loop {
123 /// println!("Even though I'm in an infinite loop, you can still cancel me!");
124 /// Timer::new(Duration::from_secs(1)).await;
125 /// }
126 /// });
127 ///
128 /// // Run an executor thread.
129 /// thread::spawn(move || {
130 /// let (p, u) = parking::pair();
131 /// let ticker = ex.ticker(move || u.unpark());
132 /// loop {
133 /// if !ticker.tick() {
134 /// p.park();
135 /// }
136 /// }
137 /// });
138 ///
139 /// block_on(async {
140 /// Timer::new(Duration::from_secs(3)).await;
141 /// task.cancel().await;
142 /// });
143 /// ```
144 pub async fn cancel(self) -> Option<T> {
145 let mut task = self;
146 let handle = task.0.take().unwrap();
147 handle.cancel();
148 handle.await
149 }
150}
151
152impl<T> Drop for Task<T> {
153 fn drop(&mut self) {
154 if let Some(handle) = &self.0 {
155 handle.cancel();
156 }
157 }
158}
159
160impl<T> Future for Task<T> {
161 type Output = T;
162
163 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164 match Pin::new(&mut self.0.as_mut().unwrap()).poll(cx) {
165 Poll::Pending => Poll::Pending,
166 Poll::Ready(output) => Poll::Ready(output.expect("task has failed")),
167 }
168 }
169}
170
171/// A single-threaded executor.
172#[derive(Debug)]
173pub struct LocalExecutor {
174 /// The task queue.
175 queue: Arc<ConcurrentQueue<Runnable>>,
176
177 /// Callback invoked to wake the executor up.
178 callback: Callback,
179
180 /// Make sure the type is `!Send` and `!Sync`.
181 _marker: PhantomData<Rc<()>>,
182}
183
184impl UnwindSafe for LocalExecutor {}
185impl RefUnwindSafe for LocalExecutor {}
186
187impl LocalExecutor {
188 /// Creates a new single-threaded executor.
189 ///
190 /// # Examples
191 ///
192 /// ```
193 /// use multitask::LocalExecutor;
194 ///
195 /// let (p, u) = parking::pair();
196 /// let ex = LocalExecutor::new(move || u.unpark());
197 /// ```
198 pub fn new(notify: impl Fn() + Send + Sync + 'static) -> LocalExecutor {
199 LocalExecutor {
200 queue: Arc::new(ConcurrentQueue::unbounded()),
201 callback: Callback::new(notify),
202 _marker: PhantomData,
203 }
204 }
205
206 /// Spawns a thread-local future onto this executor.
207 ///
208 /// Returns a [`Task`] handle for the spawned future.
209 ///
210 /// # Examples
211 ///
212 /// ```
213 /// use multitask::LocalExecutor;
214 ///
215 /// let (p, u) = parking::pair();
216 /// let ex = LocalExecutor::new(move || u.unpark());
217 ///
218 /// let task = ex.spawn(async { println!("hello") });
219 /// ```
220 pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
221 let queue = self.queue.clone();
222 let callback = self.callback.clone();
223
224 // The function that schedules a runnable task when it gets woken up.
225 let schedule = move |runnable| {
226 queue.push(runnable).unwrap();
227 callback.call();
228 };
229
230 // Create a task, push it into the queue by scheduling it, and return its `Task` handle.
231 let (runnable, handle) = async_task::spawn_local(future, schedule, ());
232 runnable.schedule();
233 Task(Some(handle))
234 }
235
236 /// Runs a single task and returns `true` if one was found.
237 ///
238 /// # Examples
239 ///
240 /// ```
241 /// use multitask::LocalExecutor;
242 ///
243 /// let (p, u) = parking::pair();
244 /// let ex = LocalExecutor::new(move || u.unpark());
245 ///
246 /// assert!(!ex.tick());
247 /// let task = ex.spawn(async { println!("hello") });
248 ///
249 /// // This prints "hello".
250 /// assert!(ex.tick());
251 /// ```
252 pub fn tick(&self) -> bool {
253 if let Ok(r) = self.queue.pop() {
254 r.run();
255 true
256 } else {
257 false
258 }
259 }
260}
261
262impl Drop for LocalExecutor {
263 fn drop(&mut self) {
264 // TODO(stjepang): Close the local queue and empty it.
265 // TODO(stjepang): Cancel all remaining tasks.
266 }
267}
268
269/// State shared between [`Executor`] and [`Ticker`].
270#[derive(Debug)]
271struct Global {
272 /// The global queue.
273 queue: ConcurrentQueue<Runnable>,
274
275 /// Shards of the global queue created by tickers.
276 shards: RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>,
277
278 /// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
279 notified: AtomicBool,
280
281 /// A list of sleeping tickers.
282 sleepers: Mutex<Sleepers>,
283}
284
285impl Global {
286 /// Notifies a sleeping ticker.
287 #[inline]
288 fn notify(&self) {
289 if !self
290 .notified
291 .compare_and_swap(false, true, Ordering::SeqCst)
292 {
293 let callback = self.sleepers.lock().unwrap().notify();
294 if let Some(cb) = callback {
295 cb.call();
296 }
297 }
298 }
299}
300
301/// A list of sleeping tickers.
302#[derive(Debug)]
303struct Sleepers {
304 /// Number of sleeping tickers (both notified and unnotified).
305 count: usize,
306
307 /// Callbacks of sleeping unnotified tickers.
308 ///
309 /// A sleeping ticker is notified when its callback is missing from this list.
310 callbacks: Vec<Callback>,
311}
312
313impl Sleepers {
314 /// Inserts a new sleeping ticker.
315 fn insert(&mut self, callback: &Callback) {
316 self.count += 1;
317 self.callbacks.push(callback.clone());
318 }
319
320 /// Re-inserts a sleeping ticker's callback if it was notified.
321 ///
322 /// Returns `true` if the ticker was notified.
323 fn update(&mut self, callback: &Callback) -> bool {
324 if self.callbacks.iter().all(|cb| cb != callback) {
325 self.callbacks.push(callback.clone());
326 true
327 } else {
328 false
329 }
330 }
331
332 /// Removes a previously inserted sleeping ticker.
333 fn remove(&mut self, callback: &Callback) {
334 self.count -= 1;
335 for i in (0..self.callbacks.len()).rev() {
336 if &self.callbacks[i] == callback {
337 self.callbacks.remove(i);
338 return;
339 }
340 }
341 }
342
343 /// Returns `true` if a sleeping ticker is notified or no tickers are sleeping.
344 fn is_notified(&self) -> bool {
345 self.count == 0 || self.count > self.callbacks.len()
346 }
347
348 /// Returns notification callback for a sleeping ticker.
349 ///
350 /// If a ticker was notified already or there are no tickers, `None` will be returned.
351 fn notify(&mut self) -> Option<Callback> {
352 if self.callbacks.len() == self.count {
353 self.callbacks.pop()
354 } else {
355 None
356 }
357 }
358}
359
360/// A multi-threaded executor.
361#[derive(Debug)]
362pub struct Executor {
363 global: Arc<Global>,
364}
365
366impl UnwindSafe for Executor {}
367impl RefUnwindSafe for Executor {}
368
369impl Executor {
370 /// Creates a new multi-threaded executor.
371 ///
372 /// # Examples
373 ///
374 /// ```
375 /// use multitask::Executor;
376 ///
377 /// let ex = Executor::new();
378 /// ```
379 pub fn new() -> Executor {
380 Executor {
381 global: Arc::new(Global {
382 queue: ConcurrentQueue::unbounded(),
383 shards: RwLock::new(Vec::new()),
384 notified: AtomicBool::new(true),
385 sleepers: Mutex::new(Sleepers {
386 count: 0,
387 callbacks: Vec::new(),
388 }),
389 }),
390 }
391 }
392
393 /// Spawns a future onto this executor.
394 ///
395 /// Returns a [`Task`] handle for the spawned future.
396 ///
397 /// # Examples
398 ///
399 /// ```
400 /// use multitask::Executor;
401 ///
402 /// let ex = Executor::new();
403 /// let task = ex.spawn(async { println!("hello") });
404 /// ```
405 pub fn spawn<T: Send + 'static>(
406 &self,
407 future: impl Future<Output = T> + Send + 'static,
408 ) -> Task<T> {
409 let global = self.global.clone();
410
411 // The function that schedules a runnable task when it gets woken up.
412 let schedule = move |runnable| {
413 global.queue.push(runnable).unwrap();
414 global.notify();
415 };
416
417 // Create a task, push it into the queue by scheduling it, and return its `Task` handle.
418 let (runnable, handle) = async_task::spawn(future, schedule, ());
419 runnable.schedule();
420 Task(Some(handle))
421 }
422
423 /// Creates a new ticker for executing tasks.
424 ///
425 /// In a multi-threaded executor, each executor thread will create its own ticker and then keep
426 /// calling [`Ticker::tick()`] in a loop.
427 ///
428 /// # Examples
429 ///
430 /// ```
431 /// use blocking::block_on;
432 /// use multitask::Executor;
433 /// use std::thread;
434 ///
435 /// let ex = Executor::new();
436 ///
437 /// // Create two executor threads.
438 /// for _ in 0..2 {
439 /// let (p, u) = parking::pair();
440 /// let ticker = ex.ticker(move || u.unpark());
441 /// thread::spawn(move || {
442 /// loop {
443 /// if !ticker.tick() {
444 /// p.park();
445 /// }
446 /// }
447 /// });
448 /// }
449 ///
450 /// // Spawn a future and wait for one of the threads to run it.
451 /// let task = ex.spawn(async { 1 + 2 });
452 /// assert_eq!(block_on(task), 3);
453 /// ```
454 pub fn ticker(&self, notify: impl Fn() + Send + Sync + 'static) -> Ticker {
455 // Create a ticker and put its stealer handle into the executor.
456 let ticker = Ticker {
457 global: Arc::new(self.global.clone()),
458 shard: Arc::new(ConcurrentQueue::bounded(512)),
459 callback: Callback::new(notify),
460 sleeping: Cell::new(false),
461 ticks: Cell::new(0),
462 };
463 self.global
464 .shards
465 .write()
466 .unwrap()
467 .push(ticker.shard.clone());
468 ticker
469 }
470}
471
472impl Default for Executor {
473 fn default() -> Executor {
474 Executor::new()
475 }
476}
477
478/// Runs tasks in a multi-threaded executor.
479#[derive(Debug)]
480pub struct Ticker {
481 /// The global queue.
482 global: Arc<Arc<Global>>,
483
484 /// A shard of the global queue.
485 shard: Arc<ConcurrentQueue<Runnable>>,
486
487 /// Callback invoked to wake this ticker up.
488 callback: Callback,
489
490 /// Set to `true` when in sleeping state.
491 ///
492 /// States a ticker can be in:
493 /// 1) Woken.
494 /// 2a) Sleeping and unnotified.
495 /// 2b) Sleeping and notified.
496 sleeping: Cell<bool>,
497
498 /// Bumped every time a task is run.
499 ticks: Cell<usize>,
500}
501
502impl UnwindSafe for Ticker {}
503impl RefUnwindSafe for Ticker {}
504
505impl Ticker {
506 /// Moves the ticker into sleeping and unnotified state.
507 ///
508 /// Returns `false` if the ticker was already sleeping and unnotified.
509 fn sleep(&self) -> bool {
510 let mut sleepers = self.global.sleepers.lock().unwrap();
511
512 if self.sleeping.get() {
513 // Already sleeping, check if notified.
514 if !sleepers.update(&self.callback) {
515 return false;
516 }
517 } else {
518 // Move to sleeping state.
519 sleepers.insert(&self.callback);
520 }
521
522 self.global
523 .notified
524 .swap(sleepers.is_notified(), Ordering::SeqCst);
525
526 self.sleeping.set(true);
527 true
528 }
529
530 /// Moves the ticker into woken state.
531 ///
532 /// Returns `false` if the ticker was already woken.
533 fn wake(&self) -> bool {
534 if self.sleeping.get() {
535 let mut sleepers = self.global.sleepers.lock().unwrap();
536 sleepers.remove(&self.callback);
537
538 self.global
539 .notified
540 .swap(sleepers.is_notified(), Ordering::SeqCst);
541 }
542
543 self.sleeping.replace(false)
544 }
545
546 /// Runs a single task and returns `true` if one was found.
547 pub fn tick(&self) -> bool {
548 loop {
549 match self.search() {
550 None => {
551 // Move to sleeping and unnotified state.
552 if !self.sleep() {
553 // If already sleeping and unnotified, return.
554 return false;
555 }
556 }
557 Some(r) => {
558 // Wake up.
559 self.wake();
560
561 // Notify another ticker now to pick up where this ticker left off, just in
562 // case running the task takes a long time.
563 self.global.notify();
564
565 // Bump the ticker.
566 let ticks = self.ticks.get();
567 self.ticks.set(ticks.wrapping_add(1));
568
569 // Steal tasks from the global queue to ensure fair task scheduling.
570 if ticks % 64 == 0 {
571 steal(&self.global.queue, &self.shard);
572 }
573
574 // Run the task.
575 r.run();
576
577 return true;
578 }
579 }
580 }
581 }
582
583 /// Finds the next task to run.
584 fn search(&self) -> Option<Runnable> {
585 if let Ok(r) = self.shard.pop() {
586 return Some(r);
587 }
588
589 // Try stealing from the global queue.
590 if let Ok(r) = self.global.queue.pop() {
591 steal(&self.global.queue, &self.shard);
592 return Some(r);
593 }
594
595 // Try stealing from other shards.
596 let shards = self.global.shards.read().unwrap();
597
598 // Pick a random starting point in the iterator list and rotate the list.
599 let n = shards.len();
600 let start = fastrand::usize(..n);
601 let iter = shards.iter().chain(shards.iter()).skip(start).take(n);
602
603 // Remove this ticker's shard.
604 let iter = iter.filter(|shard| !Arc::ptr_eq(shard, &self.shard));
605
606 // Try stealing from each shard in the list.
607 for shard in iter {
608 steal(shard, &self.shard);
609 if let Ok(r) = self.shard.pop() {
610 return Some(r);
611 }
612 }
613
614 None
615 }
616}
617
618impl Drop for Ticker {
619 fn drop(&mut self) {
620 // Wake and unregister the ticker.
621 self.wake();
622 self.global
623 .shards
624 .write()
625 .unwrap()
626 .retain(|shard| !Arc::ptr_eq(shard, &self.shard));
627
628 // Re-schedule remaining tasks in the shard.
629 while let Ok(r) = self.shard.pop() {
630 r.schedule();
631 }
632 // Notify another ticker to start searching for tasks.
633 self.global.notify();
634
635 // TODO(stjepang): Cancel all remaining tasks.
636 }
637}
638
639/// Steals some items from one queue into another.
640fn steal<T>(src: &ConcurrentQueue<T>, dest: &ConcurrentQueue<T>) {
641 // Half of `src`'s length rounded up.
642 let mut count = (src.len() + 1) / 2;
643
644 if count > 0 {
645 // Don't steal more than fits into the queue.
646 if let Some(cap) = dest.capacity() {
647 count = count.min(cap - dest.len());
648 }
649
650 // Steal tasks.
651 for _ in 0..count {
652 if let Ok(t) = src.pop() {
653 assert!(dest.push(t).is_ok());
654 } else {
655 break;
656 }
657 }
658 }
659}
660
661/// A cloneable callback function.
662#[derive(Clone)]
663struct Callback(Arc<Box<dyn Fn() + Send + Sync>>);
664
665impl Callback {
666 fn new(f: impl Fn() + Send + Sync + 'static) -> Callback {
667 Callback(Arc::new(Box::new(f)))
668 }
669
670 fn call(&self) {
671 (self.0)();
672 }
673}
674
675impl PartialEq for Callback {
676 fn eq(&self, other: &Callback) -> bool {
677 Arc::ptr_eq(&self.0, &other.0)
678 }
679}
680
681impl Eq for Callback {}
682
683impl fmt::Debug for Callback {
684 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
685 f.debug_struct("<callback>")
686 .finish()
687 }
688}