multitask/
lib.rs

1//! An executor for running async tasks.
2
3#![forbid(unsafe_code)]
4#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
5
6use std::cell::Cell;
7use std::future::Future;
8use std::marker::PhantomData;
9use std::panic::{RefUnwindSafe, UnwindSafe};
10use std::fmt;
11use std::pin::Pin;
12use std::rc::Rc;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::{Arc, Mutex, RwLock};
15use std::task::{Context, Poll};
16
17use concurrent_queue::ConcurrentQueue;
18
19/// A runnable future, ready for execution.
20///
21/// When a future is internally spawned using `async_task::spawn()` or `async_task::spawn_local()`,
22/// we get back two values:
23///
24/// 1. an `async_task::Task<()>`, which we refer to as a `Runnable`
25/// 2. an `async_task::JoinHandle<T, ()>`, which is wrapped inside a `Task<T>`
26///
27/// Once a `Runnable` is run, it "vanishes" and only reappears when its future is woken. When it's
28/// woken up, its schedule function is called, which means the `Runnable` gets pushed into a task
29/// queue in an executor.
30type Runnable = async_task::Task<()>;
31
32/// A spawned future.
33///
34/// Tasks are also futures themselves and yield the output of the spawned future.
35///
36/// When a task is dropped, its gets canceled and won't be polled again. To cancel a task a bit
37/// more gracefully and wait until it stops running, use the [`cancel()`][Task::cancel()] method.
38///
39/// Tasks that panic get immediately canceled. Awaiting a canceled task also causes a panic.
40///
41/// If a task panics, the panic will be thrown by the [`Ticker::tick()`] invocation that polled it.
42///
43/// # Examples
44///
45/// ```
46/// use blocking::block_on;
47/// use multitask::Executor;
48/// use std::thread;
49///
50/// let ex = Executor::new();
51///
52/// // Spawn a future onto the executor.
53/// let task = ex.spawn(async {
54///     println!("Hello from a task!");
55///     1 + 2
56/// });
57///
58/// // Run an executor thread.
59/// thread::spawn(move || {
60///     let (p, u) = parking::pair();
61///     let ticker = ex.ticker(move || u.unpark());
62///     loop {
63///         if !ticker.tick() {
64///             p.park();
65///         }
66///     }
67/// });
68///
69/// // Wait for the result.
70/// assert_eq!(block_on(task), 3);
71/// ```
72#[must_use = "tasks get canceled when dropped, use `.detach()` to run them in the background"]
73#[derive(Debug)]
74pub struct Task<T>(Option<async_task::JoinHandle<T, ()>>);
75
76impl<T> Task<T> {
77    /// Detaches the task to let it keep running in the background.
78    ///
79    /// # Examples
80    ///
81    /// ```
82    /// use async_io::Timer;
83    /// use multitask::Executor;
84    /// use std::time::Duration;
85    ///
86    /// let ex = Executor::new();
87    ///
88    /// // Spawn a deamon future.
89    /// ex.spawn(async {
90    ///     loop {
91    ///         println!("I'm a daemon task looping forever.");
92    ///         Timer::new(Duration::from_secs(1)).await;
93    ///     }
94    /// })
95    /// .detach();
96    /// ```
97    pub fn detach(mut self) {
98        self.0.take().unwrap();
99    }
100
101    /// Cancels the task and waits for it to stop running.
102    ///
103    /// Returns the task's output if it was completed just before it got canceled, or [`None`] if
104    /// it didn't complete.
105    ///
106    /// While it's possible to simply drop the [`Task`] to cancel it, this is a cleaner way of
107    /// canceling because it also waits for the task to stop running.
108    ///
109    /// # Examples
110    ///
111    /// ```
112    /// use async_io::Timer;
113    /// use blocking::block_on;
114    /// use multitask::Executor;
115    /// use std::thread;
116    /// use std::time::Duration;
117    ///
118    /// let ex = Executor::new();
119    ///
120    /// // Spawn a deamon future.
121    /// let task = ex.spawn(async {
122    ///     loop {
123    ///         println!("Even though I'm in an infinite loop, you can still cancel me!");
124    ///         Timer::new(Duration::from_secs(1)).await;
125    ///     }
126    /// });
127    ///
128    /// // Run an executor thread.
129    /// thread::spawn(move || {
130    ///     let (p, u) = parking::pair();
131    ///     let ticker = ex.ticker(move || u.unpark());
132    ///     loop {
133    ///         if !ticker.tick() {
134    ///             p.park();
135    ///         }
136    ///     }
137    /// });
138    ///
139    /// block_on(async {
140    ///     Timer::new(Duration::from_secs(3)).await;
141    ///     task.cancel().await;
142    /// });
143    /// ```
144    pub async fn cancel(self) -> Option<T> {
145        let mut task = self;
146        let handle = task.0.take().unwrap();
147        handle.cancel();
148        handle.await
149    }
150}
151
152impl<T> Drop for Task<T> {
153    fn drop(&mut self) {
154        if let Some(handle) = &self.0 {
155            handle.cancel();
156        }
157    }
158}
159
160impl<T> Future for Task<T> {
161    type Output = T;
162
163    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164        match Pin::new(&mut self.0.as_mut().unwrap()).poll(cx) {
165            Poll::Pending => Poll::Pending,
166            Poll::Ready(output) => Poll::Ready(output.expect("task has failed")),
167        }
168    }
169}
170
171/// A single-threaded executor.
172#[derive(Debug)]
173pub struct LocalExecutor {
174    /// The task queue.
175    queue: Arc<ConcurrentQueue<Runnable>>,
176
177    /// Callback invoked to wake the executor up.
178    callback: Callback,
179
180    /// Make sure the type is `!Send` and `!Sync`.
181    _marker: PhantomData<Rc<()>>,
182}
183
184impl UnwindSafe for LocalExecutor {}
185impl RefUnwindSafe for LocalExecutor {}
186
187impl LocalExecutor {
188    /// Creates a new single-threaded executor.
189    ///
190    /// # Examples
191    ///
192    /// ```
193    /// use multitask::LocalExecutor;
194    ///
195    /// let (p, u) = parking::pair();
196    /// let ex = LocalExecutor::new(move || u.unpark());
197    /// ```
198    pub fn new(notify: impl Fn() + Send + Sync + 'static) -> LocalExecutor {
199        LocalExecutor {
200            queue: Arc::new(ConcurrentQueue::unbounded()),
201            callback: Callback::new(notify),
202            _marker: PhantomData,
203        }
204    }
205
206    /// Spawns a thread-local future onto this executor.
207    ///
208    /// Returns a [`Task`] handle for the spawned future.
209    ///
210    /// # Examples
211    ///
212    /// ```
213    /// use multitask::LocalExecutor;
214    ///
215    /// let (p, u) = parking::pair();
216    /// let ex = LocalExecutor::new(move || u.unpark());
217    ///
218    /// let task = ex.spawn(async { println!("hello") });
219    /// ```
220    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
221        let queue = self.queue.clone();
222        let callback = self.callback.clone();
223
224        // The function that schedules a runnable task when it gets woken up.
225        let schedule = move |runnable| {
226            queue.push(runnable).unwrap();
227            callback.call();
228        };
229
230        // Create a task, push it into the queue by scheduling it, and return its `Task` handle.
231        let (runnable, handle) = async_task::spawn_local(future, schedule, ());
232        runnable.schedule();
233        Task(Some(handle))
234    }
235
236    /// Runs a single task and returns `true` if one was found.
237    ///
238    /// # Examples
239    ///
240    /// ```
241    /// use multitask::LocalExecutor;
242    ///
243    /// let (p, u) = parking::pair();
244    /// let ex = LocalExecutor::new(move || u.unpark());
245    ///
246    /// assert!(!ex.tick());
247    /// let task = ex.spawn(async { println!("hello") });
248    ///
249    /// // This prints "hello".
250    /// assert!(ex.tick());
251    /// ```
252    pub fn tick(&self) -> bool {
253        if let Ok(r) = self.queue.pop() {
254            r.run();
255            true
256        } else {
257            false
258        }
259    }
260}
261
262impl Drop for LocalExecutor {
263    fn drop(&mut self) {
264        // TODO(stjepang): Close the local queue and empty it.
265        // TODO(stjepang): Cancel all remaining tasks.
266    }
267}
268
269/// State shared between [`Executor`] and [`Ticker`].
270#[derive(Debug)]
271struct Global {
272    /// The global queue.
273    queue: ConcurrentQueue<Runnable>,
274
275    /// Shards of the global queue created by tickers.
276    shards: RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>,
277
278    /// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
279    notified: AtomicBool,
280
281    /// A list of sleeping tickers.
282    sleepers: Mutex<Sleepers>,
283}
284
285impl Global {
286    /// Notifies a sleeping ticker.
287    #[inline]
288    fn notify(&self) {
289        if !self
290            .notified
291            .compare_and_swap(false, true, Ordering::SeqCst)
292        {
293            let callback = self.sleepers.lock().unwrap().notify();
294            if let Some(cb) = callback {
295                cb.call();
296            }
297        }
298    }
299}
300
301/// A list of sleeping tickers.
302#[derive(Debug)]
303struct Sleepers {
304    /// Number of sleeping tickers (both notified and unnotified).
305    count: usize,
306
307    /// Callbacks of sleeping unnotified tickers.
308    ///
309    /// A sleeping ticker is notified when its callback is missing from this list.
310    callbacks: Vec<Callback>,
311}
312
313impl Sleepers {
314    /// Inserts a new sleeping ticker.
315    fn insert(&mut self, callback: &Callback) {
316        self.count += 1;
317        self.callbacks.push(callback.clone());
318    }
319
320    /// Re-inserts a sleeping ticker's callback if it was notified.
321    ///
322    /// Returns `true` if the ticker was notified.
323    fn update(&mut self, callback: &Callback) -> bool {
324        if self.callbacks.iter().all(|cb| cb != callback) {
325            self.callbacks.push(callback.clone());
326            true
327        } else {
328            false
329        }
330    }
331
332    /// Removes a previously inserted sleeping ticker.
333    fn remove(&mut self, callback: &Callback) {
334        self.count -= 1;
335        for i in (0..self.callbacks.len()).rev() {
336            if &self.callbacks[i] == callback {
337                self.callbacks.remove(i);
338                return;
339            }
340        }
341    }
342
343    /// Returns `true` if a sleeping ticker is notified or no tickers are sleeping.
344    fn is_notified(&self) -> bool {
345        self.count == 0 || self.count > self.callbacks.len()
346    }
347
348    /// Returns notification callback for a sleeping ticker.
349    ///
350    /// If a ticker was notified already or there are no tickers, `None` will be returned.
351    fn notify(&mut self) -> Option<Callback> {
352        if self.callbacks.len() == self.count {
353            self.callbacks.pop()
354        } else {
355            None
356        }
357    }
358}
359
360/// A multi-threaded executor.
361#[derive(Debug)]
362pub struct Executor {
363    global: Arc<Global>,
364}
365
366impl UnwindSafe for Executor {}
367impl RefUnwindSafe for Executor {}
368
369impl Executor {
370    /// Creates a new multi-threaded executor.
371    ///
372    /// # Examples
373    ///
374    /// ```
375    /// use multitask::Executor;
376    ///
377    /// let ex = Executor::new();
378    /// ```
379    pub fn new() -> Executor {
380        Executor {
381            global: Arc::new(Global {
382                queue: ConcurrentQueue::unbounded(),
383                shards: RwLock::new(Vec::new()),
384                notified: AtomicBool::new(true),
385                sleepers: Mutex::new(Sleepers {
386                    count: 0,
387                    callbacks: Vec::new(),
388                }),
389            }),
390        }
391    }
392
393    /// Spawns a future onto this executor.
394    ///
395    /// Returns a [`Task`] handle for the spawned future.
396    ///
397    /// # Examples
398    ///
399    /// ```
400    /// use multitask::Executor;
401    ///
402    /// let ex = Executor::new();
403    /// let task = ex.spawn(async { println!("hello") });
404    /// ```
405    pub fn spawn<T: Send + 'static>(
406        &self,
407        future: impl Future<Output = T> + Send + 'static,
408    ) -> Task<T> {
409        let global = self.global.clone();
410
411        // The function that schedules a runnable task when it gets woken up.
412        let schedule = move |runnable| {
413            global.queue.push(runnable).unwrap();
414            global.notify();
415        };
416
417        // Create a task, push it into the queue by scheduling it, and return its `Task` handle.
418        let (runnable, handle) = async_task::spawn(future, schedule, ());
419        runnable.schedule();
420        Task(Some(handle))
421    }
422
423    /// Creates a new ticker for executing tasks.
424    ///
425    /// In a multi-threaded executor, each executor thread will create its own ticker and then keep
426    /// calling [`Ticker::tick()`] in a loop.
427    ///
428    /// # Examples
429    ///
430    /// ```
431    /// use blocking::block_on;
432    /// use multitask::Executor;
433    /// use std::thread;
434    ///
435    /// let ex = Executor::new();
436    ///
437    /// // Create two executor threads.
438    /// for _ in 0..2 {
439    ///     let (p, u) = parking::pair();
440    ///     let ticker = ex.ticker(move || u.unpark());
441    ///     thread::spawn(move || {
442    ///         loop {
443    ///             if !ticker.tick() {
444    ///                 p.park();
445    ///             }
446    ///         }
447    ///     });
448    /// }
449    ///
450    /// // Spawn a future and wait for one of the threads to run it.
451    /// let task = ex.spawn(async { 1 + 2 });
452    /// assert_eq!(block_on(task), 3);
453    /// ```
454    pub fn ticker(&self, notify: impl Fn() + Send + Sync + 'static) -> Ticker {
455        // Create a ticker and put its stealer handle into the executor.
456        let ticker = Ticker {
457            global: Arc::new(self.global.clone()),
458            shard: Arc::new(ConcurrentQueue::bounded(512)),
459            callback: Callback::new(notify),
460            sleeping: Cell::new(false),
461            ticks: Cell::new(0),
462        };
463        self.global
464            .shards
465            .write()
466            .unwrap()
467            .push(ticker.shard.clone());
468        ticker
469    }
470}
471
472impl Default for Executor {
473    fn default() -> Executor {
474        Executor::new()
475    }
476}
477
478/// Runs tasks in a multi-threaded executor.
479#[derive(Debug)]
480pub struct Ticker {
481    /// The global queue.
482    global: Arc<Arc<Global>>,
483
484    /// A shard of the global queue.
485    shard: Arc<ConcurrentQueue<Runnable>>,
486
487    /// Callback invoked to wake this ticker up.
488    callback: Callback,
489
490    /// Set to `true` when in sleeping state.
491    ///
492    /// States a ticker can be in:
493    /// 1) Woken.
494    /// 2a) Sleeping and unnotified.
495    /// 2b) Sleeping and notified.
496    sleeping: Cell<bool>,
497
498    /// Bumped every time a task is run.
499    ticks: Cell<usize>,
500}
501
502impl UnwindSafe for Ticker {}
503impl RefUnwindSafe for Ticker {}
504
505impl Ticker {
506    /// Moves the ticker into sleeping and unnotified state.
507    ///
508    /// Returns `false` if the ticker was already sleeping and unnotified.
509    fn sleep(&self) -> bool {
510        let mut sleepers = self.global.sleepers.lock().unwrap();
511
512        if self.sleeping.get() {
513            // Already sleeping, check if notified.
514            if !sleepers.update(&self.callback) {
515                return false;
516            }
517        } else {
518            // Move to sleeping state.
519            sleepers.insert(&self.callback);
520        }
521
522        self.global
523            .notified
524            .swap(sleepers.is_notified(), Ordering::SeqCst);
525
526        self.sleeping.set(true);
527        true
528    }
529
530    /// Moves the ticker into woken state.
531    ///
532    /// Returns `false` if the ticker was already woken.
533    fn wake(&self) -> bool {
534        if self.sleeping.get() {
535            let mut sleepers = self.global.sleepers.lock().unwrap();
536            sleepers.remove(&self.callback);
537
538            self.global
539                .notified
540                .swap(sleepers.is_notified(), Ordering::SeqCst);
541        }
542
543        self.sleeping.replace(false)
544    }
545
546    /// Runs a single task and returns `true` if one was found.
547    pub fn tick(&self) -> bool {
548        loop {
549            match self.search() {
550                None => {
551                    // Move to sleeping and unnotified state.
552                    if !self.sleep() {
553                        // If already sleeping and unnotified, return.
554                        return false;
555                    }
556                }
557                Some(r) => {
558                    // Wake up.
559                    self.wake();
560
561                    // Notify another ticker now to pick up where this ticker left off, just in
562                    // case running the task takes a long time.
563                    self.global.notify();
564
565                    // Bump the ticker.
566                    let ticks = self.ticks.get();
567                    self.ticks.set(ticks.wrapping_add(1));
568
569                    // Steal tasks from the global queue to ensure fair task scheduling.
570                    if ticks % 64 == 0 {
571                        steal(&self.global.queue, &self.shard);
572                    }
573
574                    // Run the task.
575                    r.run();
576
577                    return true;
578                }
579            }
580        }
581    }
582
583    /// Finds the next task to run.
584    fn search(&self) -> Option<Runnable> {
585        if let Ok(r) = self.shard.pop() {
586            return Some(r);
587        }
588
589        // Try stealing from the global queue.
590        if let Ok(r) = self.global.queue.pop() {
591            steal(&self.global.queue, &self.shard);
592            return Some(r);
593        }
594
595        // Try stealing from other shards.
596        let shards = self.global.shards.read().unwrap();
597
598        // Pick a random starting point in the iterator list and rotate the list.
599        let n = shards.len();
600        let start = fastrand::usize(..n);
601        let iter = shards.iter().chain(shards.iter()).skip(start).take(n);
602
603        // Remove this ticker's shard.
604        let iter = iter.filter(|shard| !Arc::ptr_eq(shard, &self.shard));
605
606        // Try stealing from each shard in the list.
607        for shard in iter {
608            steal(shard, &self.shard);
609            if let Ok(r) = self.shard.pop() {
610                return Some(r);
611            }
612        }
613
614        None
615    }
616}
617
618impl Drop for Ticker {
619    fn drop(&mut self) {
620        // Wake and unregister the ticker.
621        self.wake();
622        self.global
623            .shards
624            .write()
625            .unwrap()
626            .retain(|shard| !Arc::ptr_eq(shard, &self.shard));
627
628        // Re-schedule remaining tasks in the shard.
629        while let Ok(r) = self.shard.pop() {
630            r.schedule();
631        }
632        // Notify another ticker to start searching for tasks.
633        self.global.notify();
634
635        // TODO(stjepang): Cancel all remaining tasks.
636    }
637}
638
639/// Steals some items from one queue into another.
640fn steal<T>(src: &ConcurrentQueue<T>, dest: &ConcurrentQueue<T>) {
641    // Half of `src`'s length rounded up.
642    let mut count = (src.len() + 1) / 2;
643
644    if count > 0 {
645        // Don't steal more than fits into the queue.
646        if let Some(cap) = dest.capacity() {
647            count = count.min(cap - dest.len());
648        }
649
650        // Steal tasks.
651        for _ in 0..count {
652            if let Ok(t) = src.pop() {
653                assert!(dest.push(t).is_ok());
654            } else {
655                break;
656            }
657        }
658    }
659}
660
661/// A cloneable callback function.
662#[derive(Clone)]
663struct Callback(Arc<Box<dyn Fn() + Send + Sync>>);
664
665impl Callback {
666    fn new(f: impl Fn() + Send + Sync + 'static) -> Callback {
667        Callback(Arc::new(Box::new(f)))
668    }
669
670    fn call(&self) {
671        (self.0)();
672    }
673}
674
675impl PartialEq for Callback {
676    fn eq(&self, other: &Callback) -> bool {
677        Arc::ptr_eq(&self.0, &other.0)
678    }
679}
680
681impl Eq for Callback {}
682
683impl fmt::Debug for Callback {
684    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
685        f.debug_struct("<callback>")
686            .finish()
687    }
688}