1use std::{
2 any::Any,
3 cell::{Cell, RefCell},
4 collections::VecDeque,
5 future::{Future, ready},
6 io,
7 marker::PhantomData,
8 panic::AssertUnwindSafe,
9 rc::Rc,
10 sync::Arc,
11 task::{Context, Poll},
12 time::Duration,
13};
14
15use async_task::{Runnable, Task};
16use compio_buf::IntoInner;
17use compio_driver::{
18 AsRawFd, Key, NotifyHandle, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
19};
20use compio_log::{debug, instrument};
21use crossbeam_queue::SegQueue;
22use futures_util::{FutureExt, future::Either};
23
24pub(crate) mod op;
25#[cfg(feature = "time")]
26pub(crate) mod time;
27
28mod buffer_pool;
29pub use buffer_pool::*;
30
31mod send_wrapper;
32use send_wrapper::SendWrapper;
33
34#[cfg(feature = "time")]
35use crate::runtime::time::{TimerFuture, TimerRuntime};
36use crate::{BufResult, runtime::op::OpFuture};
37
38scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
39
40pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
43
44struct RunnableQueue {
45 local_runnables: SendWrapper<RefCell<VecDeque<Runnable>>>,
46 sync_runnables: SegQueue<Runnable>,
47}
48
49impl RunnableQueue {
50 pub fn new() -> Self {
51 Self {
52 local_runnables: SendWrapper::new(RefCell::new(VecDeque::new())),
53 sync_runnables: SegQueue::new(),
54 }
55 }
56
57 pub fn schedule(&self, runnable: Runnable, handle: &NotifyHandle) {
58 if let Some(runnables) = self.local_runnables.get() {
59 runnables.borrow_mut().push_back(runnable);
60 } else {
61 self.sync_runnables.push(runnable);
62 handle.notify().ok();
63 }
64 }
65
66 pub unsafe fn run(&self, event_interval: usize) -> bool {
68 let local_runnables = self.local_runnables.get_unchecked();
69 for _i in 0..event_interval {
70 let next_task = local_runnables.borrow_mut().pop_front();
71 let has_local_task = next_task.is_some();
72 if let Some(task) = next_task {
73 task.run();
74 }
75 let has_sync_task = !self.sync_runnables.is_empty();
77 if has_sync_task {
78 if let Some(task) = self.sync_runnables.pop() {
79 task.run();
80 }
81 } else if !has_local_task {
82 break;
83 }
84 }
85 !(local_runnables.borrow_mut().is_empty() && self.sync_runnables.is_empty())
86 }
87}
88
89thread_local! {
90 static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
91}
92
93pub struct Runtime {
96 driver: RefCell<Proactor>,
97 runnables: Arc<RunnableQueue>,
98 #[cfg(feature = "time")]
99 timer_runtime: RefCell<TimerRuntime>,
100 event_interval: usize,
101 id: u64,
109 _p: PhantomData<Rc<VecDeque<Runnable>>>,
112}
113
114impl Runtime {
115 pub fn new() -> io::Result<Self> {
117 Self::builder().build()
118 }
119
120 pub fn builder() -> RuntimeBuilder {
122 RuntimeBuilder::new()
123 }
124
125 fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
126 let id = RUNTIME_ID.get();
127 RUNTIME_ID.set(id + 1);
128 Ok(Self {
129 driver: RefCell::new(builder.proactor_builder.build()?),
130 runnables: Arc::new(RunnableQueue::new()),
131 #[cfg(feature = "time")]
132 timer_runtime: RefCell::new(TimerRuntime::new()),
133 event_interval: builder.event_interval,
134 id,
135 _p: PhantomData,
136 })
137 }
138
139 pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
142 if CURRENT_RUNTIME.is_set() {
143 Ok(CURRENT_RUNTIME.with(f))
144 } else {
145 Err(f)
146 }
147 }
148
149 pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
155 #[cold]
156 fn not_in_compio_runtime() -> ! {
157 panic!("not in a compio runtime")
158 }
159
160 if CURRENT_RUNTIME.is_set() {
161 CURRENT_RUNTIME.with(f)
162 } else {
163 not_in_compio_runtime()
164 }
165 }
166
167 pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
170 CURRENT_RUNTIME.set(self, f)
171 }
172
173 pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
179 let runnables = self.runnables.clone();
180 let handle = self.driver.borrow().handle();
181 let schedule = move |runnable| {
182 runnables.schedule(runnable, &handle);
183 };
184 let (runnable, task) = async_task::spawn_unchecked(future, schedule);
185 runnable.schedule();
186 task
187 }
188
189 pub fn run(&self) -> bool {
195 unsafe { self.runnables.run(self.event_interval) }
197 }
198
199 pub fn block_on<F: Future>(&self, future: F) -> F::Output {
201 CURRENT_RUNTIME.set(self, || {
202 let mut result = None;
203 unsafe { self.spawn_unchecked(async { result = Some(future.await) }) }.detach();
204 loop {
205 let remaining_tasks = self.run();
206 if let Some(result) = result.take() {
207 return result;
208 }
209 if remaining_tasks {
210 self.poll_with(Some(Duration::ZERO));
211 } else {
212 self.poll();
213 }
214 }
215 })
216 }
217
218 pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
223 unsafe { self.spawn_unchecked(AssertUnwindSafe(future).catch_unwind()) }
224 }
225
226 pub fn spawn_blocking<T: Send + 'static>(
230 &self,
231 f: impl (FnOnce() -> T) + Send + 'static,
232 ) -> JoinHandle<T> {
233 let op = Asyncify::new(move || {
234 let res = std::panic::catch_unwind(AssertUnwindSafe(f));
235 BufResult(Ok(0), res)
236 });
237 #[allow(deprecated)]
240 unsafe {
241 self.spawn_unchecked(self.submit(op).map(|res| res.1.into_inner()))
242 }
243 }
244
245 pub fn attach(&self, fd: RawFd) -> io::Result<()> {
250 self.driver.borrow_mut().attach(fd)
251 }
252
253 fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
254 self.driver.borrow_mut().push(op)
255 }
256
257 #[deprecated = "use compio::runtime::submit instead"]
265 pub fn submit<T: OpCode + 'static>(&self, op: T) -> impl Future<Output = BufResult<usize, T>> {
266 #[allow(deprecated)]
267 self.submit_with_flags(op).map(|(res, _)| res)
268 }
269
270 #[deprecated = "use compio::runtime::submit_with_flags instead"]
281 pub fn submit_with_flags<T: OpCode + 'static>(
282 &self,
283 op: T,
284 ) -> impl Future<Output = (BufResult<usize, T>, u32)> {
285 match self.submit_raw(op) {
286 PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
287 PushEntry::Ready(res) => {
288 Either::Right(ready((res, 0)))
291 }
292 }
293 }
294
295 pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
296 self.driver.borrow_mut().cancel(op);
297 }
298
299 #[cfg(feature = "time")]
300 pub(crate) fn cancel_timer(&self, key: usize) {
301 self.timer_runtime.borrow_mut().cancel(key);
302 }
303
304 pub(crate) fn poll_task<T: OpCode>(
305 &self,
306 cx: &mut Context,
307 op: Key<T>,
308 ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
309 instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
310 let mut driver = self.driver.borrow_mut();
311 driver.pop(op).map_pending(|mut k| {
312 driver.update_waker(&mut k, cx.waker().clone());
313 k
314 })
315 }
316
317 #[cfg(feature = "time")]
318 pub(crate) fn poll_timer(&self, cx: &mut Context, key: usize) -> Poll<()> {
319 instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
320 let mut timer_runtime = self.timer_runtime.borrow_mut();
321 if !timer_runtime.is_completed(key) {
322 debug!("pending");
323 timer_runtime.update_waker(key, cx.waker().clone());
324 Poll::Pending
325 } else {
326 debug!("ready");
327 Poll::Ready(())
328 }
329 }
330
331 pub fn current_timeout(&self) -> Option<Duration> {
335 #[cfg(not(feature = "time"))]
336 let timeout = None;
337 #[cfg(feature = "time")]
338 let timeout = self.timer_runtime.borrow().min_timeout();
339 timeout
340 }
341
342 pub fn poll(&self) {
347 instrument!(compio_log::Level::DEBUG, "poll");
348 let timeout = self.current_timeout();
349 debug!("timeout: {:?}", timeout);
350 self.poll_with(timeout)
351 }
352
353 pub fn poll_with(&self, timeout: Option<Duration>) {
357 instrument!(compio_log::Level::DEBUG, "poll_with");
358
359 let mut driver = self.driver.borrow_mut();
360 match driver.poll(timeout) {
361 Ok(()) => {}
362 Err(e) => match e.kind() {
363 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
364 debug!("expected error: {e}");
365 }
366 _ => panic!("{e:?}"),
367 },
368 }
369 #[cfg(feature = "time")]
370 self.timer_runtime.borrow_mut().wake();
371 }
372
373 pub(crate) fn create_buffer_pool(
374 &self,
375 buffer_len: u16,
376 buffer_size: usize,
377 ) -> io::Result<compio_driver::BufferPool> {
378 self.driver
379 .borrow_mut()
380 .create_buffer_pool(buffer_len, buffer_size)
381 }
382
383 pub(crate) unsafe fn release_buffer_pool(
384 &self,
385 buffer_pool: compio_driver::BufferPool,
386 ) -> io::Result<()> {
387 self.driver.borrow_mut().release_buffer_pool(buffer_pool)
388 }
389
390 pub(crate) fn id(&self) -> u64 {
391 self.id
392 }
393}
394
395impl Drop for Runtime {
396 fn drop(&mut self) {
397 self.enter(|| {
398 while self.runnables.sync_runnables.pop().is_some() {}
399 let local_runnables = unsafe { self.runnables.local_runnables.get_unchecked() };
400 loop {
401 let runnable = local_runnables.borrow_mut().pop_front();
402 if runnable.is_none() {
403 break;
404 }
405 }
406 })
407 }
408}
409
410impl AsRawFd for Runtime {
411 fn as_raw_fd(&self) -> RawFd {
412 self.driver.borrow().as_raw_fd()
413 }
414}
415
416#[cfg(feature = "criterion")]
417impl criterion::async_executor::AsyncExecutor for Runtime {
418 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
419 self.block_on(future)
420 }
421}
422
423#[cfg(feature = "criterion")]
424impl criterion::async_executor::AsyncExecutor for &Runtime {
425 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
426 (**self).block_on(future)
427 }
428}
429
430#[derive(Debug, Clone)]
432pub struct RuntimeBuilder {
433 proactor_builder: ProactorBuilder,
434 event_interval: usize,
435}
436
437impl Default for RuntimeBuilder {
438 fn default() -> Self {
439 Self::new()
440 }
441}
442
443impl RuntimeBuilder {
444 pub fn new() -> Self {
446 Self {
447 proactor_builder: ProactorBuilder::new(),
448 event_interval: 61,
449 }
450 }
451
452 pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
454 self.proactor_builder = builder;
455 self
456 }
457
458 pub fn event_interval(&mut self, val: usize) -> &mut Self {
463 self.event_interval = val;
464 self
465 }
466
467 pub fn build(&self) -> io::Result<Runtime> {
469 Runtime::with_builder(self)
470 }
471}
472
473pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
497 Runtime::with_current(|r| r.spawn(future))
498}
499
500pub fn spawn_blocking<T: Send + 'static>(
509 f: impl (FnOnce() -> T) + Send + 'static,
510) -> JoinHandle<T> {
511 Runtime::with_current(|r| r.spawn_blocking(f))
512}
513
514pub async fn submit<T: OpCode + 'static>(op: T) -> BufResult<usize, T> {
521 submit_with_flags(op).await.0
522}
523
524pub async fn submit_with_flags<T: OpCode + 'static>(op: T) -> (BufResult<usize, T>, u32) {
532 let state = Runtime::with_current(|r| r.submit_raw(op));
533 match state {
534 PushEntry::Pending(user_data) => OpFuture::new(user_data).await,
535 PushEntry::Ready(res) => {
536 (res, 0)
539 }
540 }
541}
542
543#[cfg(feature = "time")]
544pub(crate) async fn create_timer(instant: std::time::Instant) {
545 let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
546 if let Some(key) = key {
547 TimerFuture::new(key).await
548 }
549}