1#![warn(missing_docs)]
4
5use std::{
6 collections::HashSet,
7 future::Future,
8 io,
9 num::NonZeroUsize,
10 panic::resume_unwind,
11 thread::{JoinHandle, available_parallelism},
12};
13
14use compio_driver::{AsyncifyPool, DispatchError, Dispatchable, ProactorBuilder};
15use compio_runtime::{JoinHandle as CompioJoinHandle, Runtime};
16use flume::{Sender, unbounded};
17use futures_channel::oneshot;
18
19type Spawning = Box<dyn Spawnable + Send>;
20
21trait Spawnable {
22 fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()>;
23}
24
25struct Concrete<F, R> {
27 callback: oneshot::Sender<R>,
28 func: F,
29}
30
31impl<F, R> Concrete<F, R> {
32 pub fn new(func: F) -> (Self, oneshot::Receiver<R>) {
33 let (tx, rx) = oneshot::channel();
34 (Self { callback: tx, func }, rx)
35 }
36}
37
38impl<F, Fut, R> Spawnable for Concrete<F, R>
39where
40 F: FnOnce() -> Fut + Send + 'static,
41 Fut: Future<Output = R>,
42 R: Send + 'static,
43{
44 fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()> {
45 let Concrete { callback, func } = *self;
46 handle.spawn(async move {
47 let res = func().await;
48 callback.send(res).ok();
49 })
50 }
51}
52
53impl<F, R> Dispatchable for Concrete<F, R>
54where
55 F: FnOnce() -> R + Send + 'static,
56 R: Send + 'static,
57{
58 fn run(self: Box<Self>) {
59 let Concrete { callback, func } = *self;
60 let res = func();
61 callback.send(res).ok();
62 }
63}
64
65#[derive(Debug)]
67pub struct Dispatcher {
68 sender: Sender<Spawning>,
69 threads: Vec<JoinHandle<()>>,
70 pool: AsyncifyPool,
71}
72
73impl Dispatcher {
74 pub(crate) fn new_impl(builder: DispatcherBuilder) -> io::Result<Self> {
76 let DispatcherBuilder {
77 nthreads,
78 concurrent,
79 stack_size,
80 mut thread_affinity,
81 mut names,
82 mut proactor_builder,
83 } = builder;
84 proactor_builder.force_reuse_thread_pool();
85 let pool = proactor_builder.create_or_get_thread_pool();
86 let (sender, receiver) = unbounded::<Spawning>();
87
88 let threads = (0..nthreads)
89 .map({
90 |index| {
91 let proactor_builder = proactor_builder.clone();
92 let receiver = receiver.clone();
93
94 let thread_builder = std::thread::Builder::new();
95 let thread_builder = if let Some(s) = stack_size {
96 thread_builder.stack_size(s)
97 } else {
98 thread_builder
99 };
100 let thread_builder = if let Some(f) = &mut names {
101 thread_builder.name(f(index))
102 } else {
103 thread_builder
104 };
105
106 let cpus = if let Some(f) = &mut thread_affinity {
107 f(index)
108 } else {
109 HashSet::new()
110 };
111 thread_builder.spawn(move || {
112 Runtime::builder()
113 .with_proactor(proactor_builder)
114 .thread_affinity(cpus)
115 .build()
116 .expect("cannot create compio runtime")
117 .block_on(async move {
118 while let Ok(f) = receiver.recv_async().await {
119 let task = Runtime::with_current(|rt| f.spawn(rt));
120 if concurrent {
121 task.detach()
122 } else {
123 task.await.ok();
124 }
125 }
126 });
127 })
128 }
129 })
130 .collect::<io::Result<Vec<_>>>()?;
131 Ok(Self {
132 sender,
133 threads,
134 pool,
135 })
136 }
137
138 pub fn new() -> io::Result<Self> {
140 Self::builder().build()
141 }
142
143 pub fn builder() -> DispatcherBuilder {
145 DispatcherBuilder::default()
146 }
147
148 pub fn dispatch<Fn, Fut, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
159 where
160 Fn: (FnOnce() -> Fut) + Send + 'static,
161 Fut: Future<Output = R> + 'static,
162 R: Send + 'static,
163 {
164 let (concrete, rx) = Concrete::new(f);
165
166 match self.sender.send(Box::new(concrete)) {
167 Ok(_) => Ok(rx),
168 Err(err) => {
169 let recovered =
171 unsafe { Box::from_raw(Box::into_raw(err.0) as *mut Concrete<Fn, R>) };
172 Err(DispatchError(recovered.func))
173 }
174 }
175 }
176
177 pub fn dispatch_blocking<Fn, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
190 where
191 Fn: FnOnce() -> R + Send + 'static,
192 R: Send + 'static,
193 {
194 let (concrete, rx) = Concrete::new(f);
195
196 self.pool
197 .dispatch(concrete)
198 .map_err(|e| DispatchError(e.0.func))?;
199
200 Ok(rx)
201 }
202
203 pub async fn join(self) -> io::Result<()> {
206 drop(self.sender);
207 let (tx, rx) = oneshot::channel::<Vec<_>>();
208 if let Err(f) = self.pool.dispatch({
209 move || {
210 let results = self
211 .threads
212 .into_iter()
213 .map(|thread| thread.join())
214 .collect();
215 tx.send(results).ok();
216 }
217 }) {
218 std::thread::spawn(f.0);
219 }
220 let results = rx
221 .await
222 .map_err(|_| io::Error::other("the join task cancelled unexpectedly"))?;
223 for res in results {
224 res.unwrap_or_else(|e| resume_unwind(e));
225 }
226 Ok(())
227 }
228}
229
230pub struct DispatcherBuilder {
232 nthreads: usize,
233 concurrent: bool,
234 stack_size: Option<usize>,
235 thread_affinity: Option<Box<dyn FnMut(usize) -> HashSet<usize>>>,
236 names: Option<Box<dyn FnMut(usize) -> String>>,
237 proactor_builder: ProactorBuilder,
238}
239
240impl DispatcherBuilder {
241 pub fn new() -> Self {
243 Self {
244 nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
245 concurrent: true,
246 stack_size: None,
247 thread_affinity: None,
248 names: None,
249 proactor_builder: ProactorBuilder::new(),
250 }
251 }
252
253 pub fn concurrent(mut self, concurrent: bool) -> Self {
258 self.concurrent = concurrent;
259 self
260 }
261
262 pub fn worker_threads(mut self, nthreads: NonZeroUsize) -> Self {
266 self.nthreads = nthreads.get();
267 self
268 }
269
270 pub fn stack_size(mut self, s: usize) -> Self {
272 self.stack_size = Some(s);
273 self
274 }
275
276 pub fn thread_affinity(mut self, f: impl FnMut(usize) -> HashSet<usize> + 'static) -> Self {
278 self.thread_affinity = Some(Box::new(f));
279 self
280 }
281
282 pub fn thread_names(mut self, f: impl (FnMut(usize) -> String) + 'static) -> Self {
284 self.names = Some(Box::new(f) as _);
285 self
286 }
287
288 pub fn proactor_builder(mut self, builder: ProactorBuilder) -> Self {
290 self.proactor_builder = builder;
291 self
292 }
293
294 pub fn build(self) -> io::Result<Dispatcher> {
296 Dispatcher::new_impl(self)
297 }
298}
299
300impl Default for DispatcherBuilder {
301 fn default() -> Self {
302 Self::new()
303 }
304}