1use std::collections::{BTreeSet, VecDeque};
2use std::num::NonZeroUsize;
3use std::{io, marker, mem, panic, thread};
4
5mod context_switch;
6mod stack;
7mod uring;
8
9use std::cell::UnsafeCell;
10
11thread_local! {
12 static RUNTIME: UnsafeCell<Option<RuntimeState>> = UnsafeCell::new(None);
13}
14
15fn ensure_runtime_exists() {
16 RUNTIME.with(|tls| {
17 let runtime = unsafe { &*tls.get() };
18 assert!(runtime.is_some());
19 })
20}
21
22unsafe fn runtime() -> &'static mut RuntimeState {
24 RUNTIME.with(|tls| {
25 let borrow = &mut *tls.get();
26 borrow.as_mut().unwrap() })
28}
29
30pub unsafe fn concurrency_pair() -> (Waker, Waiter) {
33 let running = runtime().running();
34 (Waker(running), Waiter(running))
35}
36
37#[derive(Debug)]
39pub struct Waker(FiberIndex);
40
41impl Waker {
42 pub unsafe fn schedule(self) {
45 runtime().ready_fibers.push_back(self.0);
46 }
47
48 pub unsafe fn schedule_immediately(self) {
51 runtime().ready_fibers.push_front(self.0);
52 }
53}
54
55#[derive(Debug)]
57pub struct Waiter(FiberIndex);
58
59impl Waiter {
60 pub unsafe fn park(self) {
63 let to = runtime().process_io_and_wait();
64 let to = runtime().fibers.get(to).continuation;
65 let continuation = &mut runtime().fibers.get_mut(self.0).continuation;
66 context_switch::jump(to, continuation); }
68}
69
70struct RuntimeState {
71 uring: uring::Uring,
72 fibers: Fibers,
73 ready_fibers: VecDeque<FiberIndex>,
74 running_fiber: Option<FiberIndex>,
75 stack_pool: Vec<*const u8>,
76 bootstrap: mem::MaybeUninit<context_switch::Continuation>,
77}
78
79impl RuntimeState {
80 fn new() -> Self {
81 RuntimeState {
82 uring: uring::Uring::new(),
83 fibers: Fibers::new(),
84 ready_fibers: VecDeque::new(),
85 running_fiber: None,
86 stack_pool: Vec::new(),
87 bootstrap: mem::MaybeUninit::uninit(),
88 }
89 }
90
91 fn allocate_stack(&mut self) -> *const u8 {
93 if let Some(stack_bottom) = self.stack_pool.pop() {
94 return stack_bottom;
95 }
96
97 let stack = stack::Stack::new(NonZeroUsize::MIN, NonZeroUsize::new(32).unwrap()).unwrap();
98 let stack_base = stack.base();
99 mem::forget(stack); stack_base
101 }
102
103 fn running(&self) -> FiberIndex {
105 self.running_fiber.unwrap() }
107
108 fn process_io(&mut self) {
109 for (user_data, result) in self.uring.process_cq() {
110 let fiber = FiberIndex(user_data.0 as u32);
111 self.fibers.get_mut(fiber).syscall_result = Some(result);
112 self.ready_fibers.push_back(fiber);
113 }
114 }
115
116 fn process_io_and_wait(&mut self) -> FiberIndex {
118 if let Some(fiber) = self.ready_fibers.pop_front() {
120 self.running_fiber = Some(fiber);
121 return fiber;
122 }
123
124 loop {
126 self.process_io();
127
128 if let Some(fiber) = self.ready_fibers.pop_front() {
129 self.running_fiber = Some(fiber);
130 break fiber;
131 }
132
133 self.uring.wait_for_completed_syscall();
134 }
135 }
136
137 fn cancel(&mut self, fiber: FiberIndex) {
138 let state = self.fibers.get_mut(fiber);
139 state.is_cancelled = true;
140
141 if state.issuing_syscall {
142 self.uring.cancel_syscall(uring::UserData(fiber.0 as u64));
143 }
144
145 let children = state.children.clone();
146 for child in children {
147 self.cancel(child);
148 }
149 }
150}
151
152impl Drop for RuntimeState {
153 fn drop(&mut self) {
154 let guard_pages = 1;
156 let usable_pages = 32;
157
158 let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) as usize };
159 assert_eq!(page_size, 4096);
160 let length = (guard_pages + usable_pages) * page_size;
161
162 for stack_bottom in self.stack_pool.drain(..) {
163 let pointer = unsafe { stack_bottom.sub(length) } as *mut u8;
164 drop(stack::Stack { pointer, length })
165 }
166 }
167}
168
169struct Fibers(slab::Slab<FiberState>);
170
171impl Fibers {
172 fn new() -> Self {
173 Fibers(slab::Slab::new())
174 }
175
176 fn get(&self, fiber: FiberIndex) -> &FiberState {
177 &self.0[fiber.0 as usize]
178 }
179
180 fn get_mut(&mut self, fiber: FiberIndex) -> &mut FiberState {
181 &mut self.0[fiber.0 as usize]
182 }
183
184 fn add(
185 &mut self,
186 parent: Option<FiberIndex>,
187 stack_base: *const u8,
188 continuation: context_switch::Continuation,
189 is_contained: bool,
190 ) -> FiberIndex {
191 let index = self.0.insert(FiberState {
192 stack_base,
193 continuation,
194 is_completed: false,
195 join_handle: JoinHandleState::Unused,
196 syscall_result: None,
197 parent,
198 children: BTreeSet::new(),
199 is_cancelled: false,
200 is_contained,
201 issuing_syscall: false,
202 });
203 FiberIndex(index as u32)
204 }
205
206 fn remove(&mut self, fiber: FiberIndex) {
207 self.0.remove(fiber.0 as usize);
208 }
209
210 fn nearest_contained_ancestor(&self, fiber: FiberIndex) -> FiberIndex {
211 let mut nearest_contained_ancestor = fiber;
212 while !self.get(nearest_contained_ancestor).is_contained {
213 nearest_contained_ancestor = self.get(nearest_contained_ancestor).parent.unwrap();
215 }
216 nearest_contained_ancestor
217 }
218}
219
220#[repr(transparent)]
223#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
224struct FiberIndex(u32);
225
226#[derive(Debug)]
227struct FiberState {
228 stack_base: *const u8, continuation: context_switch::Continuation,
230 is_completed: bool,
231 join_handle: JoinHandleState,
232 syscall_result: Option<io::Result<u32>>,
233 parent: Option<FiberIndex>,
234 children: BTreeSet<FiberIndex>, is_cancelled: bool,
236 is_contained: bool,
237 issuing_syscall: bool,
238}
239
240#[derive(Debug, Clone)]
241enum JoinHandleState {
242 Unused,
243 Waiting(FiberIndex), Dropped,
245}
246
247pub fn start<F: FnOnce() -> T, T>(f: F) -> thread::Result<T> {
250 unsafe {
251 exclusive_runtime(|| {
252 let stack_base = runtime().allocate_stack();
253
254 let closure_pointer = (stack_base as *mut F).sub(1);
255 closure_pointer.write(f);
256
257 let continuation = context_switch::prepare_stack(
258 stack_base.sub(closure_union_size::<F, T>()) as *mut u8,
259 start_trampoline::<F, T> as *const (),
260 );
261
262 let root_fiber = runtime().fibers.add(None, stack_base, continuation, true);
263 runtime().running_fiber = Some(root_fiber);
264
265 let bootstrap = runtime().bootstrap.as_mut_ptr();
266 context_switch::jump(continuation, bootstrap); let output_pointer = (stack_base as *const thread::Result<T>).sub(1);
269 output_pointer.read()
270 })
271 }
272}
273
274unsafe fn exclusive_runtime<T>(f: impl FnOnce() -> T) -> T {
275 RUNTIME.with(|tls| {
276 let runtime = &mut *tls.get();
277 assert!(runtime.is_none());
278 *runtime = Some(RuntimeState::new());
279 });
280
281 let output = f();
282
283 RUNTIME.with(|tls| {
284 let runtime = &mut *tls.get();
285 *runtime = None;
286 });
287
288 output
289}
290
291unsafe extern "C" fn start_trampoline<F: FnOnce() -> T, T>() -> ! {
292 let running = runtime().running();
293 let stack_base = runtime().fibers.get(running).stack_base;
294 let closure_pointer = (stack_base as *const F).sub(1);
295 let output_pointer = (stack_base as *mut thread::Result<T>).sub(1);
296
297 let closure = closure_pointer.read();
299 let result = panic::catch_unwind(panic::AssertUnwindSafe(|| (closure)()));
300 output_pointer.write(result);
301 runtime().fibers.get_mut(running).is_completed = true;
302
303 if !runtime().fibers.get(running).children.is_empty() {
305 let to = runtime().process_io_and_wait();
306 let to = runtime().fibers.get(to).continuation;
307 let continuation = &mut runtime().fibers.get_mut(running).continuation;
308 context_switch::jump(to, continuation); }
310
311 runtime().stack_pool.push(stack_base);
313 runtime().fibers.remove(running);
314
315 let to = runtime().bootstrap.assume_init();
317 let mut dummy = mem::MaybeUninit::uninit();
318 unsafe { context_switch::jump(to, dummy.as_mut_ptr()) };
319 unreachable!();
320}
321
322pub fn yield_now() {
323 ensure_runtime_exists();
324
325 unsafe {
326 runtime().process_io();
327
328 if runtime().ready_fibers.is_empty() {
329 return;
330 }
331
332 let (waker, waiter) = concurrency_pair();
333 waker.schedule();
334 waiter.park();
335 }
336}
337
338pub fn spawn<F: FnOnce() -> T + 'static, T: 'static>(f: F) -> JoinHandle<T> {
340 ensure_runtime_exists();
341
342 unsafe { spawn_inner(f, false) }
343}
344
345pub fn contain<F: FnOnce() -> T + 'static, T: 'static>(f: F) -> thread::Result<T> {
347 ensure_runtime_exists();
348
349 unsafe { spawn_inner(f, true) }.join()
350}
351
352unsafe fn spawn_inner<F: FnOnce() -> T + 'static, T: 'static>(
353 f: F,
354 is_contained: bool,
355) -> JoinHandle<T> {
356 let stack_base = runtime().allocate_stack();
357
358 let closure_pointer = (stack_base as *mut F).sub(1);
359 closure_pointer.write(f);
360
361 let continuation = context_switch::prepare_stack(
362 stack_base.sub(closure_union_size::<F, T>()) as *mut u8,
363 spawn_trampoline::<F, T> as *const (),
364 );
365
366 let parent = runtime().running();
367 let child = runtime()
368 .fibers
369 .add(Some(parent), stack_base, continuation, is_contained);
370 runtime().fibers.get_mut(parent).children.insert(child);
371 runtime().ready_fibers.push_back(child);
372
373 JoinHandle::new(child)
374}
375
376unsafe extern "C" fn spawn_trampoline<F: FnOnce() -> T, T>() -> ! {
377 let running = runtime().running();
378 let stack_base = runtime().fibers.get(running).stack_base;
379 let closure_pointer = (stack_base as *const F).sub(1);
380 let output_pointer = (stack_base as *mut thread::Result<T>).sub(1);
381
382 let closure = closure_pointer.read();
384 let result = panic::catch_unwind(panic::AssertUnwindSafe(|| (closure)()));
385 let is_err = result.is_err();
386 output_pointer.write(result);
387 runtime().fibers.get_mut(running).is_completed = true;
388
389 if is_err {
391 let nearest_contained_ancestor = runtime().fibers.nearest_contained_ancestor(running);
392 runtime().cancel(nearest_contained_ancestor);
393 }
394
395 if !runtime().fibers.get(running).children.is_empty() {
397 let to = runtime().process_io_and_wait();
398 let to = runtime().fibers.get(to).continuation;
399 let continuation = &mut runtime().fibers.get_mut(running).continuation;
400 context_switch::jump(to, continuation); }
402
403 if let JoinHandleState::Waiting(fiber) = runtime().fibers.get(running).join_handle {
405 runtime().ready_fibers.push_back(fiber);
406 }
407
408 let parent = runtime().fibers.get(running).parent.unwrap();
410 runtime().fibers.get_mut(parent).children.remove(&running);
411
412 if runtime().fibers.get(parent).is_completed && runtime().fibers.get(parent).children.is_empty()
413 {
414 runtime().ready_fibers.push_back(parent);
415 }
416
417 if let JoinHandleState::Dropped = runtime().fibers.get(running).join_handle {
419 let stack_base = runtime().fibers.get(running).stack_base;
420 runtime().stack_pool.push(stack_base);
421 }
422
423 let to = runtime().process_io_and_wait();
425 let to = runtime().fibers.get(to).continuation;
426 let mut dummy = mem::MaybeUninit::uninit();
427 unsafe { context_switch::jump(to, dummy.as_mut_ptr()) };
428 unreachable!();
429}
430
431#[derive(Debug)]
433pub struct JoinHandle<T> {
434 fiber: FiberIndex,
435 output: marker::PhantomData<T>,
436}
437
438impl<T> JoinHandle<T> {
439 fn new(fiber: FiberIndex) -> Self {
440 JoinHandle {
441 fiber,
442 output: marker::PhantomData,
443 }
444 }
445
446 pub fn join(self) -> thread::Result<T> {
448 ensure_runtime_exists();
449
450 unsafe {
451 let stack_base = runtime().fibers.get(self.fiber).stack_base;
452 let output_pointer = (stack_base as *const thread::Result<T>).sub(1);
453
454 if runtime().fibers.get(self.fiber).is_completed {
456 return output_pointer.read();
457 }
458
459 let running = runtime().running();
461
462 runtime().fibers.get_mut(self.fiber).join_handle = JoinHandleState::Waiting(running);
463
464 let to = runtime().process_io_and_wait();
465 let to = runtime().fibers.get(to).continuation;
466 let continuation = &mut runtime().fibers.get_mut(running).continuation;
467 context_switch::jump(to, continuation); assert!(runtime().fibers.get(self.fiber).is_completed);
470 output_pointer.read()
471 }
472 }
473
474 pub fn cancel(&self) {
476 ensure_runtime_exists();
477
478 unsafe {
479 runtime().cancel(self.fiber);
480 }
481 }
482}
483
484impl<T> Drop for JoinHandle<T> {
485 fn drop(&mut self) {
486 ensure_runtime_exists();
487
488 unsafe {
489 runtime().fibers.get_mut(self.fiber).join_handle = JoinHandleState::Dropped;
490
491 if runtime().fibers.get(self.fiber).is_completed {
493 let stack_base = runtime().fibers.get(self.fiber).stack_base;
494 runtime().stack_pool.push(stack_base);
495 runtime().fibers.remove(self.fiber);
496 }
497 }
498 }
499}
500
501pub(crate) fn syscall(sqe: io_uring::squeue::Entry) -> io::Result<u32> {
504 ensure_runtime_exists();
505
506 unsafe {
507 let running = runtime().running();
508
509 assert!(!runtime().fibers.get(running).issuing_syscall);
510 runtime().fibers.get_mut(running).issuing_syscall = true;
511
512 assert!(runtime().fibers.get(running).syscall_result.is_none());
513 runtime()
514 .uring
515 .issue_syscall(uring::UserData(running.0 as u64), sqe); let to = runtime().process_io_and_wait();
518
519 if running != to {
520 let to = runtime().fibers.get(to).continuation;
521 let continuation = &mut runtime().fibers.get_mut(running).continuation;
522 context_switch::jump(to, continuation); }
524
525 assert!(runtime().fibers.get(running).issuing_syscall);
526 runtime().fibers.get_mut(running).issuing_syscall = false;
527
528 runtime()
529 .fibers
530 .get_mut(running)
531 .syscall_result
532 .take()
533 .unwrap()
534 }
535}
536
537pub fn nop() -> io::Result<()> {
538 let result = syscall(io_uring::opcode::Nop::new().build())?;
539 assert_eq!(result, 0);
540 Ok(())
541}
542
543const fn closure_union_size<F: FnOnce() -> T, T>() -> usize {
544 let closure_size = mem::size_of::<F>();
545 let output_size = mem::size_of::<T>();
546
547 if closure_size > output_size {
548 closure_size
549 } else {
550 output_size
551 }
552}
553
554pub fn cancel() {
556 ensure_runtime_exists();
557
558 unsafe {
559 let running = runtime().running();
560 runtime().cancel(running);
561 }
562}
563
564pub fn is_cancelled() -> bool {
566 ensure_runtime_exists();
567
568 unsafe {
569 let running = runtime().running();
570 runtime().fibers.get(running).is_cancelled
571 }
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577 use crate::time;
578 use std::thread;
579 use std::time::{Duration, Instant};
580
581 mod start {
582 use super::*;
583 use std::time::Duration;
584
585 #[test]
586 fn returns_output() {
587 let output = start(|| 123);
588
589 assert_eq!(output.unwrap(), 123);
590 }
591
592 #[test]
593 fn catches_panic() {
594 let result = start(|| panic!("oops"));
595
596 assert!(result.is_err());
597 }
598
599 #[test]
600 #[should_panic]
601 fn cant_nest() {
602 start(|| {
603 start(|| {}).unwrap();
604 })
605 .unwrap();
606 }
607
608 #[test]
609 fn waits_for_children() {
610 static mut VALUE: usize = 0;
611
612 start(|| {
613 let handle = spawn(|| unsafe { VALUE += 1 });
614 drop(handle);
615
616 let handle = spawn(|| unsafe { VALUE += 1 });
617 mem::forget(handle);
618 })
619 .unwrap();
620
621 assert_eq!(unsafe { VALUE }, 2);
622 }
623
624 #[test]
625 fn works_consecutively() {
626 start(|| {}).unwrap();
627 start(|| {}).unwrap();
628 }
629
630 #[test]
631 fn works_in_parallel() {
632 let handle = thread::spawn(|| {
633 start(|| {
634 thread::sleep(Duration::from_millis(2));
635 })
636 .unwrap();
637 });
638
639 thread::sleep(Duration::from_millis(1));
640 start(|| {}).unwrap();
641
642 assert!(handle.join().is_ok());
643 }
644 }
645
646 mod contain {
647 use super::*;
648
649 #[test]
650 fn returns_output() {
651 start(|| {
652 let output = contain(|| 123);
653
654 assert_eq!(output.unwrap(), 123);
655 })
656 .unwrap();
657 }
658
659 #[test]
660 fn catches_panic() {
661 start(|| {
662 let result = contain(|| panic!("oops"));
663
664 assert!(result.is_err());
665 })
666 .unwrap();
667 }
668
669 #[test]
670 fn cant_nest_start() {
671 start(|| {
672 let result = contain(|| start(|| {}).unwrap());
673
674 assert!(result.is_err());
675 })
676 .unwrap();
677 }
678
679 #[test]
680 fn waits_for_children() {
681 start(|| {
682 static mut VALUE: usize = 0;
683
684 contain(|| {
685 let handle = spawn(|| unsafe { VALUE += 1 });
686 drop(handle);
687
688 let handle = spawn(|| unsafe { VALUE += 1 });
689 mem::forget(handle);
690 })
691 .unwrap();
692
693 assert_eq!(unsafe { VALUE }, 2);
694 })
695 .unwrap();
696 }
697 }
698
699 mod spawn {
700 use super::*;
701
702 #[test]
703 fn returns_child_output() {
704 start(|| {
705 let handle = spawn(|| 123);
706
707 let output = handle.join();
708
709 assert_eq!(output.unwrap(), 123);
710 })
711 .unwrap();
712 }
713
714 #[test]
715 fn returns_non_child_output() {
716 start(|| {
717 let other = spawn(|| 123);
718 let handle = spawn(|| other.join().unwrap());
719
720 let output = handle.join();
721
722 assert_eq!(output.unwrap(), 123);
723 })
724 .unwrap();
725 }
726
727 #[test]
728 fn returns_already_completed_output() {
729 start(|| {
730 let handle = spawn(|| 123);
731
732 yield_now();
733 let output = handle.join();
734
735 assert_eq!(output.unwrap(), 123);
736 })
737 .unwrap();
738 }
739
740 #[test]
741 fn catches_panic() {
742 start(|| {
743 let result = spawn(|| panic!("oops")).join();
744
745 assert!(result.is_err());
746 })
747 .unwrap();
748 }
749
750 #[test]
751 fn cant_nest_start() {
752 start(|| {
753 let result = spawn(|| start(|| {})).join();
754
755 assert!(result.is_err());
756 })
757 .unwrap();
758 }
759
760 #[test]
761 fn waits_for_children() {
762 start(|| {
763 static mut VALUE: usize = 0;
764
765 spawn(|| {
766 let handle = spawn(|| unsafe { VALUE += 1 });
767 drop(handle);
768
769 let handle = spawn(|| unsafe { VALUE += 1 });
770 mem::forget(handle);
771 })
772 .join()
773 .unwrap();
774
775 assert_eq!(unsafe { VALUE }, 2);
776 })
777 .unwrap();
778 }
779 }
780
781 mod yield_now {
782 use super::*;
783
784 #[test]
785 fn to_same_fiber() {
786 start(|| {
787 yield_now();
788 })
789 .unwrap();
790 }
791
792 #[test]
793 fn to_other_fiber() {
794 start(|| {
795 static mut VALUE: usize = 0;
796
797 spawn(|| unsafe { VALUE += 1 });
798 assert_eq!(unsafe { VALUE }, 0);
799
800 yield_now();
801
802 assert_eq!(unsafe { VALUE }, 1);
803 })
804 .unwrap();
805 }
806 }
807
808 mod cancellation {
809 use super::*;
810
811 #[test]
812 fn initially_not_cancelled() {
813 start(|| {
814 assert!(!is_cancelled());
815 })
816 .unwrap();
817 }
818
819 #[test]
820 fn function_cancels_fiber_hierarchy() {
821 start(|| {
822 contain(|| {
823 let handle = spawn(|| {
825 assert!(is_cancelled());
826 });
827
828 cancel();
829 handle.join().unwrap();
830 assert!(is_cancelled());
832 })
833 .unwrap();
834
835 assert!(!is_cancelled());
837 })
838 .unwrap();
839 }
840
841 #[test]
842 fn panic_cancels_fiber_hierarchy() {
843 static mut GRAND_CHILD_CANCELLED: bool = false;
844
845 start(|| {
846 let _ = contain(|| {
847 let handle = spawn(|| {
848 spawn(|| unsafe {
850 dbg!(is_cancelled());
851 GRAND_CHILD_CANCELLED = is_cancelled();
852 });
853
854 panic!("oops");
855 });
856
857 handle.join().unwrap();
858 assert!(is_cancelled());
859 });
860
861 assert!(unsafe { GRAND_CHILD_CANCELLED });
862
863 assert!(!is_cancelled());
865 })
866 .unwrap();
867 }
868
869 #[test]
870 fn method_cancels_fiber_hierarchy() {
871 start(|| {
872 let handle = spawn(|| {
873 spawn(|| {
875 assert!(is_cancelled());
876 });
877
878 yield_now();
879 assert!(is_cancelled());
880 });
881
882 yield_now();
883 handle.cancel();
884 handle.join().unwrap();
885
886 assert!(!is_cancelled());
891 })
892 .unwrap();
893 }
894 }
895
896 }