compio_dispatcher/
lib.rs

1//! Multithreading dispatcher for compio.
2
3#![warn(missing_docs)]
4
5use std::{
6    collections::HashSet,
7    future::Future,
8    io,
9    num::NonZeroUsize,
10    panic::resume_unwind,
11    thread::{JoinHandle, available_parallelism},
12};
13
14use compio_driver::{AsyncifyPool, DispatchError, Dispatchable, ProactorBuilder};
15use compio_runtime::{JoinHandle as CompioJoinHandle, Runtime};
16use flume::{Sender, unbounded};
17use futures_channel::oneshot;
18
19type Spawning = Box<dyn Spawnable + Send>;
20
21trait Spawnable {
22    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()>;
23}
24
25/// Concrete type for the closure we're sending to worker threads
26struct Concrete<F, R> {
27    callback: oneshot::Sender<R>,
28    func: F,
29}
30
31impl<F, R> Concrete<F, R> {
32    pub fn new(func: F) -> (Self, oneshot::Receiver<R>) {
33        let (tx, rx) = oneshot::channel();
34        (Self { callback: tx, func }, rx)
35    }
36}
37
38impl<F, Fut, R> Spawnable for Concrete<F, R>
39where
40    F: FnOnce() -> Fut + Send + 'static,
41    Fut: Future<Output = R>,
42    R: Send + 'static,
43{
44    fn spawn(self: Box<Self>, handle: &Runtime) -> CompioJoinHandle<()> {
45        let Concrete { callback, func } = *self;
46        handle.spawn(async move {
47            let res = func().await;
48            callback.send(res).ok();
49        })
50    }
51}
52
53impl<F, R> Dispatchable for Concrete<F, R>
54where
55    F: FnOnce() -> R + Send + 'static,
56    R: Send + 'static,
57{
58    fn run(self: Box<Self>) {
59        let Concrete { callback, func } = *self;
60        let res = func();
61        callback.send(res).ok();
62    }
63}
64
65/// The dispatcher. It manages the threads and dispatches the tasks.
66#[derive(Debug)]
67pub struct Dispatcher {
68    sender: Sender<Spawning>,
69    threads: Vec<JoinHandle<()>>,
70    pool: AsyncifyPool,
71}
72
73impl Dispatcher {
74    /// Create the dispatcher with specified number of threads.
75    pub(crate) fn new_impl(builder: DispatcherBuilder) -> io::Result<Self> {
76        let DispatcherBuilder {
77            nthreads,
78            concurrent,
79            stack_size,
80            mut thread_affinity,
81            mut names,
82            mut proactor_builder,
83        } = builder;
84        proactor_builder.force_reuse_thread_pool();
85        let pool = proactor_builder.create_or_get_thread_pool();
86        let (sender, receiver) = unbounded::<Spawning>();
87
88        let threads = (0..nthreads)
89            .map({
90                |index| {
91                    let proactor_builder = proactor_builder.clone();
92                    let receiver = receiver.clone();
93
94                    let thread_builder = std::thread::Builder::new();
95                    let thread_builder = if let Some(s) = stack_size {
96                        thread_builder.stack_size(s)
97                    } else {
98                        thread_builder
99                    };
100                    let thread_builder = if let Some(f) = &mut names {
101                        thread_builder.name(f(index))
102                    } else {
103                        thread_builder
104                    };
105
106                    let cpus = if let Some(f) = &mut thread_affinity {
107                        f(index)
108                    } else {
109                        HashSet::new()
110                    };
111                    thread_builder.spawn(move || {
112                        Runtime::builder()
113                            .with_proactor(proactor_builder)
114                            .thread_affinity(cpus)
115                            .build()
116                            .expect("cannot create compio runtime")
117                            .block_on(async move {
118                                while let Ok(f) = receiver.recv_async().await {
119                                    let task = Runtime::with_current(|rt| f.spawn(rt));
120                                    if concurrent {
121                                        task.detach()
122                                    } else {
123                                        task.await.ok();
124                                    }
125                                }
126                            });
127                    })
128                }
129            })
130            .collect::<io::Result<Vec<_>>>()?;
131        Ok(Self {
132            sender,
133            threads,
134            pool,
135        })
136    }
137
138    /// Create the dispatcher with default config.
139    pub fn new() -> io::Result<Self> {
140        Self::builder().build()
141    }
142
143    /// Create a builder to build a dispatcher.
144    pub fn builder() -> DispatcherBuilder {
145        DispatcherBuilder::default()
146    }
147
148    /// Dispatch a task to the threads
149    ///
150    /// The provided `f` should be [`Send`] because it will be send to another
151    /// thread before calling. The returned [`Future`] need not to be [`Send`]
152    /// because it will be executed on only one thread.
153    ///
154    /// # Error
155    ///
156    /// If all threads have panicked, this method will return an error with the
157    /// sent closure.
158    pub fn dispatch<Fn, Fut, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
159    where
160        Fn: (FnOnce() -> Fut) + Send + 'static,
161        Fut: Future<Output = R> + 'static,
162        R: Send + 'static,
163    {
164        let (concrete, rx) = Concrete::new(f);
165
166        match self.sender.send(Box::new(concrete)) {
167            Ok(_) => Ok(rx),
168            Err(err) => {
169                // SAFETY: We know the dispatchable we sent has type `Concrete<Fn, R>`
170                let recovered =
171                    unsafe { Box::from_raw(Box::into_raw(err.0) as *mut Concrete<Fn, R>) };
172                Err(DispatchError(recovered.func))
173            }
174        }
175    }
176
177    /// Dispatch a blocking task to the threads.
178    ///
179    /// Blocking pool of the dispatcher will be obtained from the proactor
180    /// builder. So any configuration of the proactor's blocking pool will be
181    /// applied to the dispatcher.
182    ///
183    /// # Error
184    ///
185    /// If all threads are busy and the thread pool is full, this method will
186    /// return an error with the original closure. The limit can be configured
187    /// with [`DispatcherBuilder::proactor_builder`] and
188    /// [`ProactorBuilder::thread_pool_limit`].
189    pub fn dispatch_blocking<Fn, R>(&self, f: Fn) -> Result<oneshot::Receiver<R>, DispatchError<Fn>>
190    where
191        Fn: FnOnce() -> R + Send + 'static,
192        R: Send + 'static,
193    {
194        let (concrete, rx) = Concrete::new(f);
195
196        self.pool
197            .dispatch(concrete)
198            .map_err(|e| DispatchError(e.0.func))?;
199
200        Ok(rx)
201    }
202
203    /// Stop the dispatcher and wait for the threads to complete. If there is a
204    /// thread panicked, this method will resume the panic.
205    pub async fn join(self) -> io::Result<()> {
206        drop(self.sender);
207        let (tx, rx) = oneshot::channel::<Vec<_>>();
208        if let Err(f) = self.pool.dispatch({
209            move || {
210                let results = self
211                    .threads
212                    .into_iter()
213                    .map(|thread| thread.join())
214                    .collect();
215                tx.send(results).ok();
216            }
217        }) {
218            std::thread::spawn(f.0);
219        }
220        let results = rx
221            .await
222            .map_err(|_| io::Error::other("the join task cancelled unexpectedly"))?;
223        for res in results {
224            res.unwrap_or_else(|e| resume_unwind(e));
225        }
226        Ok(())
227    }
228}
229
230/// A builder for [`Dispatcher`].
231pub struct DispatcherBuilder {
232    nthreads: usize,
233    concurrent: bool,
234    stack_size: Option<usize>,
235    thread_affinity: Option<Box<dyn FnMut(usize) -> HashSet<usize>>>,
236    names: Option<Box<dyn FnMut(usize) -> String>>,
237    proactor_builder: ProactorBuilder,
238}
239
240impl DispatcherBuilder {
241    /// Create a builder with default settings.
242    pub fn new() -> Self {
243        Self {
244            nthreads: available_parallelism().map(|n| n.get()).unwrap_or(1),
245            concurrent: true,
246            stack_size: None,
247            thread_affinity: None,
248            names: None,
249            proactor_builder: ProactorBuilder::new(),
250        }
251    }
252
253    /// If execute tasks concurrently. Default to be `true`.
254    ///
255    /// When set to `false`, tasks are executed sequentially without any
256    /// concurrency within the thread.
257    pub fn concurrent(mut self, concurrent: bool) -> Self {
258        self.concurrent = concurrent;
259        self
260    }
261
262    /// Set the number of worker threads of the dispatcher. The default value is
263    /// the CPU number. If the CPU number could not be retrieved, the
264    /// default value is 1.
265    pub fn worker_threads(mut self, nthreads: NonZeroUsize) -> Self {
266        self.nthreads = nthreads.get();
267        self
268    }
269
270    /// Set the size of stack of the worker threads.
271    pub fn stack_size(mut self, s: usize) -> Self {
272        self.stack_size = Some(s);
273        self
274    }
275
276    /// Set the thread affinity for the dispatcher.
277    pub fn thread_affinity(mut self, f: impl FnMut(usize) -> HashSet<usize> + 'static) -> Self {
278        self.thread_affinity = Some(Box::new(f));
279        self
280    }
281
282    /// Provide a function to assign names to the worker threads.
283    pub fn thread_names(mut self, f: impl (FnMut(usize) -> String) + 'static) -> Self {
284        self.names = Some(Box::new(f) as _);
285        self
286    }
287
288    /// Set the proactor builder for the inner runtimes.
289    pub fn proactor_builder(mut self, builder: ProactorBuilder) -> Self {
290        self.proactor_builder = builder;
291        self
292    }
293
294    /// Build the [`Dispatcher`].
295    pub fn build(self) -> io::Result<Dispatcher> {
296        Dispatcher::new_impl(self)
297    }
298}
299
300impl Default for DispatcherBuilder {
301    fn default() -> Self {
302        Self::new()
303    }
304}