use crate::error::{IpcError, Result};
use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
pub trait AsyncIpcSender: Send + Sync {
type Message: Send;
fn send(&self, msg: Self::Message) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>>;
fn try_send(&self, msg: Self::Message) -> Result<()>;
}
pub trait AsyncIpcReceiver: Send + Sync {
type Message: Send;
fn recv(&self) -> Pin<Box<dyn Future<Output = Result<Self::Message>> + Send + '_>>;
fn recv_timeout(
&self,
timeout: Duration,
) -> Pin<Box<dyn Future<Output = Result<Self::Message>> + Send + '_>>;
fn try_recv(&self) -> Result<Option<Self::Message>>;
}
pub trait AsyncIpcChannel: AsyncIpcSender + AsyncIpcReceiver {}
impl<T> AsyncIpcChannel for T where T: AsyncIpcSender + AsyncIpcReceiver {}
#[cfg(feature = "async")]
pub mod tokio_channel {
use super::*;
use crate::graceful::ShutdownState;
use std::sync::Arc;
use tokio::sync::mpsc;
pub struct AsyncThreadSender<T> {
inner: mpsc::Sender<T>,
shutdown: Arc<ShutdownState>,
}
pub struct AsyncThreadReceiver<T> {
inner: mpsc::Receiver<T>,
shutdown: Arc<ShutdownState>,
}
impl<T> Clone for AsyncThreadSender<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
shutdown: Arc::clone(&self.shutdown),
}
}
}
impl<T: Send + 'static> AsyncThreadSender<T> {
pub async fn send(&self, msg: T) -> Result<()> {
if self.shutdown.is_shutdown() {
return Err(IpcError::Closed);
}
self.inner.send(msg).await.map_err(|_| IpcError::Closed)
}
pub fn try_send(&self, msg: T) -> Result<()> {
if self.shutdown.is_shutdown() {
return Err(IpcError::Closed);
}
self.inner.try_send(msg).map_err(|e| match e {
mpsc::error::TrySendError::Full(_) => IpcError::WouldBlock,
mpsc::error::TrySendError::Closed(_) => IpcError::Closed,
})
}
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
pub fn shutdown(&self) {
self.shutdown.shutdown();
}
}
impl<T: Send + 'static> AsyncThreadReceiver<T> {
pub async fn recv(&mut self) -> Result<T> {
if self.shutdown.is_shutdown() {
return Err(IpcError::Closed);
}
self.inner.recv().await.ok_or(IpcError::Closed)
}
pub async fn recv_timeout(&mut self, timeout: Duration) -> Result<T> {
if self.shutdown.is_shutdown() {
return Err(IpcError::Closed);
}
tokio::time::timeout(timeout, self.inner.recv())
.await
.map_err(|_| IpcError::Timeout)?
.ok_or(IpcError::Closed)
}
pub fn try_recv(&mut self) -> Result<Option<T>> {
if self.shutdown.is_shutdown() {
return Err(IpcError::Closed);
}
match self.inner.try_recv() {
Ok(msg) => Ok(Some(msg)),
Err(mpsc::error::TryRecvError::Empty) => Ok(None),
Err(mpsc::error::TryRecvError::Disconnected) => Err(IpcError::Closed),
}
}
pub fn is_closed(&self) -> bool {
self.shutdown.is_shutdown()
}
pub fn shutdown(&self) {
self.shutdown.shutdown();
}
}
pub struct AsyncThreadChannel<T>(std::marker::PhantomData<T>);
impl<T: Send + 'static> AsyncThreadChannel<T> {
pub fn unbounded() -> (AsyncThreadSender<T>, AsyncThreadReceiver<T>) {
const LARGE_BUFFER: usize = 1_000_000;
let (tx, rx) = mpsc::channel(LARGE_BUFFER);
let shutdown = Arc::new(ShutdownState::new());
(
AsyncThreadSender {
inner: tx,
shutdown: Arc::clone(&shutdown),
},
AsyncThreadReceiver {
inner: rx,
shutdown,
},
)
}
pub fn bounded(capacity: usize) -> (AsyncThreadSender<T>, AsyncThreadReceiver<T>) {
let (tx, rx) = mpsc::channel(capacity);
let shutdown = Arc::new(ShutdownState::new());
(
AsyncThreadSender {
inner: tx,
shutdown: Arc::clone(&shutdown),
},
AsyncThreadReceiver {
inner: rx,
shutdown,
},
)
}
}
pub fn spawn_handler<T, F, Fut>(
mut receiver: AsyncThreadReceiver<T>,
handler: F,
) -> tokio::task::JoinHandle<()>
where
T: Send + 'static,
F: Fn(T) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send,
{
tokio::spawn(async move {
while let Ok(msg) = receiver.recv().await {
handler(msg).await;
}
})
}
}
#[cfg(feature = "async")]
pub mod oneshot {
use super::*;
use tokio::sync::oneshot as tokio_oneshot;
pub struct OneshotSender<T> {
inner: tokio_oneshot::Sender<T>,
}
pub struct OneshotReceiver<T> {
inner: tokio_oneshot::Receiver<T>,
}
impl<T> OneshotSender<T> {
pub fn send(self, value: T) -> Result<()> {
self.inner.send(value).map_err(|_| IpcError::Closed)
}
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
impl<T> OneshotReceiver<T> {
pub async fn recv(self) -> Result<T> {
self.inner.await.map_err(|_| IpcError::Closed)
}
pub fn try_recv(&mut self) -> Result<Option<T>> {
match self.inner.try_recv() {
Ok(v) => Ok(Some(v)),
Err(tokio_oneshot::error::TryRecvError::Empty) => Ok(None),
Err(tokio_oneshot::error::TryRecvError::Closed) => Err(IpcError::Closed),
}
}
}
pub fn channel<T>() -> (OneshotSender<T>, OneshotReceiver<T>) {
let (tx, rx) = tokio_oneshot::channel();
(OneshotSender { inner: tx }, OneshotReceiver { inner: rx })
}
}
#[cfg(feature = "async")]
pub mod broadcast {
use super::*;
use tokio::sync::broadcast as tokio_broadcast;
#[derive(Clone)]
pub struct BroadcastSender<T: Clone> {
inner: tokio_broadcast::Sender<T>,
}
pub struct BroadcastReceiver<T: Clone> {
inner: tokio_broadcast::Receiver<T>,
}
impl<T: Clone + Send + 'static> BroadcastSender<T> {
pub fn send(&self, value: T) -> Result<usize> {
self.inner.send(value).map_err(|_| IpcError::Closed)
}
pub fn receiver_count(&self) -> usize {
self.inner.receiver_count()
}
pub fn subscribe(&self) -> BroadcastReceiver<T> {
BroadcastReceiver {
inner: self.inner.subscribe(),
}
}
}
impl<T: Clone + Send + 'static> BroadcastReceiver<T> {
pub async fn recv(&mut self) -> Result<T> {
loop {
match self.inner.recv().await {
Ok(v) => return Ok(v),
Err(tokio_broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio_broadcast::error::RecvError::Closed) => return Err(IpcError::Closed),
}
}
}
}
pub fn channel<T: Clone>(capacity: usize) -> (BroadcastSender<T>, BroadcastReceiver<T>) {
let (tx, rx) = tokio_broadcast::channel(capacity);
(
BroadcastSender { inner: tx },
BroadcastReceiver { inner: rx },
)
}
}
#[cfg(all(test, feature = "async"))]
mod tests {
use super::*;
#[tokio::test]
async fn test_async_thread_channel() {
use tokio_channel::AsyncThreadChannel;
let (tx, mut rx) = AsyncThreadChannel::<String>::unbounded();
tx.send("Hello".to_string()).await.unwrap();
let msg = rx.recv().await.unwrap();
assert_eq!(msg, "Hello");
}
#[tokio::test]
async fn test_async_thread_channel_timeout() {
use tokio_channel::AsyncThreadChannel;
let (_tx, mut rx) = AsyncThreadChannel::<String>::bounded(1);
let result = rx.recv_timeout(Duration::from_millis(10)).await;
assert!(matches!(result, Err(IpcError::Timeout)));
}
#[tokio::test]
async fn test_oneshot() {
let (tx, rx) = oneshot::channel::<i32>();
tx.send(42).unwrap();
let value = rx.recv().await.unwrap();
assert_eq!(value, 42);
}
#[tokio::test]
async fn test_broadcast() {
let (tx, mut rx1) = broadcast::channel::<String>(16);
let mut rx2 = tx.subscribe();
tx.send("Hello".to_string()).unwrap();
let msg1 = rx1.recv().await.unwrap();
let msg2 = rx2.recv().await.unwrap();
assert_eq!(msg1, "Hello");
assert_eq!(msg2, "Hello");
}
}