Skip to main content

commonware_utils/channel/
reservation.rs

1//! Channel reservation helpers.
2
3use super::mpsc::{
4    self,
5    error::{SendError, TrySendError},
6    OwnedPermit,
7};
8use std::{
9    future::Future,
10    pin::Pin,
11    task::{Context, Poll},
12};
13
14// The reserve future only reports channel closure; the message value is stored separately.
15type ReserveResult<T> = Result<OwnedPermit<T>, SendError<()>>;
16
17// Tokio's `reserve_owned` future is not nameable, so box it instead of exposing a future parameter.
18type ReserveFuture<T> = Pin<Box<dyn Future<Output = ReserveResult<T>> + Send>>;
19
20/// A reserved channel slot bundled with the value to send.
21#[must_use = "call send to deliver the reserved message"]
22pub struct Reserved<T> {
23    permit: OwnedPermit<T>,
24    value: T,
25}
26
27impl<T> Reserved<T> {
28    /// Sends the buffered value through the reserved slot.
29    pub fn send(self) -> mpsc::Sender<T> {
30        self.permit.send(self.value)
31    }
32}
33
34/// A future that waits for a channel slot and keeps ownership of the value.
35#[must_use = "await the reservation to acquire a channel slot"]
36pub struct Reservation<T> {
37    future: ReserveFuture<T>,
38    value: Option<T>,
39}
40
41impl<T> Reservation<T> {
42    fn new(future: impl Future<Output = ReserveResult<T>> + Send + 'static, value: T) -> Self {
43        Self {
44            future: Box::pin(future),
45            value: Some(value),
46        }
47    }
48}
49
50impl<T> Unpin for Reservation<T> {}
51
52impl<T> Future for Reservation<T> {
53    type Output = Result<Reserved<T>, SendError<T>>;
54
55    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
56        let permit = match self.future.as_mut().poll(cx) {
57            Poll::Pending => return Poll::Pending,
58            Poll::Ready(permit) => permit,
59        };
60        let value = self
61            .value
62            .take()
63            .expect("reservation polled after completion");
64        Poll::Ready(match permit {
65            Ok(permit) => Ok(Reserved { permit, value }),
66            Err(SendError(())) => Err(SendError(value)),
67        })
68    }
69}
70
71/// Extension trait for bounded channel sends that can reserve capacity.
72pub trait ReservationExt<T> {
73    /// Attempts to send immediately, reserving the message when the channel is full.
74    ///
75    /// Returns:
76    /// - `Ok(None)` when the value was sent immediately.
77    /// - `Ok(Some(_))` when the channel was full. Await the reservation and call
78    ///   [`Reserved::send`] to deliver the value.
79    /// - `Err(_)` when the receiver has been dropped.
80    #[must_use = "await and send any reservation"]
81    fn send_or_reserve(&self, value: T) -> Result<Option<Reservation<T>>, SendError<T>>
82    where
83        T: 'static;
84}
85
86impl<T: Send> ReservationExt<T> for mpsc::Sender<T> {
87    fn send_or_reserve(&self, value: T) -> Result<Option<Reservation<T>>, SendError<T>>
88    where
89        T: 'static,
90    {
91        match self.try_send(value) {
92            Ok(()) => Ok(None),
93            Err(TrySendError::Full(value)) => {
94                Ok(Some(Reservation::new(self.clone().reserve_owned(), value)))
95            }
96            Err(TrySendError::Closed(value)) => Err(SendError(value)),
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use commonware_macros::test_async;
105    use std::collections::BTreeMap;
106
107    #[test]
108    fn test_send_or_reserve_sends_immediately() {
109        let (sender, mut receiver) = mpsc::channel(1);
110        assert!(sender.send_or_reserve(1).unwrap().is_none());
111        assert_eq!(receiver.try_recv(), Ok(1));
112    }
113
114    #[test]
115    fn test_send_or_reserve_closed_returns_value() {
116        let (sender, receiver) = mpsc::channel(1);
117        drop(receiver);
118
119        match sender.send_or_reserve(1) {
120            Ok(_) => panic!("send should fail"),
121            Err(SendError(value)) => assert_eq!(value, 1),
122        }
123    }
124
125    #[test_async]
126    async fn test_send_or_reserve_waits_for_capacity() {
127        let (sender, mut receiver) = mpsc::channel(1);
128        sender.try_send(1).unwrap();
129
130        let reservation = sender
131            .send_or_reserve(2)
132            .unwrap()
133            .expect("channel should be full");
134        assert_eq!(receiver.recv().await, Some(1));
135        reservation.await.unwrap().send();
136        assert_eq!(receiver.recv().await, Some(2));
137    }
138
139    #[test_async]
140    async fn test_send_or_reserve_returns_value_when_closed_while_waiting() {
141        let (sender, receiver) = mpsc::channel(1);
142        sender.try_send(1).unwrap();
143
144        let reservation = sender
145            .send_or_reserve(2)
146            .unwrap()
147            .expect("channel should be full");
148        drop(receiver);
149
150        match reservation.await {
151            Ok(_) => panic!("reservation should fail"),
152            Err(SendError(value)) => assert_eq!(value, 2),
153        }
154    }
155
156    #[test_async]
157    async fn test_send_or_reserve_reservations_can_be_stored() {
158        let (sender, mut receiver) = mpsc::channel(1);
159        sender.try_send(0).unwrap();
160
161        let mut reservations = Vec::new();
162        reservations.push(
163            sender
164                .send_or_reserve(1)
165                .unwrap()
166                .expect("channel should be full"),
167        );
168
169        let mut reservation_map = BTreeMap::new();
170        reservation_map.insert(
171            "next",
172            sender
173                .send_or_reserve(2)
174                .unwrap()
175                .expect("channel should be full"),
176        );
177
178        assert_eq!(receiver.recv().await, Some(0));
179        reservations.pop().unwrap().await.unwrap().send();
180        assert_eq!(receiver.recv().await, Some(1));
181        reservation_map
182            .remove("next")
183            .unwrap()
184            .await
185            .unwrap()
186            .send();
187        assert_eq!(receiver.recv().await, Some(2));
188    }
189}