1use std::pin::Pin;
2use std::sync::atomic::{self, AtomicBool};
3use std::task::{Context, Poll};
4
5use pin_project_lite::pin_project;
6use rama_core::error::BoxError;
7use rama_http_types::dep::http_body::Body;
8use rama_http_types::{Request, Response};
9use tokio::sync::{mpsc, oneshot};
10use tracing::trace;
11
12use crate::{body::Incoming, proto::h2::client::ResponseFutMap};
13
14pub(crate) type RetryPromise<T, U> = oneshot::Receiver<Result<U, TrySendError<T>>>;
15pub(crate) type Promise<T> = oneshot::Receiver<Result<T, crate::Error>>;
16
17#[derive(Debug)]
24pub struct TrySendError<T> {
25 pub(crate) error: crate::Error,
26 pub(crate) message: Option<T>,
27}
28
29pub(crate) fn channel<T, U>() -> (Sender<T, U>, Receiver<T, U>) {
30 let (tx, rx) = mpsc::unbounded_channel();
31 let (giver, taker) = want::new();
32 let tx = Sender {
33 buffered_once: AtomicBool::new(false),
34 giver,
35 inner: tx,
36 };
37 let rx = Receiver { inner: rx, taker };
38 (tx, rx)
39}
40
41pub(crate) struct Sender<T, U> {
46 buffered_once: AtomicBool,
50 giver: want::Giver,
55 inner: mpsc::UnboundedSender<Envelope<T, U>>,
57}
58
59pub(crate) struct UnboundedSender<T, U> {
64 giver: want::SharedGiver,
66 inner: mpsc::UnboundedSender<Envelope<T, U>>,
67}
68
69impl<T, U> Sender<T, U> {
70 pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
71 self.giver
72 .poll_want(cx)
73 .map_err(|_| crate::Error::new_closed())
74 }
75
76 pub(crate) fn is_ready(&self) -> bool {
77 self.giver.is_wanting()
78 }
79
80 pub(crate) fn is_closed(&self) -> bool {
81 self.giver.is_canceled()
82 }
83
84 fn can_send(&self) -> bool {
85 self.giver.give() || !self.buffered_once.swap(true, atomic::Ordering::AcqRel)
90 }
91
92 pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
93 if !self.can_send() {
94 return Err(val);
95 }
96 let (tx, rx) = oneshot::channel();
97 self.inner
98 .send(Envelope(Some((val, Callback::Retry(Some(tx))))))
99 .map(move |_| rx)
100 .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
101 }
102
103 pub(crate) fn send(&self, val: T) -> Result<Promise<U>, T> {
104 if !self.can_send() {
105 return Err(val);
106 }
107 let (tx, rx) = oneshot::channel();
108 self.inner
109 .send(Envelope(Some((val, Callback::NoRetry(Some(tx))))))
110 .map(move |_| rx)
111 .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
112 }
113
114 pub(crate) fn unbound(self) -> UnboundedSender<T, U> {
115 UnboundedSender {
116 giver: self.giver.shared(),
117 inner: self.inner,
118 }
119 }
120}
121
122impl<T, U> UnboundedSender<T, U> {
123 pub(crate) fn is_ready(&self) -> bool {
124 !self.giver.is_canceled()
125 }
126
127 pub(crate) fn is_closed(&self) -> bool {
128 self.giver.is_canceled()
129 }
130
131 pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
132 let (tx, rx) = oneshot::channel();
133 self.inner
134 .send(Envelope(Some((val, Callback::Retry(Some(tx))))))
135 .map(move |_| rx)
136 .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
137 }
138
139 pub(crate) fn send(&self, val: T) -> Result<Promise<U>, T> {
140 let (tx, rx) = oneshot::channel();
141 self.inner
142 .send(Envelope(Some((val, Callback::NoRetry(Some(tx))))))
143 .map(move |_| rx)
144 .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
145 }
146}
147
148impl<T, U> Clone for UnboundedSender<T, U> {
149 fn clone(&self) -> Self {
150 UnboundedSender {
151 giver: self.giver.clone(),
152 inner: self.inner.clone(),
153 }
154 }
155}
156
157pub(crate) struct Receiver<T, U> {
158 inner: mpsc::UnboundedReceiver<Envelope<T, U>>,
159 taker: want::Taker,
160}
161
162impl<T, U> Receiver<T, U> {
163 pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<(T, Callback<T, U>)>> {
164 match self.inner.poll_recv(cx) {
165 Poll::Ready(item) => {
166 Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped")))
167 }
168 Poll::Pending => {
169 self.taker.want();
170 Poll::Pending
171 }
172 }
173 }
174
175 pub(crate) fn close(&mut self) {
176 self.taker.cancel();
177 self.inner.close();
178 }
179
180 pub(crate) fn try_recv(&mut self) -> Option<(T, Callback<T, U>)> {
181 use futures_util::FutureExt;
182 match self.inner.recv().now_or_never() {
183 Some(Some(mut env)) => env.0.take(),
184 _ => None,
185 }
186 }
187}
188
189impl<T, U> Drop for Receiver<T, U> {
190 fn drop(&mut self) {
191 self.taker.cancel();
194 }
195}
196
197struct Envelope<T, U>(Option<(T, Callback<T, U>)>);
198
199impl<T, U> Drop for Envelope<T, U> {
200 fn drop(&mut self) {
201 if let Some((val, cb)) = self.0.take() {
202 cb.send(Err(TrySendError {
203 error: crate::Error::new_canceled().with("connection closed"),
204 message: Some(val),
205 }));
206 }
207 }
208}
209
210pub(crate) enum Callback<T, U> {
211 #[allow(unused)]
212 Retry(Option<oneshot::Sender<Result<U, TrySendError<T>>>>),
213 NoRetry(Option<oneshot::Sender<Result<U, crate::Error>>>),
214}
215
216impl<T, U> Drop for Callback<T, U> {
217 fn drop(&mut self) {
218 match self {
219 Callback::Retry(tx) => {
220 if let Some(tx) = tx.take() {
221 let _ = tx.send(Err(TrySendError {
222 error: dispatch_gone(),
223 message: None,
224 }));
225 }
226 }
227 Callback::NoRetry(tx) => {
228 if let Some(tx) = tx.take() {
229 let _ = tx.send(Err(dispatch_gone()));
230 }
231 }
232 }
233 }
234}
235
236#[cold]
237fn dispatch_gone() -> crate::Error {
238 crate::Error::new_user_dispatch_gone().with(if std::thread::panicking() {
240 "user code panicked"
241 } else {
242 "runtime dropped the dispatch task"
243 })
244}
245
246impl<T, U> Callback<T, U> {
247 pub(crate) fn is_canceled(&self) -> bool {
248 match *self {
249 Callback::Retry(Some(ref tx)) => tx.is_closed(),
250 Callback::NoRetry(Some(ref tx)) => tx.is_closed(),
251 _ => unreachable!(),
252 }
253 }
254
255 pub(crate) fn poll_canceled(&mut self, cx: &mut Context<'_>) -> Poll<()> {
256 match *self {
257 Callback::Retry(Some(ref mut tx)) => tx.poll_closed(cx),
258 Callback::NoRetry(Some(ref mut tx)) => tx.poll_closed(cx),
259 _ => unreachable!(),
260 }
261 }
262
263 pub(crate) fn send(mut self, val: Result<U, TrySendError<T>>) {
264 match self {
265 Callback::Retry(ref mut tx) => {
266 let _ = tx.take().unwrap().send(val);
267 }
268 Callback::NoRetry(ref mut tx) => {
269 let _ = tx.take().unwrap().send(val.map_err(|e| e.error));
270 }
271 }
272 }
273}
274
275impl<T> TrySendError<T> {
276 pub fn take_message(&mut self) -> Option<T> {
282 self.message.take()
283 }
284
285 pub fn into_error(self) -> crate::Error {
287 self.error
288 }
289}
290
291pin_project! {
292 pub struct SendWhen<B>
293 where
294 B: Body,
295 B: Send,
296 B: 'static,
297 B: Unpin,
298 B::Data: Send,
299 B::Data: 'static,
300 B::Error: Into<BoxError>,
301 {
302 #[pin]
303 pub(crate) when: ResponseFutMap<B>,
304 #[pin]
305 pub(crate) call_back: Option<Callback<Request<B>, Response<Incoming>>>,
306 }
307}
308
309impl<B> Future for SendWhen<B>
310where
311 B: Body<Data: Send + 'static, Error: Into<BoxError>> + Send + 'static + Unpin,
312{
313 type Output = ();
314
315 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
316 let mut this = self.project();
317
318 let mut call_back = this.call_back.take().expect("polled after complete");
319
320 match Pin::new(&mut this.when).poll(cx) {
321 Poll::Ready(Ok(res)) => {
322 call_back.send(Ok(res));
323 Poll::Ready(())
324 }
325 Poll::Pending => {
326 match call_back.poll_canceled(cx) {
328 Poll::Ready(v) => v,
329 Poll::Pending => {
330 this.call_back.set(Some(call_back));
332 return Poll::Pending;
333 }
334 };
335 trace!("send_when canceled");
336 Poll::Ready(())
337 }
338 Poll::Ready(Err((error, message))) => {
339 call_back.send(Err(TrySendError { error, message }));
340 Poll::Ready(())
341 }
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use std::pin::Pin;
349 use std::task::{Context, Poll};
350
351 use super::{Callback, Receiver, channel};
352
353 #[derive(Debug)]
354 struct Custom(#[allow(dead_code)] i32);
355
356 impl<T, U> Future for Receiver<T, U> {
357 type Output = Option<(T, Callback<T, U>)>;
358
359 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
360 self.poll_recv(cx)
361 }
362 }
363
364 struct PollOnce<'a, F>(&'a mut F);
366
367 impl<F, T> Future for PollOnce<'_, F>
368 where
369 F: Future<Output = T> + Unpin,
370 {
371 type Output = Option<()>;
372
373 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
374 match Pin::new(&mut self.0).poll(cx) {
375 Poll::Ready(_) => Poll::Ready(Some(())),
376 Poll::Pending => Poll::Ready(None),
377 }
378 }
379 }
380
381 #[cfg(not(miri))]
382 #[tokio::test]
383 async fn drop_receiver_sends_cancel_errors() {
384 let (mut tx, mut rx) = channel::<Custom, ()>();
385
386 assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
388
389 let promise = tx.try_send(Custom(43)).unwrap();
390 drop(rx);
391
392 let fulfilled = promise.await;
393 let err = fulfilled
394 .expect("fulfilled")
395 .expect_err("promise should error");
396 match (err.error.is_canceled(), err.message) {
397 (true, Some(_)) => (),
398 e => panic!("expected Error::Cancel(_), found {:?}", e),
399 }
400 }
401
402 #[cfg(not(miri))]
403 #[tokio::test]
404 #[allow(clippy::let_underscore_future)]
405 async fn sender_checks_for_want_on_send() {
406 let (mut tx, mut rx) = channel::<Custom, ()>();
407
408 let _ = tx.try_send(Custom(1)).expect("1 buffered");
410 tx.try_send(Custom(2)).expect_err("2 not ready");
411
412 assert!(PollOnce(&mut rx).await.is_some(), "rx once");
413
414 tx.try_send(Custom(2)).expect_err("2 still not ready");
417
418 assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
419
420 let _ = tx.try_send(Custom(2)).expect("2 ready");
421 }
422
423 #[test]
424 #[allow(clippy::let_underscore_future)]
425 fn unbounded_sender_doesnt_bound_on_want() {
426 let (tx, rx) = channel::<Custom, ()>();
427 let mut tx = tx.unbound();
428
429 let _ = tx.try_send(Custom(1)).unwrap();
430 let _ = tx.try_send(Custom(2)).unwrap();
431 let _ = tx.try_send(Custom(3)).unwrap();
432
433 drop(rx);
434
435 let _ = tx.try_send(Custom(4)).unwrap_err();
436 }
437}