1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::mpsc;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;

/// Struct that enables functionality like waiting to be notified
/// when the count of a [`crate::Counter`] or [`crate::WeakCounter`] changes.
#[derive(Debug)]
pub struct NotifyHandle {
    receiver: mpsc::Receiver<()>,
    should_send: Arc<AtomicBool>,
    counter: Arc<AtomicUsize>,
}

#[derive(Error, Debug, PartialEq, Clone, Copy)]
pub enum NotifyError {
    #[error("All linked senders are disconnected, therefore count will never change!")]
    Disconnected,
}

#[derive(Error, Debug, PartialEq, Clone, Copy)]
pub enum NotifyTimeoutError {
    #[error("All linked senders are disconnected, therefore count will never change!")]
    Disconnected,
    #[error("Timed out before condition was reached!")]
    Timeout,
}

/// Struct that can send signals to the [`NotifyHandle`].
#[derive(Debug, Clone)]
pub(crate) struct NotifySender {
    should_send: Arc<AtomicBool>,
    sender: mpsc::Sender<()>,
}

impl NotifyHandle {
    /// Create a new [`NotifyHandle`] with a link to the associated count.
    pub(crate) fn new(counter: Arc<AtomicUsize>) -> (NotifyHandle, NotifySender) {
        // Create a new "rendezvous channel". Note that we don't
        // buffer any data in the channel, so memory won't grow if
        // no-one is receiving any data.
        let (sender, receiver) = mpsc::channel();
        let should_send = Arc::new(AtomicBool::new(false));
        (
            NotifyHandle {
                receiver,
                should_send: Arc::clone(&should_send),
                counter,
            },
            NotifySender {
                sender,
                should_send,
            },
        )
    }

    /// Block the current thread until the condition is true. This is
    /// different than spin-looping since the current thread will use channels
    /// internally to be notified when the counter changes.
    pub fn wait_until_condition(
        &self,
        condition: impl Fn(usize) -> bool,
    ) -> Result<(), NotifyError> {
        self.wait_until_condition_inner(condition, |_| self.receiver.recv())
            .map_err(|e| match e {
                mpsc::RecvError => NotifyError::Disconnected,
            })
    }

    /// [`NotifyHandle::wait_until_condition`] with a timeout.
    pub fn wait_until_condition_timeout(
        &self,
        condition: impl Fn(usize) -> bool,
        timeout: Duration,
    ) -> Result<(), NotifyTimeoutError> {
        self.wait_until_condition_inner(condition, |elapsed| {
            let remaining_time = if let Some(remaining_time) = timeout.checked_sub(elapsed) {
                remaining_time
            } else {
                return Err(mpsc::RecvTimeoutError::Timeout);
            };

            self.receiver.recv_timeout(remaining_time)
        })
        .map_err(|e| match e {
            mpsc::RecvTimeoutError::Disconnected => NotifyTimeoutError::Disconnected,
            mpsc::RecvTimeoutError::Timeout => NotifyTimeoutError::Timeout,
        })
    }

    fn wait_until_condition_inner<E>(
        &self,
        condition: impl Fn(usize) -> bool,
        recv_with_elapsed: impl Fn(Duration) -> Result<(), E>,
    ) -> Result<(), E>
    where
        E: FromDisconnected,
    {
        let start = Instant::now();

        // Drain all messages in the channel before turning sends on again.
        while let Ok(()) = self.receiver.try_recv() {}
        self.should_send.store(true, Ordering::SeqCst);

        macro_rules! return_if_condition {
            () => {
                if condition(self.counter.load(Ordering::SeqCst)) {
                    self.should_send.store(false, Ordering::SeqCst);
                    return Ok(());
                }
            };
        }

        return_if_condition!();
        loop {
            // Drain all elements from the channel until it's empty. If there were no
            // elements drained, we block on `recv()`.
            let recv_result = {
                let mut received_at_least_once = false;
                loop {
                    match self.receiver.try_recv() {
                        Ok(()) => received_at_least_once = true,
                        Err(mpsc::TryRecvError::Empty) => {
                            if received_at_least_once {
                                break Ok(());
                            }

                            break recv_with_elapsed(start.elapsed());
                        }
                        Err(mpsc::TryRecvError::Disconnected) => break Err(E::from_disconnected()),
                    }
                }
            };

            // If the receiver thread is disconnected, then the counter
            // will never change again.
            if let Err(err) = recv_result {
                // We should check if the condition is satisfied one last time, then
                // return the error if still unsatisfied. It's possible that the
                // condition has been met even after an error case, eg. all counters
                // are dropped.
                return_if_condition!();

                self.should_send.store(false, Ordering::SeqCst);
                return Err(err);
            }

            return_if_condition!();
        }
    }
}

/// Helper trait for abstracting over `recv()` and `recv_timeout()`.
trait FromDisconnected {
    fn from_disconnected() -> Self;
}

impl FromDisconnected for mpsc::RecvError {
    fn from_disconnected() -> Self {
        mpsc::RecvError
    }
}

impl FromDisconnected for mpsc::RecvTimeoutError {
    fn from_disconnected() -> Self {
        mpsc::RecvTimeoutError::Disconnected
    }
}

impl NotifySender {
    /// Notify the handle.
    pub(crate) fn notify(&self) {
        if self.should_send.load(Ordering::SeqCst) {
            let _ = self.sender.send(());
        }
    }
}