corosensei/coroutine.rs
1use core::cell::Cell;
2use core::hint::unreachable_unchecked;
3use core::marker::PhantomData;
4use core::mem::{self, ManuallyDrop};
5use core::ptr;
6
7use crate::arch::{self, STACK_ALIGNMENT};
8#[cfg(feature = "default-stack")]
9use crate::stack::DefaultStack;
10#[cfg(windows)]
11use crate::stack::StackTebFields;
12use crate::stack::{self, StackPointer};
13use crate::trap::CoroutineTrapHandler;
14use crate::unwind::{self, initial_func_abi, CaughtPanic, ForcedUnwindErr};
15use crate::util::{self, EncodedValue};
16
17/// Value returned from resuming a coroutine.
18#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
19pub enum CoroutineResult<Yield, Return> {
20 /// Value returned by a coroutine suspending itself with a `Yielder`.
21 Yield(Yield),
22
23 /// Value returned by a coroutine returning from its main function.
24 Return(Return),
25}
26
27impl<Yield, Return> CoroutineResult<Yield, Return> {
28 /// Returns the `Yield` value as an `Option<Yield>`.
29 pub fn as_yield(self) -> Option<Yield> {
30 match self {
31 CoroutineResult::Yield(val) => Some(val),
32 CoroutineResult::Return(_) => None,
33 }
34 }
35
36 /// Returns the `Return` value as an `Option<Return>`.
37 pub fn as_return(self) -> Option<Return> {
38 match self {
39 CoroutineResult::Yield(_) => None,
40 CoroutineResult::Return(val) => Some(val),
41 }
42 }
43}
44
45/// A coroutine wraps a closure and allows suspending its execution more than
46/// once, returning a value each time.
47///
48/// # Dropping a coroutine
49///
50/// When a coroutine is dropped, its stack must be unwound so that all object on
51/// it are properly dropped. This is done by calling `force_unwind` to unwind
52/// the stack. If `force_unwind` fails then the program is aborted.
53///
54/// See the [`Coroutine::force_unwind`] function for more details.
55///
56/// # `Send`
57///
58/// In the general case, a coroutine can only be sent to another if all of the
59/// data on its stack is `Send`. There is no way to guarantee this using Rust
60/// language features so `Coroutine` does not implement the `Send` trait.
61///
62/// However if all of the code executed by a coroutine is under your control and
63/// you can ensure that all types on the stack when a coroutine is suspended
64/// are `Send` then it is safe to manually implement `Send` for a coroutine.
65#[cfg(feature = "default-stack")]
66pub struct Coroutine<Input, Yield, Return, Stack: stack::Stack = DefaultStack> {
67 // Stack that the coroutine is executing on.
68 stack: Stack,
69
70 // Current stack pointer at which the coroutine state is held. This is
71 // None when the coroutine has completed execution.
72 stack_ptr: Option<StackPointer>,
73
74 // Initial stack pointer value. This is used to detect whether a coroutine
75 // has ever been resumed since it was created.
76 //
77 // This works because it is impossible for a coroutine to revert back to its
78 // initial stack pointer: suspending a coroutine requires pushing several
79 // values to the stack.
80 initial_stack_ptr: StackPointer,
81
82 // Function to call to drop the initial state of a coroutine if it has
83 // never been resumed.
84 drop_fn: unsafe fn(ptr: *mut u8),
85
86 // We want to be covariant over Yield and Return, and contravariant
87 // over Input.
88 //
89 // Effectively this means that we can pass a
90 // Coroutine<&'a (), &'static (), &'static ()>
91 // to a function that expects a
92 // Coroutine<&'static (), &'c (), &'d ()>
93 marker: PhantomData<fn(Input) -> CoroutineResult<Yield, Return>>,
94
95 // Coroutine must be !Send.
96 /// ```compile_fail
97 /// fn send<T: Send>() {}
98 /// send::<corosensei::Coroutine<(), ()>>();
99 /// ```
100 marker2: PhantomData<*mut ()>,
101}
102
103/// A coroutine wraps a closure and allows suspending its execution more than
104/// once, returning a value each time.
105///
106/// # Dropping a coroutine
107///
108/// When a coroutine is dropped, its stack must be unwound so that all object on
109/// it are properly dropped. This is done by calling `force_unwind` to unwind
110/// the stack. If `force_unwind` fails then the program is aborted.
111///
112/// See the [`Coroutine::force_unwind`] function for more details.
113///
114/// # `Send`
115///
116/// In the general case, a coroutine can only be sent to another if all of the
117/// data on its stack is `Send`. There is no way to guarantee this using Rust
118/// language features so `Coroutine` does not implement the `Send` trait.
119///
120/// However if all of the code executed by a coroutine is under your control and
121/// you can ensure that all types on the stack when a coroutine is suspended
122/// are `Send` then it is safe to manually implement `Send` for a coroutine.
123#[cfg(not(feature = "default-stack"))]
124pub struct Coroutine<Input, Yield, Return, Stack: stack::Stack> {
125 stack: Stack,
126 stack_ptr: Option<StackPointer>,
127 initial_stack_ptr: StackPointer,
128 drop_fn: unsafe fn(ptr: *mut u8),
129 marker: PhantomData<fn(Input) -> CoroutineResult<Yield, Return>>,
130 marker2: PhantomData<*mut ()>,
131}
132
133// Coroutines can be Sync if the stack is Sync.
134unsafe impl<Input, Yield, Return, Stack: stack::Stack + Sync> Sync
135 for Coroutine<Input, Yield, Return, Stack>
136{
137}
138
139#[cfg(feature = "default-stack")]
140impl<Input, Yield, Return> Coroutine<Input, Yield, Return, DefaultStack> {
141 /// Creates a new coroutine which will execute `func` on a new stack.
142 ///
143 /// This function returns a `Coroutine` which, when resumed, will execute
144 /// `func` to completion. When desired the `func` can suspend itself via
145 /// `Yielder::suspend`.
146 pub fn new<F>(f: F) -> Self
147 where
148 F: FnOnce(&Yielder<Input, Yield>, Input) -> Return,
149 F: 'static,
150 Input: 'static,
151 Yield: 'static,
152 Return: 'static,
153 {
154 Self::with_stack(Default::default(), f)
155 }
156}
157
158impl<Input, Yield, Return, Stack: stack::Stack> Coroutine<Input, Yield, Return, Stack> {
159 /// Creates a new coroutine which will execute `func` on the given stack.
160 ///
161 /// This function returns a coroutine which, when resumed, will execute
162 /// `func` to completion. When desired the `func` can suspend itself via
163 /// [`Yielder::suspend`].
164 pub fn with_stack<F>(stack: Stack, f: F) -> Self
165 where
166 F: FnOnce(&Yielder<Input, Yield>, Input) -> Return,
167 F: 'static,
168 Input: 'static,
169 Yield: 'static,
170 Return: 'static,
171 {
172 // The ABI of the initial function is either "C" or "C-unwind" depending
173 // on whether the "asm-unwind" feature is enabled.
174 initial_func_abi! {
175 unsafe fn coroutine_func<Input, Yield, Return, F>(
176 input: EncodedValue,
177 parent_link: &mut StackPointer,
178 func: *mut F,
179 ) -> !
180 where
181 F: FnOnce(&Yielder<Input, Yield>, Input) -> Return,
182 {
183 // The yielder is a #[repr(transparent)] wrapper around the
184 // parent link on the stack.
185 let yielder = &*(parent_link as *mut StackPointer as *const Yielder<Input, Yield>);
186
187 // Read the function from the stack.
188 debug_assert_eq!(func as usize % mem::align_of::<F>(), 0);
189 let f = func.read();
190
191 // This is the input from the first call to resume(). It is not
192 // possible for a forced unwind to reach this point because we
193 // check if a coroutine has been resumed at least once before
194 // generating a forced unwind.
195 let input : Result<Input, ForcedUnwindErr> = util::decode_val(input);
196 let input = match input {
197 Ok(input) => input,
198 #[cfg_attr(feature = "asm-unwind", allow(unreachable_patterns))]
199 Err(_) => unreachable_unchecked(),
200 };
201
202 // Run the body of the generator, catching any panics.
203 let result = unwind::catch_unwind_at_root(|| f(yielder, input));
204
205 // Return any caught panics to the parent context.
206 let mut result = ManuallyDrop::new(result);
207 arch::switch_and_reset(util::encode_val(&mut result), yielder.stack_ptr.as_ptr());
208 }
209 }
210
211 // Drop function to free the initial state of the coroutine.
212 unsafe fn drop_fn<T>(ptr: *mut u8) {
213 ptr::drop_in_place(ptr as *mut T);
214 }
215
216 unsafe {
217 // Set up the stack so that the coroutine starts executing
218 // coroutine_func. Write the given function object to the stack so
219 // its address is passed to coroutine_func on the first resume.
220 let stack_ptr = arch::init_stack(&stack, coroutine_func::<Input, Yield, Return, F>, f);
221
222 Self {
223 stack,
224 stack_ptr: Some(stack_ptr),
225 initial_stack_ptr: stack_ptr,
226 drop_fn: drop_fn::<F>,
227 marker: PhantomData,
228 marker2: PhantomData,
229 }
230 }
231 }
232
233 /// Resumes execution of this coroutine.
234 ///
235 /// This function will transfer execution to the coroutine and resume from
236 /// where it last left off.
237 ///
238 /// If the coroutine calls [`Yielder::suspend`] then this function returns
239 /// [`CoroutineResult::Yield`] with the value passed to `suspend`.
240 ///
241 /// If the coroutine returns then this function returns
242 /// [`CoroutineResult::Return`] with the return value of the coroutine.
243 ///
244 /// # Panics
245 ///
246 /// Panics if the coroutine has already finished executing.
247 ///
248 /// If the coroutine itself panics during execution then the panic will be
249 /// propagated to this caller.
250 pub fn resume(&mut self, val: Input) -> CoroutineResult<Yield, Return> {
251 unsafe {
252 let stack_ptr = self
253 .stack_ptr
254 .expect("attempt to resume a completed coroutine");
255
256 // If the coroutine terminated then a caught panic may have been
257 // returned, in which case we must resume unwinding.
258 match self.resume_inner(stack_ptr, Ok(val)) {
259 CoroutineResult::Yield(val) => CoroutineResult::Yield(val),
260 CoroutineResult::Return(result) => {
261 CoroutineResult::Return(unwind::maybe_resume_unwind(result))
262 }
263 }
264 }
265 }
266
267 /// Common code for resuming execution of a coroutine.
268 unsafe fn resume_inner(
269 &mut self,
270 stack_ptr: StackPointer,
271 input: Result<Input, ForcedUnwindErr>,
272 ) -> CoroutineResult<Yield, Result<Return, CaughtPanic>> {
273 // Pre-emptively set the stack pointer to None in case
274 // switch_and_link unwinds.
275 self.stack_ptr = None;
276
277 let mut input = ManuallyDrop::new(input);
278 let (result, stack_ptr) =
279 arch::switch_and_link(util::encode_val(&mut input), stack_ptr, self.stack.base());
280 self.stack_ptr = stack_ptr;
281
282 // Decode the returned value depending on whether the coroutine
283 // terminated.
284 if stack_ptr.is_some() {
285 CoroutineResult::Yield(util::decode_val(result))
286 } else {
287 CoroutineResult::Return(util::decode_val(result))
288 }
289 }
290
291 /// Returns whether this coroutine has been resumed at least once.
292 pub fn started(&self) -> bool {
293 self.stack_ptr != Some(self.initial_stack_ptr)
294 }
295
296 /// Returns whether this coroutine has finished executing.
297 ///
298 /// A coroutine that has returned from its initial function can no longer
299 /// be resumed.
300 pub fn done(&self) -> bool {
301 self.stack_ptr.is_none()
302 }
303
304 /// Forcibly marks the coroutine as having completed, even if it is
305 /// currently suspended in the middle of a function.
306 ///
307 /// # Safety
308 ///
309 /// This is equivalent to a `longjmp` all the way back to the initial
310 /// function of the coroutine, so the same rules apply.
311 ///
312 /// This can only be done safely if there are no objects currently on the
313 /// coroutine's stack that need to execute `Drop` code.
314 pub unsafe fn force_reset(&mut self) {
315 self.stack_ptr = None;
316 }
317
318 /// Unwinds the coroutine stack, dropping any live objects that are
319 /// currently on the stack. This is automatically called when the coroutine
320 /// is dropped.
321 ///
322 /// If the coroutine has already completed then this function is a no-op.
323 ///
324 /// If the coroutine is currently suspended on a `Yielder::suspend` call
325 /// then unwinding it requires the `unwind` feature to be enabled and
326 /// for the crate to be compiled with `-C panic=unwind`.
327 ///
328 /// # Panics
329 ///
330 /// This function panics if the coroutine could not be fully unwound. This
331 /// can happen for one of two reasons:
332 /// - The `ForcedUnwind` panic that is used internally was caught and not
333 /// rethrown.
334 /// - This crate was compiled without the `unwind` feature and the
335 /// coroutine is currently suspended in the yielder (`started && !done`).
336 pub fn force_unwind(&mut self) {
337 // If the coroutine has already terminated then there is nothing to do.
338 if let Some(stack_ptr) = self.stack_ptr {
339 self.force_unwind_slow(stack_ptr);
340 }
341 }
342
343 /// Slow path of `force_unwind` when the coroutine is known to not have
344 /// terminated yet.
345 #[cold]
346 fn force_unwind_slow(&mut self, stack_ptr: StackPointer) {
347 // If the coroutine has not started yet then we just need to drop the
348 // initial object.
349 if !self.started() {
350 unsafe {
351 arch::drop_initial_obj(self.stack.base(), stack_ptr, self.drop_fn);
352 }
353 self.stack_ptr = None;
354 return;
355 }
356
357 // If the coroutine is suspended then we need the standard library so
358 // that we can unwind the stack. This also requires that the code be
359 // compiled with -C panic=unwind.
360 #[cfg(feature = "unwind")]
361 {
362 extern crate std;
363
364 let forced_unwind = unwind::ForcedUnwind(stack_ptr);
365 let result = unwind::catch_forced_unwind(|| {
366 #[cfg(not(feature = "asm-unwind"))]
367 let result = unsafe { self.resume_inner(stack_ptr, Err(forced_unwind)) };
368 #[cfg(feature = "asm-unwind")]
369 let result = unsafe { self.resume_with_exception(stack_ptr, forced_unwind) };
370 match result {
371 CoroutineResult::Yield(_) | CoroutineResult::Return(Ok(_)) => Ok(()),
372 #[cfg_attr(feature = "asm-unwind", allow(unreachable_patterns))]
373 CoroutineResult::Return(Err(e)) => Err(e),
374 }
375 });
376
377 match result {
378 Ok(_) => panic!("the ForcedUnwind panic was caught and not rethrown"),
379 Err(e) => {
380 if let Some(forced_unwind) = e.downcast_ref::<unwind::ForcedUnwind>() {
381 if forced_unwind.0 == stack_ptr {
382 return;
383 }
384 }
385
386 std::panic::resume_unwind(e);
387 }
388 }
389 }
390
391 #[cfg(not(feature = "unwind"))]
392 panic!("can't unwind a suspended coroutine without the \"unwind\" feature");
393 }
394
395 /// Variant of `resume_inner` that throws an exception in the context of
396 /// the coroutine instead of passing a value.
397 ///
398 /// Used by `force_unwind`.
399 #[cfg(feature = "asm-unwind")]
400 unsafe fn resume_with_exception(
401 &mut self,
402 stack_ptr: StackPointer,
403 forced_unwind: unwind::ForcedUnwind,
404 ) -> CoroutineResult<Yield, Result<Return, CaughtPanic>> {
405 // Pre-emptively set the stack pointer to None in case
406 // switch_and_throw unwinds.
407 self.stack_ptr = None;
408
409 let (result, stack_ptr) =
410 arch::switch_and_throw(forced_unwind, stack_ptr, self.stack.base());
411 self.stack_ptr = stack_ptr;
412
413 // Decode the returned value depending on whether the coroutine
414 // terminated.
415 if stack_ptr.is_some() {
416 CoroutineResult::Yield(util::decode_val(result))
417 } else {
418 CoroutineResult::Return(util::decode_val(result))
419 }
420 }
421
422 /// Extracts the stack from a coroutine that has finished executing.
423 ///
424 /// This allows the stack to be re-used for another coroutine.
425 #[allow(unused_mut)]
426 pub fn into_stack(mut self) -> Stack {
427 assert!(
428 self.done(),
429 "cannot extract stack from an incomplete coroutine"
430 );
431
432 #[cfg(windows)]
433 unsafe {
434 arch::update_stack_teb_fields(&mut self.stack);
435 }
436
437 unsafe {
438 let stack = ptr::read(&self.stack);
439 mem::forget(self);
440 stack
441 }
442 }
443
444 /// Returns a [`CoroutineTrapHandler`] which can be used to handle traps that
445 /// occur inside the coroutine. Examples of traps that can be handled are
446 /// invalid memory accesses and stack overflows.
447 ///
448 /// The returned [`CoroutineTrapHandler`] can be used in a trap handler to
449 /// force the trapping coroutine to return with a specific value, after
450 /// which is it considered to have completed and can no longer be resumed.
451 ///
452 /// Needless to say, this is extremely unsafe and must be used with extreme
453 /// care. See [`CoroutineTrapHandler::setup_trap_handler`] for the exact
454 /// safety requirements.
455 pub fn trap_handler(&self) -> CoroutineTrapHandler<Return> {
456 CoroutineTrapHandler {
457 stack_base: self.stack.base(),
458 stack_limit: self.stack.limit(),
459 marker: PhantomData,
460 }
461 }
462}
463
464impl<Input, Yield, Return, Stack: stack::Stack> Drop for Coroutine<Input, Yield, Return, Stack> {
465 fn drop(&mut self) {
466 let guard = scopeguard::guard((), |()| {
467 // We can't catch panics in #![no_std], force an abort using
468 // a double-panic.
469 panic!("cannot propagte coroutine panic with #![no_std]");
470 });
471 self.force_unwind();
472 mem::forget(guard);
473
474 #[cfg(windows)]
475 unsafe {
476 arch::update_stack_teb_fields(&mut self.stack);
477 }
478 }
479}
480
481/// `Yielder` is an interface provided to a coroutine which allows it to suspend
482/// itself and pass values in and out of the coroutine.
483///
484/// Multiple references can be created to the same `Yielder`, but these cannot
485/// be moved to another thread.
486#[repr(transparent)]
487pub struct Yielder<Input, Yield> {
488 // Internally the Yielder is just the parent link on the stack which is
489 // updated every time resume() is called.
490 stack_ptr: Cell<StackPointer>,
491 marker: PhantomData<fn(Yield) -> Input>,
492}
493
494impl<Input, Yield> Yielder<Input, Yield> {
495 /// Suspends the execution of a currently running coroutine.
496 ///
497 /// This function will switch control back to the original caller of
498 /// [`Coroutine::resume`]. This function will then return once the
499 /// [`Coroutine::resume`] function is called again.
500 pub fn suspend(&self, val: Yield) -> Input {
501 unsafe {
502 let mut val = ManuallyDrop::new(val);
503 let result = arch::switch_yield(util::encode_val(&mut val), self.stack_ptr.as_ptr());
504 unwind::maybe_force_unwind(util::decode_val(result))
505 }
506 }
507
508 /// Executes some code on the stack of the parent context (the one who
509 /// last resumed the current coroutine).
510 ///
511 /// This is particularly useful when executing on a coroutine with limited
512 /// stack space: stack-heavy operations can be performed in a way that
513 /// avoids stack overflows on the coroutine stack.
514 ///
515 /// # Panics
516 ///
517 /// Any panics in the provided closure are automatically propagated back up
518 /// to the caller of this function.
519 pub fn on_parent_stack<F, R>(&self, f: F) -> R
520 where
521 F: FnOnce() -> R,
522 // The F: Send bound here is somewhat subtle but important. It exists to
523 // prevent references to the Yielder from being passed into the parent
524 // thread.
525 F: Send,
526 {
527 // Get the top of the parent stack.
528 let stack_ptr = unsafe {
529 StackPointer::new_unchecked(self.stack_ptr.get().get() - arch::PARENT_LINK_OFFSET)
530 };
531
532 // Create a virtual stack that starts below the parent stack.
533 let stack = unsafe { ParentStack::new(stack_ptr) };
534
535 on_stack(stack, f)
536 }
537}
538
539/// Executes some code on the given stack.
540///
541/// This is useful when running with limited stack space: stack-intensive
542/// computation can be executed on a separate stack with more space.
543///
544/// # Panics
545///
546/// Any panics in the provided closure are automatically propagated back up to
547/// the caller of this function.
548pub fn on_stack<F, R>(stack: impl stack::Stack, f: F) -> R
549where
550 F: FnOnce() -> R,
551{
552 // Union to hold both the function and its result.
553 union FuncOrResult<F, R> {
554 func: ManuallyDrop<F>,
555 result: ManuallyDrop<Result<R, CaughtPanic>>,
556 }
557
558 initial_func_abi! {
559 unsafe fn wrapper<F, R>(ptr: *mut u8)
560 where
561 F: FnOnce() -> R,
562 {
563 // Read the function out of the union.
564 let data = &mut *(ptr as *mut FuncOrResult<F, R>);
565 let func = ManuallyDrop::take(&mut data.func);
566
567 // Call it.
568 let result = unwind::catch_unwind_at_root(func);
569
570 // And write the result back to the union.
571 data.result = ManuallyDrop::new(result);
572 }
573 }
574
575 unsafe {
576 let mut data = FuncOrResult {
577 func: ManuallyDrop::new(f),
578 };
579
580 // Call the wrapper function on the new stack.
581 arch::on_stack(&mut data as *mut _ as *mut u8, stack, wrapper::<F, R>);
582
583 // Re-throw any panics if one was caught.
584 unwind::maybe_resume_unwind(ManuallyDrop::take(&mut data.result))
585 }
586}
587
588/// Custom stack implementation used by `on_parent_stack`. This is a private
589/// type because it is generally unsafe to use:
590struct ParentStack {
591 /// Base address of the stack, below any existing data on the parent stack.
592 stack_base: StackPointer,
593
594 /// Stack pointer value of the parent stack. This is not the same as
595 /// `stack_base` since the latter has been aligned to `STACK_ALIGNMENT`.
596 ///
597 /// This is needed on Windows to access the saved TEB fields on the parent
598 /// stack.
599 #[cfg(windows)]
600 stack_ptr: StackPointer,
601}
602
603impl ParentStack {
604 #[inline]
605 unsafe fn new(stack_ptr: StackPointer) -> Self {
606 let stack_base = StackPointer::new_unchecked(stack_ptr.get() & !(STACK_ALIGNMENT - 1));
607 Self {
608 stack_base,
609 #[cfg(windows)]
610 stack_ptr,
611 }
612 }
613}
614
615unsafe impl stack::Stack for ParentStack {
616 #[inline]
617 fn base(&self) -> StackPointer {
618 self.stack_base
619 }
620
621 // We can get away with a dummy implementation here because we never expose
622 // the coroutine type to the user. This is only used for creating a
623 // CoroutineTrapHandler.
624 #[inline]
625 fn limit(&self) -> StackPointer {
626 self.stack_base
627 }
628
629 #[inline]
630 #[cfg(windows)]
631 fn teb_fields(&self) -> StackTebFields {
632 unsafe { arch::read_parent_stack_teb_fields(self.stack_ptr) }
633 }
634
635 #[inline]
636 #[cfg(windows)]
637 fn update_teb_fields(&mut self, stack_limit: usize, guaranteed_stack_bytes: usize) {
638 unsafe {
639 arch::update_parent_stack_teb_fields(
640 self.stack_ptr,
641 stack_limit,
642 guaranteed_stack_bytes,
643 );
644 }
645 }
646}