Skip to main content

compio_runtime/runtime/
mod.rs

1use std::{
2    any::Any,
3    cell::{Cell, RefCell},
4    collections::HashSet,
5    future::Future,
6    io,
7    ops::Deref,
8    panic::AssertUnwindSafe,
9    rc::Rc,
10    sync::Arc,
11    task::{Context, Poll, Waker},
12    time::Duration,
13};
14
15use async_task::Task;
16use compio_buf::IntoInner;
17use compio_driver::{
18    AsRawFd, DriverType, Extra, Key, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd,
19    op::Asyncify,
20};
21use compio_log::{debug, instrument};
22use futures_util::FutureExt;
23
24mod future;
25pub use future::Submit;
26
27#[cfg(feature = "time")]
28pub(crate) mod time;
29
30mod buffer_pool;
31pub use buffer_pool::*;
32
33mod scheduler;
34
35mod opt_waker;
36pub use opt_waker::OptWaker;
37
38mod send_wrapper;
39use send_wrapper::SendWrapper;
40
41#[cfg(feature = "time")]
42use crate::runtime::time::{TimerFuture, TimerKey, TimerRuntime};
43use crate::{BufResult, affinity::bind_to_cpu_set, runtime::scheduler::Scheduler};
44
45scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
46
47/// Type alias for `Task<Result<T, Box<dyn Any + Send>>>`, which resolves to an
48/// `Err` when the spawned future panicked.
49pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
50
51thread_local! {
52    static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
53}
54
55/// Inner structure of [`Runtime`].
56pub struct RuntimeInner {
57    driver: RefCell<Proactor>,
58    scheduler: Scheduler,
59    #[cfg(feature = "time")]
60    timer_runtime: RefCell<TimerRuntime>,
61    // Runtime id is used to check if the buffer pool is belonged to this runtime or not.
62    // Without this, if user enable `io-uring-buf-ring` feature then:
63    // 1. Create a buffer pool at runtime1
64    // 3. Create another runtime2, then use the exists buffer pool in runtime2, it may cause
65    // - io-uring report error if the buffer group id is not registered
66    // - buffer pool will return a wrong buffer which the buffer's data is uninit, that will cause
67    //   UB
68    id: u64,
69}
70
71/// The async runtime of compio.
72///
73/// It is a thread-local runtime, meaning it cannot be sent to other threads.
74#[derive(Clone)]
75pub struct Runtime(Rc<RuntimeInner>);
76
77impl Deref for Runtime {
78    type Target = RuntimeInner;
79
80    fn deref(&self) -> &Self::Target {
81        &self.0
82    }
83}
84
85impl Runtime {
86    /// Create [`Runtime`] with default config.
87    pub fn new() -> io::Result<Self> {
88        Self::builder().build()
89    }
90
91    /// Create a builder for [`Runtime`].
92    pub fn builder() -> RuntimeBuilder {
93        RuntimeBuilder::new()
94    }
95
96    /// The current driver type.
97    pub fn driver_type(&self) -> DriverType {
98        self.driver.borrow().driver_type()
99    }
100
101    /// Try to perform a function on the current runtime, and if no runtime is
102    /// running, return the function back.
103    pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
104        if CURRENT_RUNTIME.is_set() {
105            Ok(CURRENT_RUNTIME.with(f))
106        } else {
107            Err(f)
108        }
109    }
110
111    /// Perform a function on the current runtime.
112    ///
113    /// ## Panics
114    ///
115    /// This method will panic if there is no running [`Runtime`].
116    pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
117        #[cold]
118        fn not_in_compio_runtime() -> ! {
119            panic!("not in a compio runtime")
120        }
121
122        if CURRENT_RUNTIME.is_set() {
123            CURRENT_RUNTIME.with(f)
124        } else {
125            not_in_compio_runtime()
126        }
127    }
128
129    /// Set this runtime as current runtime, and perform a function in the
130    /// current scope.
131    pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
132        CURRENT_RUNTIME.set(self, f)
133    }
134
135    fn spawn_impl<F: Future + 'static>(&self, future: F) -> Task<F::Output> {
136        unsafe { self.spawn_unchecked(future) }
137    }
138
139    /// Low level API to control the runtime.
140    ///
141    /// Spawns a new asynchronous task, returning a [`Task`] for it.
142    ///
143    /// # Safety
144    ///
145    /// Borrowed variables must outlive the future.
146    pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
147        let waker = self.waker();
148        unsafe { self.scheduler.spawn_unchecked(future, waker) }
149    }
150
151    /// Low level API to control the runtime.
152    ///
153    /// Run the scheduled tasks.
154    ///
155    /// The return value indicates whether there are still tasks in the queue.
156    pub fn run(&self) -> bool {
157        self.scheduler.run()
158    }
159
160    /// Low level API to control the runtime.
161    ///
162    /// Create a waker that always notifies the runtime when woken.
163    pub fn waker(&self) -> Waker {
164        self.driver.borrow().waker()
165    }
166
167    /// Low level API to control the runtime.
168    ///
169    /// Create an optimized waker that only notifies the runtime when woken
170    /// from another thread, or when `notify-always` is enabled.
171    pub fn opt_waker(&self) -> Arc<OptWaker> {
172        OptWaker::new(self.waker())
173    }
174
175    /// Block on the future till it completes.
176    pub fn block_on<F: Future>(&self, future: F) -> F::Output {
177        self.enter(|| {
178            let opt_waker = self.opt_waker();
179            let waker = Waker::from(opt_waker.clone());
180            let mut context = Context::from_waker(&waker);
181            let mut future = std::pin::pin!(future);
182            loop {
183                if let Poll::Ready(result) = future.as_mut().poll(&mut context) {
184                    self.run();
185                    return result;
186                }
187                // We always want to reset the waker here.
188                let remaining_tasks = self.run() | opt_waker.reset();
189                if remaining_tasks {
190                    self.poll_with(Some(Duration::ZERO));
191                } else {
192                    self.poll();
193                }
194            }
195        })
196    }
197
198    /// Spawns a new asynchronous task, returning a [`Task`] for it.
199    ///
200    /// Spawning a task enables the task to execute concurrently to other tasks.
201    /// There is no guarantee that a spawned task will execute to completion.
202    pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
203        self.spawn_impl(AssertUnwindSafe(future).catch_unwind())
204    }
205
206    /// Spawns a blocking task in a new thread, and wait for it.
207    ///
208    /// The task will not be cancelled even if the future is dropped.
209    pub fn spawn_blocking<T: Send + 'static>(
210        &self,
211        f: impl (FnOnce() -> T) + Send + 'static,
212    ) -> JoinHandle<T> {
213        let op = Asyncify::new(move || {
214            let res = std::panic::catch_unwind(AssertUnwindSafe(f));
215            BufResult(Ok(0), res)
216        });
217        // It is safe and sound to use `submit` here because the task is spawned
218        // immediately.
219        self.spawn_impl(self.submit(op).map(|res| res.1.into_inner()))
220    }
221
222    /// Attach a raw file descriptor/handle/socket to the runtime.
223    ///
224    /// You only need this when authoring your own high-level APIs. High-level
225    /// resources in this crate are attached automatically.
226    pub fn attach(&self, fd: RawFd) -> io::Result<()> {
227        self.driver.borrow_mut().attach(fd)
228    }
229
230    fn submit_raw<T: OpCode + 'static>(
231        &self,
232        op: T,
233        extra: Option<Extra>,
234    ) -> PushEntry<Key<T>, BufResult<usize, T>> {
235        let mut this = self.driver.borrow_mut();
236        match extra {
237            Some(e) => this.push_with_extra(op, e),
238            None => this.push(op),
239        }
240    }
241
242    fn default_extra(&self) -> Extra {
243        self.driver.borrow().default_extra()
244    }
245
246    /// Submit an operation to the runtime.
247    ///
248    /// You only need this when authoring your own [`OpCode`].
249    fn submit<T: OpCode + 'static>(&self, op: T) -> Submit<T> {
250        Submit::new(self.clone(), op)
251    }
252
253    pub(crate) fn cancel<T: OpCode>(&self, op: Key<T>) {
254        self.driver.borrow_mut().cancel(op);
255    }
256
257    #[cfg(feature = "time")]
258    pub(crate) fn cancel_timer(&self, key: &TimerKey) {
259        self.timer_runtime.borrow_mut().cancel(key);
260    }
261
262    pub(crate) fn poll_task<T: OpCode>(
263        &self,
264        waker: &Waker,
265        key: Key<T>,
266    ) -> PushEntry<Key<T>, BufResult<usize, T>> {
267        instrument!(compio_log::Level::DEBUG, "poll_task", ?key);
268        let mut driver = self.driver.borrow_mut();
269        driver.pop(key).map_pending(|mut k| {
270            driver.update_waker(&mut k, waker);
271            k
272        })
273    }
274
275    pub(crate) fn poll_task_with_extra<T: OpCode>(
276        &self,
277        cx: &mut Context,
278        key: Key<T>,
279    ) -> PushEntry<Key<T>, (BufResult<usize, T>, Extra)> {
280        instrument!(compio_log::Level::DEBUG, "poll_task_with_extra", ?key);
281        let mut driver = self.driver.borrow_mut();
282        driver.pop_with_extra(key).map_pending(|mut k| {
283            driver.update_waker(&mut k, cx.waker());
284            k
285        })
286    }
287
288    #[cfg(feature = "time")]
289    pub(crate) fn poll_timer(&self, cx: &mut Context, key: &TimerKey) -> Poll<()> {
290        instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
291        let mut timer_runtime = self.timer_runtime.borrow_mut();
292        if timer_runtime.is_completed(key) {
293            debug!("ready");
294            Poll::Ready(())
295        } else {
296            debug!("pending");
297            timer_runtime.update_waker(key, cx.waker().clone());
298            Poll::Pending
299        }
300    }
301
302    /// Low level API to control the runtime.
303    ///
304    /// Get the timeout value to be passed to [`Proactor::poll`].
305    pub fn current_timeout(&self) -> Option<Duration> {
306        #[cfg(not(feature = "time"))]
307        let timeout = None;
308        #[cfg(feature = "time")]
309        let timeout = self.timer_runtime.borrow().min_timeout();
310        timeout
311    }
312
313    /// Low level API to control the runtime.
314    ///
315    /// Poll the inner proactor. It is equal to calling [`Runtime::poll_with`]
316    /// with [`Runtime::current_timeout`].
317    pub fn poll(&self) {
318        instrument!(compio_log::Level::DEBUG, "poll");
319        let timeout = self.current_timeout();
320        debug!("timeout: {:?}", timeout);
321        self.poll_with(timeout)
322    }
323
324    /// Low level API to control the runtime.
325    ///
326    /// Poll the inner proactor with a custom timeout.
327    pub fn poll_with(&self, timeout: Option<Duration>) {
328        instrument!(compio_log::Level::DEBUG, "poll_with");
329
330        let mut driver = self.driver.borrow_mut();
331        match driver.poll(timeout) {
332            Ok(()) => {}
333            Err(e) => match e.kind() {
334                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
335                    debug!("expected error: {e}");
336                }
337                _ => panic!("{e:?}"),
338            },
339        }
340        #[cfg(feature = "time")]
341        self.timer_runtime.borrow_mut().wake();
342    }
343
344    pub(crate) fn create_buffer_pool(
345        &self,
346        buffer_len: u16,
347        buffer_size: usize,
348    ) -> io::Result<compio_driver::BufferPool> {
349        self.driver
350            .borrow_mut()
351            .create_buffer_pool(buffer_len, buffer_size)
352    }
353
354    pub(crate) unsafe fn release_buffer_pool(
355        &self,
356        buffer_pool: compio_driver::BufferPool,
357    ) -> io::Result<()> {
358        unsafe { self.driver.borrow_mut().release_buffer_pool(buffer_pool) }
359    }
360
361    pub(crate) fn id(&self) -> u64 {
362        self.id
363    }
364
365    /// Register the personality for the runtime.
366    ///
367    /// This is only supported on io-uring driver, and will return an
368    /// [`Unsupported`] io error on all other drivers.
369    ///
370    /// The returned personality can be used with `FutureExt::with_personality`
371    /// if the `future-combinator` feature is turned on.
372    ///
373    /// [`Unsupported`]: std::io::ErrorKind::Unsupported
374    pub fn register_personality(&self) -> io::Result<u16> {
375        self.driver.borrow_mut().register_personality()
376    }
377
378    /// Unregister the given personality for the runtime.
379    ///
380    /// This is only supported on io-uring driver, and will return an
381    /// [`Unsupported`] io error on all other drivers.
382    ///
383    /// [`Unsupported`]: std::io::ErrorKind::Unsupported
384    pub fn unregister_personality(&self, personality: u16) -> io::Result<()> {
385        self.driver.borrow_mut().unregister_personality(personality)
386    }
387}
388
389impl Drop for Runtime {
390    fn drop(&mut self) {
391        // this is not the last runtime reference, no need to clear
392        if Rc::strong_count(&self.0) > 1 {
393            return;
394        }
395
396        self.enter(|| {
397            self.scheduler.clear();
398        })
399    }
400}
401
402impl AsRawFd for Runtime {
403    fn as_raw_fd(&self) -> RawFd {
404        self.driver.borrow().as_raw_fd()
405    }
406}
407
408#[cfg(feature = "criterion")]
409impl criterion::async_executor::AsyncExecutor for Runtime {
410    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
411        self.block_on(future)
412    }
413}
414
415#[cfg(feature = "criterion")]
416impl criterion::async_executor::AsyncExecutor for &Runtime {
417    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
418        (**self).block_on(future)
419    }
420}
421
422/// Builder for [`Runtime`].
423#[derive(Debug, Clone)]
424pub struct RuntimeBuilder {
425    proactor_builder: ProactorBuilder,
426    thread_affinity: HashSet<usize>,
427    event_interval: usize,
428}
429
430impl Default for RuntimeBuilder {
431    fn default() -> Self {
432        Self::new()
433    }
434}
435
436impl RuntimeBuilder {
437    /// Create the builder with default config.
438    pub fn new() -> Self {
439        Self {
440            proactor_builder: ProactorBuilder::new(),
441            event_interval: 61,
442            thread_affinity: HashSet::new(),
443        }
444    }
445
446    /// Replace proactor builder.
447    pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
448        self.proactor_builder = builder;
449        self
450    }
451
452    /// Sets the thread affinity for the runtime.
453    pub fn thread_affinity(&mut self, cpus: HashSet<usize>) -> &mut Self {
454        self.thread_affinity = cpus;
455        self
456    }
457
458    /// Sets the number of scheduler ticks after which the scheduler will poll
459    /// for external events (timers, I/O, and so on).
460    ///
461    /// A scheduler “tick” roughly corresponds to one poll invocation on a task.
462    pub fn event_interval(&mut self, val: usize) -> &mut Self {
463        self.event_interval = val;
464        self
465    }
466
467    /// Build [`Runtime`].
468    pub fn build(&self) -> io::Result<Runtime> {
469        let RuntimeBuilder {
470            proactor_builder,
471            thread_affinity,
472            event_interval,
473        } = self;
474        let id = RUNTIME_ID.get();
475        RUNTIME_ID.set(id + 1);
476        if !thread_affinity.is_empty() {
477            bind_to_cpu_set(thread_affinity);
478        }
479        let inner = RuntimeInner {
480            driver: RefCell::new(proactor_builder.build()?),
481            scheduler: Scheduler::new(*event_interval),
482            #[cfg(feature = "time")]
483            timer_runtime: RefCell::new(TimerRuntime::new()),
484            id,
485        };
486        Ok(Runtime(Rc::new(inner)))
487    }
488}
489
490/// Spawns a new asynchronous task, returning a [`Task`] for it.
491///
492/// Spawning a task enables the task to execute concurrently to other tasks.
493/// There is no guarantee that a spawned task will execute to completion.
494///
495/// ```
496/// # compio_runtime::Runtime::new().unwrap().block_on(async {
497/// let task = compio_runtime::spawn(async {
498///     println!("Hello from a spawned task!");
499///     42
500/// });
501///
502/// assert_eq!(
503///     task.await.unwrap_or_else(|e| std::panic::resume_unwind(e)),
504///     42
505/// );
506/// # })
507/// ```
508///
509/// ## Panics
510///
511/// This method doesn't create runtime. It tries to obtain the current runtime
512/// by [`Runtime::with_current`].
513pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
514    Runtime::with_current(|r| r.spawn(future))
515}
516
517/// Spawns a blocking task in a new thread, and wait for it.
518///
519/// The task will not be cancelled even if the future is dropped.
520///
521/// ## Panics
522///
523/// This method doesn't create runtime. It tries to obtain the current runtime
524/// by [`Runtime::with_current`].
525pub fn spawn_blocking<T: Send + 'static>(
526    f: impl (FnOnce() -> T) + Send + 'static,
527) -> JoinHandle<T> {
528    Runtime::with_current(|r| r.spawn_blocking(f))
529}
530
531/// Submit an operation to the current runtime, and return a future for it.
532///
533/// ## Panics
534///
535/// This method doesn't create runtime and will panic if it's not within a
536/// runtime. It tries to obtain the current runtime by
537/// [`Runtime::with_current`].
538pub fn submit<T: OpCode + 'static>(op: T) -> Submit<T> {
539    Runtime::with_current(|r| r.submit(op))
540}
541
542/// Submit an operation to the current runtime, and return a future for it with
543/// flags.
544///
545/// ## Panics
546///
547/// This method doesn't create runtime. It tries to obtain the current runtime
548/// by [`Runtime::with_current`].
549#[deprecated(since = "0.8.0", note = "use `submit(op).with_extra()` instead")]
550pub fn submit_with_extra<T: OpCode + 'static>(op: T) -> Submit<T, Extra> {
551    Runtime::with_current(|r| r.submit(op).with_extra())
552}
553
554#[cfg(feature = "time")]
555pub(crate) async fn create_timer(instant: std::time::Instant) {
556    let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
557    if let Some(key) = key {
558        TimerFuture::new(key).await
559    }
560}