quickwit_actors/
channel_with_priority.rs1use std::time::Duration;
21
22use flume::TryRecvError;
23use thiserror::Error;
24
25#[derive(Debug, Error)]
26pub enum SendError {
27 #[error("The channel is closed.")]
28 Disconnected,
29 #[error("The channel is full.")]
30 Full,
31}
32
33#[derive(Debug, Error, Copy, Clone, PartialEq, Eq)]
34pub enum RecvError {
35 #[error("A timeout occured when attempting to receive a message.")]
36 Timeout,
37 #[error("All sender were dropped an no message are pending in the channel.")]
38 Disconnected,
39}
40
41impl From<flume::RecvTimeoutError> for RecvError {
42 fn from(flume_err: flume::RecvTimeoutError) -> Self {
43 match flume_err {
44 flume::RecvTimeoutError::Timeout => Self::Timeout,
45 flume::RecvTimeoutError::Disconnected => Self::Disconnected,
46 }
47 }
48}
49
50impl<T> From<flume::SendError<T>> for SendError {
51 fn from(_send_error: flume::SendError<T>) -> Self {
52 SendError::Disconnected
53 }
54}
55
56impl<T> From<flume::TrySendError<T>> for SendError {
57 fn from(try_send_error: flume::TrySendError<T>) -> Self {
58 match try_send_error {
59 flume::TrySendError::Full(_) => SendError::Full,
60 flume::TrySendError::Disconnected(_) => SendError::Disconnected,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Copy, Eq, PartialEq)]
66pub enum Priority {
67 High,
68 Low,
69}
70
71#[derive(Clone, Copy, Debug)]
72pub enum QueueCapacity {
73 Bounded(usize),
74 Unbounded,
75}
76
77impl QueueCapacity {
78 pub(crate) fn create_channel<M>(&self) -> (flume::Sender<M>, flume::Receiver<M>) {
79 match *self {
80 QueueCapacity::Bounded(cap) => flume::bounded(cap),
81 QueueCapacity::Unbounded => flume::unbounded(),
82 }
83 }
84}
85
86pub fn channel<T>(queue_capacity: QueueCapacity) -> (Sender<T>, Receiver<T>) {
87 let (high_priority_tx, high_priority_rx) = flume::unbounded();
88 let (low_priority_tx, low_priority_rx) = queue_capacity.create_channel();
89 let receiver = Receiver {
90 low_priority_rx,
91 high_priority_rx,
92 _high_priority_tx: high_priority_tx.clone(),
93 pending: None,
94 };
95 let sender = Sender {
96 low_priority_tx,
97 high_priority_tx,
98 };
99 (sender, receiver)
100}
101
102pub struct Sender<T> {
103 low_priority_tx: flume::Sender<T>,
104 high_priority_tx: flume::Sender<T>,
105}
106
107impl<T> Sender<T> {
108 fn channel(&self, priority: Priority) -> &flume::Sender<T> {
109 match priority {
110 Priority::High => &self.high_priority_tx,
111 Priority::Low => &self.low_priority_tx,
112 }
113 }
114 pub async fn send(&self, msg: T, priority: Priority) -> Result<(), SendError> {
115 self.channel(priority).send_async(msg).await?;
116 Ok(())
117 }
118}
119
120pub struct Receiver<T> {
121 low_priority_rx: flume::Receiver<T>,
122 high_priority_rx: flume::Receiver<T>,
123 _high_priority_tx: flume::Sender<T>,
124 pending: Option<T>,
125}
126
127impl<T> Receiver<T> {
128 fn try_recv_high_priority_message(&self) -> Option<T> {
129 match self.high_priority_rx.try_recv() {
130 Ok(msg) => Some(msg),
131 Err(TryRecvError::Disconnected) => {
132 unreachable!(
133 "This can never happen, as the high priority Sender is owned by the Receiver."
134 );
135 }
136 Err(TryRecvError::Empty) => None,
137 }
138 }
139
140 pub async fn recv_high_priority_timeout(&mut self, duration: Duration) -> Result<T, RecvError> {
141 tokio::select! {
142 high_priority_msg_res = self.high_priority_rx.recv_async() => {
143 match high_priority_msg_res {
144 Ok(high_priority_msg) => { Ok(high_priority_msg) },
145 Err(_) => { unreachable!("The Receiver owns the high priority Sender to avoid any disconnection.") }, }
146 }
147 _ = tokio::time::sleep(duration) => {
148 Err(RecvError::Timeout)
149 }
150 }
151 }
152
153 pub async fn recv_timeout(&mut self, duration: Duration) -> Result<T, RecvError> {
154 if let Some(msg) = self.try_recv_high_priority_message() {
155 return Ok(msg);
156 }
157 if let Some(pending_msg) = self.pending.take() {
158 return Ok(pending_msg);
159 }
160 tokio::select! {
161 high_priority_msg_res = self.high_priority_rx.recv_async() => {
162 match high_priority_msg_res {
163 Ok(high_priority_msg) => {
164 Ok(high_priority_msg)
165 },
166 Err(_) => {
167 unreachable!("The Receiver owns the high priority Sender to avoid any disconnection.")
168 },
169 }
170 }
171 low_priority_msg_res = self.low_priority_rx.recv_async() => {
172 match low_priority_msg_res {
173 Ok(low_priority_msg) => {
174 if let Some(high_priority_msg) = self.try_recv_high_priority_message() {
175 self.pending = Some(low_priority_msg);
176 Ok(high_priority_msg)
177 } else {
178 Ok(low_priority_msg)
179 }
180 },
181 Err(flume::RecvError::Disconnected) => {
182 if let Some(high_priority_msg) = self.try_recv_high_priority_message() {
183 Ok(high_priority_msg)
184 } else {
185 Err(RecvError::Disconnected)
186 }
187 }
188 }
189 }
190 _ = tokio::time::sleep(duration) => {
191 Err(RecvError::Timeout)
192 }
193 }
194 }
195
196 pub fn drain_low_priority(&self) -> Vec<T> {
198 let mut messages = Vec::new();
199 while let Ok(msg) = self.low_priority_rx.try_recv() {
200 messages.push(msg);
201 }
202 messages
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use std::time::{Duration, Instant};
209
210 use super::*;
211
212 const TEST_TIMEOUT: Duration = Duration::from_millis(100);
213
214 #[tokio::test]
215 async fn test_recv_timeout_prority() -> anyhow::Result<()> {
216 let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
217 sender.send(1, Priority::Low).await?;
218 sender.send(2, Priority::High).await?;
219 assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(2));
220 assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(1));
221 assert_eq!(
222 receiver.recv_timeout(TEST_TIMEOUT).await,
223 Err(RecvError::Timeout)
224 );
225 Ok(())
226 }
227
228 #[tokio::test]
229 async fn test_recv_high_priority_timeout() -> anyhow::Result<()> {
230 let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
231 sender.send(1, Priority::Low).await?;
232 assert_eq!(
233 receiver.recv_high_priority_timeout(TEST_TIMEOUT).await,
234 Err(RecvError::Timeout)
235 );
236 Ok(())
237 }
238
239 #[tokio::test]
240 async fn test_recv_high_priority_ignore_disconnection() -> anyhow::Result<()> {
241 let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
242 std::mem::drop(sender);
243 assert_eq!(
244 receiver.recv_high_priority_timeout(TEST_TIMEOUT).await,
245 Err(RecvError::Timeout)
246 );
247 Ok(())
248 }
249
250 #[tokio::test]
251 async fn test_recv_disconnect() -> anyhow::Result<()> {
252 let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
253 std::mem::drop(sender);
254 assert_eq!(
255 receiver.recv_timeout(TEST_TIMEOUT).await,
256 Err(RecvError::Disconnected)
257 );
258 Ok(())
259 }
260
261 #[tokio::test]
262 async fn test_recv_timeout_simple() -> anyhow::Result<()> {
263 let (_sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
264 let start_time = Instant::now();
265 assert_eq!(
266 receiver.recv_timeout(TEST_TIMEOUT).await,
267 Err(RecvError::Timeout)
268 );
269 let elapsed = start_time.elapsed();
270 assert!(elapsed < crate::HEARTBEAT);
271 Ok(())
272 }
273
274 #[tokio::test]
275 async fn test_try_recv_prority_corner_case() -> anyhow::Result<()> {
276 let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
277 tokio::task::spawn(async move {
278 tokio::time::sleep(Duration::from_millis(10)).await;
279 sender.send(1, Priority::High).await?;
280 sender.send(2, Priority::Low).await?;
281 Result::<(), SendError>::Ok(())
282 });
283 assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(1));
284 assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(2));
285 assert_eq!(
286 receiver.recv_timeout(TEST_TIMEOUT).await,
287 Err(RecvError::Disconnected)
288 );
289 Ok(())
290 }
291}