coroutine_state/
lib.rs

1//! Inspect the state of a Rust [`Future`] created by an `async` function.
2#![no_std]
3#![warn(clippy::pedantic)]
4
5use core::{
6    future::Future,
7    hash::{Hash, Hasher},
8    mem::Discriminant,
9};
10
11struct DiscriminantExtractor {
12    discriminant: Result<Option<u32>, ()>,
13}
14
15impl Hasher for DiscriminantExtractor {
16    fn finish(&self) -> u64 {
17        unreachable!()
18    }
19
20    fn write(&mut self, _bytes: &[u8]) {
21        self.discriminant = Err(());
22    }
23
24    fn write_u32(&mut self, i: u32) {
25        if self.discriminant == Ok(None) {
26            self.discriminant = Ok(Some(i));
27        } else {
28            self.discriminant = Err(());
29        }
30    }
31}
32
33fn discriminant_value<T>(discriminant: Discriminant<T>) -> Option<u32> {
34    let mut hasher = DiscriminantExtractor {
35        discriminant: Ok(None),
36    };
37    discriminant.hash(&mut hasher);
38    hasher.discriminant.ok()?
39}
40
41/// State of a coroutine.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum CoroutineState {
44    /// The coroutine has not started running.
45    Unresumed,
46    /// The coroutine has finished running.
47    Returned,
48    /// The coroutine has panicked.
49    Panicked,
50    /// The coroutine is suspended at a specific suspension point.
51    ///
52    /// The suspension points are consecutively numbered starting from 0.
53    Suspend(u32),
54}
55
56/// Get the current coroutine state of a [`Future`].
57///
58/// ```
59/// use coroutine_state::{coroutine_state, CoroutineState};
60///
61/// async fn f() {}
62///
63/// let fut = f();
64/// assert_eq!(coroutine_state(&fut), CoroutineState::Unresumed);
65/// ```
66///
67/// # Panics
68///
69/// You should only pass in a reference to a [`Future`] returned by an `async`
70/// function. If you pass in a reference to another type of [`Future`],
71/// this function may either panic or return an arbitrary value.
72/// However the implemenation uses no `unsafe` code so this can never cause
73/// undefined behavior.
74///
75/// You should be careful not to pass a [`& Pin<&mut Fut>`](core::pin::Pin)
76/// instead of a reference to the inner `Fut`:
77/// ```
78/// use std::pin::pin;
79/// use coroutine_state::{coroutine_state, CoroutineState};
80///
81/// async fn f() {}
82///
83/// let fut = f();
84/// let fut = pin!(fut);
85///
86/// // This might panic or return an arbitrary state:
87/// // println!("{:?}", coroutine_state(&fut));
88///
89/// // Instead use:
90/// println!("{:?}", coroutine_state(&*fut));
91/// ```
92///
93/// Example of where an arbitrary value is returned:
94/// ```
95/// use std::future::Future;
96/// use std::pin::{pin, Pin};
97/// use std::task::{Context, Poll};
98/// use coroutine_state::{coroutine_state, CoroutineState};
99///
100/// #[repr(u32)]
101/// enum Foo {
102///     Foo,
103/// }
104///
105/// impl Future for Foo {
106///     type Output = ();
107///
108///
109///     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
110///         Poll::Ready(())
111///     }
112/// }
113///
114/// let fut = Foo::Foo;
115/// println!("{:?}", coroutine_state(&fut));
116/// ```
117pub fn coroutine_state<Fut: Future>(fut: &Fut) -> CoroutineState {
118    let Some(value) = discriminant_value(core::mem::discriminant(fut)) else {
119        panic!(
120            "couln't get coroutine state of Future with type {}",
121            core::any::type_name::<Fut>()
122        );
123    };
124    match value {
125        0 => CoroutineState::Unresumed,
126        1 => CoroutineState::Returned,
127        2 => CoroutineState::Panicked,
128        _ => CoroutineState::Suspend(value - 3),
129    }
130}
131
132#[cfg(test)]
133mod test {
134    use core::{
135        panic::AssertUnwindSafe,
136        pin::{pin, Pin},
137        task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
138    };
139
140    use super::*;
141
142    extern crate std;
143    use std::{prelude::rust_2021::*, vec};
144
145    struct PendingOnce {
146        is_ready: bool,
147    }
148
149    impl Future for PendingOnce {
150        type Output = ();
151
152        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153            if self.is_ready {
154                Poll::Ready(())
155            } else {
156                self.is_ready = true;
157                cx.waker().wake_by_ref();
158                Poll::Pending
159            }
160        }
161    }
162
163    fn async_yield() -> impl Future<Output = ()> {
164        PendingOnce { is_ready: false }
165    }
166
167    const NOOP_RAW_WAKER: RawWaker = {
168        const VTABLE: RawWakerVTable =
169            RawWakerVTable::new(|_| NOOP_RAW_WAKER, |_| {}, |_| {}, |_| {});
170        RawWaker::new(core::ptr::null(), &VTABLE)
171    };
172
173    static NOOP_WAKER: Waker = unsafe { Waker::from_raw(NOOP_RAW_WAKER) };
174    const NOOP_CONTEXT: Context = Context::from_waker(&NOOP_WAKER);
175
176    fn coroutine_states_inner<Fut: Future>(
177        mut fut: Pin<&mut Fut>,
178        states: &mut Vec<CoroutineState>,
179    ) {
180        states.push(coroutine_state(&*fut));
181
182        let mut cx = NOOP_CONTEXT;
183        loop {
184            let is_ready = fut.as_mut().poll(&mut cx).is_ready();
185            states.push(coroutine_state(&*fut));
186            if is_ready {
187                break;
188            }
189        }
190    }
191
192    fn coroutine_states<Fut: Future>(fut: Fut) -> Vec<CoroutineState> {
193        let mut states = vec![];
194        let fut = pin!(fut);
195        coroutine_states_inner(fut, &mut states);
196        states
197    }
198
199    #[test]
200    fn coroutine_state_empty() {
201        #[allow(clippy::unused_async)]
202        async fn f() {}
203
204        assert_eq!(
205            coroutine_states(f()),
206            vec![CoroutineState::Unresumed, CoroutineState::Returned]
207        );
208    }
209
210    #[test]
211    fn coroutine_state_one_yield() {
212        async fn f() {
213            async_yield().await;
214        }
215
216        assert_eq!(
217            coroutine_states(f()),
218            vec![
219                CoroutineState::Unresumed,
220                CoroutineState::Suspend(0),
221                CoroutineState::Returned,
222            ]
223        );
224    }
225
226    #[test]
227    fn coroutine_state_two_yields() {
228        async fn f() {
229            async_yield().await;
230            async_yield().await;
231        }
232
233        assert_eq!(
234            coroutine_states(f()),
235            vec![
236                CoroutineState::Unresumed,
237                CoroutineState::Suspend(0),
238                CoroutineState::Suspend(1),
239                CoroutineState::Returned,
240            ]
241        );
242    }
243
244    #[test]
245    fn coroutine_state_panic() {
246        #[allow(clippy::unused_async)]
247        async fn f() {
248            panic!();
249        }
250
251        let mut states = vec![];
252        let fut = f();
253        let mut fut = pin!(fut);
254
255        let r = {
256            let fut = fut.as_mut();
257            std::panic::catch_unwind(AssertUnwindSafe(|| {
258                coroutine_states_inner(fut, &mut states);
259            }))
260        };
261
262        assert!(r.is_err());
263        assert_eq!(states, vec![CoroutineState::Unresumed]);
264
265        assert_eq!(coroutine_state(&*fut), CoroutineState::Panicked);
266    }
267}