rama_http_core/client/
dispatch.rs

1use std::pin::Pin;
2use std::sync::atomic::{self, AtomicBool};
3use std::task::{Context, Poll};
4
5use pin_project_lite::pin_project;
6use rama_core::error::BoxError;
7use rama_http_types::dep::http_body::Body;
8use rama_http_types::{Request, Response};
9use tokio::sync::{mpsc, oneshot};
10use tracing::trace;
11
12use crate::{body::Incoming, proto::h2::client::ResponseFutMap};
13
14pub(crate) type RetryPromise<T, U> = oneshot::Receiver<Result<U, TrySendError<T>>>;
15pub(crate) type Promise<T> = oneshot::Receiver<Result<T, crate::Error>>;
16
17/// An error when calling `try_send_request`.
18///
19/// There is a possibility of an error occurring on a connection in-between the
20/// time that a request is queued and when it is actually written to the IO
21/// transport. If that happens, it is safe to return the request back to the
22/// caller, as it was never fully sent.
23#[derive(Debug)]
24pub struct TrySendError<T> {
25    pub(crate) error: crate::Error,
26    pub(crate) message: Option<T>,
27}
28
29pub(crate) fn channel<T, U>() -> (Sender<T, U>, Receiver<T, U>) {
30    let (tx, rx) = mpsc::unbounded_channel();
31    let (giver, taker) = want::new();
32    let tx = Sender {
33        buffered_once: AtomicBool::new(false),
34        giver,
35        inner: tx,
36    };
37    let rx = Receiver { inner: rx, taker };
38    (tx, rx)
39}
40
41/// A bounded sender of requests and callbacks for when responses are ready.
42///
43/// While the inner sender is unbounded, the Giver is used to determine
44/// if the Receiver is ready for another request.
45pub(crate) struct Sender<T, U> {
46    /// One message is always allowed, even if the Receiver hasn't asked
47    /// for it yet. This boolean keeps track of whether we've sent one
48    /// without notice.
49    buffered_once: AtomicBool,
50    /// The Giver helps watch that the Receiver side has been polled
51    /// when the queue is empty. This helps us know when a request and
52    /// response have been fully processed, and a connection is ready
53    /// for more.
54    giver: want::Giver,
55    /// Actually bounded by the Giver, plus `buffered_once`.
56    inner: mpsc::UnboundedSender<Envelope<T, U>>,
57}
58
59/// An unbounded version.
60///
61/// Cannot poll the Giver, but can still use it to determine if the Receiver
62/// has been dropped. However, this version can be cloned.
63pub(crate) struct UnboundedSender<T, U> {
64    /// Only used for `is_closed`, since mpsc::UnboundedSender cannot be checked.
65    giver: want::SharedGiver,
66    inner: mpsc::UnboundedSender<Envelope<T, U>>,
67}
68
69impl<T, U> Sender<T, U> {
70    pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
71        self.giver
72            .poll_want(cx)
73            .map_err(|_| crate::Error::new_closed())
74    }
75
76    pub(crate) fn is_ready(&self) -> bool {
77        self.giver.is_wanting()
78    }
79
80    pub(crate) fn is_closed(&self) -> bool {
81        self.giver.is_canceled()
82    }
83
84    fn can_send(&self) -> bool {
85        // If the receiver is ready *now*, then of course we can send.
86        //
87        // If the receiver isn't ready yet, but we don't have anything
88        // in the channel yet, then allow one message.
89        self.giver.give() || !self.buffered_once.swap(true, atomic::Ordering::AcqRel)
90    }
91
92    pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
93        if !self.can_send() {
94            return Err(val);
95        }
96        let (tx, rx) = oneshot::channel();
97        self.inner
98            .send(Envelope(Some((val, Callback::Retry(Some(tx))))))
99            .map(move |_| rx)
100            .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
101    }
102
103    pub(crate) fn send(&self, val: T) -> Result<Promise<U>, T> {
104        if !self.can_send() {
105            return Err(val);
106        }
107        let (tx, rx) = oneshot::channel();
108        self.inner
109            .send(Envelope(Some((val, Callback::NoRetry(Some(tx))))))
110            .map(move |_| rx)
111            .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
112    }
113
114    pub(crate) fn unbound(self) -> UnboundedSender<T, U> {
115        UnboundedSender {
116            giver: self.giver.shared(),
117            inner: self.inner,
118        }
119    }
120}
121
122impl<T, U> UnboundedSender<T, U> {
123    pub(crate) fn is_ready(&self) -> bool {
124        !self.giver.is_canceled()
125    }
126
127    pub(crate) fn is_closed(&self) -> bool {
128        self.giver.is_canceled()
129    }
130
131    pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
132        let (tx, rx) = oneshot::channel();
133        self.inner
134            .send(Envelope(Some((val, Callback::Retry(Some(tx))))))
135            .map(move |_| rx)
136            .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
137    }
138
139    pub(crate) fn send(&self, val: T) -> Result<Promise<U>, T> {
140        let (tx, rx) = oneshot::channel();
141        self.inner
142            .send(Envelope(Some((val, Callback::NoRetry(Some(tx))))))
143            .map(move |_| rx)
144            .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
145    }
146}
147
148impl<T, U> Clone for UnboundedSender<T, U> {
149    fn clone(&self) -> Self {
150        UnboundedSender {
151            giver: self.giver.clone(),
152            inner: self.inner.clone(),
153        }
154    }
155}
156
157pub(crate) struct Receiver<T, U> {
158    inner: mpsc::UnboundedReceiver<Envelope<T, U>>,
159    taker: want::Taker,
160}
161
162impl<T, U> Receiver<T, U> {
163    pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<(T, Callback<T, U>)>> {
164        match self.inner.poll_recv(cx) {
165            Poll::Ready(item) => {
166                Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped")))
167            }
168            Poll::Pending => {
169                self.taker.want();
170                Poll::Pending
171            }
172        }
173    }
174
175    pub(crate) fn close(&mut self) {
176        self.taker.cancel();
177        self.inner.close();
178    }
179
180    pub(crate) fn try_recv(&mut self) -> Option<(T, Callback<T, U>)> {
181        use futures_util::FutureExt;
182        match self.inner.recv().now_or_never() {
183            Some(Some(mut env)) => env.0.take(),
184            _ => None,
185        }
186    }
187}
188
189impl<T, U> Drop for Receiver<T, U> {
190    fn drop(&mut self) {
191        // Notify the giver about the closure first, before dropping
192        // the mpsc::Receiver.
193        self.taker.cancel();
194    }
195}
196
197struct Envelope<T, U>(Option<(T, Callback<T, U>)>);
198
199impl<T, U> Drop for Envelope<T, U> {
200    fn drop(&mut self) {
201        if let Some((val, cb)) = self.0.take() {
202            cb.send(Err(TrySendError {
203                error: crate::Error::new_canceled().with("connection closed"),
204                message: Some(val),
205            }));
206        }
207    }
208}
209
210pub(crate) enum Callback<T, U> {
211    #[allow(unused)]
212    Retry(Option<oneshot::Sender<Result<U, TrySendError<T>>>>),
213    NoRetry(Option<oneshot::Sender<Result<U, crate::Error>>>),
214}
215
216impl<T, U> Drop for Callback<T, U> {
217    fn drop(&mut self) {
218        match self {
219            Callback::Retry(tx) => {
220                if let Some(tx) = tx.take() {
221                    let _ = tx.send(Err(TrySendError {
222                        error: dispatch_gone(),
223                        message: None,
224                    }));
225                }
226            }
227            Callback::NoRetry(tx) => {
228                if let Some(tx) = tx.take() {
229                    let _ = tx.send(Err(dispatch_gone()));
230                }
231            }
232        }
233    }
234}
235
236#[cold]
237fn dispatch_gone() -> crate::Error {
238    // FIXME(nox): What errors do we want here?
239    crate::Error::new_user_dispatch_gone().with(if std::thread::panicking() {
240        "user code panicked"
241    } else {
242        "runtime dropped the dispatch task"
243    })
244}
245
246impl<T, U> Callback<T, U> {
247    pub(crate) fn is_canceled(&self) -> bool {
248        match *self {
249            Callback::Retry(Some(ref tx)) => tx.is_closed(),
250            Callback::NoRetry(Some(ref tx)) => tx.is_closed(),
251            _ => unreachable!(),
252        }
253    }
254
255    pub(crate) fn poll_canceled(&mut self, cx: &mut Context<'_>) -> Poll<()> {
256        match *self {
257            Callback::Retry(Some(ref mut tx)) => tx.poll_closed(cx),
258            Callback::NoRetry(Some(ref mut tx)) => tx.poll_closed(cx),
259            _ => unreachable!(),
260        }
261    }
262
263    pub(crate) fn send(mut self, val: Result<U, TrySendError<T>>) {
264        match self {
265            Callback::Retry(ref mut tx) => {
266                let _ = tx.take().unwrap().send(val);
267            }
268            Callback::NoRetry(ref mut tx) => {
269                let _ = tx.take().unwrap().send(val.map_err(|e| e.error));
270            }
271        }
272    }
273}
274
275impl<T> TrySendError<T> {
276    /// Take the message from this error.
277    ///
278    /// The message will not always have been recovered. If an error occurs
279    /// after the message has been serialized onto the connection, it will not
280    /// be available here.
281    pub fn take_message(&mut self) -> Option<T> {
282        self.message.take()
283    }
284
285    /// Consumes this to return the inner error.
286    pub fn into_error(self) -> crate::Error {
287        self.error
288    }
289}
290
291pin_project! {
292    pub struct SendWhen<B>
293    where
294        B: Body,
295        B: Send,
296        B: 'static,
297        B: Unpin,
298        B::Data: Send,
299        B::Data: 'static,
300        B::Error: Into<BoxError>,
301    {
302        #[pin]
303        pub(crate) when: ResponseFutMap<B>,
304        #[pin]
305        pub(crate) call_back: Option<Callback<Request<B>, Response<Incoming>>>,
306    }
307}
308
309impl<B> Future for SendWhen<B>
310where
311    B: Body<Data: Send + 'static, Error: Into<BoxError>> + Send + 'static + Unpin,
312{
313    type Output = ();
314
315    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
316        let mut this = self.project();
317
318        let mut call_back = this.call_back.take().expect("polled after complete");
319
320        match Pin::new(&mut this.when).poll(cx) {
321            Poll::Ready(Ok(res)) => {
322                call_back.send(Ok(res));
323                Poll::Ready(())
324            }
325            Poll::Pending => {
326                // check if the callback is canceled
327                match call_back.poll_canceled(cx) {
328                    Poll::Ready(v) => v,
329                    Poll::Pending => {
330                        // Move call_back back to struct before return
331                        this.call_back.set(Some(call_back));
332                        return Poll::Pending;
333                    }
334                };
335                trace!("send_when canceled");
336                Poll::Ready(())
337            }
338            Poll::Ready(Err((error, message))) => {
339                call_back.send(Err(TrySendError { error, message }));
340                Poll::Ready(())
341            }
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use std::pin::Pin;
349    use std::task::{Context, Poll};
350
351    use super::{Callback, Receiver, channel};
352
353    #[derive(Debug)]
354    struct Custom(#[allow(dead_code)] i32);
355
356    impl<T, U> Future for Receiver<T, U> {
357        type Output = Option<(T, Callback<T, U>)>;
358
359        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
360            self.poll_recv(cx)
361        }
362    }
363
364    /// Helper to check if the future is ready after polling once.
365    struct PollOnce<'a, F>(&'a mut F);
366
367    impl<F, T> Future for PollOnce<'_, F>
368    where
369        F: Future<Output = T> + Unpin,
370    {
371        type Output = Option<()>;
372
373        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
374            match Pin::new(&mut self.0).poll(cx) {
375                Poll::Ready(_) => Poll::Ready(Some(())),
376                Poll::Pending => Poll::Ready(None),
377            }
378        }
379    }
380
381    #[cfg(not(miri))]
382    #[tokio::test]
383    async fn drop_receiver_sends_cancel_errors() {
384        let (mut tx, mut rx) = channel::<Custom, ()>();
385
386        // must poll once for try_send to succeed
387        assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
388
389        let promise = tx.try_send(Custom(43)).unwrap();
390        drop(rx);
391
392        let fulfilled = promise.await;
393        let err = fulfilled
394            .expect("fulfilled")
395            .expect_err("promise should error");
396        match (err.error.is_canceled(), err.message) {
397            (true, Some(_)) => (),
398            e => panic!("expected Error::Cancel(_), found {:?}", e),
399        }
400    }
401
402    #[cfg(not(miri))]
403    #[tokio::test]
404    #[allow(clippy::let_underscore_future)]
405    async fn sender_checks_for_want_on_send() {
406        let (mut tx, mut rx) = channel::<Custom, ()>();
407
408        // one is allowed to buffer, second is rejected
409        let _ = tx.try_send(Custom(1)).expect("1 buffered");
410        tx.try_send(Custom(2)).expect_err("2 not ready");
411
412        assert!(PollOnce(&mut rx).await.is_some(), "rx once");
413
414        // Even though 1 has been popped, only 1 could be buffered for the
415        // lifetime of the channel.
416        tx.try_send(Custom(2)).expect_err("2 still not ready");
417
418        assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
419
420        let _ = tx.try_send(Custom(2)).expect("2 ready");
421    }
422
423    #[test]
424    #[allow(clippy::let_underscore_future)]
425    fn unbounded_sender_doesnt_bound_on_want() {
426        let (tx, rx) = channel::<Custom, ()>();
427        let mut tx = tx.unbound();
428
429        let _ = tx.try_send(Custom(1)).unwrap();
430        let _ = tx.try_send(Custom(2)).unwrap();
431        let _ = tx.try_send(Custom(3)).unwrap();
432
433        drop(rx);
434
435        let _ = tx.try_send(Custom(4)).unwrap_err();
436    }
437}