1use std::{
2 any::Any,
3 cell::{Cell, RefCell},
4 collections::HashSet,
5 future::{Future, ready},
6 io,
7 panic::AssertUnwindSafe,
8 sync::Arc,
9 task::{Context, Poll, Waker},
10 time::Duration,
11};
12
13use async_task::Task;
14use compio_buf::IntoInner;
15use compio_driver::{
16 AsRawFd, DriverType, Key, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
17};
18use compio_log::{debug, instrument};
19use futures_util::{FutureExt, future::Either};
20
21pub(crate) mod op;
22#[cfg(feature = "time")]
23pub(crate) mod time;
24
25mod buffer_pool;
26pub use buffer_pool::*;
27
28mod scheduler;
29
30mod opt_waker;
31pub use opt_waker::OptWaker;
32
33mod send_wrapper;
34use send_wrapper::SendWrapper;
35
36#[cfg(feature = "time")]
37use crate::runtime::time::{TimerFuture, TimerKey, TimerRuntime};
38use crate::{
39 BufResult,
40 affinity::bind_to_cpu_set,
41 runtime::{op::OpFuture, scheduler::Scheduler},
42};
43
44scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
45
46pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
49
50thread_local! {
51 static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
52}
53
54pub struct Runtime {
57 driver: RefCell<Proactor>,
58 scheduler: Scheduler,
59 #[cfg(feature = "time")]
60 timer_runtime: RefCell<TimerRuntime>,
61 id: u64,
69}
70
71impl Runtime {
72 pub fn new() -> io::Result<Self> {
74 Self::builder().build()
75 }
76
77 pub fn builder() -> RuntimeBuilder {
79 RuntimeBuilder::new()
80 }
81
82 fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
83 let RuntimeBuilder {
84 proactor_builder,
85 thread_affinity,
86 event_interval,
87 } = builder;
88 let id = RUNTIME_ID.get();
89 RUNTIME_ID.set(id + 1);
90 if !thread_affinity.is_empty() {
91 bind_to_cpu_set(thread_affinity);
92 }
93 Ok(Self {
94 driver: RefCell::new(proactor_builder.build()?),
95 scheduler: Scheduler::new(*event_interval),
96 #[cfg(feature = "time")]
97 timer_runtime: RefCell::new(TimerRuntime::new()),
98 id,
99 })
100 }
101
102 pub fn driver_type(&self) -> DriverType {
104 self.driver.borrow().driver_type()
105 }
106
107 pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
110 if CURRENT_RUNTIME.is_set() {
111 Ok(CURRENT_RUNTIME.with(f))
112 } else {
113 Err(f)
114 }
115 }
116
117 pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
123 #[cold]
124 fn not_in_compio_runtime() -> ! {
125 panic!("not in a compio runtime")
126 }
127
128 if CURRENT_RUNTIME.is_set() {
129 CURRENT_RUNTIME.with(f)
130 } else {
131 not_in_compio_runtime()
132 }
133 }
134
135 pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
138 CURRENT_RUNTIME.set(self, f)
139 }
140
141 fn spawn_impl<F: Future + 'static>(&self, future: F) -> Task<F::Output> {
142 unsafe { self.spawn_unchecked(future) }
143 }
144
145 pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
153 let waker = self.waker();
154 unsafe { self.scheduler.spawn_unchecked(future, waker) }
155 }
156
157 pub fn run(&self) -> bool {
163 self.scheduler.run()
164 }
165
166 pub fn waker(&self) -> Waker {
170 self.driver.borrow().waker()
171 }
172
173 pub fn opt_waker(&self) -> Arc<OptWaker> {
178 OptWaker::new(self.waker())
179 }
180
181 pub fn block_on<F: Future>(&self, future: F) -> F::Output {
183 self.enter(|| {
184 let opt_waker = self.opt_waker();
185 let waker = Waker::from(opt_waker.clone());
186 let mut context = Context::from_waker(&waker);
187 let mut future = std::pin::pin!(future);
188 loop {
189 if let Poll::Ready(result) = future.as_mut().poll(&mut context) {
190 self.run();
191 return result;
192 }
193 let remaining_tasks = self.run() | opt_waker.reset();
195 if remaining_tasks {
196 self.poll_with(Some(Duration::ZERO));
197 } else {
198 self.poll();
199 }
200 }
201 })
202 }
203
204 pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
209 self.spawn_impl(AssertUnwindSafe(future).catch_unwind())
210 }
211
212 pub fn spawn_blocking<T: Send + 'static>(
216 &self,
217 f: impl (FnOnce() -> T) + Send + 'static,
218 ) -> JoinHandle<T> {
219 let op = Asyncify::new(move || {
220 let res = std::panic::catch_unwind(AssertUnwindSafe(f));
221 BufResult(Ok(0), res)
222 });
223 self.spawn_impl(self.submit(op).map(|res| res.1.into_inner()))
226 }
227
228 pub fn attach(&self, fd: RawFd) -> io::Result<()> {
233 self.driver.borrow_mut().attach(fd)
234 }
235
236 fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
237 self.driver.borrow_mut().push(op)
238 }
239
240 fn submit<T: OpCode + 'static>(
248 &self,
249 op: T,
250 ) -> impl Future<Output = BufResult<usize, T>> + use<T> {
251 self.submit_with_flags(op).map(|(res, _)| res)
252 }
253
254 fn submit_with_flags<T: OpCode + 'static>(
265 &self,
266 op: T,
267 ) -> impl Future<Output = (BufResult<usize, T>, u32)> + use<T> {
268 match self.submit_raw(op) {
269 PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
270 PushEntry::Ready(res) => {
271 Either::Right(ready((res, 0)))
274 }
275 }
276 }
277
278 pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
279 self.driver.borrow_mut().cancel(op);
280 }
281
282 #[cfg(feature = "time")]
283 pub(crate) fn cancel_timer(&self, key: &TimerKey) {
284 self.timer_runtime.borrow_mut().cancel(key);
285 }
286
287 pub(crate) fn poll_task<T: OpCode>(
288 &self,
289 cx: &mut Context,
290 op: Key<T>,
291 ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
292 instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
293 let mut driver = self.driver.borrow_mut();
294 driver.pop(op).map_pending(|mut k| {
295 driver.update_waker(&mut k, cx.waker().clone());
296 k
297 })
298 }
299
300 #[cfg(feature = "time")]
301 pub(crate) fn poll_timer(&self, cx: &mut Context, key: &TimerKey) -> Poll<()> {
302 instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
303 let mut timer_runtime = self.timer_runtime.borrow_mut();
304 if timer_runtime.is_completed(key) {
305 debug!("ready");
306 Poll::Ready(())
307 } else {
308 debug!("pending");
309 timer_runtime.update_waker(key, cx.waker().clone());
310 Poll::Pending
311 }
312 }
313
314 pub fn current_timeout(&self) -> Option<Duration> {
318 #[cfg(not(feature = "time"))]
319 let timeout = None;
320 #[cfg(feature = "time")]
321 let timeout = self.timer_runtime.borrow().min_timeout();
322 timeout
323 }
324
325 pub fn poll(&self) {
330 instrument!(compio_log::Level::DEBUG, "poll");
331 let timeout = self.current_timeout();
332 debug!("timeout: {:?}", timeout);
333 self.poll_with(timeout)
334 }
335
336 pub fn poll_with(&self, timeout: Option<Duration>) {
340 instrument!(compio_log::Level::DEBUG, "poll_with");
341
342 let mut driver = self.driver.borrow_mut();
343 match driver.poll(timeout) {
344 Ok(()) => {}
345 Err(e) => match e.kind() {
346 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
347 debug!("expected error: {e}");
348 }
349 _ => panic!("{e:?}"),
350 },
351 }
352 #[cfg(feature = "time")]
353 self.timer_runtime.borrow_mut().wake();
354 }
355
356 pub(crate) fn create_buffer_pool(
357 &self,
358 buffer_len: u16,
359 buffer_size: usize,
360 ) -> io::Result<compio_driver::BufferPool> {
361 self.driver
362 .borrow_mut()
363 .create_buffer_pool(buffer_len, buffer_size)
364 }
365
366 pub(crate) unsafe fn release_buffer_pool(
367 &self,
368 buffer_pool: compio_driver::BufferPool,
369 ) -> io::Result<()> {
370 unsafe { self.driver.borrow_mut().release_buffer_pool(buffer_pool) }
371 }
372
373 pub(crate) fn id(&self) -> u64 {
374 self.id
375 }
376}
377
378impl Drop for Runtime {
379 fn drop(&mut self) {
380 self.enter(|| {
381 self.scheduler.clear();
382 })
383 }
384}
385
386impl AsRawFd for Runtime {
387 fn as_raw_fd(&self) -> RawFd {
388 self.driver.borrow().as_raw_fd()
389 }
390}
391
392#[cfg(feature = "criterion")]
393impl criterion::async_executor::AsyncExecutor for Runtime {
394 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
395 self.block_on(future)
396 }
397}
398
399#[cfg(feature = "criterion")]
400impl criterion::async_executor::AsyncExecutor for &Runtime {
401 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
402 (**self).block_on(future)
403 }
404}
405
406#[derive(Debug, Clone)]
408pub struct RuntimeBuilder {
409 proactor_builder: ProactorBuilder,
410 thread_affinity: HashSet<usize>,
411 event_interval: usize,
412}
413
414impl Default for RuntimeBuilder {
415 fn default() -> Self {
416 Self::new()
417 }
418}
419
420impl RuntimeBuilder {
421 pub fn new() -> Self {
423 Self {
424 proactor_builder: ProactorBuilder::new(),
425 event_interval: 61,
426 thread_affinity: HashSet::new(),
427 }
428 }
429
430 pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
432 self.proactor_builder = builder;
433 self
434 }
435
436 pub fn thread_affinity(&mut self, cpus: HashSet<usize>) -> &mut Self {
438 self.thread_affinity = cpus;
439 self
440 }
441
442 pub fn event_interval(&mut self, val: usize) -> &mut Self {
447 self.event_interval = val;
448 self
449 }
450
451 pub fn build(&self) -> io::Result<Runtime> {
453 Runtime::with_builder(self)
454 }
455}
456
457pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
481 Runtime::with_current(|r| r.spawn(future))
482}
483
484pub fn spawn_blocking<T: Send + 'static>(
493 f: impl (FnOnce() -> T) + Send + 'static,
494) -> JoinHandle<T> {
495 Runtime::with_current(|r| r.spawn_blocking(f))
496}
497
498pub async fn submit<T: OpCode + 'static>(op: T) -> BufResult<usize, T> {
505 submit_with_flags(op).await.0
506}
507
508pub async fn submit_with_flags<T: OpCode + 'static>(op: T) -> (BufResult<usize, T>, u32) {
516 Runtime::with_current(|r| r.submit_with_flags(op)).await
517}
518
519#[cfg(feature = "time")]
520pub(crate) async fn create_timer(instant: std::time::Instant) {
521 let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
522 if let Some(key) = key {
523 TimerFuture::new(key).await
524 }
525}