compio_runtime/runtime/
mod.rs

1use std::{
2    any::Any,
3    cell::{Cell, RefCell},
4    collections::VecDeque,
5    future::{Future, ready},
6    io,
7    marker::PhantomData,
8    panic::AssertUnwindSafe,
9    rc::Rc,
10    sync::Arc,
11    task::{Context, Poll},
12    time::Duration,
13};
14
15use async_task::{Runnable, Task};
16use compio_buf::IntoInner;
17use compio_driver::{
18    AsRawFd, Key, NotifyHandle, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
19};
20use compio_log::{debug, instrument};
21use crossbeam_queue::SegQueue;
22use futures_util::{FutureExt, future::Either};
23
24pub(crate) mod op;
25#[cfg(feature = "time")]
26pub(crate) mod time;
27
28mod buffer_pool;
29pub use buffer_pool::*;
30
31mod send_wrapper;
32use send_wrapper::SendWrapper;
33
34#[cfg(feature = "time")]
35use crate::runtime::time::{TimerFuture, TimerRuntime};
36use crate::{BufResult, runtime::op::OpFuture};
37
38scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
39
40/// Type alias for `Task<Result<T, Box<dyn Any + Send>>>`, which resolves to an
41/// `Err` when the spawned future panicked.
42pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
43
44struct RunnableQueue {
45    local_runnables: SendWrapper<RefCell<VecDeque<Runnable>>>,
46    sync_runnables: SegQueue<Runnable>,
47}
48
49impl RunnableQueue {
50    pub fn new() -> Self {
51        Self {
52            local_runnables: SendWrapper::new(RefCell::new(VecDeque::new())),
53            sync_runnables: SegQueue::new(),
54        }
55    }
56
57    pub fn schedule(&self, runnable: Runnable, handle: &NotifyHandle) {
58        if let Some(runnables) = self.local_runnables.get() {
59            runnables.borrow_mut().push_back(runnable);
60            #[cfg(feature = "notify-always")]
61            handle.notify().ok();
62        } else {
63            self.sync_runnables.push(runnable);
64            handle.notify().ok();
65        }
66    }
67
68    /// SAFETY: call in the main thread
69    pub unsafe fn run(&self, event_interval: usize) -> bool {
70        let local_runnables = self.local_runnables.get_unchecked();
71        for _i in 0..event_interval {
72            let next_task = local_runnables.borrow_mut().pop_front();
73            let has_local_task = next_task.is_some();
74            if let Some(task) = next_task {
75                task.run();
76            }
77            // Cheaper than pop.
78            let has_sync_task = !self.sync_runnables.is_empty();
79            if has_sync_task {
80                if let Some(task) = self.sync_runnables.pop() {
81                    task.run();
82                }
83            } else if !has_local_task {
84                break;
85            }
86        }
87        !(local_runnables.borrow_mut().is_empty() && self.sync_runnables.is_empty())
88    }
89}
90
91thread_local! {
92    static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
93}
94
95/// The async runtime of compio. It is a thread local runtime, and cannot be
96/// sent to other threads.
97pub struct Runtime {
98    driver: RefCell<Proactor>,
99    runnables: Arc<RunnableQueue>,
100    #[cfg(feature = "time")]
101    timer_runtime: RefCell<TimerRuntime>,
102    event_interval: usize,
103    // Runtime id is used to check if the buffer pool is belonged to this runtime or not.
104    // Without this, if user enable `io-uring-buf-ring` feature then:
105    // 1. Create a buffer pool at runtime1
106    // 3. Create another runtime2, then use the exists buffer pool in runtime2, it may cause
107    // - io-uring report error if the buffer group id is not registered
108    // - buffer pool will return a wrong buffer which the buffer's data is uninit, that will cause
109    //   UB
110    id: u64,
111    // Other fields don't make it !Send, but actually `local_runnables` implies it should be !Send,
112    // otherwise it won't be valid if the runtime is sent to other threads.
113    _p: PhantomData<Rc<VecDeque<Runnable>>>,
114}
115
116impl Runtime {
117    /// Create [`Runtime`] with default config.
118    pub fn new() -> io::Result<Self> {
119        Self::builder().build()
120    }
121
122    /// Create a builder for [`Runtime`].
123    pub fn builder() -> RuntimeBuilder {
124        RuntimeBuilder::new()
125    }
126
127    fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
128        let id = RUNTIME_ID.get();
129        RUNTIME_ID.set(id + 1);
130        Ok(Self {
131            driver: RefCell::new(builder.proactor_builder.build()?),
132            runnables: Arc::new(RunnableQueue::new()),
133            #[cfg(feature = "time")]
134            timer_runtime: RefCell::new(TimerRuntime::new()),
135            event_interval: builder.event_interval,
136            id,
137            _p: PhantomData,
138        })
139    }
140
141    /// Try to perform a function on the current runtime, and if no runtime is
142    /// running, return the function back.
143    pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
144        if CURRENT_RUNTIME.is_set() {
145            Ok(CURRENT_RUNTIME.with(f))
146        } else {
147            Err(f)
148        }
149    }
150
151    /// Perform a function on the current runtime.
152    ///
153    /// ## Panics
154    ///
155    /// This method will panic if there are no running [`Runtime`].
156    pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
157        #[cold]
158        fn not_in_compio_runtime() -> ! {
159            panic!("not in a compio runtime")
160        }
161
162        if CURRENT_RUNTIME.is_set() {
163            CURRENT_RUNTIME.with(f)
164        } else {
165            not_in_compio_runtime()
166        }
167    }
168
169    /// Set this runtime as current runtime, and perform a function in the
170    /// current scope.
171    pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
172        CURRENT_RUNTIME.set(self, f)
173    }
174
175    /// Spawns a new asynchronous task, returning a [`Task`] for it.
176    ///
177    /// # Safety
178    ///
179    /// The caller should ensure the captured lifetime long enough.
180    pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
181        let runnables = self.runnables.clone();
182        let handle = self.driver.borrow().handle();
183        let schedule = move |runnable| {
184            runnables.schedule(runnable, &handle);
185        };
186        let (runnable, task) = async_task::spawn_unchecked(future, schedule);
187        runnable.schedule();
188        task
189    }
190
191    /// Low level API to control the runtime.
192    ///
193    /// Run the scheduled tasks.
194    ///
195    /// The return value indicates whether there are still tasks in the queue.
196    pub fn run(&self) -> bool {
197        // SAFETY: self is !Send + !Sync.
198        unsafe { self.runnables.run(self.event_interval) }
199    }
200
201    /// Block on the future till it completes.
202    pub fn block_on<F: Future>(&self, future: F) -> F::Output {
203        CURRENT_RUNTIME.set(self, || {
204            let mut result = None;
205            unsafe { self.spawn_unchecked(async { result = Some(future.await) }) }.detach();
206            loop {
207                let remaining_tasks = self.run();
208                if let Some(result) = result.take() {
209                    return result;
210                }
211                if remaining_tasks {
212                    self.poll_with(Some(Duration::ZERO));
213                } else {
214                    self.poll();
215                }
216            }
217        })
218    }
219
220    /// Spawns a new asynchronous task, returning a [`Task`] for it.
221    ///
222    /// Spawning a task enables the task to execute concurrently to other tasks.
223    /// There is no guarantee that a spawned task will execute to completion.
224    pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
225        unsafe { self.spawn_unchecked(AssertUnwindSafe(future).catch_unwind()) }
226    }
227
228    /// Spawns a blocking task in a new thread, and wait for it.
229    ///
230    /// The task will not be cancelled even if the future is dropped.
231    pub fn spawn_blocking<T: Send + 'static>(
232        &self,
233        f: impl (FnOnce() -> T) + Send + 'static,
234    ) -> JoinHandle<T> {
235        let op = Asyncify::new(move || {
236            let res = std::panic::catch_unwind(AssertUnwindSafe(f));
237            BufResult(Ok(0), res)
238        });
239        // It is safe and sound to use `submit` here because the task is spawned
240        // immediately.
241        #[allow(deprecated)]
242        unsafe {
243            self.spawn_unchecked(self.submit(op).map(|res| res.1.into_inner()))
244        }
245    }
246
247    /// Attach a raw file descriptor/handle/socket to the runtime.
248    ///
249    /// You only need this when authoring your own high-level APIs. High-level
250    /// resources in this crate are attached automatically.
251    pub fn attach(&self, fd: RawFd) -> io::Result<()> {
252        self.driver.borrow_mut().attach(fd)
253    }
254
255    fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
256        self.driver.borrow_mut().push(op)
257    }
258
259    /// Submit an operation to the runtime.
260    ///
261    /// You only need this when authoring your own [`OpCode`].
262    ///
263    /// It is safe to send the returned future to another runtime and poll it,
264    /// but the exact behavior is not guaranteed, e.g. it may return pending
265    /// forever or else.
266    #[deprecated = "use compio::runtime::submit instead"]
267    pub fn submit<T: OpCode + 'static>(&self, op: T) -> impl Future<Output = BufResult<usize, T>> {
268        #[allow(deprecated)]
269        self.submit_with_flags(op).map(|(res, _)| res)
270    }
271
272    /// Submit an operation to the runtime.
273    ///
274    /// The difference between [`Runtime::submit`] is this method will return
275    /// the flags
276    ///
277    /// You only need this when authoring your own [`OpCode`].
278    ///
279    /// It is safe to send the returned future to another runtime and poll it,
280    /// but the exact behavior is not guaranteed, e.g. it may return pending
281    /// forever or else.
282    #[deprecated = "use compio::runtime::submit_with_flags instead"]
283    pub fn submit_with_flags<T: OpCode + 'static>(
284        &self,
285        op: T,
286    ) -> impl Future<Output = (BufResult<usize, T>, u32)> {
287        match self.submit_raw(op) {
288            PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
289            PushEntry::Ready(res) => {
290                // submit_flags won't be ready immediately, if ready, it must be error without
291                // flags
292                Either::Right(ready((res, 0)))
293            }
294        }
295    }
296
297    pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
298        self.driver.borrow_mut().cancel(op);
299    }
300
301    #[cfg(feature = "time")]
302    pub(crate) fn cancel_timer(&self, key: usize) {
303        self.timer_runtime.borrow_mut().cancel(key);
304    }
305
306    pub(crate) fn poll_task<T: OpCode>(
307        &self,
308        cx: &mut Context,
309        op: Key<T>,
310    ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
311        instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
312        let mut driver = self.driver.borrow_mut();
313        driver.pop(op).map_pending(|mut k| {
314            driver.update_waker(&mut k, cx.waker().clone());
315            k
316        })
317    }
318
319    #[cfg(feature = "time")]
320    pub(crate) fn poll_timer(&self, cx: &mut Context, key: usize) -> Poll<()> {
321        instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
322        let mut timer_runtime = self.timer_runtime.borrow_mut();
323        if !timer_runtime.is_completed(key) {
324            debug!("pending");
325            timer_runtime.update_waker(key, cx.waker().clone());
326            Poll::Pending
327        } else {
328            debug!("ready");
329            Poll::Ready(())
330        }
331    }
332
333    /// Low level API to control the runtime.
334    ///
335    /// Get the timeout value to be passed to [`Proactor::poll`].
336    pub fn current_timeout(&self) -> Option<Duration> {
337        #[cfg(not(feature = "time"))]
338        let timeout = None;
339        #[cfg(feature = "time")]
340        let timeout = self.timer_runtime.borrow().min_timeout();
341        timeout
342    }
343
344    /// Low level API to control the runtime.
345    ///
346    /// Poll the inner proactor. It is equal to calling [`Runtime::poll_with`]
347    /// with [`Runtime::current_timeout`].
348    pub fn poll(&self) {
349        instrument!(compio_log::Level::DEBUG, "poll");
350        let timeout = self.current_timeout();
351        debug!("timeout: {:?}", timeout);
352        self.poll_with(timeout)
353    }
354
355    /// Low level API to control the runtime.
356    ///
357    /// Poll the inner proactor with a custom timeout.
358    pub fn poll_with(&self, timeout: Option<Duration>) {
359        instrument!(compio_log::Level::DEBUG, "poll_with");
360
361        let mut driver = self.driver.borrow_mut();
362        match driver.poll(timeout) {
363            Ok(()) => {}
364            Err(e) => match e.kind() {
365                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
366                    debug!("expected error: {e}");
367                }
368                _ => panic!("{e:?}"),
369            },
370        }
371        #[cfg(feature = "time")]
372        self.timer_runtime.borrow_mut().wake();
373    }
374
375    pub(crate) fn create_buffer_pool(
376        &self,
377        buffer_len: u16,
378        buffer_size: usize,
379    ) -> io::Result<compio_driver::BufferPool> {
380        self.driver
381            .borrow_mut()
382            .create_buffer_pool(buffer_len, buffer_size)
383    }
384
385    pub(crate) unsafe fn release_buffer_pool(
386        &self,
387        buffer_pool: compio_driver::BufferPool,
388    ) -> io::Result<()> {
389        self.driver.borrow_mut().release_buffer_pool(buffer_pool)
390    }
391
392    pub(crate) fn id(&self) -> u64 {
393        self.id
394    }
395}
396
397impl Drop for Runtime {
398    fn drop(&mut self) {
399        self.enter(|| {
400            while self.runnables.sync_runnables.pop().is_some() {}
401            let local_runnables = unsafe { self.runnables.local_runnables.get_unchecked() };
402            loop {
403                let runnable = local_runnables.borrow_mut().pop_front();
404                if runnable.is_none() {
405                    break;
406                }
407            }
408        })
409    }
410}
411
412impl AsRawFd for Runtime {
413    fn as_raw_fd(&self) -> RawFd {
414        self.driver.borrow().as_raw_fd()
415    }
416}
417
418#[cfg(feature = "criterion")]
419impl criterion::async_executor::AsyncExecutor for Runtime {
420    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
421        self.block_on(future)
422    }
423}
424
425#[cfg(feature = "criterion")]
426impl criterion::async_executor::AsyncExecutor for &Runtime {
427    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
428        (**self).block_on(future)
429    }
430}
431
432/// Builder for [`Runtime`].
433#[derive(Debug, Clone)]
434pub struct RuntimeBuilder {
435    proactor_builder: ProactorBuilder,
436    event_interval: usize,
437}
438
439impl Default for RuntimeBuilder {
440    fn default() -> Self {
441        Self::new()
442    }
443}
444
445impl RuntimeBuilder {
446    /// Create the builder with default config.
447    pub fn new() -> Self {
448        Self {
449            proactor_builder: ProactorBuilder::new(),
450            event_interval: 61,
451        }
452    }
453
454    /// Replace proactor builder.
455    pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
456        self.proactor_builder = builder;
457        self
458    }
459
460    /// Sets the number of scheduler ticks after which the scheduler will poll
461    /// for external events (timers, I/O, and so on).
462    ///
463    /// A scheduler “tick” roughly corresponds to one poll invocation on a task.
464    pub fn event_interval(&mut self, val: usize) -> &mut Self {
465        self.event_interval = val;
466        self
467    }
468
469    /// Build [`Runtime`].
470    pub fn build(&self) -> io::Result<Runtime> {
471        Runtime::with_builder(self)
472    }
473}
474
475/// Spawns a new asynchronous task, returning a [`Task`] for it.
476///
477/// Spawning a task enables the task to execute concurrently to other tasks.
478/// There is no guarantee that a spawned task will execute to completion.
479///
480/// ```
481/// # compio_runtime::Runtime::new().unwrap().block_on(async {
482/// let task = compio_runtime::spawn(async {
483///     println!("Hello from a spawned task!");
484///     42
485/// });
486///
487/// assert_eq!(
488///     task.await.unwrap_or_else(|e| std::panic::resume_unwind(e)),
489///     42
490/// );
491/// # })
492/// ```
493///
494/// ## Panics
495///
496/// This method doesn't create runtime. It tries to obtain the current runtime
497/// by [`Runtime::with_current`].
498pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
499    Runtime::with_current(|r| r.spawn(future))
500}
501
502/// Spawns a blocking task in a new thread, and wait for it.
503///
504/// The task will not be cancelled even if the future is dropped.
505///
506/// ## Panics
507///
508/// This method doesn't create runtime. It tries to obtain the current runtime
509/// by [`Runtime::with_current`].
510pub fn spawn_blocking<T: Send + 'static>(
511    f: impl (FnOnce() -> T) + Send + 'static,
512) -> JoinHandle<T> {
513    Runtime::with_current(|r| r.spawn_blocking(f))
514}
515
516/// Submit an operation to the current runtime, and return a future for it.
517///
518/// ## Panics
519///
520/// This method doesn't create runtime. It tries to obtain the current runtime
521/// by [`Runtime::with_current`].
522pub async fn submit<T: OpCode + 'static>(op: T) -> BufResult<usize, T> {
523    submit_with_flags(op).await.0
524}
525
526/// Submit an operation to the current runtime, and return a future for it with
527/// flags.
528///
529/// ## Panics
530///
531/// This method doesn't create runtime. It tries to obtain the current runtime
532/// by [`Runtime::with_current`].
533pub async fn submit_with_flags<T: OpCode + 'static>(op: T) -> (BufResult<usize, T>, u32) {
534    let state = Runtime::with_current(|r| r.submit_raw(op));
535    match state {
536        PushEntry::Pending(user_data) => OpFuture::new(user_data).await,
537        PushEntry::Ready(res) => {
538            // submit_flags won't be ready immediately, if ready, it must be error without
539            // flags, or the flags are not necessary
540            (res, 0)
541        }
542    }
543}
544
545#[cfg(feature = "time")]
546pub(crate) async fn create_timer(instant: std::time::Instant) {
547    let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
548    if let Some(key) = key {
549        TimerFuture::new(key).await
550    }
551}