iroh_blobs/util/local_pool.rs
1//! A local task pool with proper shutdown
2use std::{
3 any::Any,
4 future::Future,
5 ops::Deref,
6 pin::Pin,
7 sync::{
8 atomic::{AtomicBool, Ordering},
9 Arc,
10 },
11};
12
13use futures_lite::FutureExt;
14use tokio::{
15 sync::{Notify, Semaphore},
16 task::{JoinError, JoinSet, LocalSet},
17};
18
19type BoxedFut<T = ()> = Pin<Box<dyn Future<Output = T>>>;
20type SpawnFn<T = ()> = Box<dyn FnOnce() -> BoxedFut<T> + Send + 'static>;
21
22enum Message {
23 /// Create a new task and execute it locally
24 Execute(SpawnFn),
25 /// Shutdown the thread after finishing all tasks
26 Finish,
27}
28
29/// A local task pool with proper shutdown
30///
31/// Unlike
32/// [`LocalPoolHandle`](https://docs.rs/tokio-util/latest/tokio_util/task/struct.LocalPoolHandle.html),
33/// this pool will join all its threads when dropped, ensuring that all Drop
34/// implementations are run to completion.
35///
36/// On drop, this pool will immediately cancel all *tasks* that are currently
37/// being executed, and will wait for all threads to finish executing their
38/// loops before returning. This means that all drop implementations will be
39/// able to run to completion before drop exits.
40///
41/// On [`LocalPool::finish`], this pool will notify all threads to shut down,
42/// and then wait for all threads to finish executing their loops before
43/// returning. This means that all currently executing tasks will be allowed to
44/// run to completion.
45///
46/// The pool will install the [`tracing::Subscriber`] which was set on the current thread of
47/// where it was created as the default subscriber in all spawned threads.
48#[derive(Debug)]
49pub struct LocalPool {
50 threads: Vec<std::thread::JoinHandle<()>>,
51 shutdown_sem: Arc<Semaphore>,
52 cancel_token: CancellationToken,
53 handle: LocalPoolHandle,
54}
55
56impl Deref for LocalPool {
57 type Target = LocalPoolHandle;
58
59 fn deref(&self) -> &Self::Target {
60 &self.handle
61 }
62}
63
64/// A handle to a [`LocalPool`]
65#[derive(Debug, Clone)]
66pub struct LocalPoolHandle {
67 /// The sender half of the channel used to send tasks to the pool
68 send: async_channel::Sender<Message>,
69}
70
71/// What to do when a panic occurs in a pool thread
72#[derive(Clone, Copy, Debug, PartialEq, Eq)]
73pub enum PanicMode {
74 /// Log the panic and continue
75 ///
76 /// The panic will be re-thrown when the pool is dropped.
77 LogAndContinue,
78 /// Log the panic and immediately shut down the pool.
79 ///
80 /// The panic will be re-thrown when the pool is dropped.
81 Shutdown,
82}
83
84/// Local task pool configuration
85#[derive(Clone, Debug)]
86pub struct Config {
87 /// Number of threads in the pool
88 pub threads: usize,
89 /// Prefix for thread names
90 pub thread_name_prefix: &'static str,
91 /// Ignore panics in pool threads
92 pub panic_mode: PanicMode,
93}
94
95impl Default for Config {
96 fn default() -> Self {
97 Self {
98 threads: num_cpus::get(),
99 thread_name_prefix: "local-pool",
100 panic_mode: PanicMode::Shutdown,
101 }
102 }
103}
104
105impl Default for LocalPool {
106 fn default() -> Self {
107 Self::new(Default::default())
108 }
109}
110
111impl LocalPool {
112 /// Create a new local pool with a single std thread.
113 pub fn single() -> Self {
114 Self::new(Config {
115 threads: 1,
116 ..Default::default()
117 })
118 }
119
120 /// Create a new local pool with the given config.
121 ///
122 /// This will use the current tokio runtime handle, so it must be called
123 /// from within a tokio runtime.
124 pub fn new(config: Config) -> Self {
125 let Config {
126 threads,
127 thread_name_prefix,
128 panic_mode,
129 } = config;
130 let cancel_token = CancellationToken::new();
131 let (send, recv) = async_channel::unbounded::<Message>();
132 let shutdown_sem = Arc::new(Semaphore::new(0));
133 let handle = tokio::runtime::Handle::current();
134 let handles = (0..threads)
135 .map(|i| {
136 Self::spawn_pool_thread(
137 format!("{thread_name_prefix}-{i}"),
138 recv.clone(),
139 cancel_token.clone(),
140 panic_mode,
141 shutdown_sem.clone(),
142 handle.clone(),
143 )
144 })
145 .collect::<std::io::Result<Vec<_>>>()
146 .expect("invalid thread name");
147 Self {
148 threads: handles,
149 handle: LocalPoolHandle { send },
150 cancel_token,
151 shutdown_sem,
152 }
153 }
154
155 /// Get a cheaply cloneable handle to the pool
156 ///
157 /// This is not strictly necessary since we implement deref for
158 /// LocalPoolHandle, but makes getting a handle more explicit.
159 pub fn handle(&self) -> &LocalPoolHandle {
160 &self.handle
161 }
162
163 /// Spawn a new pool thread.
164 fn spawn_pool_thread(
165 thread_name: String,
166 recv: async_channel::Receiver<Message>,
167 cancel_token: CancellationToken,
168 panic_mode: PanicMode,
169 shutdown_sem: Arc<Semaphore>,
170 handle: tokio::runtime::Handle,
171 ) -> std::io::Result<std::thread::JoinHandle<()>> {
172 let tracing_dispatcher = tracing::dispatcher::get_default(|dispatcher| dispatcher.clone());
173 std::thread::Builder::new()
174 .name(thread_name)
175 .spawn(move || {
176 let _tracing_guard = tracing::dispatcher::set_default(&tracing_dispatcher);
177 let mut s = JoinSet::new();
178 let mut last_panic = None;
179 let mut handle_join = |res: Option<std::result::Result<(), JoinError>>| -> bool {
180 if let Some(Err(e)) = res {
181 if let Ok(panic) = e.try_into_panic() {
182 let panic_info = get_panic_info(&panic);
183 let thread_name = get_thread_name();
184 tracing::error!(
185 "Panic in local pool thread: {}\n{}",
186 thread_name,
187 panic_info
188 );
189 last_panic = Some(panic);
190 }
191 }
192 panic_mode == PanicMode::LogAndContinue || last_panic.is_none()
193 };
194 let ls = LocalSet::new();
195 let shutdown_mode = handle.block_on(ls.run_until(async {
196 loop {
197 tokio::select! {
198 // poll the set of futures
199 res = s.join_next(), if !s.is_empty() => {
200 if !handle_join(res) {
201 break ShutdownMode::Stop;
202 }
203 },
204 // if the cancel token is cancelled, break the loop immediately
205 _ = cancel_token.cancelled() => break ShutdownMode::Stop,
206 // if we receive a message, execute it
207 msg = recv.recv() => {
208 match msg {
209 // just push into the join set
210 Ok(Message::Execute(f)) => {
211 s.spawn_local((f)());
212 }
213 // break with optional semaphore
214 Ok(Message::Finish) => break ShutdownMode::Finish,
215 // if the sender is dropped, break the loop immediately
216 Err(async_channel::RecvError) => break ShutdownMode::Stop,
217 }
218 },
219 }
220 }
221 }));
222 // soft shutdown mode is just like normal running, except that
223 // we don't add any more tasks and stop when there are no more
224 // tasks to run.
225 if shutdown_mode == ShutdownMode::Finish {
226 // somebody is asking for a clean shutdown, wait for all tasks to finish
227 handle.block_on(ls.run_until(async {
228 loop {
229 tokio::select! {
230 res = s.join_next() => {
231 if res.is_none() || !handle_join(res) {
232 break;
233 }
234 }
235 _ = cancel_token.cancelled() => break,
236 }
237 }
238 }));
239 }
240 // Always add the permit. If nobody is waiting for it, it does
241 // no harm.
242 shutdown_sem.add_permits(1);
243 if let Some(_panic) = last_panic {
244 // std::panic::resume_unwind(panic);
245 }
246 })
247 }
248
249 /// A future that resolves when the pool is cancelled
250 pub async fn cancelled(&self) {
251 self.cancel_token.cancelled().await
252 }
253
254 /// Immediately stop polling all tasks and wait for all threads to finish.
255 ///
256 /// This is like drop, but waits for thread completion asynchronously.
257 ///
258 /// If there was a panic on any of the threads, it will be re-thrown here.
259 pub async fn shutdown(self) {
260 self.cancel_token.cancel();
261 self.await_thread_completion().await;
262 // just make it explicit that this is where drop runs
263 drop(self);
264 }
265
266 /// Gently shut down the pool
267 ///
268 /// Notifies all the pool threads to shut down and waits for them to finish.
269 ///
270 /// If you just want to drop the pool without giving the threads a chance to
271 /// process their remaining tasks, just use [`Self::shutdown`].
272 ///
273 /// If you want to wait for only a limited time for the tasks to finish,
274 /// you can race this function with a timeout.
275 pub async fn finish(self) {
276 // we assume that there are exactly as many threads as there are handles.
277 // also, we assume that the threads are still running.
278 for _ in 0..self.threads_u32() {
279 // send the shutdown message
280 // sending will fail if all threads are already finished, but
281 // in that case we don't need to do anything.
282 //
283 // Threads will add a permit in any case, so await_thread_completion
284 // will then immediately return.
285 self.send.send(Message::Finish).await.ok();
286 }
287 self.await_thread_completion().await;
288 }
289
290 fn threads_u32(&self) -> u32 {
291 self.threads
292 .len()
293 .try_into()
294 .expect("invalid number of threads")
295 }
296
297 async fn await_thread_completion(&self) {
298 // wait for all threads to finish.
299 // Each thread will add a permit to the semaphore.
300 let wait_for_semaphore = async move {
301 let _ = self
302 .shutdown_sem
303 .acquire_many(self.threads_u32())
304 .await
305 .expect("semaphore closed");
306 };
307 // race the semaphore wait with the cancel token in case somebody
308 // cancels the pool while we are waiting.
309 tokio::select! {
310 _ = wait_for_semaphore => {}
311 _ = self.cancel_token.cancelled() => {}
312 }
313 }
314}
315
316impl Drop for LocalPool {
317 fn drop(&mut self) {
318 self.cancel_token.cancel();
319 let current_thread_id = std::thread::current().id();
320 for handle in self.threads.drain(..) {
321 // we have no control over from where Drop is called, especially
322 // if the pool ends up in an Arc. So we need to check if we are
323 // dropping from within a pool thread and skip it in that case.
324 if handle.thread().id() == current_thread_id {
325 tracing::error!("Dropping LocalPool from within a pool thread.");
326 continue;
327 }
328 // Log any panics and resume them
329 if let Err(panic) = handle.join() {
330 let panic_info = get_panic_info(&panic);
331 let thread_name = get_thread_name();
332 tracing::error!("Error joining thread: {}\n{}", thread_name, panic_info);
333 // std::panic::resume_unwind(panic);
334 }
335 }
336 }
337}
338
339/// Errors for spawn failures
340#[derive(thiserror::Error, Debug)]
341pub enum SpawnError {
342 /// Task was dropped, either due to a panic or because the pool was shut down.
343 #[error("cancelled")]
344 Cancelled,
345}
346
347type SpawnResult<T> = std::result::Result<T, SpawnError>;
348
349/// Future returned by [`LocalPoolHandle::spawn`] and [`LocalPoolHandle::try_spawn`].
350///
351/// Dropping this future will immediately cancel the task. The task can fail if
352/// the pool is shut down or if the task panics. In both cases the future will
353/// resolve to [`SpawnError::Cancelled`].
354#[repr(transparent)]
355#[derive(Debug)]
356pub struct Run<T>(tokio::sync::oneshot::Receiver<T>);
357
358impl<T> Run<T> {
359 /// Abort the task
360 ///
361 /// Dropping the future will also abort the task.
362 pub fn abort(&mut self) {
363 self.0.close();
364 }
365}
366
367impl<T> Future for Run<T> {
368 type Output = std::result::Result<T, SpawnError>;
369
370 fn poll(
371 mut self: Pin<&mut Self>,
372 cx: &mut std::task::Context<'_>,
373 ) -> std::task::Poll<Self::Output> {
374 // map a RecvError (other side was dropped) to a SpawnError::Shutdown
375 //
376 // The only way the receiver can be dropped is if the pool is shut down.
377 self.0.poll(cx).map_err(|_| SpawnError::Cancelled)
378 }
379}
380
381impl From<SpawnError> for std::io::Error {
382 fn from(e: SpawnError) -> Self {
383 std::io::Error::new(std::io::ErrorKind::Other, e)
384 }
385}
386
387impl LocalPoolHandle {
388 /// Get the number of tasks in the queue
389 ///
390 /// This is *not* the number of tasks being executed, but the number of
391 /// tasks waiting to be scheduled for execution. If this number is high,
392 /// it indicates that the pool is very busy.
393 ///
394 /// You might want to use this to throttle or reject requests.
395 pub fn waiting_tasks(&self) -> usize {
396 self.send.len()
397 }
398
399 /// Spawn a task in the pool and return a future that resolves when the task
400 /// is done.
401 ///
402 /// If you don't care about the result, prefer [`LocalPoolHandle::spawn_detached`]
403 /// since it is more efficient.
404 pub fn try_spawn<T, F, Fut>(&self, gen: F) -> SpawnResult<Run<T>>
405 where
406 F: FnOnce() -> Fut + Send + 'static,
407 Fut: Future<Output = T> + 'static,
408 T: Send + 'static,
409 {
410 let (mut send_res, recv_res) = tokio::sync::oneshot::channel();
411 let item = move || async move {
412 let fut = (gen)();
413 tokio::select! {
414 // send the result to the receiver
415 res = fut => { send_res.send(res).ok(); }
416 // immediately stop the task if the receiver is dropped
417 _ = send_res.closed() => {}
418 }
419 };
420 self.try_spawn_detached(item)?;
421 Ok(Run(recv_res))
422 }
423
424 /// Spawn a task in the pool.
425 ///
426 /// The task will run to completion unless the pool is shut down or the task
427 /// panics. In case of panic, the pool will either log the panic and continue
428 /// or immediately shut down, depending on the [`PanicMode`].
429 pub fn try_spawn_detached<F, Fut>(&self, gen: F) -> SpawnResult<()>
430 where
431 F: FnOnce() -> Fut + Send + 'static,
432 Fut: Future<Output = ()> + 'static,
433 {
434 let gen: SpawnFn = Box::new(move || Box::pin(gen()));
435 self.try_spawn_detached_boxed(gen)
436 }
437
438 /// Spawn a task in the pool and await the result.
439 ///
440 /// Like [`LocalPoolHandle::try_spawn`], but panics if the pool is shut down.
441 pub fn spawn<T, F, Fut>(&self, gen: F) -> Run<T>
442 where
443 F: FnOnce() -> Fut + Send + 'static,
444 Fut: Future<Output = T> + 'static,
445 T: Send + 'static,
446 {
447 self.try_spawn(gen).expect("pool is shut down")
448 }
449
450 /// Spawn a task in the pool.
451 ///
452 /// Like [`LocalPoolHandle::try_spawn_detached`], but panics if the pool is shut down.
453 pub fn spawn_detached<F, Fut>(&self, gen: F)
454 where
455 F: FnOnce() -> Fut + Send + 'static,
456 Fut: Future<Output = ()> + 'static,
457 {
458 self.try_spawn_detached(gen).expect("pool is shut down")
459 }
460
461 /// Spawn a task in the pool.
462 ///
463 /// This is like [`LocalPoolHandle::try_spawn_detached`], but assuming that the
464 /// generator function is already boxed. This is the lowest overhead way to
465 /// spawn a task in the pool.
466 pub fn try_spawn_detached_boxed(&self, gen: SpawnFn) -> SpawnResult<()> {
467 self.send
468 .send_blocking(Message::Execute(gen))
469 .map_err(|_| SpawnError::Cancelled)
470 }
471}
472
473/// Thread shutdown mode
474#[derive(Debug, Clone, Copy, PartialEq, Eq)]
475enum ShutdownMode {
476 /// Finish all tasks and then stop
477 Finish,
478 /// Stop immediately
479 Stop,
480}
481
482fn get_panic_info(panic: &Box<dyn Any + Send>) -> String {
483 if let Some(s) = panic.downcast_ref::<&str>() {
484 s.to_string()
485 } else if let Some(s) = panic.downcast_ref::<String>() {
486 s.clone()
487 } else {
488 "Panic info unavailable".to_string()
489 }
490}
491
492fn get_thread_name() -> String {
493 std::thread::current()
494 .name()
495 .unwrap_or("unnamed")
496 .to_string()
497}
498
499/// A lightweight cancellation token
500#[derive(Debug, Clone)]
501struct CancellationToken {
502 inner: Arc<CancellationTokenInner>,
503}
504
505#[derive(Debug)]
506struct CancellationTokenInner {
507 is_cancelled: AtomicBool,
508 notify: Notify,
509}
510
511impl CancellationToken {
512 fn new() -> Self {
513 Self {
514 inner: Arc::new(CancellationTokenInner {
515 is_cancelled: AtomicBool::new(false),
516 notify: Notify::new(),
517 }),
518 }
519 }
520
521 fn cancel(&self) {
522 if !self.inner.is_cancelled.swap(true, Ordering::SeqCst) {
523 self.inner.notify.notify_waiters();
524 }
525 }
526
527 async fn cancelled(&self) {
528 if self.is_cancelled() {
529 return;
530 }
531
532 // Wait for notification if not cancelled
533 self.inner.notify.notified().await;
534 }
535
536 fn is_cancelled(&self) -> bool {
537 self.inner.is_cancelled.load(Ordering::SeqCst)
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use std::{sync::atomic::AtomicU64, time::Duration};
544
545 use tracing::info;
546 use tracing_test::traced_test;
547
548 use super::*;
549
550 /// A struct that simulates a long running drop operation
551 #[derive(Debug)]
552 struct TestDrop(Option<Arc<AtomicU64>>);
553
554 impl Drop for TestDrop {
555 fn drop(&mut self) {
556 // delay to make sure the drop is executed completely
557 std::thread::sleep(Duration::from_millis(100));
558 // increment the drop counter
559 if let Some(counter) = self.0.take() {
560 counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
561 }
562 }
563 }
564
565 impl TestDrop {
566 fn new(counter: Arc<AtomicU64>) -> Self {
567 Self(Some(counter))
568 }
569
570 fn forget(mut self) {
571 self.0.take();
572 }
573 }
574
575 /// Create a non-send test future that captures a TestDrop instance
576 async fn delay_then_drop(x: TestDrop) {
577 tokio::time::sleep(Duration::from_millis(100)).await;
578 // drop x at the end. we will never get here when the future is
579 // no longer polled, but drop should still be called
580 drop(x);
581 }
582
583 /// Use a TestDrop instance to test cancellation
584 async fn delay_then_forget(x: TestDrop, delay: Duration) {
585 tokio::time::sleep(delay).await;
586 x.forget();
587 }
588
589 #[tokio::test]
590 #[traced_test]
591 async fn test_tracing() {
592 // This test wants to make sure that logging inside the pool propagates to the
593 // tracing subscriber that was set for the current thread at the time the pool was
594 // created.
595 //
596 // Look, there should be a custom tracing subscriber here that allows us to inspect
597 // the messages sent to it so we can verify it received all the messages. But have
598 // you ever tried to implement a tracing subscriber? In the mean time this test will
599 // just always pass, to really see the test run it with:
600 //
601 // cargo nextest run -p iroh-blobs local_pool::tests::test_tracing --success-output final
602 //
603 // and eyeball the output. yolo
604 info!("hello from the test");
605 let pool = LocalPool::single();
606 pool.spawn(|| async move {
607 info!("hello from the pool");
608 })
609 .await
610 .unwrap();
611 }
612
613 #[tokio::test]
614 async fn test_drop() {
615 let _ = tracing_subscriber::fmt::try_init();
616 let pool = LocalPool::new(Config::default());
617 let counter = Arc::new(AtomicU64::new(0));
618 let n = 4;
619 for _ in 0..n {
620 let td = TestDrop::new(counter.clone());
621 pool.spawn_detached(move || delay_then_drop(td));
622 }
623 drop(pool);
624 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), n);
625 }
626
627 #[tokio::test]
628 async fn test_finish() {
629 let _ = tracing_subscriber::fmt::try_init();
630 let pool = LocalPool::new(Config::default());
631 let counter = Arc::new(AtomicU64::new(0));
632 let n = 4;
633 for _ in 0..n {
634 let td = TestDrop::new(counter.clone());
635 pool.spawn_detached(move || delay_then_drop(td));
636 }
637 pool.finish().await;
638 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), n);
639 }
640
641 #[tokio::test]
642 async fn test_cancel() {
643 let _ = tracing_subscriber::fmt::try_init();
644 let pool = LocalPool::new(Config {
645 threads: 2,
646 ..Config::default()
647 });
648 let c1 = Arc::new(AtomicU64::new(0));
649 let td1 = TestDrop::new(c1.clone());
650 let handle = pool.spawn(move || {
651 // this one will be aborted anyway, so use a long delay to make sure
652 // that it does not accidentally run to completion
653 delay_then_forget(td1, Duration::from_secs(10))
654 });
655 drop(handle);
656 let c2 = Arc::new(AtomicU64::new(0));
657 let td2 = TestDrop::new(c2.clone());
658 let _handle = pool.spawn(move || {
659 // this one will not be aborted, so use a short delay so the test
660 // does not take too long
661 delay_then_forget(td2, Duration::from_millis(100))
662 });
663 pool.finish().await;
664 // c1 will be aborted, so drop will run before forget, so the counter will be increased
665 assert_eq!(c1.load(std::sync::atomic::Ordering::SeqCst), 1);
666 // c2 will not be aborted, so drop will run after forget, so the counter will not be increased
667 assert_eq!(c2.load(std::sync::atomic::Ordering::SeqCst), 0);
668 }
669
670 // #[tokio::test]
671 // #[should_panic]
672 // #[ignore = "todo"]
673 // async fn test_panic() {
674 // let _ = tracing_subscriber::fmt::try_init();
675 // let pool = LocalPool::new(Config {
676 // threads: 2,
677 // ..Config::default()
678 // });
679 // pool.spawn_detached(|| async {
680 // panic!("test panic");
681 // });
682 // // we can't use shutdown here, because we need to allow time for the
683 // // panic to happen.
684 // pool.finish().await;
685 // }
686}