hiver_runtime/task/
raw_task.rs1use std::cell::UnsafeCell;
11use std::future::Future;
12use std::mem::MaybeUninit;
13use std::pin::Pin;
14use std::ptr::NonNull;
15use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
16use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
17
18use crate::scheduler::{RawTask, SchedulerHandle, TaskId, gen_task_id};
19
20const STATE_RUNNING: u8 = 0;
21const STATE_WAITING: u8 = 1;
22const STATE_COMPLETED: u8 = 2;
23
24struct TaskVTable {
26 poll: unsafe fn(*const TaskCore) -> bool,
27 take_output: unsafe fn(*const TaskCore, *mut ()) -> bool,
28 drop_future_and_dealloc: unsafe fn(*const TaskCore),
29}
30
31#[repr(C)]
34pub(crate) struct TaskCore {
35 vtable: &'static TaskVTable,
36 id: TaskId,
37 state: AtomicU8,
38 ref_count: AtomicUsize,
39 scheduler: SchedulerHandle,
40}
41
42impl TaskCore {
43 pub(crate) fn id(&self) -> TaskId {
46 self.id
47 }
48
49 #[must_use]
52 pub(crate) fn is_completed(&self) -> bool {
53 self.state.load(Ordering::Acquire) == STATE_COMPLETED
54 }
55
56 #[allow(dead_code)]
57 pub(crate) fn scheduler(&self) -> &SchedulerHandle {
58 &self.scheduler
59 }
60
61 unsafe fn poll(&self) -> bool {
63 (self.vtable.poll)(self)
64 }
65
66 fn inc_ref(ptr: *const Self) -> *const Self {
68 let core = &unsafe { &*ptr };
69 core.ref_count.fetch_add(1, Ordering::Relaxed);
70 ptr
71 }
72
73 unsafe fn dec_ref(ptr: *const Self) {
75 let core = &*ptr;
76 if core.ref_count.fetch_sub(1, Ordering::Release) == 1 {
77 std::sync::atomic::fence(Ordering::Acquire);
78 (core.vtable.drop_future_and_dealloc)(ptr);
79 }
80 }
81}
82
83#[repr(C)]
86struct ConcreteTask<F: Future + Send + 'static> {
87 core: TaskCore,
88 future: UnsafeCell<MaybeUninit<F>>,
89 output: UnsafeCell<MaybeUninit<F::Output>>,
90}
91
92impl<F: Future + Send + 'static> ConcreteTask<F> {
93 const fn vtable() -> &'static TaskVTable {
94 &TaskVTable {
95 poll: Self::poll_impl,
96 take_output: Self::take_output_impl,
97 drop_future_and_dealloc: Self::drop_future_and_dealloc_impl,
98 }
99 }
100
101 unsafe fn poll_impl(core: *const TaskCore) -> bool {
102 let task = &*(core as *const Self);
103 let waker = create_task_waker(core);
104 let mut cx = Context::from_waker(&waker);
105
106 let future = &mut *task.future.get();
107 let future = Pin::new_unchecked(future.assume_init_mut());
108
109 match future.poll(&mut cx) {
110 Poll::Ready(value) => {
111 (*task.output.get()).write(value);
112 task.core.state.store(STATE_COMPLETED, Ordering::Release);
113 std::ptr::drop_in_place((*task.future.get()).as_mut_ptr());
114 true
115 },
116 Poll::Pending => {
117 task.core.state.store(STATE_WAITING, Ordering::Release);
118 false
119 },
120 }
121 }
122
123 unsafe fn take_output_impl(core: *const TaskCore, dest: *mut ()) -> bool {
124 let task = &*(core as *const Self);
125 if task.core.state.load(Ordering::Acquire) != STATE_COMPLETED {
126 return false;
127 }
128 let output = (*task.output.get()).assume_init_read();
129 (dest as *mut F::Output).write(output);
130 true
131 }
132
133 unsafe fn drop_future_and_dealloc_impl(core: *const TaskCore) {
134 let task = core as *mut Self;
135 let state = (*task).core.state.load(Ordering::Acquire);
136 if state == STATE_RUNNING || state == STATE_WAITING {
137 std::ptr::drop_in_place((*(*task).future.get()).as_mut_ptr());
138 }
139 if state == STATE_COMPLETED {
140 }
148 let _ = Box::from_raw(task);
149 }
150}
151
152static WAKER_VTABLE: RawWakerVTable =
155 RawWakerVTable::new(waker_clone, waker_wake, waker_wake_by_ref, waker_drop);
156
157fn create_task_waker(core: *const TaskCore) -> Waker {
158 TaskCore::inc_ref(core);
160 unsafe { Waker::from_raw(RawWaker::new(core as *const (), &WAKER_VTABLE)) }
161}
162
163unsafe fn waker_clone(data: *const ()) -> RawWaker {
164 TaskCore::inc_ref(data as *const TaskCore);
165 RawWaker::new(data, &WAKER_VTABLE)
166}
167
168unsafe fn waker_wake(data: *const ()) {
169 let core = data as *const TaskCore;
170 try_re_enqueue(core);
171 TaskCore::dec_ref(core); }
173
174unsafe fn waker_wake_by_ref(data: *const ()) {
175 try_re_enqueue(data as *const TaskCore);
176}
177
178unsafe fn waker_drop(data: *const ()) {
179 TaskCore::dec_ref(data as *const TaskCore);
180}
181
182fn try_re_enqueue(core: *const TaskCore) {
183 unsafe {
184 let c = &*core;
185 if c.state
186 .compare_exchange(STATE_WAITING, STATE_RUNNING, Ordering::AcqRel, Ordering::Relaxed)
187 .is_err()
188 {
189 return;
190 }
191 TaskCore::inc_ref(core);
193 let _ = c.scheduler.submit(core as RawTask);
194 }
195}
196
197pub(crate) struct TaskRef(Option<NonNull<TaskCore>>);
202
203impl TaskRef {
204 pub(crate) fn new(ptr: *const TaskCore) -> Self {
205 TaskRef(Some(unsafe { NonNull::new_unchecked(ptr.cast_mut()) }))
206 }
207
208 pub(crate) fn core(&self) -> Option<&TaskCore> {
209 self.0.map(|nn| unsafe { nn.as_ref() })
210 }
211
212 #[allow(dead_code)]
213 pub(crate) fn is_some(&self) -> bool {
214 self.0.is_some()
215 }
216}
217
218impl Drop for TaskRef {
219 fn drop(&mut self) {
220 if let Some(nn) = self.0.take() {
221 unsafe {
222 TaskCore::dec_ref(nn.as_ptr());
223 }
224 }
225 }
226}
227
228unsafe impl Send for TaskRef {}
230unsafe impl Sync for TaskRef {}
231
232pub(crate) fn allocate_task<F>(future: F, scheduler: SchedulerHandle) -> (RawTask, TaskRef)
235where
236 F: Future + Send + 'static,
237 F::Output: Send + 'static,
238{
239 let id = gen_task_id();
240
241 let task = Box::new(ConcreteTask {
242 core: TaskCore {
243 vtable: ConcreteTask::<F>::vtable(),
244 id,
245 state: AtomicU8::new(STATE_RUNNING),
246 ref_count: AtomicUsize::new(2), scheduler,
248 },
249 future: UnsafeCell::new(MaybeUninit::new(future)),
250 output: UnsafeCell::new(MaybeUninit::uninit()),
251 });
252
253 let raw: *mut ConcreteTask<F> = Box::into_raw(task);
254 let core_ptr: *const TaskCore = raw as *const TaskCore;
255
256 let raw_task = core_ptr as RawTask;
257 let task_ref = TaskRef::new(core_ptr);
258
259 (raw_task, task_ref)
260}
261
262pub unsafe fn poll_raw_task(raw_task: RawTask) -> bool {
268 let core = raw_task as *const TaskCore;
269 (*core).poll()
270}
271
272pub unsafe fn deallocate_completed_task(raw_task: RawTask) {
279 TaskCore::dec_ref(raw_task as *const TaskCore);
280}
281
282pub(crate) unsafe fn read_output<T>(core: &TaskCore) -> Option<T> {
288 let mut output: MaybeUninit<T> = MaybeUninit::uninit();
289 let ok = (core.vtable.take_output)(core, &mut output as *mut _ as *mut ());
290 if ok { Some(output.assume_init()) } else { None }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_sync_future_completes() {
299 let scheduler = SchedulerHandle::new_default();
300 let (raw_task, task_ref) = allocate_task(async { 42 }, scheduler);
301
302 let core = task_ref.core().unwrap();
303 assert_eq!(core.state.load(Ordering::Acquire), STATE_RUNNING);
304 assert!(!core.is_completed());
305
306 unsafe {
307 assert!(poll_raw_task(raw_task));
308 }
309 assert!(core.is_completed());
310
311 unsafe {
312 let output: i32 = read_output(core).unwrap();
313 assert_eq!(output, 42);
314 }
315
316 unsafe {
318 deallocate_completed_task(raw_task);
319 }
320 }
322
323 #[test]
324 fn test_task_id_unique() {
325 let scheduler = SchedulerHandle::new_default();
326 let (_, ref1) = allocate_task(async {}, scheduler.clone());
327 let (_, ref2) = allocate_task(async {}, scheduler);
328
329 assert_ne!(ref1.core().unwrap().id(), ref2.core().unwrap().id());
330 }
331}