compio_runtime/runtime/
mod.rs

1use std::{
2    any::Any,
3    cell::{Cell, RefCell},
4    collections::VecDeque,
5    future::{Future, ready},
6    io,
7    marker::PhantomData,
8    panic::AssertUnwindSafe,
9    rc::Rc,
10    sync::Arc,
11    task::{Context, Poll},
12    time::Duration,
13};
14
15use async_task::{Runnable, Task};
16use compio_buf::IntoInner;
17use compio_driver::{
18    AsRawFd, Key, NotifyHandle, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
19};
20use compio_log::{debug, instrument};
21use crossbeam_queue::SegQueue;
22use futures_util::{FutureExt, future::Either};
23
24pub(crate) mod op;
25#[cfg(feature = "time")]
26pub(crate) mod time;
27
28mod buffer_pool;
29pub use buffer_pool::*;
30
31mod send_wrapper;
32use send_wrapper::SendWrapper;
33
34#[cfg(feature = "time")]
35use crate::runtime::time::{TimerFuture, TimerRuntime};
36use crate::{BufResult, runtime::op::OpFuture};
37
38scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
39
40/// Type alias for `Task<Result<T, Box<dyn Any + Send>>>`, which resolves to an
41/// `Err` when the spawned future panicked.
42pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
43
44struct RunnableQueue {
45    local_runnables: SendWrapper<RefCell<VecDeque<Runnable>>>,
46    sync_runnables: SegQueue<Runnable>,
47}
48
49impl RunnableQueue {
50    pub fn new() -> Self {
51        Self {
52            local_runnables: SendWrapper::new(RefCell::new(VecDeque::new())),
53            sync_runnables: SegQueue::new(),
54        }
55    }
56
57    pub fn schedule(&self, runnable: Runnable, handle: &NotifyHandle) {
58        if let Some(runnables) = self.local_runnables.get() {
59            runnables.borrow_mut().push_back(runnable);
60        } else {
61            self.sync_runnables.push(runnable);
62            handle.notify().ok();
63        }
64    }
65
66    /// SAFETY: call in the main thread
67    pub unsafe fn run(&self, event_interval: usize) -> bool {
68        let local_runnables = self.local_runnables.get_unchecked();
69        for _i in 0..event_interval {
70            let next_task = local_runnables.borrow_mut().pop_front();
71            let has_local_task = next_task.is_some();
72            if let Some(task) = next_task {
73                task.run();
74            }
75            // Cheaper than pop.
76            let has_sync_task = !self.sync_runnables.is_empty();
77            if has_sync_task {
78                if let Some(task) = self.sync_runnables.pop() {
79                    task.run();
80                }
81            } else if !has_local_task {
82                break;
83            }
84        }
85        !(local_runnables.borrow_mut().is_empty() && self.sync_runnables.is_empty())
86    }
87}
88
89thread_local! {
90    static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
91}
92
93/// The async runtime of compio. It is a thread local runtime, and cannot be
94/// sent to other threads.
95pub struct Runtime {
96    driver: RefCell<Proactor>,
97    runnables: Arc<RunnableQueue>,
98    #[cfg(feature = "time")]
99    timer_runtime: RefCell<TimerRuntime>,
100    event_interval: usize,
101    // Runtime id is used to check if the buffer pool is belonged to this runtime or not.
102    // Without this, if user enable `io-uring-buf-ring` feature then:
103    // 1. Create a buffer pool at runtime1
104    // 3. Create another runtime2, then use the exists buffer pool in runtime2, it may cause
105    // - io-uring report error if the buffer group id is not registered
106    // - buffer pool will return a wrong buffer which the buffer's data is uninit, that will cause
107    //   UB
108    id: u64,
109    // Other fields don't make it !Send, but actually `local_runnables` implies it should be !Send,
110    // otherwise it won't be valid if the runtime is sent to other threads.
111    _p: PhantomData<Rc<VecDeque<Runnable>>>,
112}
113
114impl Runtime {
115    /// Create [`Runtime`] with default config.
116    pub fn new() -> io::Result<Self> {
117        Self::builder().build()
118    }
119
120    /// Create a builder for [`Runtime`].
121    pub fn builder() -> RuntimeBuilder {
122        RuntimeBuilder::new()
123    }
124
125    fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
126        let id = RUNTIME_ID.get();
127        RUNTIME_ID.set(id + 1);
128        Ok(Self {
129            driver: RefCell::new(builder.proactor_builder.build()?),
130            runnables: Arc::new(RunnableQueue::new()),
131            #[cfg(feature = "time")]
132            timer_runtime: RefCell::new(TimerRuntime::new()),
133            event_interval: builder.event_interval,
134            id,
135            _p: PhantomData,
136        })
137    }
138
139    /// Try to perform a function on the current runtime, and if no runtime is
140    /// running, return the function back.
141    pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
142        if CURRENT_RUNTIME.is_set() {
143            Ok(CURRENT_RUNTIME.with(f))
144        } else {
145            Err(f)
146        }
147    }
148
149    /// Perform a function on the current runtime.
150    ///
151    /// ## Panics
152    ///
153    /// This method will panic if there are no running [`Runtime`].
154    pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
155        #[cold]
156        fn not_in_compio_runtime() -> ! {
157            panic!("not in a compio runtime")
158        }
159
160        if CURRENT_RUNTIME.is_set() {
161            CURRENT_RUNTIME.with(f)
162        } else {
163            not_in_compio_runtime()
164        }
165    }
166
167    /// Set this runtime as current runtime, and perform a function in the
168    /// current scope.
169    pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
170        CURRENT_RUNTIME.set(self, f)
171    }
172
173    /// Spawns a new asynchronous task, returning a [`Task`] for it.
174    ///
175    /// # Safety
176    ///
177    /// The caller should ensure the captured lifetime long enough.
178    pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
179        let runnables = self.runnables.clone();
180        let handle = self.driver.borrow().handle();
181        let schedule = move |runnable| {
182            runnables.schedule(runnable, &handle);
183        };
184        let (runnable, task) = async_task::spawn_unchecked(future, schedule);
185        runnable.schedule();
186        task
187    }
188
189    /// Low level API to control the runtime.
190    ///
191    /// Run the scheduled tasks.
192    ///
193    /// The return value indicates whether there are still tasks in the queue.
194    pub fn run(&self) -> bool {
195        // SAFETY: self is !Send + !Sync.
196        unsafe { self.runnables.run(self.event_interval) }
197    }
198
199    /// Block on the future till it completes.
200    pub fn block_on<F: Future>(&self, future: F) -> F::Output {
201        CURRENT_RUNTIME.set(self, || {
202            let mut result = None;
203            unsafe { self.spawn_unchecked(async { result = Some(future.await) }) }.detach();
204            loop {
205                let remaining_tasks = self.run();
206                if let Some(result) = result.take() {
207                    return result;
208                }
209                if remaining_tasks {
210                    self.poll_with(Some(Duration::ZERO));
211                } else {
212                    self.poll();
213                }
214            }
215        })
216    }
217
218    /// Spawns a new asynchronous task, returning a [`Task`] for it.
219    ///
220    /// Spawning a task enables the task to execute concurrently to other tasks.
221    /// There is no guarantee that a spawned task will execute to completion.
222    pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
223        unsafe { self.spawn_unchecked(AssertUnwindSafe(future).catch_unwind()) }
224    }
225
226    /// Spawns a blocking task in a new thread, and wait for it.
227    ///
228    /// The task will not be cancelled even if the future is dropped.
229    pub fn spawn_blocking<T: Send + 'static>(
230        &self,
231        f: impl (FnOnce() -> T) + Send + 'static,
232    ) -> JoinHandle<T> {
233        let op = Asyncify::new(move || {
234            let res = std::panic::catch_unwind(AssertUnwindSafe(f));
235            BufResult(Ok(0), res)
236        });
237        // It is safe and sound to use `submit` here because the task is spawned
238        // immediately.
239        #[allow(deprecated)]
240        unsafe {
241            self.spawn_unchecked(self.submit(op).map(|res| res.1.into_inner()))
242        }
243    }
244
245    /// Attach a raw file descriptor/handle/socket to the runtime.
246    ///
247    /// You only need this when authoring your own high-level APIs. High-level
248    /// resources in this crate are attached automatically.
249    pub fn attach(&self, fd: RawFd) -> io::Result<()> {
250        self.driver.borrow_mut().attach(fd)
251    }
252
253    fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
254        self.driver.borrow_mut().push(op)
255    }
256
257    /// Submit an operation to the runtime.
258    ///
259    /// You only need this when authoring your own [`OpCode`].
260    ///
261    /// It is safe to send the returned future to another runtime and poll it,
262    /// but the exact behavior is not guaranteed, e.g. it may return pending
263    /// forever or else.
264    #[deprecated = "use compio::runtime::submit instead"]
265    pub fn submit<T: OpCode + 'static>(&self, op: T) -> impl Future<Output = BufResult<usize, T>> {
266        #[allow(deprecated)]
267        self.submit_with_flags(op).map(|(res, _)| res)
268    }
269
270    /// Submit an operation to the runtime.
271    ///
272    /// The difference between [`Runtime::submit`] is this method will return
273    /// the flags
274    ///
275    /// You only need this when authoring your own [`OpCode`].
276    ///
277    /// It is safe to send the returned future to another runtime and poll it,
278    /// but the exact behavior is not guaranteed, e.g. it may return pending
279    /// forever or else.
280    #[deprecated = "use compio::runtime::submit_with_flags instead"]
281    pub fn submit_with_flags<T: OpCode + 'static>(
282        &self,
283        op: T,
284    ) -> impl Future<Output = (BufResult<usize, T>, u32)> {
285        match self.submit_raw(op) {
286            PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
287            PushEntry::Ready(res) => {
288                // submit_flags won't be ready immediately, if ready, it must be error without
289                // flags
290                Either::Right(ready((res, 0)))
291            }
292        }
293    }
294
295    pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
296        self.driver.borrow_mut().cancel(op);
297    }
298
299    #[cfg(feature = "time")]
300    pub(crate) fn cancel_timer(&self, key: usize) {
301        self.timer_runtime.borrow_mut().cancel(key);
302    }
303
304    pub(crate) fn poll_task<T: OpCode>(
305        &self,
306        cx: &mut Context,
307        op: Key<T>,
308    ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
309        instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
310        let mut driver = self.driver.borrow_mut();
311        driver.pop(op).map_pending(|mut k| {
312            driver.update_waker(&mut k, cx.waker().clone());
313            k
314        })
315    }
316
317    #[cfg(feature = "time")]
318    pub(crate) fn poll_timer(&self, cx: &mut Context, key: usize) -> Poll<()> {
319        instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
320        let mut timer_runtime = self.timer_runtime.borrow_mut();
321        if !timer_runtime.is_completed(key) {
322            debug!("pending");
323            timer_runtime.update_waker(key, cx.waker().clone());
324            Poll::Pending
325        } else {
326            debug!("ready");
327            Poll::Ready(())
328        }
329    }
330
331    /// Low level API to control the runtime.
332    ///
333    /// Get the timeout value to be passed to [`Proactor::poll`].
334    pub fn current_timeout(&self) -> Option<Duration> {
335        #[cfg(not(feature = "time"))]
336        let timeout = None;
337        #[cfg(feature = "time")]
338        let timeout = self.timer_runtime.borrow().min_timeout();
339        timeout
340    }
341
342    /// Low level API to control the runtime.
343    ///
344    /// Poll the inner proactor. It is equal to calling [`Runtime::poll_with`]
345    /// with [`Runtime::current_timeout`].
346    pub fn poll(&self) {
347        instrument!(compio_log::Level::DEBUG, "poll");
348        let timeout = self.current_timeout();
349        debug!("timeout: {:?}", timeout);
350        self.poll_with(timeout)
351    }
352
353    /// Low level API to control the runtime.
354    ///
355    /// Poll the inner proactor with a custom timeout.
356    pub fn poll_with(&self, timeout: Option<Duration>) {
357        instrument!(compio_log::Level::DEBUG, "poll_with");
358
359        let mut driver = self.driver.borrow_mut();
360        match driver.poll(timeout) {
361            Ok(()) => {}
362            Err(e) => match e.kind() {
363                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
364                    debug!("expected error: {e}");
365                }
366                _ => panic!("{e:?}"),
367            },
368        }
369        #[cfg(feature = "time")]
370        self.timer_runtime.borrow_mut().wake();
371    }
372
373    pub(crate) fn create_buffer_pool(
374        &self,
375        buffer_len: u16,
376        buffer_size: usize,
377    ) -> io::Result<compio_driver::BufferPool> {
378        self.driver
379            .borrow_mut()
380            .create_buffer_pool(buffer_len, buffer_size)
381    }
382
383    pub(crate) unsafe fn release_buffer_pool(
384        &self,
385        buffer_pool: compio_driver::BufferPool,
386    ) -> io::Result<()> {
387        self.driver.borrow_mut().release_buffer_pool(buffer_pool)
388    }
389
390    pub(crate) fn id(&self) -> u64 {
391        self.id
392    }
393}
394
395impl Drop for Runtime {
396    fn drop(&mut self) {
397        self.enter(|| {
398            while self.runnables.sync_runnables.pop().is_some() {}
399            let local_runnables = unsafe { self.runnables.local_runnables.get_unchecked() };
400            loop {
401                let runnable = local_runnables.borrow_mut().pop_front();
402                if runnable.is_none() {
403                    break;
404                }
405            }
406        })
407    }
408}
409
410impl AsRawFd for Runtime {
411    fn as_raw_fd(&self) -> RawFd {
412        self.driver.borrow().as_raw_fd()
413    }
414}
415
416#[cfg(feature = "criterion")]
417impl criterion::async_executor::AsyncExecutor for Runtime {
418    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
419        self.block_on(future)
420    }
421}
422
423#[cfg(feature = "criterion")]
424impl criterion::async_executor::AsyncExecutor for &Runtime {
425    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
426        (**self).block_on(future)
427    }
428}
429
430/// Builder for [`Runtime`].
431#[derive(Debug, Clone)]
432pub struct RuntimeBuilder {
433    proactor_builder: ProactorBuilder,
434    event_interval: usize,
435}
436
437impl Default for RuntimeBuilder {
438    fn default() -> Self {
439        Self::new()
440    }
441}
442
443impl RuntimeBuilder {
444    /// Create the builder with default config.
445    pub fn new() -> Self {
446        Self {
447            proactor_builder: ProactorBuilder::new(),
448            event_interval: 61,
449        }
450    }
451
452    /// Replace proactor builder.
453    pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
454        self.proactor_builder = builder;
455        self
456    }
457
458    /// Sets the number of scheduler ticks after which the scheduler will poll
459    /// for external events (timers, I/O, and so on).
460    ///
461    /// A scheduler “tick” roughly corresponds to one poll invocation on a task.
462    pub fn event_interval(&mut self, val: usize) -> &mut Self {
463        self.event_interval = val;
464        self
465    }
466
467    /// Build [`Runtime`].
468    pub fn build(&self) -> io::Result<Runtime> {
469        Runtime::with_builder(self)
470    }
471}
472
473/// Spawns a new asynchronous task, returning a [`Task`] for it.
474///
475/// Spawning a task enables the task to execute concurrently to other tasks.
476/// There is no guarantee that a spawned task will execute to completion.
477///
478/// ```
479/// # compio_runtime::Runtime::new().unwrap().block_on(async {
480/// let task = compio_runtime::spawn(async {
481///     println!("Hello from a spawned task!");
482///     42
483/// });
484///
485/// assert_eq!(
486///     task.await.unwrap_or_else(|e| std::panic::resume_unwind(e)),
487///     42
488/// );
489/// # })
490/// ```
491///
492/// ## Panics
493///
494/// This method doesn't create runtime. It tries to obtain the current runtime
495/// by [`Runtime::with_current`].
496pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
497    Runtime::with_current(|r| r.spawn(future))
498}
499
500/// Spawns a blocking task in a new thread, and wait for it.
501///
502/// The task will not be cancelled even if the future is dropped.
503///
504/// ## Panics
505///
506/// This method doesn't create runtime. It tries to obtain the current runtime
507/// by [`Runtime::with_current`].
508pub fn spawn_blocking<T: Send + 'static>(
509    f: impl (FnOnce() -> T) + Send + 'static,
510) -> JoinHandle<T> {
511    Runtime::with_current(|r| r.spawn_blocking(f))
512}
513
514/// Submit an operation to the current runtime, and return a future for it.
515///
516/// ## Panics
517///
518/// This method doesn't create runtime. It tries to obtain the current runtime
519/// by [`Runtime::with_current`].
520pub async fn submit<T: OpCode + 'static>(op: T) -> BufResult<usize, T> {
521    submit_with_flags(op).await.0
522}
523
524/// Submit an operation to the current runtime, and return a future for it with
525/// flags.
526///
527/// ## Panics
528///
529/// This method doesn't create runtime. It tries to obtain the current runtime
530/// by [`Runtime::with_current`].
531pub async fn submit_with_flags<T: OpCode + 'static>(op: T) -> (BufResult<usize, T>, u32) {
532    let state = Runtime::with_current(|r| r.submit_raw(op));
533    match state {
534        PushEntry::Pending(user_data) => OpFuture::new(user_data).await,
535        PushEntry::Ready(res) => {
536            // submit_flags won't be ready immediately, if ready, it must be error without
537            // flags, or the flags are not necessary
538            (res, 0)
539        }
540    }
541}
542
543#[cfg(feature = "time")]
544pub(crate) async fn create_timer(instant: std::time::Instant) {
545    let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
546    if let Some(key) = key {
547        TimerFuture::new(key).await
548    }
549}