1use std::{
2 any::Any,
3 cell::{Cell, RefCell},
4 collections::HashSet,
5 future::{Future, ready},
6 io,
7 panic::AssertUnwindSafe,
8 sync::Arc,
9 task::{Context, Poll, Waker},
10 time::Duration,
11};
12
13use async_task::Task;
14use compio_buf::IntoInner;
15use compio_driver::{
16 AsRawFd, DriverType, Key, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd, op::Asyncify,
17};
18use compio_log::{debug, instrument};
19use futures_util::{FutureExt, future::Either};
20
21pub(crate) mod op;
22#[cfg(feature = "time")]
23pub(crate) mod time;
24
25mod buffer_pool;
26pub use buffer_pool::*;
27
28mod scheduler;
29
30mod opt_waker;
31pub use opt_waker::OptWaker;
32
33mod send_wrapper;
34use send_wrapper::SendWrapper;
35
36#[cfg(feature = "time")]
37use crate::runtime::time::{TimerFuture, TimerKey, TimerRuntime};
38use crate::{
39 BufResult,
40 affinity::bind_to_cpu_set,
41 runtime::{op::OpFuture, scheduler::Scheduler},
42};
43
44scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
45
46pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
49
50thread_local! {
51 static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
52}
53
54pub struct Runtime {
57 driver: RefCell<Proactor>,
58 scheduler: Scheduler,
59 #[cfg(feature = "time")]
60 timer_runtime: RefCell<TimerRuntime>,
61 id: u64,
69}
70
71impl Runtime {
72 pub fn new() -> io::Result<Self> {
74 Self::builder().build()
75 }
76
77 pub fn builder() -> RuntimeBuilder {
79 RuntimeBuilder::new()
80 }
81
82 fn with_builder(builder: &RuntimeBuilder) -> io::Result<Self> {
83 let RuntimeBuilder {
84 proactor_builder,
85 thread_affinity,
86 event_interval,
87 } = builder;
88 let id = RUNTIME_ID.get();
89 RUNTIME_ID.set(id + 1);
90 if !thread_affinity.is_empty() {
91 bind_to_cpu_set(thread_affinity);
92 }
93 Ok(Self {
94 driver: RefCell::new(proactor_builder.build()?),
95 scheduler: Scheduler::new(*event_interval),
96 #[cfg(feature = "time")]
97 timer_runtime: RefCell::new(TimerRuntime::new()),
98 id,
99 })
100 }
101
102 pub fn driver_type(&self) -> DriverType {
104 self.driver.borrow().driver_type()
105 }
106
107 pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
110 if CURRENT_RUNTIME.is_set() {
111 Ok(CURRENT_RUNTIME.with(f))
112 } else {
113 Err(f)
114 }
115 }
116
117 pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
123 #[cold]
124 fn not_in_compio_runtime() -> ! {
125 panic!("not in a compio runtime")
126 }
127
128 if CURRENT_RUNTIME.is_set() {
129 CURRENT_RUNTIME.with(f)
130 } else {
131 not_in_compio_runtime()
132 }
133 }
134
135 pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
138 CURRENT_RUNTIME.set(self, f)
139 }
140
141 fn spawn_impl<F: Future + 'static>(&self, future: F) -> Task<F::Output> {
142 let waker = self.driver.borrow().waker();
143 self.scheduler.spawn(future, waker)
144 }
145
146 pub fn run(&self) -> bool {
152 self.scheduler.run()
153 }
154
155 pub fn waker(&self) -> Waker {
159 self.driver.borrow().waker()
160 }
161
162 pub fn opt_waker(&self) -> Arc<OptWaker> {
167 OptWaker::new(self.waker())
168 }
169
170 pub fn block_on<F: Future>(&self, future: F) -> F::Output {
172 self.enter(|| {
173 let opt_waker = self.opt_waker();
174 let waker = Waker::from(opt_waker.clone());
175 let mut context = Context::from_waker(&waker);
176 let mut future = std::pin::pin!(future);
177 loop {
178 if let Poll::Ready(result) = future.as_mut().poll(&mut context) {
179 self.run();
180 return result;
181 }
182 let remaining_tasks = self.run() | opt_waker.reset();
184 if remaining_tasks {
185 self.poll_with(Some(Duration::ZERO));
186 } else {
187 self.poll();
188 }
189 }
190 })
191 }
192
193 pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
198 self.spawn_impl(AssertUnwindSafe(future).catch_unwind())
199 }
200
201 pub fn spawn_blocking<T: Send + 'static>(
205 &self,
206 f: impl (FnOnce() -> T) + Send + 'static,
207 ) -> JoinHandle<T> {
208 let op = Asyncify::new(move || {
209 let res = std::panic::catch_unwind(AssertUnwindSafe(f));
210 BufResult(Ok(0), res)
211 });
212 self.spawn_impl(self.submit(op).map(|res| res.1.into_inner()))
215 }
216
217 pub fn attach(&self, fd: RawFd) -> io::Result<()> {
222 self.driver.borrow_mut().attach(fd)
223 }
224
225 fn submit_raw<T: OpCode + 'static>(&self, op: T) -> PushEntry<Key<T>, BufResult<usize, T>> {
226 self.driver.borrow_mut().push(op)
227 }
228
229 fn submit<T: OpCode + 'static>(
237 &self,
238 op: T,
239 ) -> impl Future<Output = BufResult<usize, T>> + use<T> {
240 self.submit_with_flags(op).map(|(res, _)| res)
241 }
242
243 fn submit_with_flags<T: OpCode + 'static>(
254 &self,
255 op: T,
256 ) -> impl Future<Output = (BufResult<usize, T>, u32)> + use<T> {
257 match self.submit_raw(op) {
258 PushEntry::Pending(user_data) => Either::Left(OpFuture::new(user_data)),
259 PushEntry::Ready(res) => {
260 Either::Right(ready((res, 0)))
263 }
264 }
265 }
266
267 pub(crate) fn cancel_op<T: OpCode>(&self, op: Key<T>) {
268 self.driver.borrow_mut().cancel(op);
269 }
270
271 #[cfg(feature = "time")]
272 pub(crate) fn cancel_timer(&self, key: &TimerKey) {
273 self.timer_runtime.borrow_mut().cancel(key);
274 }
275
276 pub(crate) fn poll_task<T: OpCode>(
277 &self,
278 cx: &mut Context,
279 op: Key<T>,
280 ) -> PushEntry<Key<T>, (BufResult<usize, T>, u32)> {
281 instrument!(compio_log::Level::DEBUG, "poll_task", ?op);
282 let mut driver = self.driver.borrow_mut();
283 driver.pop(op).map_pending(|mut k| {
284 driver.update_waker(&mut k, cx.waker().clone());
285 k
286 })
287 }
288
289 #[cfg(feature = "time")]
290 pub(crate) fn poll_timer(&self, cx: &mut Context, key: &TimerKey) -> Poll<()> {
291 instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
292 let mut timer_runtime = self.timer_runtime.borrow_mut();
293 if timer_runtime.is_completed(key) {
294 debug!("ready");
295 Poll::Ready(())
296 } else {
297 debug!("pending");
298 timer_runtime.update_waker(key, cx.waker().clone());
299 Poll::Pending
300 }
301 }
302
303 pub fn current_timeout(&self) -> Option<Duration> {
307 #[cfg(not(feature = "time"))]
308 let timeout = None;
309 #[cfg(feature = "time")]
310 let timeout = self.timer_runtime.borrow().min_timeout();
311 timeout
312 }
313
314 pub fn poll(&self) {
319 instrument!(compio_log::Level::DEBUG, "poll");
320 let timeout = self.current_timeout();
321 debug!("timeout: {:?}", timeout);
322 self.poll_with(timeout)
323 }
324
325 pub fn poll_with(&self, timeout: Option<Duration>) {
329 instrument!(compio_log::Level::DEBUG, "poll_with");
330
331 let mut driver = self.driver.borrow_mut();
332 match driver.poll(timeout) {
333 Ok(()) => {}
334 Err(e) => match e.kind() {
335 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
336 debug!("expected error: {e}");
337 }
338 _ => panic!("{e:?}"),
339 },
340 }
341 #[cfg(feature = "time")]
342 self.timer_runtime.borrow_mut().wake();
343 }
344
345 pub(crate) fn create_buffer_pool(
346 &self,
347 buffer_len: u16,
348 buffer_size: usize,
349 ) -> io::Result<compio_driver::BufferPool> {
350 self.driver
351 .borrow_mut()
352 .create_buffer_pool(buffer_len, buffer_size)
353 }
354
355 pub(crate) unsafe fn release_buffer_pool(
356 &self,
357 buffer_pool: compio_driver::BufferPool,
358 ) -> io::Result<()> {
359 unsafe { self.driver.borrow_mut().release_buffer_pool(buffer_pool) }
360 }
361
362 pub(crate) fn id(&self) -> u64 {
363 self.id
364 }
365}
366
367impl Drop for Runtime {
368 fn drop(&mut self) {
369 self.enter(|| {
370 self.scheduler.clear();
371 })
372 }
373}
374
375impl AsRawFd for Runtime {
376 fn as_raw_fd(&self) -> RawFd {
377 self.driver.borrow().as_raw_fd()
378 }
379}
380
381#[cfg(feature = "criterion")]
382impl criterion::async_executor::AsyncExecutor for Runtime {
383 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
384 self.block_on(future)
385 }
386}
387
388#[cfg(feature = "criterion")]
389impl criterion::async_executor::AsyncExecutor for &Runtime {
390 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
391 (**self).block_on(future)
392 }
393}
394
395#[derive(Debug, Clone)]
397pub struct RuntimeBuilder {
398 proactor_builder: ProactorBuilder,
399 thread_affinity: HashSet<usize>,
400 event_interval: usize,
401}
402
403impl Default for RuntimeBuilder {
404 fn default() -> Self {
405 Self::new()
406 }
407}
408
409impl RuntimeBuilder {
410 pub fn new() -> Self {
412 Self {
413 proactor_builder: ProactorBuilder::new(),
414 event_interval: 61,
415 thread_affinity: HashSet::new(),
416 }
417 }
418
419 pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
421 self.proactor_builder = builder;
422 self
423 }
424
425 pub fn thread_affinity(&mut self, cpus: HashSet<usize>) -> &mut Self {
427 self.thread_affinity = cpus;
428 self
429 }
430
431 pub fn event_interval(&mut self, val: usize) -> &mut Self {
436 self.event_interval = val;
437 self
438 }
439
440 pub fn build(&self) -> io::Result<Runtime> {
442 Runtime::with_builder(self)
443 }
444}
445
446pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
470 Runtime::with_current(|r| r.spawn(future))
471}
472
473pub fn spawn_blocking<T: Send + 'static>(
482 f: impl (FnOnce() -> T) + Send + 'static,
483) -> JoinHandle<T> {
484 Runtime::with_current(|r| r.spawn_blocking(f))
485}
486
487pub async fn submit<T: OpCode + 'static>(op: T) -> BufResult<usize, T> {
494 submit_with_flags(op).await.0
495}
496
497pub async fn submit_with_flags<T: OpCode + 'static>(op: T) -> (BufResult<usize, T>, u32) {
505 Runtime::with_current(|r| r.submit_with_flags(op)).await
506}
507
508#[cfg(feature = "time")]
509pub(crate) async fn create_timer(instant: std::time::Instant) {
510 let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
511 if let Some(key) = key {
512 TimerFuture::new(key).await
513 }
514}