cosync/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(rust_2018_idioms)]
3#![deny(missing_docs)]
4#![deny(rustdoc::all)]
5
6// this is vendored code from the `futures-rs` crate, to avoid
7// having a huge dependency when we only need a little bit
8mod futures;
9use crate::futures::{enter::enter, waker_ref, ArcWake, FuturesUnordered};
10
11use std::{
12    collections::VecDeque,
13    fmt,
14    future::Future,
15    marker::PhantomData,
16    ops,
17    pin::Pin,
18    ptr::NonNull,
19    sync::{
20        atomic::{AtomicBool, Ordering},
21        Arc, Mutex, Weak,
22    },
23    task::{Context, Poll},
24    thread::{self, Thread},
25};
26
27/// A single-threaded, sequential, parameterized async task queue.
28///
29/// This executor allows you to queue multiple tasks in sequence, and to
30/// queue tasks within other tasks. Tasks are done in the order they
31/// are queued.
32///
33/// You can queue a task by using [queue](Cosync::queue), by spawning a [CosyncQueueHandle]
34/// and calling [queue](CosyncQueueHandle::queue), or, within a task, calling
35/// [queue_task](CosyncInput::queue) on [CosyncInput].
36#[derive(Debug)]
37pub struct Cosync<T: ?Sized> {
38    pool: FuturesUnordered<FutureObject>,
39    incoming: Arc<Mutex<VecDeque<FutureObject>>>,
40    data: Box<Option<NonNull<T>>>,
41    kill_box: Arc<()>,
42}
43
44impl<T: 'static + ?Sized> Cosync<T> {
45    /// Create a new, empty queue of tasks.
46    pub fn new() -> Self {
47        Self {
48            pool: FuturesUnordered::new(),
49            incoming: Default::default(),
50            data: Box::new(None),
51            kill_box: Arc::new(()),
52        }
53    }
54
55    /// Returns the number of tasks queued. This *includes* the task currently being executed. Use
56    /// [is_executing] to see if there is a task currently being executed (ie, it returned `Pending`
57    /// at some point in its execution).
58    ///
59    /// [is_executing]: Self::is_executing
60    pub fn len(&self) -> usize {
61        let one = if self.is_executing() { 1 } else { 0 };
62
63        one + self.incoming.lock().unwrap().len()
64    }
65
66    /// Returns true if no futures are being executed *and* there are no futures in the queue.
67    pub fn is_empty(&self) -> bool {
68        !self.is_executing() && self.incoming.lock().unwrap().is_empty()
69    }
70
71    /// Returns true if `cosync` has a `Pending` future. It is possible for
72    /// the `cosync` to have no `Pending` future, but to have tasks queued still.
73    pub fn is_executing(&self) -> bool {
74        !self.pool.is_empty()
75    }
76
77    /// Creates a queue handle which can be used to spawn tasks.
78    pub fn create_queue_handle(&self) -> CosyncQueueHandle<T> {
79        let heap_ptr = &*self.data as *const Option<_>;
80
81        CosyncQueueHandle {
82            heap_ptr,
83            incoming: self.incoming.clone(),
84            kill_box: Arc::downgrade(&self.kill_box),
85        }
86    }
87
88    /// Adds a new Task to the TaskQueue.
89    pub fn queue<Task, Out>(&mut self, task: Task)
90    where
91        Task: FnOnce(CosyncInput<T>) -> Out + Send + 'static,
92        Out: Future<Output = ()> + Send,
93    {
94        let queue_handle = self.create_queue_handle();
95
96        queue_handle.queue(task)
97    }
98
99    /// Run all tasks in the queue to completion. You probably want `run_until_stall`.
100    ///
101    /// ```
102    /// # use cosync::Cosync;
103    ///
104    /// let mut cosync: Cosync<i32> = Cosync::new();
105    /// cosync.queue(move |mut input| async move {
106    ///     let mut input = input.get();
107    ///     *input = 10;
108    /// });
109    ///
110    /// let mut value = 0;
111    /// cosync.run_blocking(&mut value);
112    /// assert_eq!(value, 10);
113    /// ```
114    ///
115    /// The function will block the calling thread until *all* tasks in the pool
116    /// are complete, including any spawned while running existing tasks.
117    pub fn run_blocking(&mut self, parameter: &mut T) {
118        // hoist the T:
119        unsafe {
120            *self.data = Some(NonNull::new_unchecked(parameter as *mut _));
121        }
122
123        run_executor(|cx| self.poll_pool(cx));
124
125        // we null out here so we don't do bad things
126        *self.data = None;
127    }
128
129    /// Runs all tasks in the queue and returns if no more progress can be made
130    /// on any task.
131    ///
132    /// ```
133    /// use cosync::{sleep_ticks, Cosync};
134    ///
135    /// let mut cosync = Cosync::new();
136    /// cosync.queue(move |mut input| async move {
137    ///     *input.get() = 10;
138    ///     // this will make the executor stall for a call
139    ///     // we call `run_until_stall` an additional time,
140    ///     // so we'll complete this 1 tick sleep.
141    ///     sleep_ticks(1).await;
142    ///
143    ///     *input.get() = 20;
144    /// });
145    ///
146    /// let mut value = 0;
147    /// cosync.run_until_stall(&mut value);
148    /// assert_eq!(value, 10);
149    /// cosync.run_until_stall(&mut value);
150    /// assert_eq!(value, 20);
151    /// ```
152    ///
153    /// This function will not block the calling thread and will return the moment
154    /// that there are no tasks left for which progress can be made;
155    /// remaining incomplete tasks in the pool can continue with further use of one
156    /// of the pool's run or poll methods. While the function is running, all tasks
157    /// in the pool will try to make progress.
158    pub fn run_until_stall(&mut self, parameter: &mut T) {
159        // hoist the T:
160        unsafe {
161            *self.data = Some(NonNull::new_unchecked(parameter as *mut _));
162        }
163
164        poll_executor(|ctx| {
165            let _output = self.poll_pool(ctx);
166        });
167
168        // null it
169        *self.data = None;
170    }
171
172    // Make maximal progress on the entire pool of spawned task, returning `Ready`
173    // if the pool is empty and `Pending` if no further progress can be made.
174    fn poll_pool(&mut self, cx: &mut Context<'_>) -> Poll<()> {
175        // state for the FuturesUnordered, which will never be used
176        loop {
177            let ret = self.poll_pool_once(cx);
178
179            // no queued tasks; we may be done
180            match ret {
181                Poll::Pending => return Poll::Pending,
182                Poll::Ready(None) => return Poll::Ready(()),
183                _ => {}
184            }
185        }
186    }
187
188    // Try make minimal progress on the pool of spawned tasks
189    fn poll_pool_once(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
190        // grab our next task...
191        if self.pool.is_empty() {
192            if let Some(task) = self.incoming.lock().unwrap().pop_front() {
193                self.pool.push(task)
194            }
195        }
196
197        // try to execute the next ready future
198        Pin::new(&mut self.pool).poll_next(cx)
199    }
200}
201
202/// A handle to spawn tasks.
203///
204/// # Examples
205/// ```
206/// # use cosync::Cosync;
207/// let mut cosync = Cosync::new();
208/// let handler = cosync.create_queue_handle();
209///
210/// // make a thread and join it...
211/// std::thread::spawn(move || {
212///     handler.queue(|mut input| async move {
213///         *input.get() = 20;
214///     });
215/// })
216/// .join()
217/// .unwrap();
218///
219/// let mut value = 1;
220/// cosync.run_blocking(&mut value);
221/// assert_eq!(value, 20);
222/// ```
223#[derive(Debug)]
224pub struct CosyncQueueHandle<T: ?Sized> {
225    heap_ptr: *const Option<NonNull<T>>,
226    incoming: Arc<Mutex<VecDeque<FutureObject>>>,
227    kill_box: Weak<()>,
228}
229
230impl<T: 'static + ?Sized> CosyncQueueHandle<T> {
231    /// Adds a new Task to the TaskQueue.
232    pub fn queue<Task, Out>(&self, task: Task)
233    where
234        Task: FnOnce(CosyncInput<T>) -> Out + Send + 'static,
235        Out: Future<Output = ()> + Send,
236    {
237        queue_task(task, self.kill_box.clone(), self.heap_ptr, &self.incoming);
238    }
239}
240
241// safety:
242// we guarantee with a kill counter that the main `.get` of CosyncInput
243// never dereferences invalid data, and it's only made in the same thread
244// as Cosync, so we should never have a problem with multithreaded access
245// at the same time.
246#[allow(clippy::non_send_fields_in_send_ty)]
247unsafe impl<T: ?Sized> Send for CosyncQueueHandle<T> {}
248unsafe impl<T: ?Sized> Sync for CosyncQueueHandle<T> {}
249
250impl<T: ?Sized> Clone for CosyncQueueHandle<T> {
251    fn clone(&self) -> Self {
252        Self {
253            heap_ptr: self.heap_ptr,
254            incoming: self.incoming.clone(),
255            kill_box: self.kill_box.clone(),
256        }
257    }
258}
259
260/// A guarded pointer to create a [CosyncInputGuard] by [get] and to queue more tasks by [queue]
261///
262/// [queue]: Self::queue
263/// [get]: Self::get
264#[derive(Debug)]
265pub struct CosyncInput<T: ?Sized>(CosyncQueueHandle<T>);
266
267impl<T: 'static + ?Sized> CosyncInput<T> {
268    /// Gets the underlying [CosyncInputGuard].
269    pub fn get(&mut self) -> CosyncInputGuard<'_, T> {
270        // if you find this guard, it means that you somehow moved the `CosyncInput` out of
271        // the closure, and then dropped the `Cosync`. Why would you do that? Don't do that.
272        assert!(
273            Weak::strong_count(&self.0.kill_box) == 1,
274            "cosync was dropped improperly"
275        );
276
277        // we can always dereference this data, as we maintain
278        // that it's always present.
279        let o = unsafe {
280            (&*self.0.heap_ptr)
281                .expect("cosync was not initialized this run correctly")
282                .as_mut()
283        };
284
285        CosyncInputGuard(o, PhantomData)
286    }
287
288    /// Queues a new task. This goes to the back of queue.
289    pub fn queue<Task, Out>(&self, task: Task)
290    where
291        Task: Fn(CosyncInput<T>) -> Out + Send + 'static,
292        Out: Future<Output = ()> + Send,
293    {
294        self.0.queue(task)
295    }
296
297    /// Creates a queue handle which can be used to spawn tasks.
298    pub fn create_queue_handle(&self) -> CosyncQueueHandle<T> {
299        self.0.clone()
300    }
301}
302
303// safety:
304// we create `CosyncInput` per task, and it doesn't escape our closure.
305// therefore, it's `*const` field should only be accessible when we know
306// it's valid.
307#[allow(clippy::non_send_fields_in_send_ty)]
308unsafe impl<T: ?Sized> Send for CosyncInput<T> {}
309unsafe impl<T: ?Sized> Sync for CosyncInput<T> {}
310
311/// A guarded pointer.
312///
313/// This exists to prevent holding onto the `CosyncInputGuard` over `.await` calls. It will need to
314/// be fetched again from [CosyncInput] after awaits.
315pub struct CosyncInputGuard<'a, T: ?Sized>(&'a mut T, PhantomData<*const u8>);
316
317impl<'a, T: ?Sized> ops::Deref for CosyncInputGuard<'a, T> {
318    type Target = T;
319
320    fn deref(&self) -> &Self::Target {
321        self.0
322    }
323}
324
325impl<'a, T: ?Sized> ops::DerefMut for CosyncInputGuard<'a, T> {
326    fn deref_mut(&mut self) -> &mut Self::Target {
327        self.0
328    }
329}
330
331impl<T: 'static> Default for Cosync<T> {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337struct FutureObject(Pin<Box<dyn Future<Output = ()> + 'static>>);
338impl Future for FutureObject {
339    type Output = ();
340
341    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
342        Pin::new(&mut self.0).poll(cx)
343    }
344}
345
346impl fmt::Debug for FutureObject {
347    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
348        f.debug_struct("FutureObject").finish()
349    }
350}
351
352pub(crate) struct ThreadNotify {
353    /// The (single) executor thread.
354    pub thread: Thread,
355    /// A flag to ensure a wakeup (i.e. `unpark()`) is not "forgotten"
356    /// before the next `park()`, which may otherwise happen if the code
357    /// being executed as part of the future(s) being polled makes use of
358    /// park / unpark calls of its own, i.e. we cannot assume that no other
359    /// code uses park / unpark on the executing `thread`.
360    pub unparked: AtomicBool,
361}
362
363impl ArcWake for ThreadNotify {
364    fn wake_by_ref(this: &Arc<Self>) {
365        // Make sure the wakeup is remembered until the next `park()`.
366        let unparked = this.unparked.swap(true, Ordering::Relaxed);
367        if !unparked {
368            // If the thread has not been unparked yet, it must be done
369            // now. If it was actually parked, it will run again,
370            // otherwise the token made available by `unpark`
371            // may be consumed before reaching `park()`, but `unparked`
372            // ensures it is not forgotten.
373            this.thread.unpark();
374        }
375    }
376}
377
378thread_local! {
379    static CURRENT_THREAD_NOTIFY: Arc<ThreadNotify> = Arc::new(ThreadNotify {
380        thread: thread::current(),
381        unparked: AtomicBool::new(false),
382    });
383}
384
385// Set up and run a basic single-threaded spawner loop, invoking `f` on each
386// turn.
387fn run_executor<T, F>(mut work_on_future: F) -> T
388where
389    F: FnMut(&mut Context<'_>) -> Poll<T>,
390{
391    let _enter = enter().expect(
392        "cannot execute `LocalPool` executor from within \
393         another executor",
394    );
395
396    CURRENT_THREAD_NOTIFY.with(|thread_notify| {
397        let waker = waker_ref::waker_ref(thread_notify);
398        let mut cx = Context::from_waker(&waker);
399        loop {
400            if let Poll::Ready(t) = work_on_future(&mut cx) {
401                return t;
402            }
403            // Consume the wakeup that occurred while executing `f`, if any.
404            let unparked = thread_notify.unparked.swap(false, Ordering::Acquire);
405            if !unparked {
406                // No wakeup occurred. It may occur now, right before parking,
407                // but in that case the token made available by `unpark()`
408                // is guaranteed to still be available and `park()` is a no-op.
409                thread::park();
410                // When the thread is unparked, `unparked` will have been set
411                // and needs to be unset before the next call to `f` to avoid
412                // a redundant loop iteration.
413                thread_notify.unparked.store(false, Ordering::Release);
414            }
415        }
416    })
417}
418
419fn poll_executor<T, F: FnMut(&mut Context<'_>) -> T>(mut f: F) -> T {
420    let _enter = enter().expect(
421        "cannot execute `LocalPool` executor from within \
422         another executor",
423    );
424
425    CURRENT_THREAD_NOTIFY.with(|thread_notify| {
426        let waker = waker_ref::waker_ref(thread_notify);
427        let mut cx = Context::from_waker(&waker);
428        f(&mut cx)
429    })
430}
431
432/// Adds a new Task to the TaskQueue.
433fn queue_task<T: 'static + ?Sized, Task, Out>(
434    task: Task,
435    kill_box: Weak<()>,
436    heap_ptr: *const Option<NonNull<T>>,
437    incoming: &Arc<Mutex<VecDeque<FutureObject>>>,
438) where
439    Task: FnOnce(CosyncInput<T>) -> Out + Send + 'static,
440    Out: Future<Output = ()> + Send,
441{
442    // force the future to move...
443    let task = task;
444    let sec = CosyncInput(CosyncQueueHandle {
445        heap_ptr,
446        incoming: incoming.clone(),
447        kill_box,
448    });
449
450    let our_cb = Box::pin(async move {
451        task(sec).await;
452    });
453
454    incoming.lock().unwrap().push_back(FutureObject(our_cb));
455}
456
457/// Sleep the `Cosync` for a given number of calls to `run_until_stall`.
458///
459/// If you run `run_until_stall` once per tick in your main loop, then
460/// this will sleep for that number of ticks.
461/// If you run `run`
462pub fn sleep_ticks(ticks: usize) -> SleepForTick {
463    SleepForTick::new(ticks)
464}
465
466/// A helper struct which registers a sleep for a given number of ticks.
467#[derive(Clone, Copy, Debug)]
468#[doc(hidden)] // so users only see `sleep_ticks` above.
469pub struct SleepForTick(pub usize);
470
471impl SleepForTick {
472    /// Sleep for the number of ticks provided.
473    pub fn new(ticks: usize) -> Self {
474        Self(ticks)
475    }
476}
477
478impl Future for SleepForTick {
479    type Output = ();
480
481    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
482        if self.0 == 0 {
483            Poll::Ready(())
484        } else {
485            self.0 -= 1;
486
487            // temp: this is relatively expensive.
488            // we should be able to just register this at will
489            cx.waker().wake_by_ref();
490
491            Poll::Pending
492        }
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    static_assertions::assert_not_impl_all!(CosyncInputGuard<'_, i32>: Send);
501
502    #[test]
503    fn ordering() {
504        let mut cosync = Cosync::new();
505
506        let mut value = 0;
507        cosync.queue(|_i| async move {
508            println!("actual task body!");
509        });
510        cosync.run_until_stall(&mut value);
511    }
512
513    #[test]
514    #[allow(clippy::needless_late_init)]
515    fn pool_is_sequential() {
516        // notice that value is declared here
517        let mut value;
518
519        let mut executor: Cosync<i32> = Cosync::new();
520        executor.queue(move |mut input| async move {
521            let mut input = input.get();
522
523            assert_eq!(*input, 10);
524            *input = 10;
525        });
526
527        executor.queue(move |mut input| async move {
528            assert_eq!(*input.get(), 10);
529
530            // this will make the executor sleep, stall,
531            // and exit out of this tick
532            // we call `run_until_stall` an additional time,
533            // so we'll complete this 1 tick sleep.
534            let sleep = SleepForTick(1);
535            sleep.await;
536
537            let input = &mut *input.get();
538            assert_eq!(*input, 30);
539            *input = 0;
540        });
541
542        // initialized here, after tasks are made
543        // (so code is correctly being deferred)
544        value = 10;
545        executor.run_until_stall(&mut value);
546        value = 30;
547        executor.run_until_stall(&mut value);
548        assert_eq!(value, 0);
549    }
550
551    #[test]
552    fn run_until_stalled_stalls() {
553        let mut cosync = Cosync::new();
554
555        cosync.queue(move |mut input| async move {
556            *input.get() = 10;
557            // this will make the executor stall for a call
558            // we call `run_until_stall` an additional time,
559            // so we'll complete this 1 tick sleep.
560            sleep_ticks(1).await;
561
562            *input.get() = 20;
563        });
564
565        let mut value = 0;
566        cosync.run_until_stall(&mut value);
567        assert_eq!(value, 10);
568        cosync.run_until_stall(&mut value);
569        assert_eq!(value, 20);
570    }
571
572    #[test]
573    #[allow(clippy::needless_late_init)]
574    fn pool_remains_sequential() {
575        // notice that value is declared here
576        let mut value;
577
578        let mut executor: Cosync<i32> = Cosync::new();
579        executor.queue(move |mut input| async move {
580            println!("starting task 1");
581            *input.get() = 10;
582
583            sleep_ticks(100).await;
584
585            *input.get() = 20;
586        });
587
588        executor.queue(move |mut input| async move {
589            assert_eq!(*input.get(), 20);
590        });
591
592        value = 0;
593        executor.run_until_stall(&mut value);
594    }
595
596    #[test]
597    #[allow(clippy::needless_late_init)]
598    fn pool_is_still_sequential() {
599        // notice that value is declared here
600        let mut value;
601
602        let mut executor: Cosync<i32> = Cosync::new();
603        executor.queue(move |mut input| async move {
604            println!("starting task 1");
605            *input.get() = 10;
606
607            input.queue(move |mut input| async move {
608                println!("starting task 3");
609                assert_eq!(*input.get(), 20);
610
611                *input.get() = 30;
612            });
613        });
614
615        executor.queue(move |mut input| async move {
616            println!("starting task 2");
617            *input.get() = 20;
618        });
619
620        // initialized here, after tasks are made
621        // (so code is correctly being deferred)
622        value = 0;
623        executor.run_until_stall(&mut value);
624        assert_eq!(value, 30);
625    }
626
627    #[test]
628    #[allow(clippy::needless_late_init)]
629    fn cosync_can_be_moved() {
630        // notice that value is declared here
631        let mut value;
632
633        let mut executor: Cosync<i32> = Cosync::new();
634        executor.queue(move |mut input| async move {
635            println!("starting task 1");
636            *input.get() = 10;
637
638            sleep_ticks(1).await;
639
640            *input.get() = 20;
641        });
642
643        // initialized here, after tasks are made
644        // (so code is correctly being deferred)
645        value = 0;
646        executor.run_until_stall(&mut value);
647        assert_eq!(value, 10);
648
649        // move it somewhere else..
650        let mut executor = Box::new(executor);
651        executor.run_until_stall(&mut value);
652
653        assert_eq!(value, 20);
654    }
655
656    #[test]
657    #[should_panic(expected = "cosync was dropped improperly")]
658    fn ub_on_move_is_prevented() {
659        let (sndr, rx) = std::sync::mpsc::channel();
660        let mut executor: Cosync<i32> = Cosync::new();
661
662        executor.queue(move |input| async move {
663            let sndr: std::sync::mpsc::Sender<_> = sndr;
664            sndr.send(input).unwrap();
665        });
666
667        let mut value = 0;
668        executor.run_blocking(&mut value);
669        drop(executor);
670
671        // the executor was dropped. whoopsie!
672        let mut v = rx.recv().unwrap();
673        *v.get() = 20;
674    }
675
676    #[test]
677    fn threading() {
678        let mut cosync = Cosync::new();
679        let handler = cosync.create_queue_handle();
680
681        // make a thread and join it...
682        std::thread::spawn(move || {
683            handler.queue(|mut input| async move {
684                *input.get() = 20;
685            });
686        })
687        .join()
688        .unwrap();
689
690        let mut value = 1;
691        cosync.run_blocking(&mut value);
692        assert_eq!(value, 20);
693    }
694
695    #[test]
696    fn trybuild() {
697        let t = trybuild::TestCases::new();
698        t.compile_fail("tests/try_build/*.rs");
699    }
700
701    #[test]
702    fn dynamic_dispatch() {
703        trait DynDispatch {
704            fn test(&self) -> i32;
705        }
706
707        impl DynDispatch for i32 {
708            fn test(&self) -> i32 {
709                *self
710            }
711        }
712
713        impl DynDispatch for &'static str {
714            fn test(&self) -> i32 {
715                self.parse().unwrap()
716            }
717        }
718
719        let mut cosync: Cosync<dyn DynDispatch> = Cosync::new();
720        cosync.queue(|mut input: CosyncInput<dyn DynDispatch>| async move {
721            {
722                let inner: &mut dyn DynDispatch = &mut *input.get();
723                assert_eq!(inner.test(), 3);
724            }
725
726            sleep_ticks(1).await;
727
728            {
729                let inner: &mut dyn DynDispatch = &mut *input.get();
730                assert_eq!(inner.test(), 3);
731            }
732        });
733
734        cosync.run_until_stall(&mut 3);
735        cosync.run_until_stall(&mut "3");
736    }
737
738    #[test]
739    fn unsized_type() {
740        let mut cosync: Cosync<str> = Cosync::new();
741
742        cosync.queue(|mut input| async move {
743            let input_guard = input.get();
744            let inner_str: &str = &input_guard;
745            println!("inner str = {}", inner_str);
746        });
747    }
748
749    #[test]
750    fn can_move_non_copy() {
751        let mut cosync: Cosync<i32> = Cosync::new();
752
753        let my_vec = vec![10];
754
755        cosync.queue(|_input| async move {
756            let mut vec = my_vec;
757            vec.push(10);
758
759            assert_eq!(*vec, [10, 10]);
760        });
761    }
762}