richat_shared/
shutdown.rs

1use {
2    slab::Slab,
3    std::{
4        future::Future,
5        pin::Pin,
6        sync::{Arc, Mutex, MutexGuard},
7        task::{Context, Poll, Waker},
8    },
9};
10
11#[derive(Debug)]
12pub struct Shutdown {
13    state: Arc<Mutex<State>>,
14    index: usize,
15}
16
17impl Shutdown {
18    pub fn new() -> Self {
19        let mut state = State {
20            shutdown: false,
21            wakers: Slab::with_capacity(64),
22        };
23        let index = state.wakers.insert(None);
24
25        Self {
26            state: Arc::new(Mutex::new(state)),
27            index,
28        }
29    }
30
31    fn state_lock(&self) -> MutexGuard<'_, State> {
32        match self.state.lock() {
33            Ok(guard) => guard,
34            Err(error) => error.into_inner(),
35        }
36    }
37
38    pub fn shutdown(&self) {
39        let mut state = self.state_lock();
40        state.shutdown = true;
41        for (_index, value) in state.wakers.iter_mut() {
42            if let Some(waker) = value.take() {
43                waker.wake();
44            }
45        }
46    }
47
48    pub fn is_set(&self) -> bool {
49        self.state_lock().shutdown
50    }
51}
52
53impl Default for Shutdown {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl Clone for Shutdown {
60    fn clone(&self) -> Self {
61        let mut state = self.state_lock();
62        let index = state.wakers.insert(None);
63
64        Self {
65            state: Arc::clone(&self.state),
66            index,
67        }
68    }
69}
70
71impl Drop for Shutdown {
72    fn drop(&mut self) {
73        let mut state = self.state_lock();
74        state.wakers.remove(self.index);
75    }
76}
77
78impl Future for Shutdown {
79    type Output = ();
80
81    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
82        let me = self.as_ref().get_ref();
83        let mut state = me.state_lock();
84
85        if state.shutdown {
86            return Poll::Ready(());
87        }
88
89        state.wakers[self.index] = Some(cx.waker().clone());
90        Poll::Pending
91    }
92}
93
94#[derive(Debug)]
95struct State {
96    shutdown: bool,
97    wakers: Slab<Option<Waker>>,
98}