1use std::{
2 any::Any,
3 cell::{Cell, RefCell},
4 collections::VecDeque,
5 future::{Future, ready},
6 io,
7 marker::PhantomData,
8 panic::AssertUnwindSafe,
9 rc::Rc,
10 sync::Arc,
11 task::{Context, Poll},
12 time::Duration,
13};
14
15use async_task::{Runnable, Task};
16use compio_buf::IntoInner;
17use compio_driver::{
18 AsRawFd, Key, NotifyHandle, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
19};
20use compio_log::{debug, instrument};
21use crossbeam_queue::SegQueue;
22use futures_util::{FutureExt, future::Either};
23
24pub(crate) mod op;
25#[cfg(feature = "time")]
26pub(crate) mod time;
27
28mod buffer_pool;
29pub use buffer_pool::*;
30
31mod send_wrapper;
32use send_wrapper::SendWrapper;
33
34#[cfg(feature = "time")]
35use crate::runtime::time::{TimerFuture, TimerRuntime};
36use crate::{BufResult, runtime::op::OpFuture};
37
38scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
39
40pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
43
44struct RunnableQueue {
45 local_runnables: SendWrapper<RefCell<VecDeque<Runnable>>>,
46 sync_runnables: SegQueue<Runnable>,
47}
48
49impl RunnableQueue {
50 pub fn new() -> Self {
51 Self {
52 local_runnables: SendWrapper::new(RefCell::new(VecDeque::new())),
53 sync_runnables: SegQueue::new(),
54 }
55 }
56
57 pub fn schedule(&self, runnable: Runnable, handle: &NotifyHandle) {
58 if let Some(runnables) = self.local_runnables.get() {
59 runnables.borrow_mut().push_back(runnable);
60 #[cfg(feature = "notify-always")]
61 handle.notify().ok();
62 } else {
63 self.sync_runnables.push(runnable);
64 handle.notify().ok();
65 }
66 }
67
68 pub unsafe fn run(&self, event_interval: usize) -> bool {
70 let local_runnables = self.local_runnables.get_unchecked();
71 for _i in 0..event_interval {
72 let next_task = local_runnables.borrow_mut().pop_front();
73 let has_local_task = next_task.is_some();
74 if let Some(task) = next_task {
75 task.run();
76 }
77 let has_sync_task = !self.sync_runnables.is_empty();
79 if has_sync_task {
80 if let Some(task) = self.sync_runnables.pop() {
81 task.run();
82 }
83 } else if !has_local_task {
84 break;
85 }
86 }
87 !(local_runnables.borrow_mut().is_empty() && self.sync_runnables.is_empty())
88 }
89}
90
91thread_local! {
92 static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
93}
94
95pub struct Runtime {
98 driver: RefCell<Proactor>,
99 runnables: Arc<RunnableQueue>,
100 #[cfg(feature = "time")]
101 timer_runtime: RefCell<TimerRuntime>,
102 event_interval: usize,
103 id: u64,
111 _p: PhantomData<Rc<VecDeque<Runnable>>>,
114}
115
116impl Runtime {
117 pub fn new() -> io::Result<Self> {
119 Self::builder().build()
120 }
121
122 pub fn builder() -> RuntimeBuilder {
124 RuntimeBuilder::new()
125 }
126
127 fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
128 let id = RUNTIME_ID.get();
129 RUNTIME_ID.set(id + 1);
130 Ok(Self {
131 driver: RefCell::new(builder.proactor_builder.build()?),
132 runnables: Arc::new(RunnableQueue::new()),
133 #[cfg(feature = "time")]
134 timer_runtime: RefCell::new(TimerRuntime::new()),
135 event_interval: builder.event_interval,
136 id,
137 _p: PhantomData,
138 })
139 }
140
141 pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
144 if CURRENT_RUNTIME.is_set() {
145 Ok(CURRENT_RUNTIME.with(f))
146 } else {
147 Err(f)
148 }
149 }
150
151 pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
157 #[cold]
158 fn not_in_compio_runtime() -> ! {
159 panic!("not in a compio runtime")
160 }
161
162 if CURRENT_RUNTIME.is_set() {
163 CURRENT_RUNTIME.with(f)
164 } else {
165 not_in_compio_runtime()
166 }
167 }
168
169 pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
172 CURRENT_RUNTIME.set(self, f)
173 }
174
175 pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
181 let runnables = self.runnables.clone();
182 let handle = self.driver.borrow().handle();
183 let schedule = move |runnable| {
184 runnables.schedule(runnable, &handle);
185 };
186 let (runnable, task) = async_task::spawn_unchecked(future, schedule);
187 runnable.schedule();
188 task
189 }
190
191 pub fn run(&self) -> bool {
197 unsafe { self.runnables.run(self.event_interval) }
199 }
200
201 pub fn block_on<F: Future>(&self, future: F) -> F::Output {
203 CURRENT_RUNTIME.set(self, || {
204 let mut result = None;
205 unsafe { self.spawn_unchecked(async { result = Some(future.await) }) }.detach();
206 loop {
207 let remaining_tasks = self.run();
208 if let Some(result) = result.take() {
209 return result;
210 }
211 if remaining_tasks {
212 self.poll_with(Some(Duration::ZERO));
213 } else {
214 self.poll();
215 }
216 }
217 })
218 }
219
220 pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
225 unsafe { self.spawn_unchecked(AssertUnwindSafe(future).catch_unwind()) }
226 }
227
228 pub fn spawn_blocking<T: Send + 'static>(
232 &self,
233 f: impl (FnOnce() -> T) + Send + 'static,
234 ) -> JoinHandle<T> {
235 let op = Asyncify::new(move || {
236 let res = std::panic::catch_unwind(AssertUnwindSafe(f));
237 BufResult(Ok(0), res)
238 });
239 #[allow(deprecated)]
242 unsafe {
243 self.spawn_unchecked(self.submit(op).map(|res| res.1.into_inner()))
244 }
245 }
246
247 pub fn attach(&self, fd: RawFd) -> io::Result<()> {
252 self.driver.borrow_mut().attach(fd)
253 }
254
255 fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
256 self.driver.borrow_mut().push(op)
257 }
258
259 #[deprecated = "use compio::runtime::submit instead"]
267 pub fn submit<T: OpCode + 'static>(&self, op: T) -> impl Future<Output = BufResult<usize, T>> {
268 #[allow(deprecated)]
269 self.submit_with_flags(op).map(|(res, _)| res)
270 }
271
272 #[deprecated = "use compio::runtime::submit_with_flags instead"]
283 pub fn submit_with_flags<T: OpCode + 'static>(
284 &self,
285 op: T,
286 ) -> impl Future<Output = (BufResult<usize, T>, u32)> {
287 match self.submit_raw(op) {
288 PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
289 PushEntry::Ready(res) => {
290 Either::Right(ready((res, 0)))
293 }
294 }
295 }
296
297 pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
298 self.driver.borrow_mut().cancel(op);
299 }
300
301 #[cfg(feature = "time")]
302 pub(crate) fn cancel_timer(&self, key: usize) {
303 self.timer_runtime.borrow_mut().cancel(key);
304 }
305
306 pub(crate) fn poll_task<T: OpCode>(
307 &self,
308 cx: &mut Context,
309 op: Key<T>,
310 ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
311 instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
312 let mut driver = self.driver.borrow_mut();
313 driver.pop(op).map_pending(|mut k| {
314 driver.update_waker(&mut k, cx.waker().clone());
315 k
316 })
317 }
318
319 #[cfg(feature = "time")]
320 pub(crate) fn poll_timer(&self, cx: &mut Context, key: usize) -> Poll<()> {
321 instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
322 let mut timer_runtime = self.timer_runtime.borrow_mut();
323 if !timer_runtime.is_completed(key) {
324 debug!("pending");
325 timer_runtime.update_waker(key, cx.waker().clone());
326 Poll::Pending
327 } else {
328 debug!("ready");
329 Poll::Ready(())
330 }
331 }
332
333 pub fn current_timeout(&self) -> Option<Duration> {
337 #[cfg(not(feature = "time"))]
338 let timeout = None;
339 #[cfg(feature = "time")]
340 let timeout = self.timer_runtime.borrow().min_timeout();
341 timeout
342 }
343
344 pub fn poll(&self) {
349 instrument!(compio_log::Level::DEBUG, "poll");
350 let timeout = self.current_timeout();
351 debug!("timeout: {:?}", timeout);
352 self.poll_with(timeout)
353 }
354
355 pub fn poll_with(&self, timeout: Option<Duration>) {
359 instrument!(compio_log::Level::DEBUG, "poll_with");
360
361 let mut driver = self.driver.borrow_mut();
362 match driver.poll(timeout) {
363 Ok(()) => {}
364 Err(e) => match e.kind() {
365 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
366 debug!("expected error: {e}");
367 }
368 _ => panic!("{e:?}"),
369 },
370 }
371 #[cfg(feature = "time")]
372 self.timer_runtime.borrow_mut().wake();
373 }
374
375 pub(crate) fn create_buffer_pool(
376 &self,
377 buffer_len: u16,
378 buffer_size: usize,
379 ) -> io::Result<compio_driver::BufferPool> {
380 self.driver
381 .borrow_mut()
382 .create_buffer_pool(buffer_len, buffer_size)
383 }
384
385 pub(crate) unsafe fn release_buffer_pool(
386 &self,
387 buffer_pool: compio_driver::BufferPool,
388 ) -> io::Result<()> {
389 self.driver.borrow_mut().release_buffer_pool(buffer_pool)
390 }
391
392 pub(crate) fn id(&self) -> u64 {
393 self.id
394 }
395}
396
397impl Drop for Runtime {
398 fn drop(&mut self) {
399 self.enter(|| {
400 while self.runnables.sync_runnables.pop().is_some() {}
401 let local_runnables = unsafe { self.runnables.local_runnables.get_unchecked() };
402 loop {
403 let runnable = local_runnables.borrow_mut().pop_front();
404 if runnable.is_none() {
405 break;
406 }
407 }
408 })
409 }
410}
411
412impl AsRawFd for Runtime {
413 fn as_raw_fd(&self) -> RawFd {
414 self.driver.borrow().as_raw_fd()
415 }
416}
417
418#[cfg(feature = "criterion")]
419impl criterion::async_executor::AsyncExecutor for Runtime {
420 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
421 self.block_on(future)
422 }
423}
424
425#[cfg(feature = "criterion")]
426impl criterion::async_executor::AsyncExecutor for &Runtime {
427 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
428 (**self).block_on(future)
429 }
430}
431
432#[derive(Debug, Clone)]
434pub struct RuntimeBuilder {
435 proactor_builder: ProactorBuilder,
436 event_interval: usize,
437}
438
439impl Default for RuntimeBuilder {
440 fn default() -> Self {
441 Self::new()
442 }
443}
444
445impl RuntimeBuilder {
446 pub fn new() -> Self {
448 Self {
449 proactor_builder: ProactorBuilder::new(),
450 event_interval: 61,
451 }
452 }
453
454 pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
456 self.proactor_builder = builder;
457 self
458 }
459
460 pub fn event_interval(&mut self, val: usize) -> &mut Self {
465 self.event_interval = val;
466 self
467 }
468
469 pub fn build(&self) -> io::Result<Runtime> {
471 Runtime::with_builder(self)
472 }
473}
474
475pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
499 Runtime::with_current(|r| r.spawn(future))
500}
501
502pub fn spawn_blocking<T: Send + 'static>(
511 f: impl (FnOnce() -> T) + Send + 'static,
512) -> JoinHandle<T> {
513 Runtime::with_current(|r| r.spawn_blocking(f))
514}
515
516pub async fn submit<T: OpCode + 'static>(op: T) -> BufResult<usize, T> {
523 submit_with_flags(op).await.0
524}
525
526pub async fn submit_with_flags<T: OpCode + 'static>(op: T) -> (BufResult<usize, T>, u32) {
534 let state = Runtime::with_current(|r| r.submit_raw(op));
535 match state {
536 PushEntry::Pending(user_data) => OpFuture::new(user_data).await,
537 PushEntry::Ready(res) => {
538 (res, 0)
541 }
542 }
543}
544
545#[cfg(feature = "time")]
546pub(crate) async fn create_timer(instant: std::time::Instant) {
547 let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
548 if let Some(key) = key {
549 TimerFuture::new(key).await
550 }
551}