atomr_core/dispatch/
dispatcher.rs1use std::future::Future;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio::runtime::{Handle, Runtime};
8use tokio::task::JoinHandle;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct DispatcherConfig {
14 pub throughput: u32,
16 pub throughput_deadline: Option<Duration>,
20}
21
22impl Default for DispatcherConfig {
23 fn default() -> Self {
24 Self { throughput: 10, throughput_deadline: None }
25 }
26}
27
28pub trait Dispatcher: Send + Sync {
30 fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle;
31
32 fn throughput(&self) -> u32 {
33 10
34 }
35
36 fn throughput_deadline(&self) -> Option<Duration> {
38 None
39 }
40}
41
42pub struct DispatcherHandle(pub(crate) JoinHandle<()>);
43
44impl DispatcherHandle {
45 pub async fn join(self) {
46 let _ = self.0.await;
47 }
48
49 pub fn abort(&self) {
50 self.0.abort();
51 }
52}
53
54pub struct DefaultDispatcher {
56 handle: Handle,
57 config: DispatcherConfig,
58}
59
60impl DefaultDispatcher {
61 pub fn new(handle: Handle, throughput: u32) -> Self {
62 Self { handle, config: DispatcherConfig { throughput, throughput_deadline: None } }
63 }
64
65 pub fn with_config(handle: Handle, config: DispatcherConfig) -> Self {
66 Self { handle, config }
67 }
68
69 pub fn current() -> Self {
70 Self::with_config(Handle::current(), DispatcherConfig::default())
71 }
72}
73
74impl Dispatcher for DefaultDispatcher {
75 fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle {
76 DispatcherHandle(self.handle.spawn(task))
77 }
78
79 fn throughput(&self) -> u32 {
80 self.config.throughput
81 }
82
83 fn throughput_deadline(&self) -> Option<Duration> {
84 self.config.throughput_deadline
85 }
86}
87
88pub struct PinnedDispatcher {
90 rt: Arc<Runtime>,
91}
92
93impl PinnedDispatcher {
94 pub fn new() -> std::io::Result<Self> {
95 let rt = tokio::runtime::Builder::new_current_thread().enable_all().build()?;
96 Ok(Self { rt: Arc::new(rt) })
97 }
98}
99
100impl Dispatcher for PinnedDispatcher {
101 fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle {
102 DispatcherHandle(self.rt.spawn(task))
103 }
104}
105
106pub fn spawn<F>(f: F) -> JoinHandle<F::Output>
108where
109 F: Future + Send + 'static,
110 F::Output: Send + 'static,
111{
112 tokio::spawn(f)
113}
114
115pub struct ThreadPoolDispatcher {
117 rt: Arc<Runtime>,
118 throughput: u32,
119}
120
121impl ThreadPoolDispatcher {
122 pub fn new(worker_threads: usize, throughput: u32) -> std::io::Result<Self> {
123 let rt = tokio::runtime::Builder::new_multi_thread()
124 .worker_threads(worker_threads.max(1))
125 .enable_all()
126 .build()?;
127 Ok(Self { rt: Arc::new(rt), throughput })
128 }
129}
130
131impl Dispatcher for ThreadPoolDispatcher {
132 fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle {
133 DispatcherHandle(self.rt.spawn(task))
134 }
135 fn throughput(&self) -> u32 {
136 self.throughput
137 }
138}
139
140pub struct CallingThreadDispatcher;
144
145impl Dispatcher for CallingThreadDispatcher {
146 fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle {
147 DispatcherHandle(tokio::task::spawn(task))
148 }
149 fn throughput(&self) -> u32 {
150 1
151 }
152}
153
154pub struct SingleThreadDispatcher {
160 rt: Arc<Runtime>,
161 config: DispatcherConfig,
162}
163
164impl SingleThreadDispatcher {
165 pub fn new(config: DispatcherConfig) -> std::io::Result<Self> {
166 let rt = tokio::runtime::Builder::new_current_thread().enable_all().build()?;
167 Ok(Self { rt: Arc::new(rt), config })
168 }
169}
170
171impl Dispatcher for SingleThreadDispatcher {
172 fn spawn_task(&self, task: futures_util::future::BoxFuture<'static, ()>) -> DispatcherHandle {
173 DispatcherHandle(self.rt.spawn(task))
174 }
175 fn throughput(&self) -> u32 {
176 self.config.throughput
177 }
178 fn throughput_deadline(&self) -> Option<Duration> {
179 self.config.throughput_deadline
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[tokio::test]
188 async fn default_dispatcher_runs_task() {
189 let d = DefaultDispatcher::current();
190 let (tx, rx) = tokio::sync::oneshot::channel();
191 let h = d.spawn_task(Box::pin(async move {
192 tx.send(42u32).unwrap();
193 }));
194 assert_eq!(rx.await.unwrap(), 42);
195 h.join().await;
196 }
197
198 #[test]
199 fn dispatcher_config_default_is_unbounded_deadline() {
200 let c = DispatcherConfig::default();
201 assert_eq!(c.throughput, 10);
202 assert_eq!(c.throughput_deadline, None);
203 }
204
205 #[tokio::test]
206 async fn default_dispatcher_with_config_exposes_knobs() {
207 let cfg = DispatcherConfig { throughput: 50, throughput_deadline: Some(Duration::from_millis(5)) };
208 let d = DefaultDispatcher::with_config(Handle::current(), cfg.clone());
209 assert_eq!(d.throughput(), 50);
210 assert_eq!(d.throughput_deadline(), Some(Duration::from_millis(5)));
211 }
212
213 #[test]
214 fn single_thread_dispatcher_runs_task() {
215 let d = SingleThreadDispatcher::new(DispatcherConfig::default()).unwrap();
216 let (tx, rx) = std::sync::mpsc::channel();
219 let h = d.spawn_task(Box::pin(async move {
220 tx.send(7u32).unwrap();
221 }));
222 std::thread::sleep(Duration::from_millis(20));
227 h.abort();
228 let _ = rx.recv_timeout(Duration::from_millis(50));
229 }
230}