nydus_utils/
mpmc.rs

1// Copyright (C) 2022 Alibaba Cloud. All rights reserved.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5//! Asynchronous Multi-Producer Multi-Consumer channel.
6//!
7//! This module provides an asynchronous multi-producer multi-consumer channel based on [tokio::sync::Notify].
8
9use std::collections::VecDeque;
10use std::io::{Error, ErrorKind, Result};
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Mutex, MutexGuard};
13use tokio::sync::Notify;
14
15/// An asynchronous multi-producer multi-consumer channel based on [tokio::sync::Notify].
16pub struct Channel<T> {
17    closed: AtomicBool,
18    notifier: Notify,
19    requests: Mutex<VecDeque<T>>,
20}
21
22impl<T> Default for Channel<T> {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl<T> Channel<T> {
29    /// Create a new instance of [`Channel`].
30    pub fn new() -> Self {
31        Channel {
32            closed: AtomicBool::new(false),
33            notifier: Notify::new(),
34            requests: Mutex::new(VecDeque::new()),
35        }
36    }
37
38    /// Close the channel.
39    pub fn close(&self) {
40        self.closed.store(true, Ordering::Release);
41        self.notifier.notify_waiters();
42    }
43
44    /// Send a message to the channel.
45    ///
46    /// The message object will be returned on error, to ease the lifecycle management.
47    pub fn send(&self, msg: T) -> std::result::Result<(), T> {
48        if self.closed.load(Ordering::Acquire) {
49            Err(msg)
50        } else {
51            self.requests.lock().unwrap().push_back(msg);
52            self.notifier.notify_one();
53            Ok(())
54        }
55    }
56
57    /// Try to receive a message from the channel.
58    pub fn try_recv(&self) -> Option<T> {
59        self.requests.lock().unwrap().pop_front()
60    }
61
62    /// Receive message from the channel in asynchronous mode.
63    pub async fn recv(&self) -> Result<T> {
64        let future = self.notifier.notified();
65        tokio::pin!(future);
66
67        loop {
68            // Make sure that no wakeup is lost if we get `None` from `try_recv`.
69            future.as_mut().enable();
70
71            if let Some(msg) = self.try_recv() {
72                return Ok(msg);
73            } else if self.closed.load(Ordering::Acquire) {
74                return Err(Error::new(ErrorKind::BrokenPipe, "channel has been closed"));
75            }
76
77            // Wait for a call to `notify_one`.
78            //
79            // This uses `.as_mut()` to avoid consuming the future,
80            // which lets us call `Pin::set` below.
81            future.as_mut().await;
82
83            // Reset the future in case another call to `try_recv` got the message before us.
84            future.set(self.notifier.notified());
85        }
86    }
87
88    /// Flush all pending requests specified by the predicator.
89    ///
90    pub fn flush_pending_prefetch_requests<F>(&self, mut f: F)
91    where
92        F: FnMut(&T) -> bool,
93    {
94        self.requests.lock().unwrap().retain(|t| !f(t));
95    }
96
97    /// Lock the channel to block all queue operations.
98    pub fn lock_channel(&self) -> MutexGuard<VecDeque<T>> {
99        self.requests.lock().unwrap()
100    }
101
102    /// Notify all waiters.
103    pub fn notify_waiters(&self) {
104        self.notifier.notify_waiters();
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use std::sync::Arc;
112
113    #[test]
114    fn test_new_channel() {
115        let channel = Channel::new();
116
117        channel.send(1u32).unwrap();
118        channel.send(2u32).unwrap();
119        assert_eq!(channel.try_recv().unwrap(), 1);
120        assert_eq!(channel.try_recv().unwrap(), 2);
121
122        channel.close();
123        channel.send(2u32).unwrap_err();
124    }
125
126    #[test]
127    fn test_flush_channel() {
128        let channel = Channel::new();
129
130        channel.send(1u32).unwrap();
131        channel.send(2u32).unwrap();
132        channel.flush_pending_prefetch_requests(|_| true);
133        assert!(channel.try_recv().is_none());
134
135        channel.notify_waiters();
136        let _guard = channel.lock_channel();
137    }
138
139    #[test]
140    fn test_async_recv() {
141        let channel = Arc::new(Channel::new());
142        let channel2 = channel.clone();
143
144        let t = std::thread::spawn(move || {
145            channel2.send(1u32).unwrap();
146        });
147
148        let rt = tokio::runtime::Builder::new_current_thread()
149            .enable_all()
150            .build()
151            .unwrap();
152        rt.block_on(async {
153            let msg = channel.recv().await.unwrap();
154            assert_eq!(msg, 1);
155        });
156
157        t.join().unwrap();
158    }
159
160    #[test]
161    fn test_default_channel_send_and_recv() {
162        let channel = Channel::default();
163
164        channel.send(0x1u32).unwrap();
165        channel.send(0x2u32).unwrap();
166        assert_eq!(channel.try_recv().unwrap(), 0x1);
167        assert_eq!(channel.try_recv().unwrap(), 0x2);
168
169        channel.close();
170        channel.send(2u32).unwrap_err();
171    }
172}