fuse_backend_rs/common/
mpmc.rs1use std::collections::VecDeque;
11use std::io::{Error, ErrorKind, Result};
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Mutex, MutexGuard};
14use tokio::sync::Notify;
15
16pub struct Channel<T> {
18 closed: AtomicBool,
19 notifier: Notify,
20 requests: Mutex<VecDeque<T>>,
21}
22
23impl<T> Default for Channel<T> {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl<T> Channel<T> {
30 pub fn new() -> Self {
32 Channel {
33 closed: AtomicBool::new(false),
34 notifier: Notify::new(),
35 requests: Mutex::new(VecDeque::new()),
36 }
37 }
38
39 pub fn close(&self) {
41 self.closed.store(true, Ordering::Release);
42 self.notifier.notify_waiters();
43 }
44
45 pub fn send(&self, msg: T) -> std::result::Result<(), T> {
49 if self.closed.load(Ordering::Acquire) {
50 Err(msg)
51 } else {
52 self.requests.lock().unwrap().push_back(msg);
53 self.notifier.notify_one();
54 Ok(())
55 }
56 }
57
58 pub fn try_recv(&self) -> Option<T> {
60 self.requests.lock().unwrap().pop_front()
61 }
62
63 pub async fn recv(&self) -> Result<T> {
65 let future = self.notifier.notified();
66 tokio::pin!(future);
67
68 loop {
69 future.as_mut().enable();
71
72 if let Some(msg) = self.try_recv() {
73 return Ok(msg);
74 } else if self.closed.load(Ordering::Acquire) {
75 return Err(Error::new(ErrorKind::BrokenPipe, "channel has been closed"));
76 }
77
78 future.as_mut().await;
83
84 future.set(self.notifier.notified());
86 }
87 }
88
89 pub fn flush_pending_prefetch_requests<F>(&self, mut f: F)
92 where
93 F: FnMut(&T) -> bool,
94 {
95 self.requests.lock().unwrap().retain(|t| !f(t));
96 }
97
98 pub fn lock_channel(&self) -> MutexGuard<VecDeque<T>> {
100 self.requests.lock().unwrap()
101 }
102
103 pub fn notify_waiters(&self) {
105 self.notifier.notify_waiters();
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use std::sync::Arc;
113
114 #[test]
115 fn test_new_channel() {
116 let channel = Channel::new();
117
118 channel.send(1u32).unwrap();
119 channel.send(2u32).unwrap();
120 assert_eq!(channel.try_recv().unwrap(), 1);
121 assert_eq!(channel.try_recv().unwrap(), 2);
122
123 channel.close();
124 channel.send(2u32).unwrap_err();
125 }
126
127 #[test]
128 fn test_flush_channel() {
129 let channel = Channel::new();
130
131 channel.send(1u32).unwrap();
132 channel.send(2u32).unwrap();
133 channel.flush_pending_prefetch_requests(|_| true);
134 assert!(channel.try_recv().is_none());
135
136 channel.notify_waiters();
137 let _guard = channel.lock_channel();
138 }
139
140 #[test]
141 fn test_async_recv() {
142 let channel = Arc::new(Channel::new());
143 let channel2 = channel.clone();
144
145 let t = std::thread::spawn(move || {
146 channel2.send(1u32).unwrap();
147 });
148
149 let rt = tokio::runtime::Builder::new_current_thread()
150 .enable_all()
151 .build()
152 .unwrap();
153 rt.block_on(async {
154 let msg = channel.recv().await.unwrap();
155 assert_eq!(msg, 1);
156 });
157
158 t.join().unwrap();
159 }
160}