corosensei/
coroutine.rs

1use core::cell::Cell;
2use core::hint::unreachable_unchecked;
3use core::marker::PhantomData;
4use core::mem::{self, ManuallyDrop};
5use core::ptr;
6
7use crate::arch::{self, STACK_ALIGNMENT};
8#[cfg(feature = "default-stack")]
9use crate::stack::DefaultStack;
10#[cfg(windows)]
11use crate::stack::StackTebFields;
12use crate::stack::{self, StackPointer};
13use crate::trap::CoroutineTrapHandler;
14use crate::unwind::{self, initial_func_abi, CaughtPanic, ForcedUnwindErr};
15use crate::util::{self, EncodedValue};
16
17/// Value returned from resuming a coroutine.
18#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
19pub enum CoroutineResult<Yield, Return> {
20    /// Value returned by a coroutine suspending itself with a `Yielder`.
21    Yield(Yield),
22
23    /// Value returned by a coroutine returning from its main function.
24    Return(Return),
25}
26
27impl<Yield, Return> CoroutineResult<Yield, Return> {
28    /// Returns the `Yield` value as an `Option<Yield>`.
29    pub fn as_yield(self) -> Option<Yield> {
30        match self {
31            CoroutineResult::Yield(val) => Some(val),
32            CoroutineResult::Return(_) => None,
33        }
34    }
35
36    /// Returns the `Return` value as an `Option<Return>`.
37    pub fn as_return(self) -> Option<Return> {
38        match self {
39            CoroutineResult::Yield(_) => None,
40            CoroutineResult::Return(val) => Some(val),
41        }
42    }
43}
44
45/// A coroutine wraps a closure and allows suspending its execution more than
46/// once, returning a value each time.
47///
48/// # Dropping a coroutine
49///
50/// When a coroutine is dropped, its stack must be unwound so that all object on
51/// it are properly dropped. This is done by calling `force_unwind` to unwind
52/// the stack. If `force_unwind` fails then the program is aborted.
53///
54/// See the [`Coroutine::force_unwind`] function for more details.
55///
56/// # `Send`
57///
58/// In the general case, a coroutine can only be sent to another if all of the
59/// data on its stack is `Send`. There is no way to guarantee this using Rust
60/// language features so `Coroutine` does not implement the `Send` trait.
61///
62/// However if all of the code executed by a coroutine is under your control and
63/// you can ensure that all types on the stack when a coroutine is suspended
64/// are `Send` then it is safe to manually implement `Send` for a coroutine.
65#[cfg(feature = "default-stack")]
66pub struct Coroutine<Input, Yield, Return, Stack: stack::Stack = DefaultStack> {
67    // Stack that the coroutine is executing on.
68    stack: Stack,
69
70    // Current stack pointer at which the coroutine state is held. This is
71    // None when the coroutine has completed execution.
72    stack_ptr: Option<StackPointer>,
73
74    // Initial stack pointer value. This is used to detect whether a coroutine
75    // has ever been resumed since it was created.
76    //
77    // This works because it is impossible for a coroutine to revert back to its
78    // initial stack pointer: suspending a coroutine requires pushing several
79    // values to the stack.
80    initial_stack_ptr: StackPointer,
81
82    // Function to call to drop the initial state of a coroutine if it has
83    // never been resumed.
84    drop_fn: unsafe fn(ptr: *mut u8),
85
86    // We want to be covariant over Yield and Return, and contravariant
87    // over Input.
88    //
89    // Effectively this means that we can pass a
90    //   Coroutine<&'a (), &'static (), &'static ()>
91    // to a function that expects a
92    //   Coroutine<&'static (), &'c (), &'d ()>
93    marker: PhantomData<fn(Input) -> CoroutineResult<Yield, Return>>,
94
95    // Coroutine must be !Send.
96    /// ```compile_fail
97    /// fn send<T: Send>() {}
98    /// send::<corosensei::Coroutine<(), ()>>();
99    /// ```
100    marker2: PhantomData<*mut ()>,
101}
102
103/// A coroutine wraps a closure and allows suspending its execution more than
104/// once, returning a value each time.
105///
106/// # Dropping a coroutine
107///
108/// When a coroutine is dropped, its stack must be unwound so that all object on
109/// it are properly dropped. This is done by calling `force_unwind` to unwind
110/// the stack. If `force_unwind` fails then the program is aborted.
111///
112/// See the [`Coroutine::force_unwind`] function for more details.
113///
114/// # `Send`
115///
116/// In the general case, a coroutine can only be sent to another if all of the
117/// data on its stack is `Send`. There is no way to guarantee this using Rust
118/// language features so `Coroutine` does not implement the `Send` trait.
119///
120/// However if all of the code executed by a coroutine is under your control and
121/// you can ensure that all types on the stack when a coroutine is suspended
122/// are `Send` then it is safe to manually implement `Send` for a coroutine.
123#[cfg(not(feature = "default-stack"))]
124pub struct Coroutine<Input, Yield, Return, Stack: stack::Stack> {
125    stack: Stack,
126    stack_ptr: Option<StackPointer>,
127    initial_stack_ptr: StackPointer,
128    drop_fn: unsafe fn(ptr: *mut u8),
129    marker: PhantomData<fn(Input) -> CoroutineResult<Yield, Return>>,
130    marker2: PhantomData<*mut ()>,
131}
132
133// Coroutines can be Sync if the stack is Sync.
134unsafe impl<Input, Yield, Return, Stack: stack::Stack + Sync> Sync
135    for Coroutine<Input, Yield, Return, Stack>
136{
137}
138
139#[cfg(feature = "default-stack")]
140impl<Input, Yield, Return> Coroutine<Input, Yield, Return, DefaultStack> {
141    /// Creates a new coroutine which will execute `func` on a new stack.
142    ///
143    /// This function returns a `Coroutine` which, when resumed, will execute
144    /// `func` to completion. When desired the `func` can suspend itself via
145    /// `Yielder::suspend`.
146    pub fn new<F>(f: F) -> Self
147    where
148        F: FnOnce(&Yielder<Input, Yield>, Input) -> Return,
149        F: 'static,
150        Input: 'static,
151        Yield: 'static,
152        Return: 'static,
153    {
154        Self::with_stack(Default::default(), f)
155    }
156}
157
158impl<Input, Yield, Return, Stack: stack::Stack> Coroutine<Input, Yield, Return, Stack> {
159    /// Creates a new coroutine which will execute `func` on the given stack.
160    ///
161    /// This function returns a coroutine which, when resumed, will execute
162    /// `func` to completion. When desired the `func` can suspend itself via
163    /// [`Yielder::suspend`].
164    pub fn with_stack<F>(stack: Stack, f: F) -> Self
165    where
166        F: FnOnce(&Yielder<Input, Yield>, Input) -> Return,
167        F: 'static,
168        Input: 'static,
169        Yield: 'static,
170        Return: 'static,
171    {
172        // The ABI of the initial function is either "C" or "C-unwind" depending
173        // on whether the "asm-unwind" feature is enabled.
174        initial_func_abi! {
175            unsafe fn coroutine_func<Input, Yield, Return, F>(
176                input: EncodedValue,
177                parent_link: &mut StackPointer,
178                func: *mut F,
179            ) -> !
180            where
181                F: FnOnce(&Yielder<Input, Yield>, Input) -> Return,
182            {
183                // The yielder is a #[repr(transparent)] wrapper around the
184                // parent link on the stack.
185                let yielder = &*(parent_link as *mut StackPointer as *const Yielder<Input, Yield>);
186
187                // Read the function from the stack.
188                debug_assert_eq!(func as usize % mem::align_of::<F>(), 0);
189                let f = func.read();
190
191                // This is the input from the first call to resume(). It is not
192                // possible for a forced unwind to reach this point because we
193                // check if a coroutine has been resumed at least once before
194                // generating a forced unwind.
195                let input : Result<Input, ForcedUnwindErr> = util::decode_val(input);
196                let input = match input {
197                    Ok(input) => input,
198                    #[cfg_attr(feature = "asm-unwind", allow(unreachable_patterns))]
199                    Err(_) => unreachable_unchecked(),
200                };
201
202                // Run the body of the generator, catching any panics.
203                let result = unwind::catch_unwind_at_root(|| f(yielder, input));
204
205                // Return any caught panics to the parent context.
206                let mut result = ManuallyDrop::new(result);
207                arch::switch_and_reset(util::encode_val(&mut result), yielder.stack_ptr.as_ptr());
208            }
209        }
210
211        // Drop function to free the initial state of the coroutine.
212        unsafe fn drop_fn<T>(ptr: *mut u8) {
213            ptr::drop_in_place(ptr as *mut T);
214        }
215
216        unsafe {
217            // Set up the stack so that the coroutine starts executing
218            // coroutine_func. Write the given function object to the stack so
219            // its address is passed to coroutine_func on the first resume.
220            let stack_ptr = arch::init_stack(&stack, coroutine_func::<Input, Yield, Return, F>, f);
221
222            Self {
223                stack,
224                stack_ptr: Some(stack_ptr),
225                initial_stack_ptr: stack_ptr,
226                drop_fn: drop_fn::<F>,
227                marker: PhantomData,
228                marker2: PhantomData,
229            }
230        }
231    }
232
233    /// Resumes execution of this coroutine.
234    ///
235    /// This function will transfer execution to the coroutine and resume from
236    /// where it last left off.
237    ///
238    /// If the coroutine calls [`Yielder::suspend`] then this function returns
239    /// [`CoroutineResult::Yield`] with the value passed to `suspend`.
240    ///
241    /// If the coroutine returns then this function returns
242    /// [`CoroutineResult::Return`] with the return value of the coroutine.
243    ///
244    /// # Panics
245    ///
246    /// Panics if the coroutine has already finished executing.
247    ///
248    /// If the coroutine itself panics during execution then the panic will be
249    /// propagated to this caller.
250    pub fn resume(&mut self, val: Input) -> CoroutineResult<Yield, Return> {
251        unsafe {
252            let stack_ptr = self
253                .stack_ptr
254                .expect("attempt to resume a completed coroutine");
255
256            // If the coroutine terminated then a caught panic may have been
257            // returned, in which case we must resume unwinding.
258            match self.resume_inner(stack_ptr, Ok(val)) {
259                CoroutineResult::Yield(val) => CoroutineResult::Yield(val),
260                CoroutineResult::Return(result) => {
261                    CoroutineResult::Return(unwind::maybe_resume_unwind(result))
262                }
263            }
264        }
265    }
266
267    /// Common code for resuming execution of a coroutine.
268    unsafe fn resume_inner(
269        &mut self,
270        stack_ptr: StackPointer,
271        input: Result<Input, ForcedUnwindErr>,
272    ) -> CoroutineResult<Yield, Result<Return, CaughtPanic>> {
273        // Pre-emptively set the stack pointer to None in case
274        // switch_and_link unwinds.
275        self.stack_ptr = None;
276
277        let mut input = ManuallyDrop::new(input);
278        let (result, stack_ptr) =
279            arch::switch_and_link(util::encode_val(&mut input), stack_ptr, self.stack.base());
280        self.stack_ptr = stack_ptr;
281
282        // Decode the returned value depending on whether the coroutine
283        // terminated.
284        if stack_ptr.is_some() {
285            CoroutineResult::Yield(util::decode_val(result))
286        } else {
287            CoroutineResult::Return(util::decode_val(result))
288        }
289    }
290
291    /// Returns whether this coroutine has been resumed at least once.
292    pub fn started(&self) -> bool {
293        self.stack_ptr != Some(self.initial_stack_ptr)
294    }
295
296    /// Returns whether this coroutine has finished executing.
297    ///
298    /// A coroutine that has returned from its initial function can no longer
299    /// be resumed.
300    pub fn done(&self) -> bool {
301        self.stack_ptr.is_none()
302    }
303
304    /// Forcibly marks the coroutine as having completed, even if it is
305    /// currently suspended in the middle of a function.
306    ///
307    /// # Safety
308    ///
309    /// This is equivalent to a `longjmp` all the way back to the initial
310    /// function of the coroutine, so the same rules apply.
311    ///
312    /// This can only be done safely if there are no objects currently on the
313    /// coroutine's stack that need to execute `Drop` code.
314    pub unsafe fn force_reset(&mut self) {
315        self.stack_ptr = None;
316    }
317
318    /// Unwinds the coroutine stack, dropping any live objects that are
319    /// currently on the stack. This is automatically called when the coroutine
320    /// is dropped.
321    ///
322    /// If the coroutine has already completed then this function is a no-op.
323    ///
324    /// If the coroutine is currently suspended on a `Yielder::suspend` call
325    /// then unwinding it requires the `unwind` feature to be enabled and
326    /// for the crate to be compiled with `-C panic=unwind`.
327    ///
328    /// # Panics
329    ///
330    /// This function panics if the coroutine could not be fully unwound. This
331    /// can happen for one of two reasons:
332    /// - The `ForcedUnwind` panic that is used internally was caught and not
333    ///   rethrown.
334    /// - This crate was compiled without the `unwind` feature and the
335    ///   coroutine is currently suspended in the yielder (`started && !done`).
336    pub fn force_unwind(&mut self) {
337        // If the coroutine has already terminated then there is nothing to do.
338        if let Some(stack_ptr) = self.stack_ptr {
339            self.force_unwind_slow(stack_ptr);
340        }
341    }
342
343    /// Slow path of `force_unwind` when the coroutine is known to not have
344    /// terminated yet.
345    #[cold]
346    fn force_unwind_slow(&mut self, stack_ptr: StackPointer) {
347        // If the coroutine has not started yet then we just need to drop the
348        // initial object.
349        if !self.started() {
350            unsafe {
351                arch::drop_initial_obj(self.stack.base(), stack_ptr, self.drop_fn);
352            }
353            self.stack_ptr = None;
354            return;
355        }
356
357        // If the coroutine is suspended then we need the standard library so
358        // that we can unwind the stack. This also requires that the code be
359        // compiled with -C panic=unwind.
360        #[cfg(feature = "unwind")]
361        {
362            extern crate std;
363
364            let forced_unwind = unwind::ForcedUnwind(stack_ptr);
365            let result = unwind::catch_forced_unwind(|| {
366                #[cfg(not(feature = "asm-unwind"))]
367                let result = unsafe { self.resume_inner(stack_ptr, Err(forced_unwind)) };
368                #[cfg(feature = "asm-unwind")]
369                let result = unsafe { self.resume_with_exception(stack_ptr, forced_unwind) };
370                match result {
371                    CoroutineResult::Yield(_) | CoroutineResult::Return(Ok(_)) => Ok(()),
372                    #[cfg_attr(feature = "asm-unwind", allow(unreachable_patterns))]
373                    CoroutineResult::Return(Err(e)) => Err(e),
374                }
375            });
376
377            match result {
378                Ok(_) => panic!("the ForcedUnwind panic was caught and not rethrown"),
379                Err(e) => {
380                    if let Some(forced_unwind) = e.downcast_ref::<unwind::ForcedUnwind>() {
381                        if forced_unwind.0 == stack_ptr {
382                            return;
383                        }
384                    }
385
386                    std::panic::resume_unwind(e);
387                }
388            }
389        }
390
391        #[cfg(not(feature = "unwind"))]
392        panic!("can't unwind a suspended coroutine without the \"unwind\" feature");
393    }
394
395    /// Variant of `resume_inner` that throws an exception in the context of
396    /// the coroutine instead of passing a value.
397    ///
398    /// Used by `force_unwind`.
399    #[cfg(feature = "asm-unwind")]
400    unsafe fn resume_with_exception(
401        &mut self,
402        stack_ptr: StackPointer,
403        forced_unwind: unwind::ForcedUnwind,
404    ) -> CoroutineResult<Yield, Result<Return, CaughtPanic>> {
405        // Pre-emptively set the stack pointer to None in case
406        // switch_and_throw unwinds.
407        self.stack_ptr = None;
408
409        let (result, stack_ptr) =
410            arch::switch_and_throw(forced_unwind, stack_ptr, self.stack.base());
411        self.stack_ptr = stack_ptr;
412
413        // Decode the returned value depending on whether the coroutine
414        // terminated.
415        if stack_ptr.is_some() {
416            CoroutineResult::Yield(util::decode_val(result))
417        } else {
418            CoroutineResult::Return(util::decode_val(result))
419        }
420    }
421
422    /// Extracts the stack from a coroutine that has finished executing.
423    ///
424    /// This allows the stack to be re-used for another coroutine.
425    #[allow(unused_mut)]
426    pub fn into_stack(mut self) -> Stack {
427        assert!(
428            self.done(),
429            "cannot extract stack from an incomplete coroutine"
430        );
431
432        #[cfg(windows)]
433        unsafe {
434            arch::update_stack_teb_fields(&mut self.stack);
435        }
436
437        unsafe {
438            let stack = ptr::read(&self.stack);
439            mem::forget(self);
440            stack
441        }
442    }
443
444    /// Returns a [`CoroutineTrapHandler`] which can be used to handle traps that
445    /// occur inside the coroutine. Examples of traps that can be handled are
446    /// invalid memory accesses and stack overflows.
447    ///
448    /// The returned [`CoroutineTrapHandler`] can be used in a trap handler to
449    /// force the trapping coroutine to return with a specific value, after
450    /// which is it considered to have completed and can no longer be resumed.
451    ///
452    /// Needless to say, this is extremely unsafe and must be used with extreme
453    /// care. See [`CoroutineTrapHandler::setup_trap_handler`] for the exact
454    /// safety requirements.
455    pub fn trap_handler(&self) -> CoroutineTrapHandler<Return> {
456        CoroutineTrapHandler {
457            stack_base: self.stack.base(),
458            stack_limit: self.stack.limit(),
459            marker: PhantomData,
460        }
461    }
462}
463
464impl<Input, Yield, Return, Stack: stack::Stack> Drop for Coroutine<Input, Yield, Return, Stack> {
465    fn drop(&mut self) {
466        let guard = scopeguard::guard((), |()| {
467            // We can't catch panics in #![no_std], force an abort using
468            // a double-panic.
469            panic!("cannot propagte coroutine panic with #![no_std]");
470        });
471        self.force_unwind();
472        mem::forget(guard);
473
474        #[cfg(windows)]
475        unsafe {
476            arch::update_stack_teb_fields(&mut self.stack);
477        }
478    }
479}
480
481/// `Yielder` is an interface provided to a coroutine which allows it to suspend
482/// itself and pass values in and out of the coroutine.
483///
484/// Multiple references can be created to the same `Yielder`, but these cannot
485/// be moved to another thread.
486#[repr(transparent)]
487pub struct Yielder<Input, Yield> {
488    // Internally the Yielder is just the parent link on the stack which is
489    // updated every time resume() is called.
490    stack_ptr: Cell<StackPointer>,
491    marker: PhantomData<fn(Yield) -> Input>,
492}
493
494impl<Input, Yield> Yielder<Input, Yield> {
495    /// Suspends the execution of a currently running coroutine.
496    ///
497    /// This function will switch control back to the original caller of
498    /// [`Coroutine::resume`]. This function will then return once the
499    /// [`Coroutine::resume`] function is called again.
500    pub fn suspend(&self, val: Yield) -> Input {
501        unsafe {
502            let mut val = ManuallyDrop::new(val);
503            let result = arch::switch_yield(util::encode_val(&mut val), self.stack_ptr.as_ptr());
504            unwind::maybe_force_unwind(util::decode_val(result))
505        }
506    }
507
508    /// Executes some code on the stack of the parent context (the one who
509    /// last resumed the current coroutine).
510    ///
511    /// This is particularly useful when executing on a coroutine with limited
512    /// stack space: stack-heavy operations can be performed in a way that
513    /// avoids stack overflows on the coroutine stack.
514    ///
515    /// # Panics
516    ///
517    /// Any panics in the provided closure are automatically propagated back up
518    /// to the caller of this function.
519    pub fn on_parent_stack<F, R>(&self, f: F) -> R
520    where
521        F: FnOnce() -> R,
522        // The F: Send bound here is somewhat subtle but important. It exists to
523        // prevent references to the Yielder from being passed into the parent
524        // thread.
525        F: Send,
526    {
527        // Get the top of the parent stack.
528        let stack_ptr = unsafe {
529            StackPointer::new_unchecked(self.stack_ptr.get().get() - arch::PARENT_LINK_OFFSET)
530        };
531
532        // Create a virtual stack that starts below the parent stack.
533        let stack = unsafe { ParentStack::new(stack_ptr) };
534
535        on_stack(stack, f)
536    }
537}
538
539/// Executes some code on the given stack.
540///
541/// This is useful when running with limited stack space: stack-intensive
542/// computation can be executed on a separate stack with more space.
543///
544/// # Panics
545///
546/// Any panics in the provided closure are automatically propagated back up to
547/// the caller of this function.
548pub fn on_stack<F, R>(stack: impl stack::Stack, f: F) -> R
549where
550    F: FnOnce() -> R,
551{
552    // Union to hold both the function and its result.
553    union FuncOrResult<F, R> {
554        func: ManuallyDrop<F>,
555        result: ManuallyDrop<Result<R, CaughtPanic>>,
556    }
557
558    initial_func_abi! {
559        unsafe fn wrapper<F, R>(ptr: *mut u8)
560        where
561            F: FnOnce() -> R,
562        {
563            // Read the function out of the union.
564            let data = &mut *(ptr as *mut FuncOrResult<F, R>);
565            let func = ManuallyDrop::take(&mut data.func);
566
567            // Call it.
568            let result = unwind::catch_unwind_at_root(func);
569
570            // And write the result back to the union.
571            data.result = ManuallyDrop::new(result);
572        }
573    }
574
575    unsafe {
576        let mut data = FuncOrResult {
577            func: ManuallyDrop::new(f),
578        };
579
580        // Call the wrapper function on the new stack.
581        arch::on_stack(&mut data as *mut _ as *mut u8, stack, wrapper::<F, R>);
582
583        // Re-throw any panics if one was caught.
584        unwind::maybe_resume_unwind(ManuallyDrop::take(&mut data.result))
585    }
586}
587
588/// Custom stack implementation used by `on_parent_stack`. This is a private
589/// type because it is generally unsafe to use:
590struct ParentStack {
591    /// Base address of the stack, below any existing data on the parent stack.
592    stack_base: StackPointer,
593
594    /// Stack pointer value of the parent stack. This is not the same as
595    /// `stack_base` since the latter has been aligned to `STACK_ALIGNMENT`.
596    ///
597    /// This is needed on Windows to access the saved TEB fields on the parent
598    /// stack.
599    #[cfg(windows)]
600    stack_ptr: StackPointer,
601}
602
603impl ParentStack {
604    #[inline]
605    unsafe fn new(stack_ptr: StackPointer) -> Self {
606        let stack_base = StackPointer::new_unchecked(stack_ptr.get() & !(STACK_ALIGNMENT - 1));
607        Self {
608            stack_base,
609            #[cfg(windows)]
610            stack_ptr,
611        }
612    }
613}
614
615unsafe impl stack::Stack for ParentStack {
616    #[inline]
617    fn base(&self) -> StackPointer {
618        self.stack_base
619    }
620
621    // We can get away with a dummy implementation here because we never expose
622    // the coroutine type to the user. This is only used for creating a
623    // CoroutineTrapHandler.
624    #[inline]
625    fn limit(&self) -> StackPointer {
626        self.stack_base
627    }
628
629    #[inline]
630    #[cfg(windows)]
631    fn teb_fields(&self) -> StackTebFields {
632        unsafe { arch::read_parent_stack_teb_fields(self.stack_ptr) }
633    }
634
635    #[inline]
636    #[cfg(windows)]
637    fn update_teb_fields(&mut self, stack_limit: usize, guaranteed_stack_bytes: usize) {
638        unsafe {
639            arch::update_parent_stack_teb_fields(
640                self.stack_ptr,
641                stack_limit,
642                guaranteed_stack_bytes,
643            );
644        }
645    }
646}