iroh_blobs/util/
local_pool.rs

1//! A local task pool with proper shutdown
2use std::{
3    any::Any,
4    future::Future,
5    ops::Deref,
6    pin::Pin,
7    sync::{
8        atomic::{AtomicBool, Ordering},
9        Arc,
10    },
11};
12
13use futures_lite::FutureExt;
14use tokio::{
15    sync::{Notify, Semaphore},
16    task::{JoinError, JoinSet, LocalSet},
17};
18
19type BoxedFut<T = ()> = Pin<Box<dyn Future<Output = T>>>;
20type SpawnFn<T = ()> = Box<dyn FnOnce() -> BoxedFut<T> + Send + 'static>;
21
22enum Message {
23    /// Create a new task and execute it locally
24    Execute(SpawnFn),
25    /// Shutdown the thread after finishing all tasks
26    Finish,
27}
28
29/// A local task pool with proper shutdown
30///
31/// Unlike
32/// [`LocalPoolHandle`](https://docs.rs/tokio-util/latest/tokio_util/task/struct.LocalPoolHandle.html),
33/// this pool will join all its threads when dropped, ensuring that all Drop
34/// implementations are run to completion.
35///
36/// On drop, this pool will immediately cancel all *tasks* that are currently
37/// being executed, and will wait for all threads to finish executing their
38/// loops before returning. This means that all drop implementations will be
39/// able to run to completion before drop exits.
40///
41/// On [`LocalPool::finish`], this pool will notify all threads to shut down,
42/// and then wait for all threads to finish executing their loops before
43/// returning. This means that all currently executing tasks will be allowed to
44/// run to completion.
45///
46/// The pool will install the [`tracing::Subscriber`] which was set on the current thread of
47/// where it was created as the default subscriber in all spawned threads.
48#[derive(Debug)]
49pub struct LocalPool {
50    threads: Vec<std::thread::JoinHandle<()>>,
51    shutdown_sem: Arc<Semaphore>,
52    cancel_token: CancellationToken,
53    handle: LocalPoolHandle,
54}
55
56impl Deref for LocalPool {
57    type Target = LocalPoolHandle;
58
59    fn deref(&self) -> &Self::Target {
60        &self.handle
61    }
62}
63
64/// A handle to a [`LocalPool`]
65#[derive(Debug, Clone)]
66pub struct LocalPoolHandle {
67    /// The sender half of the channel used to send tasks to the pool
68    send: async_channel::Sender<Message>,
69}
70
71/// What to do when a panic occurs in a pool thread
72#[derive(Clone, Copy, Debug, PartialEq, Eq)]
73pub enum PanicMode {
74    /// Log the panic and continue
75    ///
76    /// The panic will be re-thrown when the pool is dropped.
77    LogAndContinue,
78    /// Log the panic and immediately shut down the pool.
79    ///
80    /// The panic will be re-thrown when the pool is dropped.
81    Shutdown,
82}
83
84/// Local task pool configuration
85#[derive(Clone, Debug)]
86pub struct Config {
87    /// Number of threads in the pool
88    pub threads: usize,
89    /// Prefix for thread names
90    pub thread_name_prefix: &'static str,
91    /// Ignore panics in pool threads
92    pub panic_mode: PanicMode,
93}
94
95impl Default for Config {
96    fn default() -> Self {
97        Self {
98            threads: num_cpus::get(),
99            thread_name_prefix: "local-pool",
100            panic_mode: PanicMode::Shutdown,
101        }
102    }
103}
104
105impl Default for LocalPool {
106    fn default() -> Self {
107        Self::new(Default::default())
108    }
109}
110
111impl LocalPool {
112    /// Create a new local pool with a single std thread.
113    pub fn single() -> Self {
114        Self::new(Config {
115            threads: 1,
116            ..Default::default()
117        })
118    }
119
120    /// Create a new local pool with the given config.
121    ///
122    /// This will use the current tokio runtime handle, so it must be called
123    /// from within a tokio runtime.
124    pub fn new(config: Config) -> Self {
125        let Config {
126            threads,
127            thread_name_prefix,
128            panic_mode,
129        } = config;
130        let cancel_token = CancellationToken::new();
131        let (send, recv) = async_channel::unbounded::<Message>();
132        let shutdown_sem = Arc::new(Semaphore::new(0));
133        let handle = tokio::runtime::Handle::current();
134        let handles = (0..threads)
135            .map(|i| {
136                Self::spawn_pool_thread(
137                    format!("{thread_name_prefix}-{i}"),
138                    recv.clone(),
139                    cancel_token.clone(),
140                    panic_mode,
141                    shutdown_sem.clone(),
142                    handle.clone(),
143                )
144            })
145            .collect::<std::io::Result<Vec<_>>>()
146            .expect("invalid thread name");
147        Self {
148            threads: handles,
149            handle: LocalPoolHandle { send },
150            cancel_token,
151            shutdown_sem,
152        }
153    }
154
155    /// Get a cheaply cloneable handle to the pool
156    ///
157    /// This is not strictly necessary since we implement deref for
158    /// LocalPoolHandle, but makes getting a handle more explicit.
159    pub fn handle(&self) -> &LocalPoolHandle {
160        &self.handle
161    }
162
163    /// Spawn a new pool thread.
164    fn spawn_pool_thread(
165        thread_name: String,
166        recv: async_channel::Receiver<Message>,
167        cancel_token: CancellationToken,
168        panic_mode: PanicMode,
169        shutdown_sem: Arc<Semaphore>,
170        handle: tokio::runtime::Handle,
171    ) -> std::io::Result<std::thread::JoinHandle<()>> {
172        let tracing_dispatcher = tracing::dispatcher::get_default(|dispatcher| dispatcher.clone());
173        std::thread::Builder::new()
174            .name(thread_name)
175            .spawn(move || {
176                let _tracing_guard = tracing::dispatcher::set_default(&tracing_dispatcher);
177                let mut s = JoinSet::new();
178                let mut last_panic = None;
179                let mut handle_join = |res: Option<std::result::Result<(), JoinError>>| -> bool {
180                    if let Some(Err(e)) = res {
181                        if let Ok(panic) = e.try_into_panic() {
182                            let panic_info = get_panic_info(&panic);
183                            let thread_name = get_thread_name();
184                            tracing::error!(
185                                "Panic in local pool thread: {}\n{}",
186                                thread_name,
187                                panic_info
188                            );
189                            last_panic = Some(panic);
190                        }
191                    }
192                    panic_mode == PanicMode::LogAndContinue || last_panic.is_none()
193                };
194                let ls = LocalSet::new();
195                let shutdown_mode = handle.block_on(ls.run_until(async {
196                    loop {
197                        tokio::select! {
198                            // poll the set of futures
199                            res = s.join_next(), if !s.is_empty() => {
200                                if !handle_join(res) {
201                                    break ShutdownMode::Stop;
202                                }
203                            },
204                            // if the cancel token is cancelled, break the loop immediately
205                            _ = cancel_token.cancelled() => break ShutdownMode::Stop,
206                            // if we receive a message, execute it
207                            msg = recv.recv() => {
208                                match msg {
209                                    // just push into the join set
210                                    Ok(Message::Execute(f)) => {
211                                        s.spawn_local((f)());
212                                    }
213                                    // break with optional semaphore
214                                    Ok(Message::Finish) => break ShutdownMode::Finish,
215                                    // if the sender is dropped, break the loop immediately
216                                    Err(async_channel::RecvError) => break ShutdownMode::Stop,
217                                }
218                            },
219                        }
220                    }
221                }));
222                // soft shutdown mode is just like normal running, except that
223                // we don't add any more tasks and stop when there are no more
224                // tasks to run.
225                if shutdown_mode == ShutdownMode::Finish {
226                    // somebody is asking for a clean shutdown, wait for all tasks to finish
227                    handle.block_on(ls.run_until(async {
228                        loop {
229                            tokio::select! {
230                                res = s.join_next() => {
231                                    if res.is_none() || !handle_join(res) {
232                                        break;
233                                    }
234                                }
235                                _ = cancel_token.cancelled() => break,
236                            }
237                        }
238                    }));
239                }
240                // Always add the permit. If nobody is waiting for it, it does
241                // no harm.
242                shutdown_sem.add_permits(1);
243                if let Some(_panic) = last_panic {
244                    // std::panic::resume_unwind(panic);
245                }
246            })
247    }
248
249    /// A future that resolves when the pool is cancelled
250    pub async fn cancelled(&self) {
251        self.cancel_token.cancelled().await
252    }
253
254    /// Immediately stop polling all tasks and wait for all threads to finish.
255    ///
256    /// This is like drop, but waits for thread completion asynchronously.
257    ///
258    /// If there was a panic on any of the threads, it will be re-thrown here.
259    pub async fn shutdown(self) {
260        self.cancel_token.cancel();
261        self.await_thread_completion().await;
262        // just make it explicit that this is where drop runs
263        drop(self);
264    }
265
266    /// Gently shut down the pool
267    ///
268    /// Notifies all the pool threads to shut down and waits for them to finish.
269    ///
270    /// If you just want to drop the pool without giving the threads a chance to
271    /// process their remaining tasks, just use [`Self::shutdown`].
272    ///
273    /// If you want to wait for only a limited time for the tasks to finish,
274    /// you can race this function with a timeout.
275    pub async fn finish(self) {
276        // we assume that there are exactly as many threads as there are handles.
277        // also, we assume that the threads are still running.
278        for _ in 0..self.threads_u32() {
279            // send the shutdown message
280            // sending will fail if all threads are already finished, but
281            // in that case we don't need to do anything.
282            //
283            // Threads will add a permit in any case, so await_thread_completion
284            // will then immediately return.
285            self.send.send(Message::Finish).await.ok();
286        }
287        self.await_thread_completion().await;
288    }
289
290    fn threads_u32(&self) -> u32 {
291        self.threads
292            .len()
293            .try_into()
294            .expect("invalid number of threads")
295    }
296
297    async fn await_thread_completion(&self) {
298        // wait for all threads to finish.
299        // Each thread will add a permit to the semaphore.
300        let wait_for_semaphore = async move {
301            let _ = self
302                .shutdown_sem
303                .acquire_many(self.threads_u32())
304                .await
305                .expect("semaphore closed");
306        };
307        // race the semaphore wait with the cancel token in case somebody
308        // cancels the pool while we are waiting.
309        tokio::select! {
310            _ = wait_for_semaphore => {}
311            _ = self.cancel_token.cancelled() => {}
312        }
313    }
314}
315
316impl Drop for LocalPool {
317    fn drop(&mut self) {
318        self.cancel_token.cancel();
319        let current_thread_id = std::thread::current().id();
320        for handle in self.threads.drain(..) {
321            // we have no control over from where Drop is called, especially
322            // if the pool ends up in an Arc. So we need to check if we are
323            // dropping from within a pool thread and skip it in that case.
324            if handle.thread().id() == current_thread_id {
325                tracing::error!("Dropping LocalPool from within a pool thread.");
326                continue;
327            }
328            // Log any panics and resume them
329            if let Err(panic) = handle.join() {
330                let panic_info = get_panic_info(&panic);
331                let thread_name = get_thread_name();
332                tracing::error!("Error joining thread: {}\n{}", thread_name, panic_info);
333                // std::panic::resume_unwind(panic);
334            }
335        }
336    }
337}
338
339/// Errors for spawn failures
340#[derive(thiserror::Error, Debug)]
341pub enum SpawnError {
342    /// Task was dropped, either due to a panic or because the pool was shut down.
343    #[error("cancelled")]
344    Cancelled,
345}
346
347type SpawnResult<T> = std::result::Result<T, SpawnError>;
348
349/// Future returned by [`LocalPoolHandle::spawn`] and [`LocalPoolHandle::try_spawn`].
350///
351/// Dropping this future will immediately cancel the task. The task can fail if
352/// the pool is shut down or if the task panics. In both cases the future will
353/// resolve to [`SpawnError::Cancelled`].
354#[repr(transparent)]
355#[derive(Debug)]
356pub struct Run<T>(tokio::sync::oneshot::Receiver<T>);
357
358impl<T> Run<T> {
359    /// Abort the task
360    ///
361    /// Dropping the future will also abort the task.
362    pub fn abort(&mut self) {
363        self.0.close();
364    }
365}
366
367impl<T> Future for Run<T> {
368    type Output = std::result::Result<T, SpawnError>;
369
370    fn poll(
371        mut self: Pin<&mut Self>,
372        cx: &mut std::task::Context<'_>,
373    ) -> std::task::Poll<Self::Output> {
374        // map a RecvError (other side was dropped) to a SpawnError::Shutdown
375        //
376        // The only way the receiver can be dropped is if the pool is shut down.
377        self.0.poll(cx).map_err(|_| SpawnError::Cancelled)
378    }
379}
380
381impl From<SpawnError> for std::io::Error {
382    fn from(e: SpawnError) -> Self {
383        std::io::Error::new(std::io::ErrorKind::Other, e)
384    }
385}
386
387impl LocalPoolHandle {
388    /// Get the number of tasks in the queue
389    ///
390    /// This is *not* the number of tasks being executed, but the number of
391    /// tasks waiting to be scheduled for execution. If this number is high,
392    /// it indicates that the pool is very busy.
393    ///
394    /// You might want to use this to throttle or reject requests.
395    pub fn waiting_tasks(&self) -> usize {
396        self.send.len()
397    }
398
399    /// Spawn a task in the pool and return a future that resolves when the task
400    /// is done.
401    ///
402    /// If you don't care about the result, prefer [`LocalPoolHandle::spawn_detached`]
403    /// since it is more efficient.
404    pub fn try_spawn<T, F, Fut>(&self, gen: F) -> SpawnResult<Run<T>>
405    where
406        F: FnOnce() -> Fut + Send + 'static,
407        Fut: Future<Output = T> + 'static,
408        T: Send + 'static,
409    {
410        let (mut send_res, recv_res) = tokio::sync::oneshot::channel();
411        let item = move || async move {
412            let fut = (gen)();
413            tokio::select! {
414                // send the result to the receiver
415                res = fut => { send_res.send(res).ok(); }
416                // immediately stop the task if the receiver is dropped
417                _ = send_res.closed() => {}
418            }
419        };
420        self.try_spawn_detached(item)?;
421        Ok(Run(recv_res))
422    }
423
424    /// Spawn a task in the pool.
425    ///
426    /// The task will run to completion unless the pool is shut down or the task
427    /// panics. In case of panic, the pool will either log the panic and continue
428    /// or immediately shut down, depending on the [`PanicMode`].
429    pub fn try_spawn_detached<F, Fut>(&self, gen: F) -> SpawnResult<()>
430    where
431        F: FnOnce() -> Fut + Send + 'static,
432        Fut: Future<Output = ()> + 'static,
433    {
434        let gen: SpawnFn = Box::new(move || Box::pin(gen()));
435        self.try_spawn_detached_boxed(gen)
436    }
437
438    /// Spawn a task in the pool and await the result.
439    ///
440    /// Like [`LocalPoolHandle::try_spawn`], but panics if the pool is shut down.
441    pub fn spawn<T, F, Fut>(&self, gen: F) -> Run<T>
442    where
443        F: FnOnce() -> Fut + Send + 'static,
444        Fut: Future<Output = T> + 'static,
445        T: Send + 'static,
446    {
447        self.try_spawn(gen).expect("pool is shut down")
448    }
449
450    /// Spawn a task in the pool.
451    ///
452    /// Like [`LocalPoolHandle::try_spawn_detached`], but panics if the pool is shut down.
453    pub fn spawn_detached<F, Fut>(&self, gen: F)
454    where
455        F: FnOnce() -> Fut + Send + 'static,
456        Fut: Future<Output = ()> + 'static,
457    {
458        self.try_spawn_detached(gen).expect("pool is shut down")
459    }
460
461    /// Spawn a task in the pool.
462    ///
463    /// This is like [`LocalPoolHandle::try_spawn_detached`], but assuming that the
464    /// generator function is already boxed. This is the lowest overhead way to
465    /// spawn a task in the pool.
466    pub fn try_spawn_detached_boxed(&self, gen: SpawnFn) -> SpawnResult<()> {
467        self.send
468            .send_blocking(Message::Execute(gen))
469            .map_err(|_| SpawnError::Cancelled)
470    }
471}
472
473/// Thread shutdown mode
474#[derive(Debug, Clone, Copy, PartialEq, Eq)]
475enum ShutdownMode {
476    /// Finish all tasks and then stop
477    Finish,
478    /// Stop immediately
479    Stop,
480}
481
482fn get_panic_info(panic: &Box<dyn Any + Send>) -> String {
483    if let Some(s) = panic.downcast_ref::<&str>() {
484        s.to_string()
485    } else if let Some(s) = panic.downcast_ref::<String>() {
486        s.clone()
487    } else {
488        "Panic info unavailable".to_string()
489    }
490}
491
492fn get_thread_name() -> String {
493    std::thread::current()
494        .name()
495        .unwrap_or("unnamed")
496        .to_string()
497}
498
499/// A lightweight cancellation token
500#[derive(Debug, Clone)]
501struct CancellationToken {
502    inner: Arc<CancellationTokenInner>,
503}
504
505#[derive(Debug)]
506struct CancellationTokenInner {
507    is_cancelled: AtomicBool,
508    notify: Notify,
509}
510
511impl CancellationToken {
512    fn new() -> Self {
513        Self {
514            inner: Arc::new(CancellationTokenInner {
515                is_cancelled: AtomicBool::new(false),
516                notify: Notify::new(),
517            }),
518        }
519    }
520
521    fn cancel(&self) {
522        if !self.inner.is_cancelled.swap(true, Ordering::SeqCst) {
523            self.inner.notify.notify_waiters();
524        }
525    }
526
527    async fn cancelled(&self) {
528        if self.is_cancelled() {
529            return;
530        }
531
532        // Wait for notification if not cancelled
533        self.inner.notify.notified().await;
534    }
535
536    fn is_cancelled(&self) -> bool {
537        self.inner.is_cancelled.load(Ordering::SeqCst)
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use std::{sync::atomic::AtomicU64, time::Duration};
544
545    use tracing::info;
546    use tracing_test::traced_test;
547
548    use super::*;
549
550    /// A struct that simulates a long running drop operation
551    #[derive(Debug)]
552    struct TestDrop(Option<Arc<AtomicU64>>);
553
554    impl Drop for TestDrop {
555        fn drop(&mut self) {
556            // delay to make sure the drop is executed completely
557            std::thread::sleep(Duration::from_millis(100));
558            // increment the drop counter
559            if let Some(counter) = self.0.take() {
560                counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
561            }
562        }
563    }
564
565    impl TestDrop {
566        fn new(counter: Arc<AtomicU64>) -> Self {
567            Self(Some(counter))
568        }
569
570        fn forget(mut self) {
571            self.0.take();
572        }
573    }
574
575    /// Create a non-send test future that captures a TestDrop instance
576    async fn delay_then_drop(x: TestDrop) {
577        tokio::time::sleep(Duration::from_millis(100)).await;
578        // drop x at the end. we will never get here when the future is
579        // no longer polled, but drop should still be called
580        drop(x);
581    }
582
583    /// Use a TestDrop instance to test cancellation
584    async fn delay_then_forget(x: TestDrop, delay: Duration) {
585        tokio::time::sleep(delay).await;
586        x.forget();
587    }
588
589    #[tokio::test]
590    #[traced_test]
591    async fn test_tracing() {
592        // This test wants to make sure that logging inside the pool propagates to the
593        // tracing subscriber that was set for the current thread at the time the pool was
594        // created.
595        //
596        // Look, there should be a custom tracing subscriber here that allows us to inspect
597        // the messages sent to it so we can verify it received all the messages.  But have
598        // you ever tried to implement a tracing subscriber?  In the mean time this test will
599        // just always pass, to really see the test run it with:
600        //
601        // cargo nextest run -p iroh-blobs local_pool::tests::test_tracing --success-output final
602        //
603        // and eyeball the output.  yolo
604        info!("hello from the test");
605        let pool = LocalPool::single();
606        pool.spawn(|| async move {
607            info!("hello from the pool");
608        })
609        .await
610        .unwrap();
611    }
612
613    #[tokio::test]
614    async fn test_drop() {
615        let _ = tracing_subscriber::fmt::try_init();
616        let pool = LocalPool::new(Config::default());
617        let counter = Arc::new(AtomicU64::new(0));
618        let n = 4;
619        for _ in 0..n {
620            let td = TestDrop::new(counter.clone());
621            pool.spawn_detached(move || delay_then_drop(td));
622        }
623        drop(pool);
624        assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), n);
625    }
626
627    #[tokio::test]
628    async fn test_finish() {
629        let _ = tracing_subscriber::fmt::try_init();
630        let pool = LocalPool::new(Config::default());
631        let counter = Arc::new(AtomicU64::new(0));
632        let n = 4;
633        for _ in 0..n {
634            let td = TestDrop::new(counter.clone());
635            pool.spawn_detached(move || delay_then_drop(td));
636        }
637        pool.finish().await;
638        assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), n);
639    }
640
641    #[tokio::test]
642    async fn test_cancel() {
643        let _ = tracing_subscriber::fmt::try_init();
644        let pool = LocalPool::new(Config {
645            threads: 2,
646            ..Config::default()
647        });
648        let c1 = Arc::new(AtomicU64::new(0));
649        let td1 = TestDrop::new(c1.clone());
650        let handle = pool.spawn(move || {
651            // this one will be aborted anyway, so use a long delay to make sure
652            // that it does not accidentally run to completion
653            delay_then_forget(td1, Duration::from_secs(10))
654        });
655        drop(handle);
656        let c2 = Arc::new(AtomicU64::new(0));
657        let td2 = TestDrop::new(c2.clone());
658        let _handle = pool.spawn(move || {
659            // this one will not be aborted, so use a short delay so the test
660            // does not take too long
661            delay_then_forget(td2, Duration::from_millis(100))
662        });
663        pool.finish().await;
664        // c1 will be aborted, so drop will run before forget, so the counter will be increased
665        assert_eq!(c1.load(std::sync::atomic::Ordering::SeqCst), 1);
666        // c2 will not be aborted, so drop will run after forget, so the counter will not be increased
667        assert_eq!(c2.load(std::sync::atomic::Ordering::SeqCst), 0);
668    }
669
670    // #[tokio::test]
671    // #[should_panic]
672    // #[ignore = "todo"]
673    // async fn test_panic() {
674    //     let _ = tracing_subscriber::fmt::try_init();
675    //     let pool = LocalPool::new(Config {
676    //         threads: 2,
677    //         ..Config::default()
678    //     });
679    //     pool.spawn_detached(|| async {
680    //         panic!("test panic");
681    //     });
682    //     // we can't use shutdown here, because we need to allow time for the
683    //     // panic to happen.
684    //     pool.finish().await;
685    // }
686}