use std::thread;
use tokio::runtime::{Builder, Handle, Runtime};
use tokio::sync::oneshot;
use crate::plugin::error::PluginError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AsyncMode {
#[default]
Deterministic,
Throughput,
}
impl AsyncMode {
fn worker_threads(self) -> usize {
match self {
AsyncMode::Deterministic => 1,
AsyncMode::Throughput => std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1),
}
}
}
pub struct AsyncRuntime {
handle: Handle,
shutdown_tx: Option<oneshot::Sender<()>>,
thread: Option<thread::JoinHandle<()>>,
mode: AsyncMode,
}
impl AsyncRuntime {
pub fn new(mode: AsyncMode) -> Result<Self, PluginError> {
let (handle_tx, handle_rx) = std::sync::mpsc::channel::<Handle>();
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let worker_threads = mode.worker_threads();
let thread = thread::Builder::new()
.name("orts-plugin-runtime".to_string())
.spawn(move || {
let rt: Runtime = match Builder::new_multi_thread()
.worker_threads(worker_threads)
.thread_name("orts-plugin-worker")
.enable_all()
.build()
{
Ok(rt) => rt,
Err(e) => {
log::error!("tokio runtime build failed: {e}");
return;
}
};
if handle_tx.send(rt.handle().clone()).is_err() {
return;
}
rt.block_on(async move {
let _ = shutdown_rx.await;
});
})
.map_err(|e| PluginError::Init(format!("failed to spawn async runtime thread: {e}")))?;
let handle = handle_rx.recv().map_err(|_| {
PluginError::Init("async runtime thread exited before reporting handle".to_string())
})?;
Ok(Self {
handle,
shutdown_tx: Some(shutdown_tx),
thread: Some(thread),
mode,
})
}
pub fn handle(&self) -> &Handle {
&self.handle
}
pub fn mode(&self) -> AsyncMode {
self.mode
}
}
impl Drop for AsyncRuntime {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(thread) = self.thread.take() {
let _ = thread.join();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn runtime_starts_and_shuts_down() {
let rt = AsyncRuntime::new(AsyncMode::Deterministic).expect("runtime must start");
let handle = rt.handle().clone();
let result: i32 = handle.block_on(async { 1 + 2 });
assert_eq!(result, 3);
drop(rt);
}
#[test]
fn drop_joins_runtime_thread() {
let rt = AsyncRuntime::new(AsyncMode::Deterministic).expect("runtime must start");
let result: u64 = rt.handle().block_on(async { 42 });
assert_eq!(result, 42);
drop(rt);
}
}