threader/thread_pool/task/
raw.rs

1use super::ExecutorFuture;
2use crate::thread_pool::Shared;
3use crossbeam::utils::Backoff;
4use std::{
5    cell::UnsafeCell,
6    mem,
7    pin::Pin,
8    ptr,
9    sync::{
10        atomic::{AtomicUsize, Ordering},
11        Arc, Weak,
12    },
13    task::{Context, Poll},
14};
15
16#[derive(Debug)]
17pub(super) struct RawTask {
18    ptr: *const Header,
19}
20
21impl RawTask {
22    pub(super) fn new<F: ExecutorFuture>(future: F, shared: Weak<Shared>) -> Self {
23        let inner = Inner::new(future, shared);
24        let header = inner.header();
25        mem::forget(inner);
26
27        Self { ptr: header }
28    }
29
30    pub(super) fn poll(&self, cx: &mut Context) {
31        unsafe {
32            // While this technically does block, in most cases
33            // it's only for a very short amount of time, so it
34            // should be fine to lock here. This is required to
35            // make the call to poll below defined behavior.
36            lock(self.ptr);
37
38            let vtable_poll = (*(self.ptr)).vtable.poll;
39            vtable_poll(self.ptr, cx);
40        }
41    }
42
43    pub(super) fn ptr(&self) -> *const Header {
44        self.ptr
45    }
46
47    pub(super) unsafe fn from_header(header: *const Header) -> Self {
48        Self { ptr: header }
49    }
50}
51
52impl Clone for RawTask {
53    fn clone(&self) -> Self {
54        unsafe {
55            let vtable = (*(self.ptr)).vtable;
56            (vtable.inc_refcount)(self.ptr);
57        }
58
59        RawTask { ptr: self.ptr }
60    }
61}
62
63impl Drop for RawTask {
64    fn drop(&mut self) {
65        unsafe {
66            let vtable = (*(self.ptr)).vtable;
67            (vtable.drop)(self.ptr);
68        }
69    }
70}
71
72unsafe impl Send for RawTask {}
73unsafe impl Sync for RawTask {}
74
75#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq)]
76pub(crate) struct InvalidGuard;
77
78#[repr(C)]
79struct Inner<F>
80where
81    F: ExecutorFuture,
82{
83    header: Header,
84    future: UnsafeCell<F>,
85}
86
87impl<F> Inner<F>
88where
89    F: ExecutorFuture,
90{
91    fn new(future: F, shared: Weak<Shared>) -> Arc<Self> {
92        let header = Header {
93            shared,
94            state: AtomicUsize::new(UNLOCKED),
95            vtable: VTable::new::<F>(),
96        };
97
98        Arc::new(Self {
99            header,
100            future: UnsafeCell::new(future),
101        })
102    }
103
104    fn header(&self) -> *const Header {
105        &self.header
106    }
107}
108
109pub(super) struct Header {
110    pub(super) shared: Weak<Shared>,
111    pub(super) state: AtomicUsize,
112    pub(super) vtable: &'static VTable,
113}
114
115pub(super) struct VTable {
116    pub(super) inc_refcount: unsafe fn(*const Header),
117    poll: unsafe fn(*const Header, &mut Context),
118    drop: unsafe fn(*const Header),
119}
120
121impl VTable {
122    fn new<F>() -> &'static VTable
123    where
124        F: ExecutorFuture,
125    {
126        &VTable {
127            inc_refcount: inc_refcount::<F>,
128            poll: poll::<F>,
129            drop: drop_raw::<F>,
130        }
131    }
132}
133
134const UNLOCKED: usize = 0;
135const LOCKED: usize = 1;
136const COMPLETE: usize = 2;
137
138// before calling this function, you *must* ensure
139// that you have unique access to the contained
140// future. This is usually done with the lock function.
141unsafe fn poll<F>(ptr: *const Header, cx: &mut Context)
142where
143    F: ExecutorFuture,
144{
145    debug_assert!(!ptr.is_null());
146    debug_assert!((*ptr).state.load(Ordering::SeqCst) == LOCKED);
147
148    if (*ptr).state.load(Ordering::Acquire) != COMPLETE {
149        let inner = &*(ptr as *const Inner<F>);
150        let future = Pin::new_unchecked(&mut *inner.future.get());
151
152        match future.poll(cx) {
153            Poll::Ready(()) => set_complete(ptr),
154            _ => unlock(ptr),
155        }
156    }
157}
158
159unsafe fn inc_refcount<F>(ptr: *const Header)
160where
161    F: ExecutorFuture,
162{
163    let arc = Arc::from_raw(ptr as *const Inner<F>);
164    let cloned = Arc::clone(&arc);
165    mem::forget(arc);
166    mem::forget(cloned);
167}
168
169unsafe fn drop_raw<F>(ptr: *const Header)
170where
171    F: ExecutorFuture,
172{
173    let arc = Arc::from_raw(ptr as *const Inner<F>);
174    drop(arc);
175}
176
177unsafe fn lock(ptr: *const Header) {
178    let backoff = Backoff::new();
179    while (*ptr)
180        .state
181        .compare_and_swap(UNLOCKED, LOCKED, Ordering::AcqRel)
182        != UNLOCKED
183    {
184        backoff.snooze();
185    }
186}
187
188unsafe fn unlock(ptr: *const Header) {
189    let old = (*ptr).state.swap(UNLOCKED, Ordering::Release);
190    debug_assert!(old == LOCKED);
191}
192
193unsafe fn set_complete(ptr: *const Header) {
194    let old = (*ptr).state.swap(COMPLETE, Ordering::Release);
195    debug_assert!(old == LOCKED);
196}