use std::collections::VecDeque;
use std::io::{Error, ErrorKind, Result};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Mutex, MutexGuard};
use tokio::sync::Notify;
pub struct Channel<T> {
closed: AtomicBool,
notifier: Notify,
requests: Mutex<VecDeque<T>>,
}
impl<T> Default for Channel<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Channel<T> {
pub fn new() -> Self {
Channel {
closed: AtomicBool::new(false),
notifier: Notify::new(),
requests: Mutex::new(VecDeque::new()),
}
}
pub fn close(&self) {
self.closed.store(true, Ordering::Release);
self.notifier.notify_waiters();
}
pub fn send(&self, msg: T) -> std::result::Result<(), T> {
if self.closed.load(Ordering::Acquire) {
Err(msg)
} else {
self.requests.lock().unwrap().push_back(msg);
self.notifier.notify_one();
Ok(())
}
}
pub fn try_recv(&self) -> Option<T> {
self.requests.lock().unwrap().pop_front()
}
pub async fn recv(&self) -> Result<T> {
let future = self.notifier.notified();
tokio::pin!(future);
loop {
future.as_mut().enable();
if let Some(msg) = self.try_recv() {
return Ok(msg);
} else if self.closed.load(Ordering::Acquire) {
return Err(Error::new(ErrorKind::BrokenPipe, "channel has been closed"));
}
future.as_mut().await;
future.set(self.notifier.notified());
}
}
pub fn flush_pending_prefetch_requests<F>(&self, mut f: F)
where
F: FnMut(&T) -> bool,
{
self.requests.lock().unwrap().retain(|t| !f(t));
}
pub fn lock_channel(&self) -> MutexGuard<VecDeque<T>> {
self.requests.lock().unwrap()
}
pub fn notify_waiters(&self) {
self.notifier.notify_waiters();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_new_channel() {
let channel = Channel::new();
channel.send(1u32).unwrap();
channel.send(2u32).unwrap();
assert_eq!(channel.try_recv().unwrap(), 1);
assert_eq!(channel.try_recv().unwrap(), 2);
channel.close();
channel.send(2u32).unwrap_err();
}
#[test]
fn test_flush_channel() {
let channel = Channel::new();
channel.send(1u32).unwrap();
channel.send(2u32).unwrap();
channel.flush_pending_prefetch_requests(|_| true);
assert!(channel.try_recv().is_none());
channel.notify_waiters();
let _guard = channel.lock_channel();
}
#[test]
fn test_async_recv() {
let channel = Arc::new(Channel::new());
let channel2 = channel.clone();
let t = std::thread::spawn(move || {
channel2.send(1u32).unwrap();
});
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let msg = channel.recv().await.unwrap();
assert_eq!(msg, 1);
});
t.join().unwrap();
}
}