use super::mpsc::{
self,
error::{SendError, TrySendError},
OwnedPermit,
};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
type ReserveResult<T> = Result<OwnedPermit<T>, SendError<()>>;
type ReserveFuture<T> = Pin<Box<dyn Future<Output = ReserveResult<T>> + Send>>;
#[must_use = "call send to deliver the reserved message"]
pub struct Reserved<T> {
permit: OwnedPermit<T>,
value: T,
}
impl<T> Reserved<T> {
pub fn send(self) -> mpsc::Sender<T> {
self.permit.send(self.value)
}
}
#[must_use = "await the reservation to acquire a channel slot"]
pub struct Reservation<T> {
future: ReserveFuture<T>,
value: Option<T>,
}
impl<T> Reservation<T> {
fn new(future: impl Future<Output = ReserveResult<T>> + Send + 'static, value: T) -> Self {
Self {
future: Box::pin(future),
value: Some(value),
}
}
}
impl<T> Unpin for Reservation<T> {}
impl<T> Future for Reservation<T> {
type Output = Result<Reserved<T>, SendError<T>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let permit = match self.future.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(permit) => permit,
};
let value = self
.value
.take()
.expect("reservation polled after completion");
Poll::Ready(match permit {
Ok(permit) => Ok(Reserved { permit, value }),
Err(SendError(())) => Err(SendError(value)),
})
}
}
pub trait ReservationExt<T> {
#[must_use = "await and send any reservation"]
fn send_or_reserve(&self, value: T) -> Result<Option<Reservation<T>>, SendError<T>>
where
T: 'static;
}
impl<T: Send> ReservationExt<T> for mpsc::Sender<T> {
fn send_or_reserve(&self, value: T) -> Result<Option<Reservation<T>>, SendError<T>>
where
T: 'static,
{
match self.try_send(value) {
Ok(()) => Ok(None),
Err(TrySendError::Full(value)) => {
Ok(Some(Reservation::new(self.clone().reserve_owned(), value)))
}
Err(TrySendError::Closed(value)) => Err(SendError(value)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_macros::test_async;
use std::collections::BTreeMap;
#[test]
fn test_send_or_reserve_sends_immediately() {
let (sender, mut receiver) = mpsc::channel(1);
assert!(sender.send_or_reserve(1).unwrap().is_none());
assert_eq!(receiver.try_recv(), Ok(1));
}
#[test]
fn test_send_or_reserve_closed_returns_value() {
let (sender, receiver) = mpsc::channel(1);
drop(receiver);
match sender.send_or_reserve(1) {
Ok(_) => panic!("send should fail"),
Err(SendError(value)) => assert_eq!(value, 1),
}
}
#[test_async]
async fn test_send_or_reserve_waits_for_capacity() {
let (sender, mut receiver) = mpsc::channel(1);
sender.try_send(1).unwrap();
let reservation = sender
.send_or_reserve(2)
.unwrap()
.expect("channel should be full");
assert_eq!(receiver.recv().await, Some(1));
reservation.await.unwrap().send();
assert_eq!(receiver.recv().await, Some(2));
}
#[test_async]
async fn test_send_or_reserve_returns_value_when_closed_while_waiting() {
let (sender, receiver) = mpsc::channel(1);
sender.try_send(1).unwrap();
let reservation = sender
.send_or_reserve(2)
.unwrap()
.expect("channel should be full");
drop(receiver);
match reservation.await {
Ok(_) => panic!("reservation should fail"),
Err(SendError(value)) => assert_eq!(value, 2),
}
}
#[test_async]
async fn test_send_or_reserve_reservations_can_be_stored() {
let (sender, mut receiver) = mpsc::channel(1);
sender.try_send(0).unwrap();
let mut reservations = Vec::new();
reservations.push(
sender
.send_or_reserve(1)
.unwrap()
.expect("channel should be full"),
);
let mut reservation_map = BTreeMap::new();
reservation_map.insert(
"next",
sender
.send_or_reserve(2)
.unwrap()
.expect("channel should be full"),
);
assert_eq!(receiver.recv().await, Some(0));
reservations.pop().unwrap().await.unwrap().send();
assert_eq!(receiver.recv().await, Some(1));
reservation_map
.remove("next")
.unwrap()
.await
.unwrap()
.send();
assert_eq!(receiver.recv().await, Some(2));
}
}