use super::{mpsc, oneshot};
pub trait FallibleExt<T> {
fn send_lossy(&self, msg: T) -> bool;
fn request<R, F>(&self, make_msg: F) -> impl std::future::Future<Output = Option<R>> + Send
where
R: Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send;
fn request_or<R, F>(
&self,
make_msg: F,
default: R,
) -> impl std::future::Future<Output = R> + Send
where
R: Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send;
fn request_or_default<R, F>(&self, make_msg: F) -> impl std::future::Future<Output = R> + Send
where
R: Default + Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send;
}
impl<T: Send> FallibleExt<T> for mpsc::UnboundedSender<T> {
fn send_lossy(&self, msg: T) -> bool {
self.send(msg).is_ok()
}
async fn request<R, F>(&self, make_msg: F) -> Option<R>
where
R: Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send,
{
let (tx, rx) = oneshot::channel();
if self.send(make_msg(tx)).is_err() {
return None;
}
rx.await.ok()
}
async fn request_or<R, F>(&self, make_msg: F, default: R) -> R
where
R: Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send,
{
self.request(make_msg).await.unwrap_or(default)
}
async fn request_or_default<R, F>(&self, make_msg: F) -> R
where
R: Default + Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send,
{
self.request(make_msg).await.unwrap_or_default()
}
}
pub trait AsyncFallibleExt<T> {
fn send_lossy(&self, msg: T) -> impl std::future::Future<Output = bool> + Send;
fn try_send_lossy(&self, msg: T) -> bool;
fn request<R, F>(&self, make_msg: F) -> impl std::future::Future<Output = Option<R>> + Send
where
R: Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send;
fn request_or<R, F>(
&self,
make_msg: F,
default: R,
) -> impl std::future::Future<Output = R> + Send
where
R: Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send;
fn request_or_default<R, F>(&self, make_msg: F) -> impl std::future::Future<Output = R> + Send
where
R: Default + Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send;
}
impl<T: Send> AsyncFallibleExt<T> for mpsc::Sender<T> {
async fn send_lossy(&self, msg: T) -> bool {
self.send(msg).await.is_ok()
}
fn try_send_lossy(&self, msg: T) -> bool {
self.try_send(msg).is_ok()
}
async fn request<R, F>(&self, make_msg: F) -> Option<R>
where
R: Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send,
{
let (tx, rx) = oneshot::channel();
if self.send(make_msg(tx)).await.is_err() {
return None;
}
rx.await.ok()
}
async fn request_or<R, F>(&self, make_msg: F, default: R) -> R
where
R: Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send,
{
self.request(make_msg).await.unwrap_or(default)
}
async fn request_or_default<R, F>(&self, make_msg: F) -> R
where
R: Default + Send,
F: FnOnce(oneshot::Sender<R>) -> T + Send,
{
self.request(make_msg).await.unwrap_or_default()
}
}
pub trait OneshotExt<T> {
fn send_lossy(self, msg: T) -> bool;
}
impl<T> OneshotExt<T> for oneshot::Sender<T> {
fn send_lossy(self, msg: T) -> bool {
self.send(msg).is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_macros::test_async;
#[derive(Debug)]
#[allow(dead_code)]
enum TestMessage {
FireAndForget(u32),
Request {
responder: oneshot::Sender<String>,
},
RequestBool {
responder: oneshot::Sender<bool>,
},
RequestVec {
responder: oneshot::Sender<Vec<u32>>,
},
}
#[test]
fn test_send_lossy_success() {
let (tx, mut rx) = mpsc::unbounded_channel();
assert!(tx.send_lossy(TestMessage::FireAndForget(42)));
assert!(matches!(rx.try_recv(), Ok(TestMessage::FireAndForget(42))));
}
#[test]
fn test_send_lossy_disconnected() {
let (tx, rx) = mpsc::unbounded_channel::<TestMessage>();
drop(rx);
assert!(!tx.send_lossy(TestMessage::FireAndForget(42)));
}
#[test_async]
async fn test_request_send_disconnected() {
let (tx, rx) = mpsc::unbounded_channel::<TestMessage>();
drop(rx);
let result: Option<String> = tx
.request(|responder| TestMessage::Request { responder })
.await;
assert_eq!(result, None);
}
#[test_async]
async fn test_request_or_disconnected() {
let (tx, rx) = mpsc::unbounded_channel::<TestMessage>();
drop(rx);
let result = tx
.request_or(|responder| TestMessage::RequestBool { responder }, false)
.await;
assert!(!result);
}
#[test_async]
async fn test_request_or_default_disconnected() {
let (tx, rx) = mpsc::unbounded_channel::<TestMessage>();
drop(rx);
let result: Vec<u32> = tx
.request_or_default(|responder| TestMessage::RequestVec { responder })
.await;
assert!(result.is_empty());
}
#[test_async]
async fn test_async_send_lossy_success() {
let (tx, mut rx) = mpsc::channel(1);
assert!(tx.send_lossy(TestMessage::FireAndForget(42)).await);
assert!(matches!(rx.try_recv(), Ok(TestMessage::FireAndForget(42))));
}
#[test_async]
async fn test_async_send_lossy_disconnected() {
let (tx, rx) = mpsc::channel::<TestMessage>(1);
drop(rx);
assert!(!tx.send_lossy(TestMessage::FireAndForget(42)).await);
}
#[test_async]
async fn test_async_request_send_disconnected() {
let (tx, rx) = mpsc::channel::<TestMessage>(1);
drop(rx);
let result: Option<String> =
AsyncFallibleExt::request(&tx, |responder| TestMessage::Request { responder }).await;
assert_eq!(result, None);
}
#[test_async]
async fn test_async_request_or_disconnected() {
let (tx, rx) = mpsc::channel::<TestMessage>(1);
drop(rx);
let result = AsyncFallibleExt::request_or(
&tx,
|responder| TestMessage::RequestBool { responder },
false,
)
.await;
assert!(!result);
}
#[test_async]
async fn test_async_request_or_default_disconnected() {
let (tx, rx) = mpsc::channel::<TestMessage>(1);
drop(rx);
let result: Vec<u32> = AsyncFallibleExt::request_or_default(&tx, |responder| {
TestMessage::RequestVec { responder }
})
.await;
assert!(result.is_empty());
}
#[test]
fn test_try_send_lossy_success() {
let (tx, mut rx) = mpsc::channel(1);
assert!(tx.try_send_lossy(TestMessage::FireAndForget(42)));
assert!(matches!(rx.try_recv(), Ok(TestMessage::FireAndForget(42))));
}
#[test]
fn test_try_send_lossy_disconnected() {
let (tx, rx) = mpsc::channel::<TestMessage>(1);
drop(rx);
assert!(!tx.try_send_lossy(TestMessage::FireAndForget(42)));
}
#[test]
fn test_oneshot_send_lossy_success() {
let (tx, mut rx) = oneshot::channel::<u32>();
assert!(tx.send_lossy(42));
assert_eq!(rx.try_recv(), Ok(42));
}
#[test]
fn test_oneshot_send_lossy_disconnected() {
let (tx, rx) = oneshot::channel::<u32>();
drop(rx);
assert!(!tx.send_lossy(42));
}
}