compio_runtime/runtime/
mod.rs

1use std::{
2    any::Any,
3    cell::{Cell, RefCell},
4    collections::HashSet,
5    future::{Future, ready},
6    io,
7    panic::AssertUnwindSafe,
8    sync::Arc,
9    task::{Context, Poll, Waker},
10    time::Duration,
11};
12
13use async_task::Task;
14use compio_buf::IntoInner;
15use compio_driver::{
16    AsRawFd, DriverType, Key, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
17};
18use compio_log::{debug, instrument};
19use futures_util::{FutureExt, future::Either};
20
21pub(crate) mod op;
22#[cfg(feature = "time")]
23pub(crate) mod time;
24
25mod buffer_pool;
26pub use buffer_pool::*;
27
28mod scheduler;
29
30mod opt_waker;
31pub use opt_waker::OptWaker;
32
33mod send_wrapper;
34use send_wrapper::SendWrapper;
35
36#[cfg(feature = "time")]
37use crate::runtime::time::{TimerFuture, TimerKey, TimerRuntime};
38use crate::{
39    BufResult,
40    affinity::bind_to_cpu_set,
41    runtime::{op::OpFuture, scheduler::Scheduler},
42};
43
44scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
45
46/// Type alias for `Task<Result<T, Box<dyn Any + Send>>>`, which resolves to an
47/// `Err` when the spawned future panicked.
48pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
49
50thread_local! {
51    static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
52}
53
54/// The async runtime of compio. It is a thread local runtime, and cannot be
55/// sent to other threads.
56pub struct Runtime {
57    driver: RefCell<Proactor>,
58    scheduler: Scheduler,
59    #[cfg(feature = "time")]
60    timer_runtime: RefCell<TimerRuntime>,
61    // Runtime id is used to check if the buffer pool is belonged to this runtime or not.
62    // Without this, if user enable `io-uring-buf-ring` feature then:
63    // 1. Create a buffer pool at runtime1
64    // 3. Create another runtime2, then use the exists buffer pool in runtime2, it may cause
65    // - io-uring report error if the buffer group id is not registered
66    // - buffer pool will return a wrong buffer which the buffer's data is uninit, that will cause
67    //   UB
68    id: u64,
69}
70
71impl Runtime {
72    /// Create [`Runtime`] with default config.
73    pub fn new() -> io::Result<Self> {
74        Self::builder().build()
75    }
76
77    /// Create a builder for [`Runtime`].
78    pub fn builder() -> RuntimeBuilder {
79        RuntimeBuilder::new()
80    }
81
82    fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
83        let RuntimeBuilder {
84            proactor_builder,
85            thread_affinity,
86            event_interval,
87        } = builder;
88        let id = RUNTIME_ID.get();
89        RUNTIME_ID.set(id + 1);
90        if !thread_affinity.is_empty() {
91            bind_to_cpu_set(thread_affinity);
92        }
93        Ok(Self {
94            driver: RefCell::new(proactor_builder.build()?),
95            scheduler: Scheduler::new(*event_interval),
96            #[cfg(feature = "time")]
97            timer_runtime: RefCell::new(TimerRuntime::new()),
98            id,
99        })
100    }
101
102    /// The current driver type.
103    pub fn driver_type(&self) -> DriverType {
104        self.driver.borrow().driver_type()
105    }
106
107    /// Try to perform a function on the current runtime, and if no runtime is
108    /// running, return the function back.
109    pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
110        if CURRENT_RUNTIME.is_set() {
111            Ok(CURRENT_RUNTIME.with(f))
112        } else {
113            Err(f)
114        }
115    }
116
117    /// Perform a function on the current runtime.
118    ///
119    /// ## Panics
120    ///
121    /// This method will panic if there are no running [`Runtime`].
122    pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
123        #[cold]
124        fn not_in_compio_runtime() -> ! {
125            panic!("not in a compio runtime")
126        }
127
128        if CURRENT_RUNTIME.is_set() {
129            CURRENT_RUNTIME.with(f)
130        } else {
131            not_in_compio_runtime()
132        }
133    }
134
135    /// Set this runtime as current runtime, and perform a function in the
136    /// current scope.
137    pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
138        CURRENT_RUNTIME.set(self, f)
139    }
140
141    fn spawn_impl<F: Future + 'static>(&self, future: F) -> Task<F::Output> {
142        let waker = self.driver.borrow().waker();
143        self.scheduler.spawn(future, waker)
144    }
145
146    /// Low level API to control the runtime.
147    ///
148    /// Run the scheduled tasks.
149    ///
150    /// The return value indicates whether there are still tasks in the queue.
151    pub fn run(&self) -> bool {
152        self.scheduler.run()
153    }
154
155    /// Low level API to control the runtime.
156    ///
157    /// Create a waker that always notifies the runtime when woken.
158    pub fn waker(&self) -> Waker {
159        self.driver.borrow().waker()
160    }
161
162    /// Low level API to control the runtime.
163    ///
164    /// Create an optimized waker that only notifies the runtime when woken
165    /// from another thread, or when `notify-always` is enabled.
166    pub fn opt_waker(&self) -> Arc<OptWaker> {
167        OptWaker::new(self.waker())
168    }
169
170    /// Block on the future till it completes.
171    pub fn block_on<F: Future>(&self, future: F) -> F::Output {
172        self.enter(|| {
173            let opt_waker = self.opt_waker();
174            let waker = Waker::from(opt_waker.clone());
175            let mut context = Context::from_waker(&waker);
176            let mut future = std::pin::pin!(future);
177            loop {
178                if let Poll::Ready(result) = future.as_mut().poll(&mut context) {
179                    self.run();
180                    return result;
181                }
182                // We always want to reset the waker here.
183                let remaining_tasks = self.run() | opt_waker.reset();
184                if remaining_tasks {
185                    self.poll_with(Some(Duration::ZERO));
186                } else {
187                    self.poll();
188                }
189            }
190        })
191    }
192
193    /// Spawns a new asynchronous task, returning a [`Task`] for it.
194    ///
195    /// Spawning a task enables the task to execute concurrently to other tasks.
196    /// There is no guarantee that a spawned task will execute to completion.
197    pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
198        self.spawn_impl(AssertUnwindSafe(future).catch_unwind())
199    }
200
201    /// Spawns a blocking task in a new thread, and wait for it.
202    ///
203    /// The task will not be cancelled even if the future is dropped.
204    pub fn spawn_blocking<T: Send + 'static>(
205        &self,
206        f: impl (FnOnce() -> T) + Send + 'static,
207    ) -> JoinHandle<T> {
208        let op = Asyncify::new(move || {
209            let res = std::panic::catch_unwind(AssertUnwindSafe(f));
210            BufResult(Ok(0), res)
211        });
212        // It is safe and sound to use `submit` here because the task is spawned
213        // immediately.
214        self.spawn_impl(self.submit(op).map(|res| res.1.into_inner()))
215    }
216
217    /// Attach a raw file descriptor/handle/socket to the runtime.
218    ///
219    /// You only need this when authoring your own high-level APIs. High-level
220    /// resources in this crate are attached automatically.
221    pub fn attach(&self, fd: RawFd) -> io::Result<()> {
222        self.driver.borrow_mut().attach(fd)
223    }
224
225    fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
226        self.driver.borrow_mut().push(op)
227    }
228
229    /// Submit an operation to the runtime.
230    ///
231    /// You only need this when authoring your own [`OpCode`].
232    ///
233    /// It is safe to send the returned future to another runtime and poll it,
234    /// but the exact behavior is not guaranteed, e.g. it may return pending
235    /// forever or else.
236    fn submit<T: OpCode + 'static>(
237        &self,
238        op: T,
239    ) -> impl Future<Output = BufResult<usize, T>> + use<T> {
240        self.submit_with_flags(op).map(|(res, _)| res)
241    }
242
243    /// Submit an operation to the runtime.
244    ///
245    /// The difference between [`Runtime::submit`] is this method will return
246    /// the flags
247    ///
248    /// You only need this when authoring your own [`OpCode`].
249    ///
250    /// It is safe to send the returned future to another runtime and poll it,
251    /// but the exact behavior is not guaranteed, e.g. it may return pending
252    /// forever or else.
253    fn submit_with_flags<T: OpCode + 'static>(
254        &self,
255        op: T,
256    ) -> impl Future<Output = (BufResult<usize, T>, u32)> + use<T> {
257        match self.submit_raw(op) {
258            PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
259            PushEntry::Ready(res) => {
260                // submit_flags won't be ready immediately, if ready, it must be error without
261                // flags
262                Either::Right(ready((res, 0)))
263            }
264        }
265    }
266
267    pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
268        self.driver.borrow_mut().cancel(op);
269    }
270
271    #[cfg(feature = "time")]
272    pub(crate) fn cancel_timer(&self, key: &TimerKey) {
273        self.timer_runtime.borrow_mut().cancel(key);
274    }
275
276    pub(crate) fn poll_task<T: OpCode>(
277        &self,
278        cx: &mut Context,
279        op: Key<T>,
280    ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
281        instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
282        let mut driver = self.driver.borrow_mut();
283        driver.pop(op).map_pending(|mut k| {
284            driver.update_waker(&mut k, cx.waker().clone());
285            k
286        })
287    }
288
289    #[cfg(feature = "time")]
290    pub(crate) fn poll_timer(&self, cx: &mut Context, key: &TimerKey) -> Poll<()> {
291        instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
292        let mut timer_runtime = self.timer_runtime.borrow_mut();
293        if timer_runtime.is_completed(key) {
294            debug!("ready");
295            Poll::Ready(())
296        } else {
297            debug!("pending");
298            timer_runtime.update_waker(key, cx.waker().clone());
299            Poll::Pending
300        }
301    }
302
303    /// Low level API to control the runtime.
304    ///
305    /// Get the timeout value to be passed to [`Proactor::poll`].
306    pub fn current_timeout(&self) -> Option<Duration> {
307        #[cfg(not(feature = "time"))]
308        let timeout = None;
309        #[cfg(feature = "time")]
310        let timeout = self.timer_runtime.borrow().min_timeout();
311        timeout
312    }
313
314    /// Low level API to control the runtime.
315    ///
316    /// Poll the inner proactor. It is equal to calling [`Runtime::poll_with`]
317    /// with [`Runtime::current_timeout`].
318    pub fn poll(&self) {
319        instrument!(compio_log::Level::DEBUG, "poll");
320        let timeout = self.current_timeout();
321        debug!("timeout: {:?}", timeout);
322        self.poll_with(timeout)
323    }
324
325    /// Low level API to control the runtime.
326    ///
327    /// Poll the inner proactor with a custom timeout.
328    pub fn poll_with(&self, timeout: Option<Duration>) {
329        instrument!(compio_log::Level::DEBUG, "poll_with");
330
331        let mut driver = self.driver.borrow_mut();
332        match driver.poll(timeout) {
333            Ok(()) => {}
334            Err(e) => match e.kind() {
335                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
336                    debug!("expected error: {e}");
337                }
338                _ => panic!("{e:?}"),
339            },
340        }
341        #[cfg(feature = "time")]
342        self.timer_runtime.borrow_mut().wake();
343    }
344
345    pub(crate) fn create_buffer_pool(
346        &self,
347        buffer_len: u16,
348        buffer_size: usize,
349    ) -> io::Result<compio_driver::BufferPool> {
350        self.driver
351            .borrow_mut()
352            .create_buffer_pool(buffer_len, buffer_size)
353    }
354
355    pub(crate) unsafe fn release_buffer_pool(
356        &self,
357        buffer_pool: compio_driver::BufferPool,
358    ) -> io::Result<()> {
359        unsafe { self.driver.borrow_mut().release_buffer_pool(buffer_pool) }
360    }
361
362    pub(crate) fn id(&self) -> u64 {
363        self.id
364    }
365}
366
367impl Drop for Runtime {
368    fn drop(&mut self) {
369        self.enter(|| {
370            self.scheduler.clear();
371        })
372    }
373}
374
375impl AsRawFd for Runtime {
376    fn as_raw_fd(&self) -> RawFd {
377        self.driver.borrow().as_raw_fd()
378    }
379}
380
381#[cfg(feature = "criterion")]
382impl criterion::async_executor::AsyncExecutor for Runtime {
383    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
384        self.block_on(future)
385    }
386}
387
388#[cfg(feature = "criterion")]
389impl criterion::async_executor::AsyncExecutor for &Runtime {
390    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
391        (**self).block_on(future)
392    }
393}
394
395/// Builder for [`Runtime`].
396#[derive(Debug, Clone)]
397pub struct RuntimeBuilder {
398    proactor_builder: ProactorBuilder,
399    thread_affinity: HashSet<usize>,
400    event_interval: usize,
401}
402
403impl Default for RuntimeBuilder {
404    fn default() -> Self {
405        Self::new()
406    }
407}
408
409impl RuntimeBuilder {
410    /// Create the builder with default config.
411    pub fn new() -> Self {
412        Self {
413            proactor_builder: ProactorBuilder::new(),
414            event_interval: 61,
415            thread_affinity: HashSet::new(),
416        }
417    }
418
419    /// Replace proactor builder.
420    pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
421        self.proactor_builder = builder;
422        self
423    }
424
425    /// Sets the thread affinity for the runtime.
426    pub fn thread_affinity(&mut self, cpus: HashSet<usize>) -> &mut Self {
427        self.thread_affinity = cpus;
428        self
429    }
430
431    /// Sets the number of scheduler ticks after which the scheduler will poll
432    /// for external events (timers, I/O, and so on).
433    ///
434    /// A scheduler “tick” roughly corresponds to one poll invocation on a task.
435    pub fn event_interval(&mut self, val: usize) -> &mut Self {
436        self.event_interval = val;
437        self
438    }
439
440    /// Build [`Runtime`].
441    pub fn build(&self) -> io::Result<Runtime> {
442        Runtime::with_builder(self)
443    }
444}
445
446/// Spawns a new asynchronous task, returning a [`Task`] for it.
447///
448/// Spawning a task enables the task to execute concurrently to other tasks.
449/// There is no guarantee that a spawned task will execute to completion.
450///
451/// ```
452/// # compio_runtime::Runtime::new().unwrap().block_on(async {
453/// let task = compio_runtime::spawn(async {
454///     println!("Hello from a spawned task!");
455///     42
456/// });
457///
458/// assert_eq!(
459///     task.await.unwrap_or_else(|e| std::panic::resume_unwind(e)),
460///     42
461/// );
462/// # })
463/// ```
464///
465/// ## Panics
466///
467/// This method doesn't create runtime. It tries to obtain the current runtime
468/// by [`Runtime::with_current`].
469pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
470    Runtime::with_current(|r| r.spawn(future))
471}
472
473/// Spawns a blocking task in a new thread, and wait for it.
474///
475/// The task will not be cancelled even if the future is dropped.
476///
477/// ## Panics
478///
479/// This method doesn't create runtime. It tries to obtain the current runtime
480/// by [`Runtime::with_current`].
481pub fn spawn_blocking<T: Send + 'static>(
482    f: impl (FnOnce() -> T) + Send + 'static,
483) -> JoinHandle<T> {
484    Runtime::with_current(|r| r.spawn_blocking(f))
485}
486
487/// Submit an operation to the current runtime, and return a future for it.
488///
489/// ## Panics
490///
491/// This method doesn't create runtime. It tries to obtain the current runtime
492/// by [`Runtime::with_current`].
493pub async fn submit<T: OpCode + 'static>(op: T) -> BufResult<usize, T> {
494    submit_with_flags(op).await.0
495}
496
497/// Submit an operation to the current runtime, and return a future for it with
498/// flags.
499///
500/// ## Panics
501///
502/// This method doesn't create runtime. It tries to obtain the current runtime
503/// by [`Runtime::with_current`].
504pub async fn submit_with_flags<T: OpCode + 'static>(op: T) -> (BufResult<usize, T>, u32) {
505    Runtime::with_current(|r| r.submit_with_flags(op)).await
506}
507
508#[cfg(feature = "time")]
509pub(crate) async fn create_timer(instant: std::time::Instant) {
510    let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
511    if let Some(key) = key {
512        TimerFuture::new(key).await
513    }
514}