compio_runtime/runtime/
mod.rs

1use std::{
2    any::Any,
3    cell::{Cell, RefCell},
4    collections::HashSet,
5    future::{Future, ready},
6    io,
7    panic::AssertUnwindSafe,
8    sync::Arc,
9    task::{Context, Poll, Waker},
10    time::Duration,
11};
12
13use async_task::Task;
14use compio_buf::IntoInner;
15use compio_driver::{
16    AsRawFd, DriverType, Key, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
17};
18use compio_log::{debug, instrument};
19use futures_util::{FutureExt, future::Either};
20
21pub(crate) mod op;
22#[cfg(feature = "time")]
23pub(crate) mod time;
24
25mod buffer_pool;
26pub use buffer_pool::*;
27
28mod scheduler;
29
30mod opt_waker;
31pub use opt_waker::OptWaker;
32
33mod send_wrapper;
34use send_wrapper::SendWrapper;
35
36#[cfg(feature = "time")]
37use crate::runtime::time::{TimerFuture, TimerKey, TimerRuntime};
38use crate::{
39    BufResult,
40    affinity::bind_to_cpu_set,
41    runtime::{op::OpFuture, scheduler::Scheduler},
42};
43
44scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
45
46/// Type alias for `Task<Result<T, Box<dyn Any + Send>>>`, which resolves to an
47/// `Err` when the spawned future panicked.
48pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
49
50thread_local! {
51    static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
52}
53
54/// The async runtime of compio. It is a thread local runtime, and cannot be
55/// sent to other threads.
56pub struct Runtime {
57    driver: RefCell<Proactor>,
58    scheduler: Scheduler,
59    #[cfg(feature = "time")]
60    timer_runtime: RefCell<TimerRuntime>,
61    // Runtime id is used to check if the buffer pool is belonged to this runtime or not.
62    // Without this, if user enable `io-uring-buf-ring` feature then:
63    // 1. Create a buffer pool at runtime1
64    // 3. Create another runtime2, then use the exists buffer pool in runtime2, it may cause
65    // - io-uring report error if the buffer group id is not registered
66    // - buffer pool will return a wrong buffer which the buffer's data is uninit, that will cause
67    //   UB
68    id: u64,
69}
70
71impl Runtime {
72    /// Create [`Runtime`] with default config.
73    pub fn new() -> io::Result<Self> {
74        Self::builder().build()
75    }
76
77    /// Create a builder for [`Runtime`].
78    pub fn builder() -> RuntimeBuilder {
79        RuntimeBuilder::new()
80    }
81
82    fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
83        let RuntimeBuilder {
84            proactor_builder,
85            thread_affinity,
86            event_interval,
87        } = builder;
88        let id = RUNTIME_ID.get();
89        RUNTIME_ID.set(id + 1);
90        if !thread_affinity.is_empty() {
91            bind_to_cpu_set(thread_affinity);
92        }
93        Ok(Self {
94            driver: RefCell::new(proactor_builder.build()?),
95            scheduler: Scheduler::new(*event_interval),
96            #[cfg(feature = "time")]
97            timer_runtime: RefCell::new(TimerRuntime::new()),
98            id,
99        })
100    }
101
102    /// The current driver type.
103    pub fn driver_type(&self) -> DriverType {
104        self.driver.borrow().driver_type()
105    }
106
107    /// Try to perform a function on the current runtime, and if no runtime is
108    /// running, return the function back.
109    pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
110        if CURRENT_RUNTIME.is_set() {
111            Ok(CURRENT_RUNTIME.with(f))
112        } else {
113            Err(f)
114        }
115    }
116
117    /// Perform a function on the current runtime.
118    ///
119    /// ## Panics
120    ///
121    /// This method will panic if there are no running [`Runtime`].
122    pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
123        #[cold]
124        fn not_in_compio_runtime() -> ! {
125            panic!("not in a compio runtime")
126        }
127
128        if CURRENT_RUNTIME.is_set() {
129            CURRENT_RUNTIME.with(f)
130        } else {
131            not_in_compio_runtime()
132        }
133    }
134
135    /// Set this runtime as current runtime, and perform a function in the
136    /// current scope.
137    pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
138        CURRENT_RUNTIME.set(self, f)
139    }
140
141    fn spawn_impl<F: Future + 'static>(&self, future: F) -> Task<F::Output> {
142        unsafe { self.spawn_unchecked(future) }
143    }
144
145    /// Low level API to control the runtime.
146    ///
147    /// Spawns a new asynchronous task, returning a [`Task`] for it.
148    ///
149    /// # Safety
150    ///
151    /// Borrowed variables must outlive the future.
152    pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
153        let waker = self.waker();
154        unsafe { self.scheduler.spawn_unchecked(future, waker) }
155    }
156
157    /// Low level API to control the runtime.
158    ///
159    /// Run the scheduled tasks.
160    ///
161    /// The return value indicates whether there are still tasks in the queue.
162    pub fn run(&self) -> bool {
163        self.scheduler.run()
164    }
165
166    /// Low level API to control the runtime.
167    ///
168    /// Create a waker that always notifies the runtime when woken.
169    pub fn waker(&self) -> Waker {
170        self.driver.borrow().waker()
171    }
172
173    /// Low level API to control the runtime.
174    ///
175    /// Create an optimized waker that only notifies the runtime when woken
176    /// from another thread, or when `notify-always` is enabled.
177    pub fn opt_waker(&self) -> Arc<OptWaker> {
178        OptWaker::new(self.waker())
179    }
180
181    /// Block on the future till it completes.
182    pub fn block_on<F: Future>(&self, future: F) -> F::Output {
183        self.enter(|| {
184            let opt_waker = self.opt_waker();
185            let waker = Waker::from(opt_waker.clone());
186            let mut context = Context::from_waker(&waker);
187            let mut future = std::pin::pin!(future);
188            loop {
189                if let Poll::Ready(result) = future.as_mut().poll(&mut context) {
190                    self.run();
191                    return result;
192                }
193                // We always want to reset the waker here.
194                let remaining_tasks = self.run() | opt_waker.reset();
195                if remaining_tasks {
196                    self.poll_with(Some(Duration::ZERO));
197                } else {
198                    self.poll();
199                }
200            }
201        })
202    }
203
204    /// Spawns a new asynchronous task, returning a [`Task`] for it.
205    ///
206    /// Spawning a task enables the task to execute concurrently to other tasks.
207    /// There is no guarantee that a spawned task will execute to completion.
208    pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
209        self.spawn_impl(AssertUnwindSafe(future).catch_unwind())
210    }
211
212    /// Spawns a blocking task in a new thread, and wait for it.
213    ///
214    /// The task will not be cancelled even if the future is dropped.
215    pub fn spawn_blocking<T: Send + 'static>(
216        &self,
217        f: impl (FnOnce() -> T) + Send + 'static,
218    ) -> JoinHandle<T> {
219        let op = Asyncify::new(move || {
220            let res = std::panic::catch_unwind(AssertUnwindSafe(f));
221            BufResult(Ok(0), res)
222        });
223        // It is safe and sound to use `submit` here because the task is spawned
224        // immediately.
225        self.spawn_impl(self.submit(op).map(|res| res.1.into_inner()))
226    }
227
228    /// Attach a raw file descriptor/handle/socket to the runtime.
229    ///
230    /// You only need this when authoring your own high-level APIs. High-level
231    /// resources in this crate are attached automatically.
232    pub fn attach(&self, fd: RawFd) -> io::Result<()> {
233        self.driver.borrow_mut().attach(fd)
234    }
235
236    fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
237        self.driver.borrow_mut().push(op)
238    }
239
240    /// Submit an operation to the runtime.
241    ///
242    /// You only need this when authoring your own [`OpCode`].
243    ///
244    /// It is safe to send the returned future to another runtime and poll it,
245    /// but the exact behavior is not guaranteed, e.g. it may return pending
246    /// forever or else.
247    fn submit<T: OpCode + 'static>(
248        &self,
249        op: T,
250    ) -> impl Future<Output = BufResult<usize, T>> + use<T> {
251        self.submit_with_flags(op).map(|(res, _)| res)
252    }
253
254    /// Submit an operation to the runtime.
255    ///
256    /// The difference between [`Runtime::submit`] is this method will return
257    /// the flags
258    ///
259    /// You only need this when authoring your own [`OpCode`].
260    ///
261    /// It is safe to send the returned future to another runtime and poll it,
262    /// but the exact behavior is not guaranteed, e.g. it may return pending
263    /// forever or else.
264    fn submit_with_flags<T: OpCode + 'static>(
265        &self,
266        op: T,
267    ) -> impl Future<Output = (BufResult<usize, T>, u32)> + use<T> {
268        match self.submit_raw(op) {
269            PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
270            PushEntry::Ready(res) => {
271                // submit_flags won't be ready immediately, if ready, it must be error without
272                // flags
273                Either::Right(ready((res, 0)))
274            }
275        }
276    }
277
278    pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
279        self.driver.borrow_mut().cancel(op);
280    }
281
282    #[cfg(feature = "time")]
283    pub(crate) fn cancel_timer(&self, key: &TimerKey) {
284        self.timer_runtime.borrow_mut().cancel(key);
285    }
286
287    pub(crate) fn poll_task<T: OpCode>(
288        &self,
289        cx: &mut Context,
290        op: Key<T>,
291    ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
292        instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
293        let mut driver = self.driver.borrow_mut();
294        driver.pop(op).map_pending(|mut k| {
295            driver.update_waker(&mut k, cx.waker().clone());
296            k
297        })
298    }
299
300    #[cfg(feature = "time")]
301    pub(crate) fn poll_timer(&self, cx: &mut Context, key: &TimerKey) -> Poll<()> {
302        instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
303        let mut timer_runtime = self.timer_runtime.borrow_mut();
304        if timer_runtime.is_completed(key) {
305            debug!("ready");
306            Poll::Ready(())
307        } else {
308            debug!("pending");
309            timer_runtime.update_waker(key, cx.waker().clone());
310            Poll::Pending
311        }
312    }
313
314    /// Low level API to control the runtime.
315    ///
316    /// Get the timeout value to be passed to [`Proactor::poll`].
317    pub fn current_timeout(&self) -> Option<Duration> {
318        #[cfg(not(feature = "time"))]
319        let timeout = None;
320        #[cfg(feature = "time")]
321        let timeout = self.timer_runtime.borrow().min_timeout();
322        timeout
323    }
324
325    /// Low level API to control the runtime.
326    ///
327    /// Poll the inner proactor. It is equal to calling [`Runtime::poll_with`]
328    /// with [`Runtime::current_timeout`].
329    pub fn poll(&self) {
330        instrument!(compio_log::Level::DEBUG, "poll");
331        let timeout = self.current_timeout();
332        debug!("timeout: {:?}", timeout);
333        self.poll_with(timeout)
334    }
335
336    /// Low level API to control the runtime.
337    ///
338    /// Poll the inner proactor with a custom timeout.
339    pub fn poll_with(&self, timeout: Option<Duration>) {
340        instrument!(compio_log::Level::DEBUG, "poll_with");
341
342        let mut driver = self.driver.borrow_mut();
343        match driver.poll(timeout) {
344            Ok(()) => {}
345            Err(e) => match e.kind() {
346                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
347                    debug!("expected error: {e}");
348                }
349                _ => panic!("{e:?}"),
350            },
351        }
352        #[cfg(feature = "time")]
353        self.timer_runtime.borrow_mut().wake();
354    }
355
356    pub(crate) fn create_buffer_pool(
357        &self,
358        buffer_len: u16,
359        buffer_size: usize,
360    ) -> io::Result<compio_driver::BufferPool> {
361        self.driver
362            .borrow_mut()
363            .create_buffer_pool(buffer_len, buffer_size)
364    }
365
366    pub(crate) unsafe fn release_buffer_pool(
367        &self,
368        buffer_pool: compio_driver::BufferPool,
369    ) -> io::Result<()> {
370        unsafe { self.driver.borrow_mut().release_buffer_pool(buffer_pool) }
371    }
372
373    pub(crate) fn id(&self) -> u64 {
374        self.id
375    }
376}
377
378impl Drop for Runtime {
379    fn drop(&mut self) {
380        self.enter(|| {
381            self.scheduler.clear();
382        })
383    }
384}
385
386impl AsRawFd for Runtime {
387    fn as_raw_fd(&self) -> RawFd {
388        self.driver.borrow().as_raw_fd()
389    }
390}
391
392#[cfg(feature = "criterion")]
393impl criterion::async_executor::AsyncExecutor for Runtime {
394    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
395        self.block_on(future)
396    }
397}
398
399#[cfg(feature = "criterion")]
400impl criterion::async_executor::AsyncExecutor for &Runtime {
401    fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
402        (**self).block_on(future)
403    }
404}
405
406/// Builder for [`Runtime`].
407#[derive(Debug, Clone)]
408pub struct RuntimeBuilder {
409    proactor_builder: ProactorBuilder,
410    thread_affinity: HashSet<usize>,
411    event_interval: usize,
412}
413
414impl Default for RuntimeBuilder {
415    fn default() -> Self {
416        Self::new()
417    }
418}
419
420impl RuntimeBuilder {
421    /// Create the builder with default config.
422    pub fn new() -> Self {
423        Self {
424            proactor_builder: ProactorBuilder::new(),
425            event_interval: 61,
426            thread_affinity: HashSet::new(),
427        }
428    }
429
430    /// Replace proactor builder.
431    pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
432        self.proactor_builder = builder;
433        self
434    }
435
436    /// Sets the thread affinity for the runtime.
437    pub fn thread_affinity(&mut self, cpus: HashSet<usize>) -> &mut Self {
438        self.thread_affinity = cpus;
439        self
440    }
441
442    /// Sets the number of scheduler ticks after which the scheduler will poll
443    /// for external events (timers, I/O, and so on).
444    ///
445    /// A scheduler “tick” roughly corresponds to one poll invocation on a task.
446    pub fn event_interval(&mut self, val: usize) -> &mut Self {
447        self.event_interval = val;
448        self
449    }
450
451    /// Build [`Runtime`].
452    pub fn build(&self) -> io::Result<Runtime> {
453        Runtime::with_builder(self)
454    }
455}
456
457/// Spawns a new asynchronous task, returning a [`Task`] for it.
458///
459/// Spawning a task enables the task to execute concurrently to other tasks.
460/// There is no guarantee that a spawned task will execute to completion.
461///
462/// ```
463/// # compio_runtime::Runtime::new().unwrap().block_on(async {
464/// let task = compio_runtime::spawn(async {
465///     println!("Hello from a spawned task!");
466///     42
467/// });
468///
469/// assert_eq!(
470///     task.await.unwrap_or_else(|e| std::panic::resume_unwind(e)),
471///     42
472/// );
473/// # })
474/// ```
475///
476/// ## Panics
477///
478/// This method doesn't create runtime. It tries to obtain the current runtime
479/// by [`Runtime::with_current`].
480pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
481    Runtime::with_current(|r| r.spawn(future))
482}
483
484/// Spawns a blocking task in a new thread, and wait for it.
485///
486/// The task will not be cancelled even if the future is dropped.
487///
488/// ## Panics
489///
490/// This method doesn't create runtime. It tries to obtain the current runtime
491/// by [`Runtime::with_current`].
492pub fn spawn_blocking<T: Send + 'static>(
493    f: impl (FnOnce() -> T) + Send + 'static,
494) -> JoinHandle<T> {
495    Runtime::with_current(|r| r.spawn_blocking(f))
496}
497
498/// Submit an operation to the current runtime, and return a future for it.
499///
500/// ## Panics
501///
502/// This method doesn't create runtime. It tries to obtain the current runtime
503/// by [`Runtime::with_current`].
504pub async fn submit<T: OpCode + 'static>(op: T) -> BufResult<usize, T> {
505    submit_with_flags(op).await.0
506}
507
508/// Submit an operation to the current runtime, and return a future for it with
509/// flags.
510///
511/// ## Panics
512///
513/// This method doesn't create runtime. It tries to obtain the current runtime
514/// by [`Runtime::with_current`].
515pub async fn submit_with_flags<T: OpCode + 'static>(op: T) -> (BufResult<usize, T>, u32) {
516    Runtime::with_current(|r| r.submit_with_flags(op)).await
517}
518
519#[cfg(feature = "time")]
520pub(crate) async fn create_timer(instant: std::time::Instant) {
521    let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
522    if let Some(key) = key {
523        TimerFuture::new(key).await
524    }
525}