1#![no_std]
3#![warn(clippy::pedantic)]
4
5use core::{
6 future::Future,
7 hash::{Hash, Hasher},
8 mem::Discriminant,
9};
10
11struct DiscriminantExtractor {
12 discriminant: Result<Option<u32>, ()>,
13}
14
15impl Hasher for DiscriminantExtractor {
16 fn finish(&self) -> u64 {
17 unreachable!()
18 }
19
20 fn write(&mut self, _bytes: &[u8]) {
21 self.discriminant = Err(());
22 }
23
24 fn write_u32(&mut self, i: u32) {
25 if self.discriminant == Ok(None) {
26 self.discriminant = Ok(Some(i));
27 } else {
28 self.discriminant = Err(());
29 }
30 }
31}
32
33fn discriminant_value<T>(discriminant: Discriminant<T>) -> Option<u32> {
34 let mut hasher = DiscriminantExtractor {
35 discriminant: Ok(None),
36 };
37 discriminant.hash(&mut hasher);
38 hasher.discriminant.ok()?
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum CoroutineState {
44 Unresumed,
46 Returned,
48 Panicked,
50 Suspend(u32),
54}
55
56pub fn coroutine_state<Fut: Future>(fut: &Fut) -> CoroutineState {
118 let Some(value) = discriminant_value(core::mem::discriminant(fut)) else {
119 panic!(
120 "couln't get coroutine state of Future with type {}",
121 core::any::type_name::<Fut>()
122 );
123 };
124 match value {
125 0 => CoroutineState::Unresumed,
126 1 => CoroutineState::Returned,
127 2 => CoroutineState::Panicked,
128 _ => CoroutineState::Suspend(value - 3),
129 }
130}
131
132#[cfg(test)]
133mod test {
134 use core::{
135 panic::AssertUnwindSafe,
136 pin::{pin, Pin},
137 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
138 };
139
140 use super::*;
141
142 extern crate std;
143 use std::{prelude::rust_2021::*, vec};
144
145 struct PendingOnce {
146 is_ready: bool,
147 }
148
149 impl Future for PendingOnce {
150 type Output = ();
151
152 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153 if self.is_ready {
154 Poll::Ready(())
155 } else {
156 self.is_ready = true;
157 cx.waker().wake_by_ref();
158 Poll::Pending
159 }
160 }
161 }
162
163 fn async_yield() -> impl Future<Output = ()> {
164 PendingOnce { is_ready: false }
165 }
166
167 const NOOP_RAW_WAKER: RawWaker = {
168 const VTABLE: RawWakerVTable =
169 RawWakerVTable::new(|_| NOOP_RAW_WAKER, |_| {}, |_| {}, |_| {});
170 RawWaker::new(core::ptr::null(), &VTABLE)
171 };
172
173 static NOOP_WAKER: Waker = unsafe { Waker::from_raw(NOOP_RAW_WAKER) };
174 const NOOP_CONTEXT: Context = Context::from_waker(&NOOP_WAKER);
175
176 fn coroutine_states_inner<Fut: Future>(
177 mut fut: Pin<&mut Fut>,
178 states: &mut Vec<CoroutineState>,
179 ) {
180 states.push(coroutine_state(&*fut));
181
182 let mut cx = NOOP_CONTEXT;
183 loop {
184 let is_ready = fut.as_mut().poll(&mut cx).is_ready();
185 states.push(coroutine_state(&*fut));
186 if is_ready {
187 break;
188 }
189 }
190 }
191
192 fn coroutine_states<Fut: Future>(fut: Fut) -> Vec<CoroutineState> {
193 let mut states = vec![];
194 let fut = pin!(fut);
195 coroutine_states_inner(fut, &mut states);
196 states
197 }
198
199 #[test]
200 fn coroutine_state_empty() {
201 #[allow(clippy::unused_async)]
202 async fn f() {}
203
204 assert_eq!(
205 coroutine_states(f()),
206 vec![CoroutineState::Unresumed, CoroutineState::Returned]
207 );
208 }
209
210 #[test]
211 fn coroutine_state_one_yield() {
212 async fn f() {
213 async_yield().await;
214 }
215
216 assert_eq!(
217 coroutine_states(f()),
218 vec![
219 CoroutineState::Unresumed,
220 CoroutineState::Suspend(0),
221 CoroutineState::Returned,
222 ]
223 );
224 }
225
226 #[test]
227 fn coroutine_state_two_yields() {
228 async fn f() {
229 async_yield().await;
230 async_yield().await;
231 }
232
233 assert_eq!(
234 coroutine_states(f()),
235 vec![
236 CoroutineState::Unresumed,
237 CoroutineState::Suspend(0),
238 CoroutineState::Suspend(1),
239 CoroutineState::Returned,
240 ]
241 );
242 }
243
244 #[test]
245 fn coroutine_state_panic() {
246 #[allow(clippy::unused_async)]
247 async fn f() {
248 panic!();
249 }
250
251 let mut states = vec![];
252 let fut = f();
253 let mut fut = pin!(fut);
254
255 let r = {
256 let fut = fut.as_mut();
257 std::panic::catch_unwind(AssertUnwindSafe(|| {
258 coroutine_states_inner(fut, &mut states);
259 }))
260 };
261
262 assert!(r.is_err());
263 assert_eq!(states, vec![CoroutineState::Unresumed]);
264
265 assert_eq!(coroutine_state(&*fut), CoroutineState::Panicked);
266 }
267}