hyper_util/client/pool/
negotiate.rs

1//! Negotiate a pool of services
2//!
3//! The negotiate pool allows for a service that can decide between two service
4//! types based on an intermediate return value. It differs from typical
5//! routing since it doesn't depend on the request, but the response.
6//!
7//! The original use case is support ALPN upgrades to HTTP/2, with a fallback
8//! to HTTP/1.
9//!
10//! # Example
11//!
12//! ```rust,ignore
13//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
14//! # struct Conn;
15//! # impl Conn { fn negotiated_protocol(&self) -> &[u8] { b"h2" } }
16//! # let some_tls_connector = tower::service::service_fn(|_| async move {
17//! #     Ok::<_, std::convert::Infallible>(Conn)
18//! # });
19//! # let http1_layer = tower::layer::layer_fn(|s| s);
20//! # let http2_layer = tower::layer::layer_fn(|s| s);
21//! let mut pool = hyper_util::client::pool::negotiate::builder()
22//!     .connect(some_tls_connector)
23//!     .inspect(|c| c.negotiated_protocol() == b"h2")
24//!     .fallback(http1_layer)
25//!     .upgrade(http2_layer)
26//!     .build();
27//!
28//! // connect
29//! let mut svc = pool.call(http::Uri::from_static("https://hyper.rs")).await?;
30//! svc.ready().await;
31//!
32//! // http1 or http2 is now set up
33//! # let some_http_req = http::Request::new(());
34//! let resp = svc.call(some_http_req).await?;
35//! # Ok(())
36//! # }
37//! ```
38
39pub use self::internal::builder;
40
41#[cfg(docsrs)]
42pub use self::internal::Builder;
43#[cfg(docsrs)]
44pub use self::internal::Negotiate;
45#[cfg(docsrs)]
46pub use self::internal::Negotiated;
47
48mod internal {
49    use std::future::Future;
50    use std::pin::Pin;
51    use std::sync::{Arc, Mutex};
52    use std::task::{self, Poll};
53
54    use futures_core::ready;
55    use pin_project_lite::pin_project;
56    use tower_layer::Layer;
57    use tower_service::Service;
58
59    type BoxError = Box<dyn std::error::Error + Send + Sync>;
60
61    /// A negotiating pool over an inner make service.
62    ///
63    /// Created with [`builder()`].
64    ///
65    /// # Unnameable
66    ///
67    /// This type is normally unnameable, forbidding naming of the type within
68    /// code. The type is exposed in the documentation to show which methods
69    /// can be publicly called.
70    #[derive(Clone)]
71    pub struct Negotiate<L, R> {
72        left: L,
73        right: R,
74    }
75
76    /// A negotiated service returned by [`Negotiate`].
77    ///
78    /// # Unnameable
79    ///
80    /// This type is normally unnameable, forbidding naming of the type within
81    /// code. The type is exposed in the documentation to show which methods
82    /// can be publicly called.
83    #[derive(Clone, Debug)]
84    pub enum Negotiated<L, R> {
85        #[doc(hidden)]
86        Fallback(L),
87        #[doc(hidden)]
88        Upgraded(R),
89    }
90
91    pin_project! {
92        pub struct Negotiating<Dst, L, R>
93        where
94            L: Service<Dst>,
95            R: Service<()>,
96        {
97            #[pin]
98            state: State<Dst, L::Future, R::Future>,
99            left: L,
100            right: R,
101        }
102    }
103
104    pin_project! {
105        #[project = StateProj]
106        enum State<Dst, FL, FR> {
107            Eager {
108                #[pin]
109                future: FR,
110                dst: Option<Dst>,
111            },
112            Fallback {
113                #[pin]
114                future: FL,
115            },
116            Upgrade {
117                #[pin]
118                future: FR,
119            }
120        }
121    }
122
123    pin_project! {
124        #[project = NegotiatedProj]
125        pub enum NegotiatedFuture<L, R> {
126            Fallback {
127                #[pin]
128                future: L
129            },
130            Upgraded {
131                #[pin]
132                future: R
133            },
134        }
135    }
136
137    /// A builder to configure a `Negotiate`.
138    ///
139    /// # Unnameable
140    ///
141    /// This type is normally unnameable, forbidding naming of the type within
142    /// code. The type is exposed in the documentation to show which methods
143    /// can be publicly called.
144    #[derive(Debug)]
145    pub struct Builder<C, I, L, R> {
146        connect: C,
147        inspect: I,
148        fallback: L,
149        upgrade: R,
150    }
151
152    #[derive(Debug)]
153    pub struct WantsConnect;
154    #[derive(Debug)]
155    pub struct WantsInspect;
156    #[derive(Debug)]
157    pub struct WantsFallback;
158    #[derive(Debug)]
159    pub struct WantsUpgrade;
160
161    /// Start a builder to construct a `Negotiate` pool.
162    pub fn builder() -> Builder<WantsConnect, WantsInspect, WantsFallback, WantsUpgrade> {
163        Builder {
164            connect: WantsConnect,
165            inspect: WantsInspect,
166            fallback: WantsFallback,
167            upgrade: WantsUpgrade,
168        }
169    }
170
171    impl<C, I, L, R> Builder<C, I, L, R> {
172        /// Provide the initial connector.
173        pub fn connect<CC>(self, connect: CC) -> Builder<CC, I, L, R> {
174            Builder {
175                connect,
176                inspect: self.inspect,
177                fallback: self.fallback,
178                upgrade: self.upgrade,
179            }
180        }
181
182        /// Provide the inspector that determines the result of the negotiation.
183        pub fn inspect<II>(self, inspect: II) -> Builder<C, II, L, R> {
184            Builder {
185                connect: self.connect,
186                inspect,
187                fallback: self.fallback,
188                upgrade: self.upgrade,
189            }
190        }
191
192        /// Provide the layer to fallback to if negotiation fails.
193        pub fn fallback<LL>(self, fallback: LL) -> Builder<C, I, LL, R> {
194            Builder {
195                connect: self.connect,
196                inspect: self.inspect,
197                fallback,
198                upgrade: self.upgrade,
199            }
200        }
201
202        /// Provide the layer to upgrade to if negotiation succeeds.
203        pub fn upgrade<RR>(self, upgrade: RR) -> Builder<C, I, L, RR> {
204            Builder {
205                connect: self.connect,
206                inspect: self.inspect,
207                fallback: self.fallback,
208                upgrade,
209            }
210        }
211
212        /// Build the `Negotiate` pool.
213        pub fn build<Dst>(self) -> Negotiate<L::Service, R::Service>
214        where
215            C: Service<Dst>,
216            C::Error: Into<BoxError>,
217            L: Layer<Inspector<C, C::Response, I>>,
218            L::Service: Service<Dst> + Clone,
219            <L::Service as Service<Dst>>::Error: Into<BoxError>,
220            R: Layer<Inspected<C::Response>>,
221            R::Service: Service<()> + Clone,
222            <R::Service as Service<()>>::Error: Into<BoxError>,
223            I: Fn(&C::Response) -> bool + Clone,
224        {
225            let Builder {
226                connect,
227                inspect,
228                fallback,
229                upgrade,
230            } = self;
231
232            let slot = Arc::new(Mutex::new(None));
233            let wrapped = Inspector {
234                svc: connect,
235                inspect,
236                slot: slot.clone(),
237            };
238            let left = fallback.layer(wrapped);
239
240            let right = upgrade.layer(Inspected { slot });
241
242            Negotiate { left, right }
243        }
244    }
245
246    impl<L, R> Negotiate<L, R> {
247        /// Get a mutable reference to the fallback service.
248        pub fn fallback_mut(&mut self) -> &mut L {
249            &mut self.left
250        }
251
252        /// Get a mutable reference to the upgrade service.
253        pub fn upgrade_mut(&mut self) -> &mut R {
254            &mut self.right
255        }
256    }
257
258    impl<L, R, Target> Service<Target> for Negotiate<L, R>
259    where
260        L: Service<Target> + Clone,
261        L::Error: Into<BoxError>,
262        R: Service<()> + Clone,
263        R::Error: Into<BoxError>,
264    {
265        type Response = Negotiated<L::Response, R::Response>;
266        type Error = BoxError;
267        type Future = Negotiating<Target, L, R>;
268
269        fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
270            self.left.poll_ready(cx).map_err(Into::into)
271        }
272
273        fn call(&mut self, dst: Target) -> Self::Future {
274            let left = self.left.clone();
275            Negotiating {
276                state: State::Eager {
277                    future: self.right.call(()),
278                    dst: Some(dst),
279                },
280                // place clone, take original that we already polled-ready.
281                left: std::mem::replace(&mut self.left, left),
282                right: self.right.clone(),
283            }
284        }
285    }
286
287    impl<Dst, L, R> Future for Negotiating<Dst, L, R>
288    where
289        L: Service<Dst>,
290        L::Error: Into<BoxError>,
291        R: Service<()>,
292        R::Error: Into<BoxError>,
293    {
294        type Output = Result<Negotiated<L::Response, R::Response>, BoxError>;
295
296        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
297            // States:
298            // - `Eager`: try the "right" path first; on `UseOther` sentinel, fall back to left.
299            // - `Fallback`: try the left path; on `UseOther` sentinel, upgrade back to right.
300            // - `Upgrade`: retry the right path after a fallback.
301            // If all fail, give up.
302            let mut me = self.project();
303            loop {
304                match me.state.as_mut().project() {
305                    StateProj::Eager { future, dst } => match ready!(future.poll(cx)) {
306                        Ok(out) => return Poll::Ready(Ok(Negotiated::Upgraded(out))),
307                        Err(err) => {
308                            let err = err.into();
309                            if UseOther::is(&*err) {
310                                let dst = dst.take().unwrap();
311                                let f = me.left.call(dst);
312                                me.state.set(State::Fallback { future: f });
313                                continue;
314                            } else {
315                                return Poll::Ready(Err(err));
316                            }
317                        }
318                    },
319                    StateProj::Fallback { future } => match ready!(future.poll(cx)) {
320                        Ok(out) => return Poll::Ready(Ok(Negotiated::Fallback(out))),
321                        Err(err) => {
322                            let err = err.into();
323                            if UseOther::is(&*err) {
324                                let f = me.right.call(());
325                                me.state.set(State::Upgrade { future: f });
326                                continue;
327                            } else {
328                                return Poll::Ready(Err(err));
329                            }
330                        }
331                    },
332                    StateProj::Upgrade { future } => match ready!(future.poll(cx)) {
333                        Ok(out) => return Poll::Ready(Ok(Negotiated::Upgraded(out))),
334                        Err(err) => return Poll::Ready(Err(err.into())),
335                    },
336                }
337            }
338        }
339    }
340
341    impl<L, R> Negotiated<L, R> {
342        // Could be useful?
343        #[cfg(test)]
344        pub(super) fn is_fallback(&self) -> bool {
345            matches!(self, Negotiated::Fallback(_))
346        }
347
348        #[cfg(test)]
349        pub(super) fn is_upgraded(&self) -> bool {
350            matches!(self, Negotiated::Upgraded(_))
351        }
352
353        // TODO: are these the correct methods? Or .as_ref().fallback(), etc?
354
355        /// Get a reference to the fallback service if this is it.
356        pub fn fallback_ref(&self) -> Option<&L> {
357            if let Negotiated::Fallback(ref left) = self {
358                Some(left)
359            } else {
360                None
361            }
362        }
363
364        /// Get a mutable reference to the fallback service if this is it.
365        pub fn fallback_mut(&mut self) -> Option<&mut L> {
366            if let Negotiated::Fallback(ref mut left) = self {
367                Some(left)
368            } else {
369                None
370            }
371        }
372
373        /// Get a reference to the upgraded service if this is it.
374        pub fn upgraded_ref(&self) -> Option<&R> {
375            if let Negotiated::Upgraded(ref right) = self {
376                Some(right)
377            } else {
378                None
379            }
380        }
381
382        /// Get a mutable reference to the upgraded service if this is it.
383        pub fn upgraded_mut(&mut self) -> Option<&mut R> {
384            if let Negotiated::Upgraded(ref mut right) = self {
385                Some(right)
386            } else {
387                None
388            }
389        }
390    }
391
392    impl<L, R, Req, Res, E> Service<Req> for Negotiated<L, R>
393    where
394        L: Service<Req, Response = Res, Error = E>,
395        R: Service<Req, Response = Res, Error = E>,
396    {
397        type Response = Res;
398        type Error = E;
399        type Future = NegotiatedFuture<L::Future, R::Future>;
400
401        fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
402            match self {
403                Negotiated::Fallback(ref mut s) => s.poll_ready(cx),
404                Negotiated::Upgraded(ref mut s) => s.poll_ready(cx),
405            }
406        }
407
408        fn call(&mut self, req: Req) -> Self::Future {
409            match self {
410                Negotiated::Fallback(ref mut s) => NegotiatedFuture::Fallback {
411                    future: s.call(req),
412                },
413                Negotiated::Upgraded(ref mut s) => NegotiatedFuture::Upgraded {
414                    future: s.call(req),
415                },
416            }
417        }
418    }
419
420    impl<L, R, Out> Future for NegotiatedFuture<L, R>
421    where
422        L: Future<Output = Out>,
423        R: Future<Output = Out>,
424    {
425        type Output = Out;
426
427        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
428            match self.project() {
429                NegotiatedProj::Fallback { future } => future.poll(cx),
430                NegotiatedProj::Upgraded { future } => future.poll(cx),
431            }
432        }
433    }
434
435    // ===== internal =====
436
437    pub struct Inspector<M, S, I> {
438        svc: M,
439        inspect: I,
440        slot: Arc<Mutex<Option<S>>>,
441    }
442
443    pin_project! {
444        pub struct InspectFuture<F, S, I> {
445            #[pin]
446            future: F,
447            inspect: I,
448            slot: Arc<Mutex<Option<S>>>,
449        }
450    }
451
452    impl<M: Clone, S, I: Clone> Clone for Inspector<M, S, I> {
453        fn clone(&self) -> Self {
454            Self {
455                svc: self.svc.clone(),
456                inspect: self.inspect.clone(),
457                slot: self.slot.clone(),
458            }
459        }
460    }
461
462    impl<M, S, I, Target> Service<Target> for Inspector<M, S, I>
463    where
464        M: Service<Target, Response = S>,
465        M::Error: Into<BoxError>,
466        I: Clone + Fn(&S) -> bool,
467    {
468        type Response = M::Response;
469        type Error = BoxError;
470        type Future = InspectFuture<M::Future, S, I>;
471
472        fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
473            self.svc.poll_ready(cx).map_err(Into::into)
474        }
475
476        fn call(&mut self, dst: Target) -> Self::Future {
477            InspectFuture {
478                future: self.svc.call(dst),
479                inspect: self.inspect.clone(),
480                slot: self.slot.clone(),
481            }
482        }
483    }
484
485    impl<F, I, S, E> Future for InspectFuture<F, S, I>
486    where
487        F: Future<Output = Result<S, E>>,
488        E: Into<BoxError>,
489        I: Fn(&S) -> bool,
490    {
491        type Output = Result<S, BoxError>;
492
493        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
494            let me = self.project();
495            let s = ready!(me.future.poll(cx)).map_err(Into::into)?;
496            Poll::Ready(if (me.inspect)(&s) {
497                *me.slot.lock().unwrap() = Some(s);
498                Err(UseOther.into())
499            } else {
500                Ok(s)
501            })
502        }
503    }
504
505    pub struct Inspected<S> {
506        slot: Arc<Mutex<Option<S>>>,
507    }
508
509    impl<S, Target> Service<Target> for Inspected<S> {
510        type Response = S;
511        type Error = BoxError;
512        type Future = std::future::Ready<Result<S, BoxError>>;
513
514        fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
515            if self.slot.lock().unwrap().is_some() {
516                Poll::Ready(Ok(()))
517            } else {
518                Poll::Ready(Err(UseOther.into()))
519            }
520        }
521
522        fn call(&mut self, _dst: Target) -> Self::Future {
523            let s = self
524                .slot
525                .lock()
526                .unwrap()
527                .take()
528                .ok_or_else(|| UseOther.into());
529            std::future::ready(s)
530        }
531    }
532
533    impl<S> Clone for Inspected<S> {
534        fn clone(&self) -> Inspected<S> {
535            Inspected {
536                slot: self.slot.clone(),
537            }
538        }
539    }
540
541    #[derive(Debug)]
542    struct UseOther;
543
544    impl std::fmt::Display for UseOther {
545        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546            f.write_str("sentinel error; using other")
547        }
548    }
549
550    impl std::error::Error for UseOther {}
551
552    impl UseOther {
553        fn is(err: &(dyn std::error::Error + 'static)) -> bool {
554            let mut source = Some(err);
555            while let Some(err) = source {
556                if err.is::<UseOther>() {
557                    return true;
558                }
559                source = err.source();
560            }
561            false
562        }
563    }
564}
565
566#[cfg(test)]
567mod tests {
568    use futures_util::future;
569    use tower_service::Service;
570    use tower_test::assert_request_eq;
571
572    #[tokio::test]
573    async fn not_negotiated_falls_back_to_left() {
574        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
575
576        let mut negotiate = super::builder()
577            .connect(mock_svc)
578            .inspect(|_: &&str| false)
579            .fallback(layer_fn(|s| s))
580            .upgrade(layer_fn(|s| s))
581            .build();
582
583        crate::common::future::poll_fn(|cx| negotiate.poll_ready(cx))
584            .await
585            .unwrap();
586
587        let fut = negotiate.call(());
588        let nsvc = future::join(fut, async move {
589            assert_request_eq!(handle, ()).send_response("one");
590        })
591        .await
592        .0
593        .expect("call");
594        assert!(nsvc.is_fallback());
595    }
596
597    #[tokio::test]
598    async fn negotiated_uses_right() {
599        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
600
601        let mut negotiate = super::builder()
602            .connect(mock_svc)
603            .inspect(|_: &&str| true)
604            .fallback(layer_fn(|s| s))
605            .upgrade(layer_fn(|s| s))
606            .build();
607
608        crate::common::future::poll_fn(|cx| negotiate.poll_ready(cx))
609            .await
610            .unwrap();
611
612        let fut = negotiate.call(());
613        let nsvc = future::join(fut, async move {
614            assert_request_eq!(handle, ()).send_response("one");
615        })
616        .await
617        .0
618        .expect("call");
619
620        assert!(nsvc.is_upgraded());
621    }
622
623    fn layer_fn<F>(f: F) -> LayerFn<F> {
624        LayerFn(f)
625    }
626
627    #[derive(Clone)]
628    struct LayerFn<F>(F);
629
630    impl<F, S, Out> tower_layer::Layer<S> for LayerFn<F>
631    where
632        F: Fn(S) -> Out,
633    {
634        type Service = Out;
635        fn layer(&self, inner: S) -> Self::Service {
636            (self.0)(inner)
637        }
638    }
639}