use core::{
future::{Future, poll_fn},
pin::pin,
};
use std::io;
use tokio::{io::unix::AsyncFd, runtime::LocalRuntime};
mod context;
pub(crate) mod driver;
pub(crate) use context::RuntimeContext;
thread_local! {
pub(crate) static CONTEXT: RuntimeContext = RuntimeContext::new();
}
pub struct Runtime {
tokio_rt: LocalRuntime,
driver: driver::Handle,
}
pub fn spawn<T: Future + 'static>(task: T) -> tokio::task::JoinHandle<T::Output> {
tokio::task::spawn_local(task)
}
impl Runtime {
pub fn new(b: &crate::Builder) -> io::Result<Runtime> {
let tokio_rt = tokio::runtime::Builder::new_current_thread()
.on_thread_park(|| {
CONTEXT.with(|x| {
let _ = x
.handle()
.expect("Internal error, driver context not present when invoking hooks")
.flush();
});
})
.on_thread_unpark(|| {
CONTEXT.with(|x| {
if let Some(h) = x.handle() {
h.dispatch_completions();
}
});
})
.enable_all()
.build_local(Default::default())?;
let driver = driver::Handle::new(b)?;
start_uring_wakes_task(&tokio_rt, driver.clone());
Ok(Runtime { tokio_rt, driver })
}
pub fn block_on<F>(&self, future: F) -> F::Output
where
F: Future,
{
struct ContextGuard;
impl Drop for ContextGuard {
fn drop(&mut self) {
CONTEXT.with(|cx| cx.unset_driver());
}
}
CONTEXT.with(|cx| cx.set_handle(self.driver.clone()));
let _guard = ContextGuard;
let mut future = pin!(future);
let task = poll_fn(|cx| {
future.as_mut().poll(cx)
});
self.tokio_rt.block_on(task)
}
}
fn start_uring_wakes_task(tokio_rt: &LocalRuntime, driver: driver::Handle) {
let _guard = tokio_rt.enter();
let async_driver_handle = AsyncFd::new(driver).unwrap();
let task = drive_uring_wakes(async_driver_handle);
tokio::task::spawn_local(task);
}
async fn drive_uring_wakes(driver: AsyncFd<driver::Handle>) {
loop {
driver.readable().await.unwrap().clear_ready();
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::builder;
#[test]
fn block_on() {
let rt = Runtime::new(&builder()).unwrap();
rt.block_on(async {});
}
#[test]
fn block_on_twice() {
let rt = Runtime::new(&builder()).unwrap();
rt.block_on(async {});
rt.block_on(async {});
}
}