1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
2use std::sync::mpsc;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use thiserror::Error;
6
7#[derive(Debug)]
10pub struct NotifyHandle {
11 receiver: mpsc::Receiver<()>,
12 should_send: Arc<AtomicBool>,
13 counter: Arc<AtomicUsize>,
14}
15
16#[derive(Error, Debug, PartialEq, Clone, Copy)]
17pub enum NotifyError {
18 #[error("All linked senders are disconnected, therefore count will never change!")]
19 Disconnected,
20}
21
22#[derive(Error, Debug, PartialEq, Clone, Copy)]
23pub enum NotifyTimeoutError {
24 #[error("All linked senders are disconnected, therefore count will never change!")]
25 Disconnected,
26 #[error("Timed out before condition was reached!")]
27 Timeout,
28}
29
30#[derive(Debug, Clone)]
32pub(crate) struct NotifySender {
33 should_send: Arc<AtomicBool>,
34 sender: mpsc::Sender<()>,
35}
36
37impl NotifyHandle {
38 pub(crate) fn new(counter: Arc<AtomicUsize>) -> (NotifyHandle, NotifySender) {
40 let (sender, receiver) = mpsc::channel();
44 let should_send = Arc::new(AtomicBool::new(false));
45 (
46 NotifyHandle {
47 receiver,
48 should_send: Arc::clone(&should_send),
49 counter,
50 },
51 NotifySender {
52 sender,
53 should_send,
54 },
55 )
56 }
57
58 pub fn wait_until_condition(
62 &self,
63 condition: impl Fn(usize) -> bool,
64 ) -> Result<(), NotifyError> {
65 self.wait_until_condition_inner(condition, |_| self.receiver.recv())
66 .map_err(|e| match e {
67 mpsc::RecvError => NotifyError::Disconnected,
68 })
69 }
70
71 pub fn wait_until_condition_timeout(
73 &self,
74 condition: impl Fn(usize) -> bool,
75 timeout: Duration,
76 ) -> Result<(), NotifyTimeoutError> {
77 self.wait_until_condition_inner(condition, |elapsed| {
78 let remaining_time = if let Some(remaining_time) = timeout.checked_sub(elapsed) {
79 remaining_time
80 } else {
81 return Err(mpsc::RecvTimeoutError::Timeout);
82 };
83
84 self.receiver.recv_timeout(remaining_time)
85 })
86 .map_err(|e| match e {
87 mpsc::RecvTimeoutError::Disconnected => NotifyTimeoutError::Disconnected,
88 mpsc::RecvTimeoutError::Timeout => NotifyTimeoutError::Timeout,
89 })
90 }
91
92 fn wait_until_condition_inner<E>(
93 &self,
94 condition: impl Fn(usize) -> bool,
95 recv_with_elapsed: impl Fn(Duration) -> Result<(), E>,
96 ) -> Result<(), E>
97 where
98 E: FromDisconnected,
99 {
100 let start = Instant::now();
101
102 while let Ok(()) = self.receiver.try_recv() {}
104 self.should_send.store(true, Ordering::SeqCst);
105
106 macro_rules! return_if_condition {
107 () => {
108 if condition(self.counter.load(Ordering::SeqCst)) {
109 self.should_send.store(false, Ordering::SeqCst);
110 return Ok(());
111 }
112 };
113 }
114
115 return_if_condition!();
116 loop {
117 let recv_result = {
120 let mut received_at_least_once = false;
121 loop {
122 match self.receiver.try_recv() {
123 Ok(()) => received_at_least_once = true,
124 Err(mpsc::TryRecvError::Empty) => {
125 if received_at_least_once {
126 break Ok(());
127 }
128
129 break recv_with_elapsed(start.elapsed());
130 }
131 Err(mpsc::TryRecvError::Disconnected) => break Err(E::from_disconnected()),
132 }
133 }
134 };
135
136 if let Err(err) = recv_result {
139 return_if_condition!();
144
145 self.should_send.store(false, Ordering::SeqCst);
146 return Err(err);
147 }
148
149 return_if_condition!();
150 }
151 }
152}
153
154trait FromDisconnected {
156 fn from_disconnected() -> Self;
157}
158
159impl FromDisconnected for mpsc::RecvError {
160 fn from_disconnected() -> Self {
161 mpsc::RecvError
162 }
163}
164
165impl FromDisconnected for mpsc::RecvTimeoutError {
166 fn from_disconnected() -> Self {
167 mpsc::RecvTimeoutError::Disconnected
168 }
169}
170
171impl NotifySender {
172 pub(crate) fn notify(&self) {
174 if self.should_send.load(Ordering::SeqCst) {
175 let _ = self.sender.send(());
176 }
177 }
178}