1use std::any::Any;
7use std::cell::Cell;
8use std::io;
9use std::marker::PhantomData;
10use std::ops::Range;
11use std::panic::{self, AssertUnwindSafe};
12
13cfg_if::cfg_if! {
14 if #[cfg(windows)] {
15 mod windows;
16 use windows as imp;
17 } else if #[cfg(unix)] {
18 mod unix;
19 use unix as imp;
20 } else {
21 compile_error!("fibers are not supported on this platform");
22 }
23}
24
25#[derive(Debug)]
27pub struct FiberStack(imp::FiberStack);
28
29impl FiberStack {
30 pub fn new(size: usize) -> io::Result<Self> {
32 Ok(Self(imp::FiberStack::new(size)?))
33 }
34
35 pub unsafe fn from_raw_parts(bottom: *mut u8, len: usize) -> io::Result<Self> {
48 Ok(Self(imp::FiberStack::from_raw_parts(bottom, len)?))
49 }
50
51 pub fn top(&self) -> Option<*mut u8> {
56 self.0.top()
57 }
58
59 pub fn range(&self) -> Option<Range<usize>> {
62 self.0.range()
63 }
64}
65
66pub struct Fiber<'a, Resume, Yield, Return> {
67 stack: FiberStack,
68 inner: imp::Fiber,
69 done: Cell<bool>,
70 _phantom: PhantomData<&'a (Resume, Yield, Return)>,
71}
72
73pub struct Suspend<Resume, Yield, Return> {
74 inner: imp::Suspend,
75 _phantom: PhantomData<(Resume, Yield, Return)>,
76}
77
78enum RunResult<Resume, Yield, Return> {
79 Executing,
80 Resuming(Resume),
81 Yield(Yield),
82 Returned(Return),
83 Panicked(Box<dyn Any + Send>),
84}
85
86impl<'a, Resume, Yield, Return> Fiber<'a, Resume, Yield, Return> {
87 pub fn new(
93 stack: FiberStack,
94 func: impl FnOnce(Resume, &Suspend<Resume, Yield, Return>) -> Return + 'a,
95 ) -> io::Result<Self> {
96 let inner = imp::Fiber::new(&stack.0, func)?;
97
98 Ok(Self {
99 stack,
100 inner,
101 done: Cell::new(false),
102 _phantom: PhantomData,
103 })
104 }
105
106 pub fn resume(&self, val: Resume) -> Result<Return, Yield> {
122 assert!(!self.done.replace(true), "cannot resume a finished fiber");
123 let result = Cell::new(RunResult::Resuming(val));
124 self.inner.resume(&self.stack.0, &result);
125 match result.into_inner() {
126 RunResult::Resuming(_) | RunResult::Executing => unreachable!(),
127 RunResult::Yield(y) => {
128 self.done.set(false);
129 Err(y)
130 }
131 RunResult::Returned(r) => Ok(r),
132 RunResult::Panicked(payload) => std::panic::resume_unwind(payload),
133 }
134 }
135
136 pub fn done(&self) -> bool {
138 self.done.get()
139 }
140
141 pub fn stack(&self) -> &FiberStack {
143 &self.stack
144 }
145}
146
147impl<Resume, Yield, Return> Suspend<Resume, Yield, Return> {
148 pub fn suspend(&self, value: Yield) -> Resume {
158 self.inner
159 .switch::<Resume, Yield, Return>(RunResult::Yield(value))
160 }
161
162 fn execute(
163 inner: imp::Suspend,
164 initial: Resume,
165 func: impl FnOnce(Resume, &Suspend<Resume, Yield, Return>) -> Return,
166 ) {
167 let suspend = Suspend {
168 inner,
169 _phantom: PhantomData,
170 };
171 let result = panic::catch_unwind(AssertUnwindSafe(|| (func)(initial, &suspend)));
172 suspend.inner.switch::<Resume, Yield, Return>(match result {
173 Ok(result) => RunResult::Returned(result),
174 Err(panic) => RunResult::Panicked(panic),
175 });
176 }
177}
178
179impl<A, B, C> Drop for Fiber<'_, A, B, C> {
180 fn drop(&mut self) {
181 debug_assert!(self.done.get(), "fiber dropped without finishing");
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::{Fiber, FiberStack};
188 use std::cell::Cell;
189 use std::panic::{self, AssertUnwindSafe};
190 use std::rc::Rc;
191
192 #[test]
193 fn small_stacks() {
194 Fiber::<(), (), ()>::new(FiberStack::new(0).unwrap(), |_, _| {})
195 .unwrap()
196 .resume(())
197 .unwrap();
198 Fiber::<(), (), ()>::new(FiberStack::new(1).unwrap(), |_, _| {})
199 .unwrap()
200 .resume(())
201 .unwrap();
202 }
203
204 #[test]
205 fn smoke() {
206 let hit = Rc::new(Cell::new(false));
207 let hit2 = hit.clone();
208 let fiber = Fiber::<(), (), ()>::new(FiberStack::new(1024 * 1024).unwrap(), move |_, _| {
209 hit2.set(true);
210 })
211 .unwrap();
212 assert!(!hit.get());
213 fiber.resume(()).unwrap();
214 assert!(hit.get());
215 }
216
217 #[test]
218 fn suspend_and_resume() {
219 let hit = Rc::new(Cell::new(false));
220 let hit2 = hit.clone();
221 let fiber = Fiber::<(), (), ()>::new(FiberStack::new(1024 * 1024).unwrap(), move |_, s| {
222 s.suspend(());
223 hit2.set(true);
224 s.suspend(());
225 })
226 .unwrap();
227 assert!(!hit.get());
228 assert!(fiber.resume(()).is_err());
229 assert!(!hit.get());
230 assert!(fiber.resume(()).is_err());
231 assert!(hit.get());
232 assert!(fiber.resume(()).is_ok());
233 assert!(hit.get());
234 }
235
236 #[test]
237 fn backtrace_traces_to_host() {
238 #[inline(never)] fn look_for_me() {
240 run_test();
241 }
242 fn assert_contains_host() {
243 let trace = backtrace::Backtrace::new();
244 println!("{:?}", trace);
245 assert!(
246 trace
247 .frames()
248 .iter()
249 .flat_map(|f| f.symbols())
250 .filter_map(|s| Some(s.name()?.to_string()))
251 .any(|s| s.contains("look_for_me"))
252 || cfg!(windows)
254 || cfg!(all(target_os = "macos", target_arch = "aarch64"))
256 );
257 }
258
259 fn run_test() {
260 let fiber =
261 Fiber::<(), (), ()>::new(FiberStack::new(1024 * 1024).unwrap(), move |(), s| {
262 assert_contains_host();
263 s.suspend(());
264 assert_contains_host();
265 s.suspend(());
266 assert_contains_host();
267 })
268 .unwrap();
269 assert!(fiber.resume(()).is_err());
270 assert!(fiber.resume(()).is_err());
271 assert!(fiber.resume(()).is_ok());
272 }
273
274 look_for_me();
275 }
276
277 #[test]
278 fn panics_propagated() {
279 let a = Rc::new(Cell::new(false));
280 let b = SetOnDrop(a.clone());
281 let fiber =
282 Fiber::<(), (), ()>::new(FiberStack::new(1024 * 1024).unwrap(), move |(), _s| {
283 let _ = &b;
284 panic!();
285 })
286 .unwrap();
287 assert!(panic::catch_unwind(AssertUnwindSafe(|| fiber.resume(()))).is_err());
288 assert!(a.get());
289
290 struct SetOnDrop(Rc<Cell<bool>>);
291
292 impl Drop for SetOnDrop {
293 fn drop(&mut self) {
294 self.0.set(true);
295 }
296 }
297 }
298
299 #[test]
300 fn suspend_and_resume_values() {
301 let fiber = Fiber::new(FiberStack::new(1024 * 1024).unwrap(), move |first, s| {
302 assert_eq!(first, 2.0);
303 assert_eq!(s.suspend(4), 3.0);
304 "hello".to_string()
305 })
306 .unwrap();
307 assert_eq!(fiber.resume(2.0), Err(4));
308 assert_eq!(fiber.resume(3.0), Ok("hello".to_string()));
309 }
310}