use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::{Handle, Runtime};
use tokio::task::JoinHandle;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DispatcherConfig {
pub throughput: u32,
pub throughput_deadline: Option<Duration>,
}
impl Default for DispatcherConfig {
fn default() -> Self {
Self { throughput: 10, throughput_deadline: None }
}
}
pub trait Dispatcher: Send + Sync {
fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle;
fn throughput(&self) -> u32 {
10
}
fn throughput_deadline(&self) -> Option<Duration> {
None
}
}
pub struct DispatcherHandle(pub(crate) JoinHandle<()>);
impl DispatcherHandle {
pub async fn join(self) {
let _ = self.0.await;
}
pub fn abort(&self) {
self.0.abort();
}
}
pub struct DefaultDispatcher {
handle: Handle,
config: DispatcherConfig,
}
impl DefaultDispatcher {
pub fn new(handle: Handle, throughput: u32) -> Self {
Self { handle, config: DispatcherConfig { throughput, throughput_deadline: None } }
}
pub fn with_config(handle: Handle, config: DispatcherConfig) -> Self {
Self { handle, config }
}
pub fn current() -> Self {
Self::with_config(Handle::current(), DispatcherConfig::default())
}
}
impl Dispatcher for DefaultDispatcher {
fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle {
DispatcherHandle(self.handle.spawn(task))
}
fn throughput(&self) -> u32 {
self.config.throughput
}
fn throughput_deadline(&self) -> Option<Duration> {
self.config.throughput_deadline
}
}
pub struct PinnedDispatcher {
rt: Arc<Runtime>,
}
impl PinnedDispatcher {
pub fn new() -> std::io::Result<Self> {
let rt = tokio::runtime::Builder::new_current_thread().enable_all().build()?;
Ok(Self { rt: Arc::new(rt) })
}
}
impl Dispatcher for PinnedDispatcher {
fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle {
DispatcherHandle(self.rt.spawn(task))
}
}
pub fn spawn<F>(f: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
tokio::spawn(f)
}
pub struct ThreadPoolDispatcher {
rt: Arc<Runtime>,
throughput: u32,
}
impl ThreadPoolDispatcher {
pub fn new(worker_threads: usize, throughput: u32) -> std::io::Result<Self> {
let rt = tokio::runtime::Builder::new_multi_thread()
.worker_threads(worker_threads.max(1))
.enable_all()
.build()?;
Ok(Self { rt: Arc::new(rt), throughput })
}
}
impl Dispatcher for ThreadPoolDispatcher {
fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle {
DispatcherHandle(self.rt.spawn(task))
}
fn throughput(&self) -> u32 {
self.throughput
}
}
pub struct CallingThreadDispatcher;
impl Dispatcher for CallingThreadDispatcher {
fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle {
DispatcherHandle(tokio::task::spawn(task))
}
fn throughput(&self) -> u32 {
1
}
}
pub struct SingleThreadDispatcher {
rt: Arc<Runtime>,
config: DispatcherConfig,
}
impl SingleThreadDispatcher {
pub fn new(config: DispatcherConfig) -> std::io::Result<Self> {
let rt = tokio::runtime::Builder::new_current_thread().enable_all().build()?;
Ok(Self { rt: Arc::new(rt), config })
}
}
impl Dispatcher for SingleThreadDispatcher {
fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle {
DispatcherHandle(self.rt.spawn(task))
}
fn throughput(&self) -> u32 {
self.config.throughput
}
fn throughput_deadline(&self) -> Option<Duration> {
self.config.throughput_deadline
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn default_dispatcher_runs_task() {
let d = DefaultDispatcher::current();
let (tx, rx) = tokio::sync::oneshot::channel();
let h = d.spawn_task(Box::pin(async move {
tx.send(42u32).unwrap();
}));
assert_eq!(rx.await.unwrap(), 42);
h.join().await;
}
#[test]
fn dispatcher_config_default_is_unbounded_deadline() {
let c = DispatcherConfig::default();
assert_eq!(c.throughput, 10);
assert_eq!(c.throughput_deadline, None);
}
#[tokio::test]
async fn default_dispatcher_with_config_exposes_knobs() {
let cfg = DispatcherConfig { throughput: 50, throughput_deadline: Some(Duration::from_millis(5)) };
let d = DefaultDispatcher::with_config(Handle::current(), cfg.clone());
assert_eq!(d.throughput(), 50);
assert_eq!(d.throughput_deadline(), Some(Duration::from_millis(5)));
}
#[test]
fn single_thread_dispatcher_runs_task() {
let d = SingleThreadDispatcher::new(DispatcherConfig::default()).unwrap();
let (tx, rx) = std::sync::mpsc::channel();
let h = d.spawn_task(Box::pin(async move {
tx.send(7u32).unwrap();
}));
std::thread::sleep(Duration::from_millis(20));
h.abort();
let _ = rx.recv_timeout(Duration::from_millis(50));
}
}