fuse_backend_rs/common/
mpmc.rs

1// Copyright (C) 2022 Alibaba Cloud. All rights reserved.
2//
3// SPDX-License-Identifier: Apache-2.0
4// Async implementation of Multi-Producer-Multi-Consumer channel.
5
6//! Asynchronous Multi-Producer Multi-Consumer channel.
7//!
8//! This module provides an asynchronous multi-producer multi-consumer channel based on [tokio::sync::Notify].
9
10use std::collections::VecDeque;
11use std::io::{Error, ErrorKind, Result};
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Mutex, MutexGuard};
14use tokio::sync::Notify;
15
16/// An asynchronous multi-producer multi-consumer channel based on [tokio::sync::Notify].
17pub struct Channel<T> {
18    closed: AtomicBool,
19    notifier: Notify,
20    requests: Mutex<VecDeque<T>>,
21}
22
23impl<T> Default for Channel<T> {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl<T> Channel<T> {
30    /// Create a new instance of [`Channel`].
31    pub fn new() -> Self {
32        Channel {
33            closed: AtomicBool::new(false),
34            notifier: Notify::new(),
35            requests: Mutex::new(VecDeque::new()),
36        }
37    }
38
39    /// Close the channel.
40    pub fn close(&self) {
41        self.closed.store(true, Ordering::Release);
42        self.notifier.notify_waiters();
43    }
44
45    /// Send a message to the channel.
46    ///
47    /// The message object will be returned on error, to ease the lifecycle management.
48    pub fn send(&self, msg: T) -> std::result::Result<(), T> {
49        if self.closed.load(Ordering::Acquire) {
50            Err(msg)
51        } else {
52            self.requests.lock().unwrap().push_back(msg);
53            self.notifier.notify_one();
54            Ok(())
55        }
56    }
57
58    /// Try to receive a message from the channel.
59    pub fn try_recv(&self) -> Option<T> {
60        self.requests.lock().unwrap().pop_front()
61    }
62
63    /// Receive message from the channel in asynchronous mode.
64    pub async fn recv(&self) -> Result<T> {
65        let future = self.notifier.notified();
66        tokio::pin!(future);
67
68        loop {
69            // Make sure that no wakeup is lost if we get `None` from `try_recv`.
70            future.as_mut().enable();
71
72            if let Some(msg) = self.try_recv() {
73                return Ok(msg);
74            } else if self.closed.load(Ordering::Acquire) {
75                return Err(Error::new(ErrorKind::BrokenPipe, "channel has been closed"));
76            }
77
78            // Wait for a call to `notify_one`.
79            //
80            // This uses `.as_mut()` to avoid consuming the future,
81            // which lets us call `Pin::set` below.
82            future.as_mut().await;
83
84            // Reset the future in case another call to `try_recv` got the message before us.
85            future.set(self.notifier.notified());
86        }
87    }
88
89    /// Flush all pending requests specified by the predicator.
90    ///
91    pub fn flush_pending_prefetch_requests<F>(&self, mut f: F)
92    where
93        F: FnMut(&T) -> bool,
94    {
95        self.requests.lock().unwrap().retain(|t| !f(t));
96    }
97
98    /// Lock the channel to block all queue operations.
99    pub fn lock_channel(&self) -> MutexGuard<VecDeque<T>> {
100        self.requests.lock().unwrap()
101    }
102
103    /// Notify all waiters.
104    pub fn notify_waiters(&self) {
105        self.notifier.notify_waiters();
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use std::sync::Arc;
113
114    #[test]
115    fn test_new_channel() {
116        let channel = Channel::new();
117
118        channel.send(1u32).unwrap();
119        channel.send(2u32).unwrap();
120        assert_eq!(channel.try_recv().unwrap(), 1);
121        assert_eq!(channel.try_recv().unwrap(), 2);
122
123        channel.close();
124        channel.send(2u32).unwrap_err();
125    }
126
127    #[test]
128    fn test_flush_channel() {
129        let channel = Channel::new();
130
131        channel.send(1u32).unwrap();
132        channel.send(2u32).unwrap();
133        channel.flush_pending_prefetch_requests(|_| true);
134        assert!(channel.try_recv().is_none());
135
136        channel.notify_waiters();
137        let _guard = channel.lock_channel();
138    }
139
140    #[test]
141    fn test_async_recv() {
142        let channel = Arc::new(Channel::new());
143        let channel2 = channel.clone();
144
145        let t = std::thread::spawn(move || {
146            channel2.send(1u32).unwrap();
147        });
148
149        let rt = tokio::runtime::Builder::new_current_thread()
150            .enable_all()
151            .build()
152            .unwrap();
153        rt.block_on(async {
154            let msg = channel.recv().await.unwrap();
155            assert_eq!(msg, 1);
156        });
157
158        t.join().unwrap();
159    }
160}