1use std::{
2 any::Any,
3 cell::{Cell, RefCell},
4 collections::HashSet,
5 future::Future,
6 io,
7 ops::Deref,
8 panic::AssertUnwindSafe,
9 rc::Rc,
10 sync::Arc,
11 task::{Context, Poll, Waker},
12 time::Duration,
13};
14
15use async_task::Task;
16use compio_buf::IntoInner;
17use compio_driver::{
18 AsRawFd, DriverType, Extra, Key, OpCode, Proactor, ProactorBuilder, PushEntry, RawFd,
19 op::Asyncify,
20};
21use compio_log::{debug, instrument};
22use futures_util::FutureExt;
23
24mod future;
25pub use future::Submit;
26
27#[cfg(feature = "time")]
28pub(crate) mod time;
29
30mod buffer_pool;
31pub use buffer_pool::*;
32
33mod scheduler;
34
35mod opt_waker;
36pub use opt_waker::OptWaker;
37
38mod send_wrapper;
39use send_wrapper::SendWrapper;
40
41#[cfg(feature = "time")]
42use crate::runtime::time::{TimerFuture, TimerKey, TimerRuntime};
43use crate::{BufResult, affinity::bind_to_cpu_set, runtime::scheduler::Scheduler};
44
45scoped_tls::scoped_thread_local!(static CURRENT_RUNTIME: Runtime);
46
47pub type JoinHandle<T> = Task<Result<T, Box<dyn Any + Send>>>;
50
51thread_local! {
52 static RUNTIME_ID: Cell<u64> = const { Cell::new(0) };
53}
54
55pub struct RuntimeInner {
57 driver: RefCell<Proactor>,
58 scheduler: Scheduler,
59 #[cfg(feature = "time")]
60 timer_runtime: RefCell<TimerRuntime>,
61 id: u64,
69}
70
71#[derive(Clone)]
75pub struct Runtime(Rc<RuntimeInner>);
76
77impl Deref for Runtime {
78 type Target = RuntimeInner;
79
80 fn deref(&self) -> &Self::Target {
81 &self.0
82 }
83}
84
85impl Runtime {
86 pub fn new() -> io::Result<Self> {
88 Self::builder().build()
89 }
90
91 pub fn builder() -> RuntimeBuilder {
93 RuntimeBuilder::new()
94 }
95
96 pub fn driver_type(&self) -> DriverType {
98 self.driver.borrow().driver_type()
99 }
100
101 pub fn try_with_current<T, F: FnOnce(&Self) -> T>(f: F) -> Result<T, F> {
104 if CURRENT_RUNTIME.is_set() {
105 Ok(CURRENT_RUNTIME.with(f))
106 } else {
107 Err(f)
108 }
109 }
110
111 pub fn with_current<T, F: FnOnce(&Self) -> T>(f: F) -> T {
117 #[cold]
118 fn not_in_compio_runtime() -> ! {
119 panic!("not in a compio runtime")
120 }
121
122 if CURRENT_RUNTIME.is_set() {
123 CURRENT_RUNTIME.with(f)
124 } else {
125 not_in_compio_runtime()
126 }
127 }
128
129 pub fn enter<T, F: FnOnce() -> T>(&self, f: F) -> T {
132 CURRENT_RUNTIME.set(self, f)
133 }
134
135 fn spawn_impl<F: Future + 'static>(&self, future: F) -> Task<F::Output> {
136 unsafe { self.spawn_unchecked(future) }
137 }
138
139 pub unsafe fn spawn_unchecked<F: Future>(&self, future: F) -> Task<F::Output> {
147 let waker = self.waker();
148 unsafe { self.scheduler.spawn_unchecked(future, waker) }
149 }
150
151 pub fn run(&self) -> bool {
157 self.scheduler.run()
158 }
159
160 pub fn waker(&self) -> Waker {
164 self.driver.borrow().waker()
165 }
166
167 pub fn opt_waker(&self) -> Arc<OptWaker> {
172 OptWaker::new(self.waker())
173 }
174
175 pub fn block_on<F: Future>(&self, future: F) -> F::Output {
177 self.enter(|| {
178 let opt_waker = self.opt_waker();
179 let waker = Waker::from(opt_waker.clone());
180 let mut context = Context::from_waker(&waker);
181 let mut future = std::pin::pin!(future);
182 loop {
183 if let Poll::Ready(result) = future.as_mut().poll(&mut context) {
184 self.run();
185 return result;
186 }
187 let remaining_tasks = self.run() | opt_waker.reset();
189 if remaining_tasks {
190 self.poll_with(Some(Duration::ZERO));
191 } else {
192 self.poll();
193 }
194 }
195 })
196 }
197
198 pub fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
203 self.spawn_impl(AssertUnwindSafe(future).catch_unwind())
204 }
205
206 pub fn spawn_blocking<T: Send + 'static>(
210 &self,
211 f: impl (FnOnce() -> T) + Send + 'static,
212 ) -> JoinHandle<T> {
213 let op = Asyncify::new(move || {
214 let res = std::panic::catch_unwind(AssertUnwindSafe(f));
215 BufResult(Ok(0), res)
216 });
217 self.spawn_impl(self.submit(op).map(|res| res.1.into_inner()))
220 }
221
222 pub fn attach(&self, fd: RawFd) -> io::Result<()> {
227 self.driver.borrow_mut().attach(fd)
228 }
229
230 fn submit_raw<T: OpCode + 'static>(
231 &self,
232 op: T,
233 extra: Option<Extra>,
234 ) -> PushEntry<Key<T>, BufResult<usize, T>> {
235 let mut this = self.driver.borrow_mut();
236 match extra {
237 Some(e) => this.push_with_extra(op, e),
238 None => this.push(op),
239 }
240 }
241
242 fn default_extra(&self) -> Extra {
243 self.driver.borrow().default_extra()
244 }
245
246 fn submit<T: OpCode + 'static>(&self, op: T) -> Submit<T> {
250 Submit::new(self.clone(), op)
251 }
252
253 pub(crate) fn cancel<T: OpCode>(&self, op: Key<T>) {
254 self.driver.borrow_mut().cancel(op);
255 }
256
257 #[cfg(feature = "time")]
258 pub(crate) fn cancel_timer(&self, key: &TimerKey) {
259 self.timer_runtime.borrow_mut().cancel(key);
260 }
261
262 pub(crate) fn poll_task<T: OpCode>(
263 &self,
264 waker: &Waker,
265 key: Key<T>,
266 ) -> PushEntry<Key<T>, BufResult<usize, T>> {
267 instrument!(compio_log::Level::DEBUG, "poll_task", ?key);
268 let mut driver = self.driver.borrow_mut();
269 driver.pop(key).map_pending(|mut k| {
270 driver.update_waker(&mut k, waker);
271 k
272 })
273 }
274
275 pub(crate) fn poll_task_with_extra<T: OpCode>(
276 &self,
277 cx: &mut Context,
278 key: Key<T>,
279 ) -> PushEntry<Key<T>, (BufResult<usize, T>, Extra)> {
280 instrument!(compio_log::Level::DEBUG, "poll_task_with_extra", ?key);
281 let mut driver = self.driver.borrow_mut();
282 driver.pop_with_extra(key).map_pending(|mut k| {
283 driver.update_waker(&mut k, cx.waker());
284 k
285 })
286 }
287
288 #[cfg(feature = "time")]
289 pub(crate) fn poll_timer(&self, cx: &mut Context, key: &TimerKey) -> Poll<()> {
290 instrument!(compio_log::Level::DEBUG, "poll_timer", ?cx, ?key);
291 let mut timer_runtime = self.timer_runtime.borrow_mut();
292 if timer_runtime.is_completed(key) {
293 debug!("ready");
294 Poll::Ready(())
295 } else {
296 debug!("pending");
297 timer_runtime.update_waker(key, cx.waker().clone());
298 Poll::Pending
299 }
300 }
301
302 pub fn current_timeout(&self) -> Option<Duration> {
306 #[cfg(not(feature = "time"))]
307 let timeout = None;
308 #[cfg(feature = "time")]
309 let timeout = self.timer_runtime.borrow().min_timeout();
310 timeout
311 }
312
313 pub fn poll(&self) {
318 instrument!(compio_log::Level::DEBUG, "poll");
319 let timeout = self.current_timeout();
320 debug!("timeout: {:?}", timeout);
321 self.poll_with(timeout)
322 }
323
324 pub fn poll_with(&self, timeout: Option<Duration>) {
328 instrument!(compio_log::Level::DEBUG, "poll_with");
329
330 let mut driver = self.driver.borrow_mut();
331 match driver.poll(timeout) {
332 Ok(()) => {}
333 Err(e) => match e.kind() {
334 io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
335 debug!("expected error: {e}");
336 }
337 _ => panic!("{e:?}"),
338 },
339 }
340 #[cfg(feature = "time")]
341 self.timer_runtime.borrow_mut().wake();
342 }
343
344 pub(crate) fn create_buffer_pool(
345 &self,
346 buffer_len: u16,
347 buffer_size: usize,
348 ) -> io::Result<compio_driver::BufferPool> {
349 self.driver
350 .borrow_mut()
351 .create_buffer_pool(buffer_len, buffer_size)
352 }
353
354 pub(crate) unsafe fn release_buffer_pool(
355 &self,
356 buffer_pool: compio_driver::BufferPool,
357 ) -> io::Result<()> {
358 unsafe { self.driver.borrow_mut().release_buffer_pool(buffer_pool) }
359 }
360
361 pub(crate) fn id(&self) -> u64 {
362 self.id
363 }
364
365 pub fn register_personality(&self) -> io::Result<u16> {
375 self.driver.borrow_mut().register_personality()
376 }
377
378 pub fn unregister_personality(&self, personality: u16) -> io::Result<()> {
385 self.driver.borrow_mut().unregister_personality(personality)
386 }
387}
388
389impl Drop for Runtime {
390 fn drop(&mut self) {
391 if Rc::strong_count(&self.0) > 1 {
393 return;
394 }
395
396 self.enter(|| {
397 self.scheduler.clear();
398 })
399 }
400}
401
402impl AsRawFd for Runtime {
403 fn as_raw_fd(&self) -> RawFd {
404 self.driver.borrow().as_raw_fd()
405 }
406}
407
408#[cfg(feature = "criterion")]
409impl criterion::async_executor::AsyncExecutor for Runtime {
410 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
411 self.block_on(future)
412 }
413}
414
415#[cfg(feature = "criterion")]
416impl criterion::async_executor::AsyncExecutor for &Runtime {
417 fn block_on<T>(&self, future: impl Future<Output = T>) -> T {
418 (**self).block_on(future)
419 }
420}
421
422#[derive(Debug, Clone)]
424pub struct RuntimeBuilder {
425 proactor_builder: ProactorBuilder,
426 thread_affinity: HashSet<usize>,
427 event_interval: usize,
428}
429
430impl Default for RuntimeBuilder {
431 fn default() -> Self {
432 Self::new()
433 }
434}
435
436impl RuntimeBuilder {
437 pub fn new() -> Self {
439 Self {
440 proactor_builder: ProactorBuilder::new(),
441 event_interval: 61,
442 thread_affinity: HashSet::new(),
443 }
444 }
445
446 pub fn with_proactor(&mut self, builder: ProactorBuilder) -> &mut Self {
448 self.proactor_builder = builder;
449 self
450 }
451
452 pub fn thread_affinity(&mut self, cpus: HashSet<usize>) -> &mut Self {
454 self.thread_affinity = cpus;
455 self
456 }
457
458 pub fn event_interval(&mut self, val: usize) -> &mut Self {
463 self.event_interval = val;
464 self
465 }
466
467 pub fn build(&self) -> io::Result<Runtime> {
469 let RuntimeBuilder {
470 proactor_builder,
471 thread_affinity,
472 event_interval,
473 } = self;
474 let id = RUNTIME_ID.get();
475 RUNTIME_ID.set(id + 1);
476 if !thread_affinity.is_empty() {
477 bind_to_cpu_set(thread_affinity);
478 }
479 let inner = RuntimeInner {
480 driver: RefCell::new(proactor_builder.build()?),
481 scheduler: Scheduler::new(*event_interval),
482 #[cfg(feature = "time")]
483 timer_runtime: RefCell::new(TimerRuntime::new()),
484 id,
485 };
486 Ok(Runtime(Rc::new(inner)))
487 }
488}
489
490pub fn spawn<F: Future + 'static>(future: F) -> JoinHandle<F::Output> {
514 Runtime::with_current(|r| r.spawn(future))
515}
516
517pub fn spawn_blocking<T: Send + 'static>(
526 f: impl (FnOnce() -> T) + Send + 'static,
527) -> JoinHandle<T> {
528 Runtime::with_current(|r| r.spawn_blocking(f))
529}
530
531pub fn submit<T: OpCode + 'static>(op: T) -> Submit<T> {
539 Runtime::with_current(|r| r.submit(op))
540}
541
542#[deprecated(since = "0.8.0", note = "use `submit(op).with_extra()` instead")]
550pub fn submit_with_extra<T: OpCode + 'static>(op: T) -> Submit<T, Extra> {
551 Runtime::with_current(|r| r.submit(op).with_extra())
552}
553
554#[cfg(feature = "time")]
555pub(crate) async fn create_timer(instant: std::time::Instant) {
556 let key = Runtime::with_current(|r| r.timer_runtime.borrow_mut().insert(instant));
557 if let Some(key) = key {
558 TimerFuture::new(key).await
559 }
560}