commonware_utils/channel/
reservation.rs1use super::mpsc::{
4 self,
5 error::{SendError, TrySendError},
6 OwnedPermit,
7};
8use std::{
9 future::Future,
10 pin::Pin,
11 task::{Context, Poll},
12};
13
14type ReserveResult<T> = Result<OwnedPermit<T>, SendError<()>>;
16
17type ReserveFuture<T> = Pin<Box<dyn Future<Output = ReserveResult<T>> + Send>>;
19
20#[must_use = "call send to deliver the reserved message"]
22pub struct Reserved<T> {
23 permit: OwnedPermit<T>,
24 value: T,
25}
26
27impl<T> Reserved<T> {
28 pub fn send(self) -> mpsc::Sender<T> {
30 self.permit.send(self.value)
31 }
32}
33
34#[must_use = "await the reservation to acquire a channel slot"]
36pub struct Reservation<T> {
37 future: ReserveFuture<T>,
38 value: Option<T>,
39}
40
41impl<T> Reservation<T> {
42 fn new(future: impl Future<Output = ReserveResult<T>> + Send + 'static, value: T) -> Self {
43 Self {
44 future: Box::pin(future),
45 value: Some(value),
46 }
47 }
48}
49
50impl<T> Unpin for Reservation<T> {}
51
52impl<T> Future for Reservation<T> {
53 type Output = Result<Reserved<T>, SendError<T>>;
54
55 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
56 let permit = match self.future.as_mut().poll(cx) {
57 Poll::Pending => return Poll::Pending,
58 Poll::Ready(permit) => permit,
59 };
60 let value = self
61 .value
62 .take()
63 .expect("reservation polled after completion");
64 Poll::Ready(match permit {
65 Ok(permit) => Ok(Reserved { permit, value }),
66 Err(SendError(())) => Err(SendError(value)),
67 })
68 }
69}
70
71pub trait ReservationExt<T> {
73 #[must_use = "await and send any reservation"]
81 fn send_or_reserve(&self, value: T) -> Result<Option<Reservation<T>>, SendError<T>>
82 where
83 T: 'static;
84}
85
86impl<T: Send> ReservationExt<T> for mpsc::Sender<T> {
87 fn send_or_reserve(&self, value: T) -> Result<Option<Reservation<T>>, SendError<T>>
88 where
89 T: 'static,
90 {
91 match self.try_send(value) {
92 Ok(()) => Ok(None),
93 Err(TrySendError::Full(value)) => {
94 Ok(Some(Reservation::new(self.clone().reserve_owned(), value)))
95 }
96 Err(TrySendError::Closed(value)) => Err(SendError(value)),
97 }
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use commonware_macros::test_async;
105 use std::collections::BTreeMap;
106
107 #[test]
108 fn test_send_or_reserve_sends_immediately() {
109 let (sender, mut receiver) = mpsc::channel(1);
110 assert!(sender.send_or_reserve(1).unwrap().is_none());
111 assert_eq!(receiver.try_recv(), Ok(1));
112 }
113
114 #[test]
115 fn test_send_or_reserve_closed_returns_value() {
116 let (sender, receiver) = mpsc::channel(1);
117 drop(receiver);
118
119 match sender.send_or_reserve(1) {
120 Ok(_) => panic!("send should fail"),
121 Err(SendError(value)) => assert_eq!(value, 1),
122 }
123 }
124
125 #[test_async]
126 async fn test_send_or_reserve_waits_for_capacity() {
127 let (sender, mut receiver) = mpsc::channel(1);
128 sender.try_send(1).unwrap();
129
130 let reservation = sender
131 .send_or_reserve(2)
132 .unwrap()
133 .expect("channel should be full");
134 assert_eq!(receiver.recv().await, Some(1));
135 reservation.await.unwrap().send();
136 assert_eq!(receiver.recv().await, Some(2));
137 }
138
139 #[test_async]
140 async fn test_send_or_reserve_returns_value_when_closed_while_waiting() {
141 let (sender, receiver) = mpsc::channel(1);
142 sender.try_send(1).unwrap();
143
144 let reservation = sender
145 .send_or_reserve(2)
146 .unwrap()
147 .expect("channel should be full");
148 drop(receiver);
149
150 match reservation.await {
151 Ok(_) => panic!("reservation should fail"),
152 Err(SendError(value)) => assert_eq!(value, 2),
153 }
154 }
155
156 #[test_async]
157 async fn test_send_or_reserve_reservations_can_be_stored() {
158 let (sender, mut receiver) = mpsc::channel(1);
159 sender.try_send(0).unwrap();
160
161 let mut reservations = Vec::new();
162 reservations.push(
163 sender
164 .send_or_reserve(1)
165 .unwrap()
166 .expect("channel should be full"),
167 );
168
169 let mut reservation_map = BTreeMap::new();
170 reservation_map.insert(
171 "next",
172 sender
173 .send_or_reserve(2)
174 .unwrap()
175 .expect("channel should be full"),
176 );
177
178 assert_eq!(receiver.recv().await, Some(0));
179 reservations.pop().unwrap().await.unwrap().send();
180 assert_eq!(receiver.recv().await, Some(1));
181 reservation_map
182 .remove("next")
183 .unwrap()
184 .await
185 .unwrap()
186 .send();
187 assert_eq!(receiver.recv().await, Some(2));
188 }
189}