quickwit_actors/
channel_with_priority.rs

1// Copyright (C) 2021 Quickwit, Inc.
2//
3// Quickwit is offered under the AGPL v3.0 and as commercial software.
4// For commercial licensing, contact us at hello@quickwit.io.
5//
6// AGPL:
7// This program is free software: you can redistribute it and/or modify
8// it under the terms of the GNU Affero General Public License as
9// published by the Free Software Foundation, either version 3 of the
10// License, or (at your option) any later version.
11//
12// This program is distributed in the hope that it will be useful,
13// but WITHOUT ANY WARRANTY; without even the implied warranty of
14// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15// GNU Affero General Public License for more details.
16//
17// You should have received a copy of the GNU Affero General Public License
18// along with this program. If not, see <http://www.gnu.org/licenses/>.
19
20use std::time::Duration;
21
22use flume::TryRecvError;
23use thiserror::Error;
24
25#[derive(Debug, Error)]
26pub enum SendError {
27    #[error("The channel is closed.")]
28    Disconnected,
29    #[error("The channel is full.")]
30    Full,
31}
32
33#[derive(Debug, Error, Copy, Clone, PartialEq, Eq)]
34pub enum RecvError {
35    #[error("A timeout occured when attempting to receive a message.")]
36    Timeout,
37    #[error("All sender were dropped an no message are pending in the channel.")]
38    Disconnected,
39}
40
41impl From<flume::RecvTimeoutError> for RecvError {
42    fn from(flume_err: flume::RecvTimeoutError) -> Self {
43        match flume_err {
44            flume::RecvTimeoutError::Timeout => Self::Timeout,
45            flume::RecvTimeoutError::Disconnected => Self::Disconnected,
46        }
47    }
48}
49
50impl<T> From<flume::SendError<T>> for SendError {
51    fn from(_send_error: flume::SendError<T>) -> Self {
52        SendError::Disconnected
53    }
54}
55
56impl<T> From<flume::TrySendError<T>> for SendError {
57    fn from(try_send_error: flume::TrySendError<T>) -> Self {
58        match try_send_error {
59            flume::TrySendError::Full(_) => SendError::Full,
60            flume::TrySendError::Disconnected(_) => SendError::Disconnected,
61        }
62    }
63}
64
65#[derive(Debug, Clone, Copy, Eq, PartialEq)]
66pub enum Priority {
67    High,
68    Low,
69}
70
71#[derive(Clone, Copy, Debug)]
72pub enum QueueCapacity {
73    Bounded(usize),
74    Unbounded,
75}
76
77impl QueueCapacity {
78    pub(crate) fn create_channel<M>(&self) -> (flume::Sender<M>, flume::Receiver<M>) {
79        match *self {
80            QueueCapacity::Bounded(cap) => flume::bounded(cap),
81            QueueCapacity::Unbounded => flume::unbounded(),
82        }
83    }
84}
85
86pub fn channel<T>(queue_capacity: QueueCapacity) -> (Sender<T>, Receiver<T>) {
87    let (high_priority_tx, high_priority_rx) = flume::unbounded();
88    let (low_priority_tx, low_priority_rx) = queue_capacity.create_channel();
89    let receiver = Receiver {
90        low_priority_rx,
91        high_priority_rx,
92        _high_priority_tx: high_priority_tx.clone(),
93        pending: None,
94    };
95    let sender = Sender {
96        low_priority_tx,
97        high_priority_tx,
98    };
99    (sender, receiver)
100}
101
102pub struct Sender<T> {
103    low_priority_tx: flume::Sender<T>,
104    high_priority_tx: flume::Sender<T>,
105}
106
107impl<T> Sender<T> {
108    fn channel(&self, priority: Priority) -> &flume::Sender<T> {
109        match priority {
110            Priority::High => &self.high_priority_tx,
111            Priority::Low => &self.low_priority_tx,
112        }
113    }
114    pub async fn send(&self, msg: T, priority: Priority) -> Result<(), SendError> {
115        self.channel(priority).send_async(msg).await?;
116        Ok(())
117    }
118}
119
120pub struct Receiver<T> {
121    low_priority_rx: flume::Receiver<T>,
122    high_priority_rx: flume::Receiver<T>,
123    _high_priority_tx: flume::Sender<T>,
124    pending: Option<T>,
125}
126
127impl<T> Receiver<T> {
128    fn try_recv_high_priority_message(&self) -> Option<T> {
129        match self.high_priority_rx.try_recv() {
130            Ok(msg) => Some(msg),
131            Err(TryRecvError::Disconnected) => {
132                unreachable!(
133                    "This can never happen, as the high priority Sender is owned by the Receiver."
134                );
135            }
136            Err(TryRecvError::Empty) => None,
137        }
138    }
139
140    pub async fn recv_high_priority_timeout(&mut self, duration: Duration) -> Result<T, RecvError> {
141        tokio::select! {
142            high_priority_msg_res = self.high_priority_rx.recv_async() => {
143                match high_priority_msg_res {
144                    Ok(high_priority_msg) => { Ok(high_priority_msg) },
145                    Err(_) => { unreachable!("The Receiver owns the high priority Sender to avoid any disconnection.") }, }
146                }
147            _ = tokio::time::sleep(duration) => {
148                Err(RecvError::Timeout)
149            }
150        }
151    }
152
153    pub async fn recv_timeout(&mut self, duration: Duration) -> Result<T, RecvError> {
154        if let Some(msg) = self.try_recv_high_priority_message() {
155            return Ok(msg);
156        }
157        if let Some(pending_msg) = self.pending.take() {
158            return Ok(pending_msg);
159        }
160        tokio::select! {
161            high_priority_msg_res = self.high_priority_rx.recv_async() => {
162                match high_priority_msg_res {
163                    Ok(high_priority_msg) => {
164                        Ok(high_priority_msg)
165                    },
166                    Err(_) => {
167                        unreachable!("The Receiver owns the high priority Sender to avoid any disconnection.")
168                    },
169                }
170            }
171            low_priority_msg_res = self.low_priority_rx.recv_async() => {
172                match low_priority_msg_res {
173                    Ok(low_priority_msg) => {
174                        if let Some(high_priority_msg) = self.try_recv_high_priority_message() {
175                            self.pending = Some(low_priority_msg);
176                            Ok(high_priority_msg)
177                        } else {
178                            Ok(low_priority_msg)
179                        }
180                    },
181                    Err(flume::RecvError::Disconnected) => {
182                        if let Some(high_priority_msg) = self.try_recv_high_priority_message() {
183                            Ok(high_priority_msg)
184                        } else {
185                            Err(RecvError::Disconnected)
186                        }
187                    }
188                }
189           }
190            _ = tokio::time::sleep(duration) => {
191                Err(RecvError::Timeout)
192            }
193        }
194    }
195
196    /// Drain all of the pending low priority messages and return them.
197    pub fn drain_low_priority(&self) -> Vec<T> {
198        let mut messages = Vec::new();
199        while let Ok(msg) = self.low_priority_rx.try_recv() {
200            messages.push(msg);
201        }
202        messages
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use std::time::{Duration, Instant};
209
210    use super::*;
211
212    const TEST_TIMEOUT: Duration = Duration::from_millis(100);
213
214    #[tokio::test]
215    async fn test_recv_timeout_prority() -> anyhow::Result<()> {
216        let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
217        sender.send(1, Priority::Low).await?;
218        sender.send(2, Priority::High).await?;
219        assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(2));
220        assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(1));
221        assert_eq!(
222            receiver.recv_timeout(TEST_TIMEOUT).await,
223            Err(RecvError::Timeout)
224        );
225        Ok(())
226    }
227
228    #[tokio::test]
229    async fn test_recv_high_priority_timeout() -> anyhow::Result<()> {
230        let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
231        sender.send(1, Priority::Low).await?;
232        assert_eq!(
233            receiver.recv_high_priority_timeout(TEST_TIMEOUT).await,
234            Err(RecvError::Timeout)
235        );
236        Ok(())
237    }
238
239    #[tokio::test]
240    async fn test_recv_high_priority_ignore_disconnection() -> anyhow::Result<()> {
241        let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
242        std::mem::drop(sender);
243        assert_eq!(
244            receiver.recv_high_priority_timeout(TEST_TIMEOUT).await,
245            Err(RecvError::Timeout)
246        );
247        Ok(())
248    }
249
250    #[tokio::test]
251    async fn test_recv_disconnect() -> anyhow::Result<()> {
252        let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
253        std::mem::drop(sender);
254        assert_eq!(
255            receiver.recv_timeout(TEST_TIMEOUT).await,
256            Err(RecvError::Disconnected)
257        );
258        Ok(())
259    }
260
261    #[tokio::test]
262    async fn test_recv_timeout_simple() -> anyhow::Result<()> {
263        let (_sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
264        let start_time = Instant::now();
265        assert_eq!(
266            receiver.recv_timeout(TEST_TIMEOUT).await,
267            Err(RecvError::Timeout)
268        );
269        let elapsed = start_time.elapsed();
270        assert!(elapsed < crate::HEARTBEAT);
271        Ok(())
272    }
273
274    #[tokio::test]
275    async fn test_try_recv_prority_corner_case() -> anyhow::Result<()> {
276        let (sender, mut receiver) = super::channel::<usize>(QueueCapacity::Unbounded);
277        tokio::task::spawn(async move {
278            tokio::time::sleep(Duration::from_millis(10)).await;
279            sender.send(1, Priority::High).await?;
280            sender.send(2, Priority::Low).await?;
281            Result::<(), SendError>::Ok(())
282        });
283        assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(1));
284        assert_eq!(receiver.recv_timeout(TEST_TIMEOUT).await, Ok(2));
285        assert_eq!(
286            receiver.recv_timeout(TEST_TIMEOUT).await,
287            Err(RecvError::Disconnected)
288        );
289        Ok(())
290    }
291}