1use std::collections::VecDeque;
10use std::io::{Error, ErrorKind, Result};
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Mutex, MutexGuard};
13use tokio::sync::Notify;
14
15pub struct Channel<T> {
17 closed: AtomicBool,
18 notifier: Notify,
19 requests: Mutex<VecDeque<T>>,
20}
21
22impl<T> Default for Channel<T> {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl<T> Channel<T> {
29 pub fn new() -> Self {
31 Channel {
32 closed: AtomicBool::new(false),
33 notifier: Notify::new(),
34 requests: Mutex::new(VecDeque::new()),
35 }
36 }
37
38 pub fn close(&self) {
40 self.closed.store(true, Ordering::Release);
41 self.notifier.notify_waiters();
42 }
43
44 pub fn send(&self, msg: T) -> std::result::Result<(), T> {
48 if self.closed.load(Ordering::Acquire) {
49 Err(msg)
50 } else {
51 self.requests.lock().unwrap().push_back(msg);
52 self.notifier.notify_one();
53 Ok(())
54 }
55 }
56
57 pub fn try_recv(&self) -> Option<T> {
59 self.requests.lock().unwrap().pop_front()
60 }
61
62 pub async fn recv(&self) -> Result<T> {
64 let future = self.notifier.notified();
65 tokio::pin!(future);
66
67 loop {
68 future.as_mut().enable();
70
71 if let Some(msg) = self.try_recv() {
72 return Ok(msg);
73 } else if self.closed.load(Ordering::Acquire) {
74 return Err(Error::new(ErrorKind::BrokenPipe, "channel has been closed"));
75 }
76
77 future.as_mut().await;
82
83 future.set(self.notifier.notified());
85 }
86 }
87
88 pub fn flush_pending_prefetch_requests<F>(&self, mut f: F)
91 where
92 F: FnMut(&T) -> bool,
93 {
94 self.requests.lock().unwrap().retain(|t| !f(t));
95 }
96
97 pub fn lock_channel(&self) -> MutexGuard<VecDeque<T>> {
99 self.requests.lock().unwrap()
100 }
101
102 pub fn notify_waiters(&self) {
104 self.notifier.notify_waiters();
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use std::sync::Arc;
112
113 #[test]
114 fn test_new_channel() {
115 let channel = Channel::new();
116
117 channel.send(1u32).unwrap();
118 channel.send(2u32).unwrap();
119 assert_eq!(channel.try_recv().unwrap(), 1);
120 assert_eq!(channel.try_recv().unwrap(), 2);
121
122 channel.close();
123 channel.send(2u32).unwrap_err();
124 }
125
126 #[test]
127 fn test_flush_channel() {
128 let channel = Channel::new();
129
130 channel.send(1u32).unwrap();
131 channel.send(2u32).unwrap();
132 channel.flush_pending_prefetch_requests(|_| true);
133 assert!(channel.try_recv().is_none());
134
135 channel.notify_waiters();
136 let _guard = channel.lock_channel();
137 }
138
139 #[test]
140 fn test_async_recv() {
141 let channel = Arc::new(Channel::new());
142 let channel2 = channel.clone();
143
144 let t = std::thread::spawn(move || {
145 channel2.send(1u32).unwrap();
146 });
147
148 let rt = tokio::runtime::Builder::new_current_thread()
149 .enable_all()
150 .build()
151 .unwrap();
152 rt.block_on(async {
153 let msg = channel.recv().await.unwrap();
154 assert_eq!(msg, 1);
155 });
156
157 t.join().unwrap();
158 }
159
160 #[test]
161 fn test_default_channel_send_and_recv() {
162 let channel = Channel::default();
163
164 channel.send(0x1u32).unwrap();
165 channel.send(0x2u32).unwrap();
166 assert_eq!(channel.try_recv().unwrap(), 0x1);
167 assert_eq!(channel.try_recv().unwrap(), 0x2);
168
169 channel.close();
170 channel.send(2u32).unwrap_err();
171 }
172}