raii_counter/
notify.rs

1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
2use std::sync::mpsc;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use thiserror::Error;
6
7/// Struct that enables functionality like waiting to be notified
8/// when the count of a [`crate::Counter`] or [`crate::WeakCounter`] changes.
9#[derive(Debug)]
10pub struct NotifyHandle {
11    receiver: mpsc::Receiver<()>,
12    should_send: Arc<AtomicBool>,
13    counter: Arc<AtomicUsize>,
14}
15
16#[derive(Error, Debug, PartialEq, Clone, Copy)]
17pub enum NotifyError {
18    #[error("All linked senders are disconnected, therefore count will never change!")]
19    Disconnected,
20}
21
22#[derive(Error, Debug, PartialEq, Clone, Copy)]
23pub enum NotifyTimeoutError {
24    #[error("All linked senders are disconnected, therefore count will never change!")]
25    Disconnected,
26    #[error("Timed out before condition was reached!")]
27    Timeout,
28}
29
30/// Struct that can send signals to the [`NotifyHandle`].
31#[derive(Debug, Clone)]
32pub(crate) struct NotifySender {
33    should_send: Arc<AtomicBool>,
34    sender: mpsc::Sender<()>,
35}
36
37impl NotifyHandle {
38    /// Create a new [`NotifyHandle`] with a link to the associated count.
39    pub(crate) fn new(counter: Arc<AtomicUsize>) -> (NotifyHandle, NotifySender) {
40        // Create a new "rendezvous channel". Note that we don't
41        // buffer any data in the channel, so memory won't grow if
42        // no-one is receiving any data.
43        let (sender, receiver) = mpsc::channel();
44        let should_send = Arc::new(AtomicBool::new(false));
45        (
46            NotifyHandle {
47                receiver,
48                should_send: Arc::clone(&should_send),
49                counter,
50            },
51            NotifySender {
52                sender,
53                should_send,
54            },
55        )
56    }
57
58    /// Block the current thread until the condition is true. This is
59    /// different than spin-looping since the current thread will use channels
60    /// internally to be notified when the counter changes.
61    pub fn wait_until_condition(
62        &self,
63        condition: impl Fn(usize) -> bool,
64    ) -> Result<(), NotifyError> {
65        self.wait_until_condition_inner(condition, |_| self.receiver.recv())
66            .map_err(|e| match e {
67                mpsc::RecvError => NotifyError::Disconnected,
68            })
69    }
70
71    /// [`NotifyHandle::wait_until_condition`] with a timeout.
72    pub fn wait_until_condition_timeout(
73        &self,
74        condition: impl Fn(usize) -> bool,
75        timeout: Duration,
76    ) -> Result<(), NotifyTimeoutError> {
77        self.wait_until_condition_inner(condition, |elapsed| {
78            let remaining_time = if let Some(remaining_time) = timeout.checked_sub(elapsed) {
79                remaining_time
80            } else {
81                return Err(mpsc::RecvTimeoutError::Timeout);
82            };
83
84            self.receiver.recv_timeout(remaining_time)
85        })
86        .map_err(|e| match e {
87            mpsc::RecvTimeoutError::Disconnected => NotifyTimeoutError::Disconnected,
88            mpsc::RecvTimeoutError::Timeout => NotifyTimeoutError::Timeout,
89        })
90    }
91
92    fn wait_until_condition_inner<E>(
93        &self,
94        condition: impl Fn(usize) -> bool,
95        recv_with_elapsed: impl Fn(Duration) -> Result<(), E>,
96    ) -> Result<(), E>
97    where
98        E: FromDisconnected,
99    {
100        let start = Instant::now();
101
102        // Drain all messages in the channel before turning sends on again.
103        while let Ok(()) = self.receiver.try_recv() {}
104        self.should_send.store(true, Ordering::SeqCst);
105
106        macro_rules! return_if_condition {
107            () => {
108                if condition(self.counter.load(Ordering::SeqCst)) {
109                    self.should_send.store(false, Ordering::SeqCst);
110                    return Ok(());
111                }
112            };
113        }
114
115        return_if_condition!();
116        loop {
117            // Drain all elements from the channel until it's empty. If there were no
118            // elements drained, we block on `recv()`.
119            let recv_result = {
120                let mut received_at_least_once = false;
121                loop {
122                    match self.receiver.try_recv() {
123                        Ok(()) => received_at_least_once = true,
124                        Err(mpsc::TryRecvError::Empty) => {
125                            if received_at_least_once {
126                                break Ok(());
127                            }
128
129                            break recv_with_elapsed(start.elapsed());
130                        }
131                        Err(mpsc::TryRecvError::Disconnected) => break Err(E::from_disconnected()),
132                    }
133                }
134            };
135
136            // If the receiver thread is disconnected, then the counter
137            // will never change again.
138            if let Err(err) = recv_result {
139                // We should check if the condition is satisfied one last time, then
140                // return the error if still unsatisfied. It's possible that the
141                // condition has been met even after an error case, eg. all counters
142                // are dropped.
143                return_if_condition!();
144
145                self.should_send.store(false, Ordering::SeqCst);
146                return Err(err);
147            }
148
149            return_if_condition!();
150        }
151    }
152}
153
154/// Helper trait for abstracting over `recv()` and `recv_timeout()`.
155trait FromDisconnected {
156    fn from_disconnected() -> Self;
157}
158
159impl FromDisconnected for mpsc::RecvError {
160    fn from_disconnected() -> Self {
161        mpsc::RecvError
162    }
163}
164
165impl FromDisconnected for mpsc::RecvTimeoutError {
166    fn from_disconnected() -> Self {
167        mpsc::RecvTimeoutError::Disconnected
168    }
169}
170
171impl NotifySender {
172    /// Notify the handle.
173    pub(crate) fn notify(&self) {
174        if self.should_send.load(Ordering::SeqCst) {
175            let _ = self.sender.send(());
176        }
177    }
178}