Skip to main content

hyper_util/client/pool/
negotiate.rs

1//! Negotiate a pool of services
2//!
3//! The negotiate pool allows for a service that can decide between two service
4//! types based on an intermediate return value. It differs from typical
5//! routing since it doesn't depend on the request, but the response.
6//!
7//! The original use case is support ALPN upgrades to HTTP/2, with a fallback
8//! to HTTP/1.
9//!
10//! # Example
11//!
12//! ```rust,ignore
13//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
14//! # struct Conn;
15//! # impl Conn { fn negotiated_protocol(&self) -> &[u8] { b"h2" } }
16//! # let some_tls_connector = tower::service::service_fn(|_| async move {
17//! #     Ok::<_, std::convert::Infallible>(Conn)
18//! # });
19//! # let http1_layer = tower::layer::layer_fn(|s| s);
20//! # let http2_layer = tower::layer::layer_fn(|s| s);
21//! let mut pool = hyper_util::client::pool::negotiate::builder()
22//!     .connect(some_tls_connector)
23//!     .inspect(|c| c.negotiated_protocol() == b"h2")
24//!     .fallback(http1_layer)
25//!     .upgrade(http2_layer)
26//!     .build();
27//!
28//! // connect
29//! let mut svc = pool.call(http::Uri::from_static("https://hyper.rs")).await?;
30//! svc.ready().await;
31//!
32//! // http1 or http2 is now set up
33//! # let some_http_req = http::Request::new(());
34//! let resp = svc.call(some_http_req).await?;
35//! # Ok(())
36//! # }
37//! ```
38
39pub use self::internal::builder;
40
41#[cfg(docsrs)]
42pub use self::internal::Builder;
43#[cfg(docsrs)]
44pub use self::internal::Negotiate;
45#[cfg(docsrs)]
46pub use self::internal::Negotiated;
47
48mod internal {
49    use std::future::Future;
50    use std::pin::Pin;
51    use std::sync::{Arc, Mutex};
52    use std::task::{self, ready, Poll};
53
54    use pin_project_lite::pin_project;
55    use tower_layer::Layer;
56    use tower_service::Service;
57
58    type BoxError = Box<dyn std::error::Error + Send + Sync>;
59
60    /// A negotiating pool over an inner make service.
61    ///
62    /// Created with [`builder()`].
63    ///
64    /// # Unnameable
65    ///
66    /// This type is normally unnameable, forbidding naming of the type within
67    /// code. The type is exposed in the documentation to show which methods
68    /// can be publicly called.
69    #[derive(Clone)]
70    pub struct Negotiate<L, R> {
71        left: L,
72        right: R,
73    }
74
75    /// A negotiated service returned by [`Negotiate`].
76    ///
77    /// # Unnameable
78    ///
79    /// This type is normally unnameable, forbidding naming of the type within
80    /// code. The type is exposed in the documentation to show which methods
81    /// can be publicly called.
82    #[derive(Clone, Debug)]
83    pub enum Negotiated<L, R> {
84        #[doc(hidden)]
85        Fallback(L),
86        #[doc(hidden)]
87        Upgraded(R),
88    }
89
90    pin_project! {
91        pub struct Negotiating<Dst, L, R>
92        where
93            L: Service<Dst>,
94            R: Service<()>,
95        {
96            #[pin]
97            state: State<Dst, L::Future, R::Future>,
98            left: L,
99            right: R,
100        }
101    }
102
103    pin_project! {
104        #[project = StateProj]
105        enum State<Dst, FL, FR> {
106            Eager {
107                #[pin]
108                future: FR,
109                dst: Option<Dst>,
110            },
111            Fallback {
112                #[pin]
113                future: FL,
114            },
115            Upgrade {
116                #[pin]
117                future: FR,
118            }
119        }
120    }
121
122    pin_project! {
123        #[project = NegotiatedProj]
124        pub enum NegotiatedFuture<L, R> {
125            Fallback {
126                #[pin]
127                future: L
128            },
129            Upgraded {
130                #[pin]
131                future: R
132            },
133        }
134    }
135
136    /// A builder to configure a `Negotiate`.
137    ///
138    /// # Unnameable
139    ///
140    /// This type is normally unnameable, forbidding naming of the type within
141    /// code. The type is exposed in the documentation to show which methods
142    /// can be publicly called.
143    #[derive(Debug)]
144    pub struct Builder<C, I, L, R> {
145        connect: C,
146        inspect: I,
147        fallback: L,
148        upgrade: R,
149    }
150
151    #[derive(Debug)]
152    pub struct WantsConnect;
153    #[derive(Debug)]
154    pub struct WantsInspect;
155    #[derive(Debug)]
156    pub struct WantsFallback;
157    #[derive(Debug)]
158    pub struct WantsUpgrade;
159
160    /// Start a builder to construct a `Negotiate` pool.
161    pub fn builder() -> Builder<WantsConnect, WantsInspect, WantsFallback, WantsUpgrade> {
162        Builder {
163            connect: WantsConnect,
164            inspect: WantsInspect,
165            fallback: WantsFallback,
166            upgrade: WantsUpgrade,
167        }
168    }
169
170    impl<C, I, L, R> Builder<C, I, L, R> {
171        /// Provide the initial connector.
172        pub fn connect<CC>(self, connect: CC) -> Builder<CC, I, L, R> {
173            Builder {
174                connect,
175                inspect: self.inspect,
176                fallback: self.fallback,
177                upgrade: self.upgrade,
178            }
179        }
180
181        /// Provide the inspector that determines the result of the negotiation.
182        pub fn inspect<II>(self, inspect: II) -> Builder<C, II, L, R> {
183            Builder {
184                connect: self.connect,
185                inspect,
186                fallback: self.fallback,
187                upgrade: self.upgrade,
188            }
189        }
190
191        /// Provide the layer to fallback to if negotiation fails.
192        pub fn fallback<LL>(self, fallback: LL) -> Builder<C, I, LL, R> {
193            Builder {
194                connect: self.connect,
195                inspect: self.inspect,
196                fallback,
197                upgrade: self.upgrade,
198            }
199        }
200
201        /// Provide the layer to upgrade to if negotiation succeeds.
202        pub fn upgrade<RR>(self, upgrade: RR) -> Builder<C, I, L, RR> {
203            Builder {
204                connect: self.connect,
205                inspect: self.inspect,
206                fallback: self.fallback,
207                upgrade,
208            }
209        }
210
211        /// Build the `Negotiate` pool.
212        pub fn build<Dst>(self) -> Negotiate<L::Service, R::Service>
213        where
214            C: Service<Dst>,
215            C::Error: Into<BoxError>,
216            L: Layer<Inspector<C, C::Response, I>>,
217            L::Service: Service<Dst> + Clone,
218            <L::Service as Service<Dst>>::Error: Into<BoxError>,
219            R: Layer<Inspected<C::Response>>,
220            R::Service: Service<()> + Clone,
221            <R::Service as Service<()>>::Error: Into<BoxError>,
222            I: Fn(&C::Response) -> bool + Clone,
223        {
224            let Builder {
225                connect,
226                inspect,
227                fallback,
228                upgrade,
229            } = self;
230
231            let slot = Arc::new(Mutex::new(None));
232            let wrapped = Inspector {
233                svc: connect,
234                inspect,
235                slot: slot.clone(),
236            };
237            let left = fallback.layer(wrapped);
238
239            let right = upgrade.layer(Inspected { slot });
240
241            Negotiate { left, right }
242        }
243    }
244
245    impl<L, R> Negotiate<L, R> {
246        /// Get a mutable reference to the fallback service.
247        pub fn fallback_mut(&mut self) -> &mut L {
248            &mut self.left
249        }
250
251        /// Get a mutable reference to the upgrade service.
252        pub fn upgrade_mut(&mut self) -> &mut R {
253            &mut self.right
254        }
255    }
256
257    impl<L, R, Target> Service<Target> for Negotiate<L, R>
258    where
259        L: Service<Target> + Clone,
260        L::Error: Into<BoxError>,
261        R: Service<()> + Clone,
262        R::Error: Into<BoxError>,
263    {
264        type Response = Negotiated<L::Response, R::Response>;
265        type Error = BoxError;
266        type Future = Negotiating<Target, L, R>;
267
268        fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
269            self.left.poll_ready(cx).map_err(Into::into)
270        }
271
272        fn call(&mut self, dst: Target) -> Self::Future {
273            let left = self.left.clone();
274            Negotiating {
275                state: State::Eager {
276                    future: self.right.call(()),
277                    dst: Some(dst),
278                },
279                // place clone, take original that we already polled-ready.
280                left: std::mem::replace(&mut self.left, left),
281                right: self.right.clone(),
282            }
283        }
284    }
285
286    impl<Dst, L, R> Future for Negotiating<Dst, L, R>
287    where
288        L: Service<Dst>,
289        L::Error: Into<BoxError>,
290        R: Service<()>,
291        R::Error: Into<BoxError>,
292    {
293        type Output = Result<Negotiated<L::Response, R::Response>, BoxError>;
294
295        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
296            // States:
297            // - `Eager`: try the "right" path first; on `UseOther` sentinel, fall back to left.
298            // - `Fallback`: try the left path; on `UseOther` sentinel, upgrade back to right.
299            // - `Upgrade`: retry the right path after a fallback.
300            // If all fail, give up.
301            let mut me = self.project();
302            loop {
303                match me.state.as_mut().project() {
304                    StateProj::Eager { future, dst } => match ready!(future.poll(cx)) {
305                        Ok(out) => return Poll::Ready(Ok(Negotiated::Upgraded(out))),
306                        Err(err) => {
307                            let err = err.into();
308                            if UseOther::is(&*err) {
309                                let dst = dst.take().unwrap();
310                                let f = me.left.call(dst);
311                                me.state.set(State::Fallback { future: f });
312                                continue;
313                            } else {
314                                return Poll::Ready(Err(err));
315                            }
316                        }
317                    },
318                    StateProj::Fallback { future } => match ready!(future.poll(cx)) {
319                        Ok(out) => return Poll::Ready(Ok(Negotiated::Fallback(out))),
320                        Err(err) => {
321                            let err = err.into();
322                            if UseOther::is(&*err) {
323                                let f = me.right.call(());
324                                me.state.set(State::Upgrade { future: f });
325                                continue;
326                            } else {
327                                return Poll::Ready(Err(err));
328                            }
329                        }
330                    },
331                    StateProj::Upgrade { future } => match ready!(future.poll(cx)) {
332                        Ok(out) => return Poll::Ready(Ok(Negotiated::Upgraded(out))),
333                        Err(err) => return Poll::Ready(Err(err.into())),
334                    },
335                }
336            }
337        }
338    }
339
340    impl<L, R> Negotiated<L, R> {
341        // Could be useful?
342        #[cfg(test)]
343        pub(super) fn is_fallback(&self) -> bool {
344            matches!(self, Negotiated::Fallback(_))
345        }
346
347        #[cfg(test)]
348        pub(super) fn is_upgraded(&self) -> bool {
349            matches!(self, Negotiated::Upgraded(_))
350        }
351
352        // TODO: are these the correct methods? Or .as_ref().fallback(), etc?
353
354        /// Get a reference to the fallback service if this is it.
355        pub fn fallback_ref(&self) -> Option<&L> {
356            if let Negotiated::Fallback(ref left) = self {
357                Some(left)
358            } else {
359                None
360            }
361        }
362
363        /// Get a mutable reference to the fallback service if this is it.
364        pub fn fallback_mut(&mut self) -> Option<&mut L> {
365            if let Negotiated::Fallback(ref mut left) = self {
366                Some(left)
367            } else {
368                None
369            }
370        }
371
372        /// Get a reference to the upgraded service if this is it.
373        pub fn upgraded_ref(&self) -> Option<&R> {
374            if let Negotiated::Upgraded(ref right) = self {
375                Some(right)
376            } else {
377                None
378            }
379        }
380
381        /// Get a mutable reference to the upgraded service if this is it.
382        pub fn upgraded_mut(&mut self) -> Option<&mut R> {
383            if let Negotiated::Upgraded(ref mut right) = self {
384                Some(right)
385            } else {
386                None
387            }
388        }
389    }
390
391    impl<L, R, Req, Res, E> Service<Req> for Negotiated<L, R>
392    where
393        L: Service<Req, Response = Res, Error = E>,
394        R: Service<Req, Response = Res, Error = E>,
395    {
396        type Response = Res;
397        type Error = E;
398        type Future = NegotiatedFuture<L::Future, R::Future>;
399
400        fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
401            match self {
402                Negotiated::Fallback(ref mut s) => s.poll_ready(cx),
403                Negotiated::Upgraded(ref mut s) => s.poll_ready(cx),
404            }
405        }
406
407        fn call(&mut self, req: Req) -> Self::Future {
408            match self {
409                Negotiated::Fallback(ref mut s) => NegotiatedFuture::Fallback {
410                    future: s.call(req),
411                },
412                Negotiated::Upgraded(ref mut s) => NegotiatedFuture::Upgraded {
413                    future: s.call(req),
414                },
415            }
416        }
417    }
418
419    impl<L, R, Out> Future for NegotiatedFuture<L, R>
420    where
421        L: Future<Output = Out>,
422        R: Future<Output = Out>,
423    {
424        type Output = Out;
425
426        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
427            match self.project() {
428                NegotiatedProj::Fallback { future } => future.poll(cx),
429                NegotiatedProj::Upgraded { future } => future.poll(cx),
430            }
431        }
432    }
433
434    // ===== internal =====
435
436    pub struct Inspector<M, S, I> {
437        svc: M,
438        inspect: I,
439        slot: Arc<Mutex<Option<S>>>,
440    }
441
442    pin_project! {
443        pub struct InspectFuture<F, S, I> {
444            #[pin]
445            future: F,
446            inspect: I,
447            slot: Arc<Mutex<Option<S>>>,
448        }
449    }
450
451    impl<M: Clone, S, I: Clone> Clone for Inspector<M, S, I> {
452        fn clone(&self) -> Self {
453            Self {
454                svc: self.svc.clone(),
455                inspect: self.inspect.clone(),
456                slot: self.slot.clone(),
457            }
458        }
459    }
460
461    impl<M, S, I, Target> Service<Target> for Inspector<M, S, I>
462    where
463        M: Service<Target, Response = S>,
464        M::Error: Into<BoxError>,
465        I: Clone + Fn(&S) -> bool,
466    {
467        type Response = M::Response;
468        type Error = BoxError;
469        type Future = InspectFuture<M::Future, S, I>;
470
471        fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
472            self.svc.poll_ready(cx).map_err(Into::into)
473        }
474
475        fn call(&mut self, dst: Target) -> Self::Future {
476            InspectFuture {
477                future: self.svc.call(dst),
478                inspect: self.inspect.clone(),
479                slot: self.slot.clone(),
480            }
481        }
482    }
483
484    impl<F, I, S, E> Future for InspectFuture<F, S, I>
485    where
486        F: Future<Output = Result<S, E>>,
487        E: Into<BoxError>,
488        I: Fn(&S) -> bool,
489    {
490        type Output = Result<S, BoxError>;
491
492        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
493            let me = self.project();
494            let s = ready!(me.future.poll(cx)).map_err(Into::into)?;
495            Poll::Ready(if (me.inspect)(&s) {
496                *me.slot.lock().unwrap() = Some(s);
497                Err(UseOther.into())
498            } else {
499                Ok(s)
500            })
501        }
502    }
503
504    pub struct Inspected<S> {
505        slot: Arc<Mutex<Option<S>>>,
506    }
507
508    impl<S, Target> Service<Target> for Inspected<S> {
509        type Response = S;
510        type Error = BoxError;
511        type Future = std::future::Ready<Result<S, BoxError>>;
512
513        fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
514            if self.slot.lock().unwrap().is_some() {
515                Poll::Ready(Ok(()))
516            } else {
517                Poll::Ready(Err(UseOther.into()))
518            }
519        }
520
521        fn call(&mut self, _dst: Target) -> Self::Future {
522            let s = self
523                .slot
524                .lock()
525                .unwrap()
526                .take()
527                .ok_or_else(|| UseOther.into());
528            std::future::ready(s)
529        }
530    }
531
532    impl<S> Clone for Inspected<S> {
533        fn clone(&self) -> Inspected<S> {
534            Inspected {
535                slot: self.slot.clone(),
536            }
537        }
538    }
539
540    #[derive(Debug)]
541    struct UseOther;
542
543    impl std::fmt::Display for UseOther {
544        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
545            f.write_str("sentinel error; using other")
546        }
547    }
548
549    impl std::error::Error for UseOther {}
550
551    impl UseOther {
552        fn is(err: &(dyn std::error::Error + 'static)) -> bool {
553            let mut source = Some(err);
554            while let Some(err) = source {
555                if err.is::<UseOther>() {
556                    return true;
557                }
558                source = err.source();
559            }
560            false
561        }
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use futures_util::future;
568    use tower_service::Service;
569    use tower_test::assert_request_eq;
570
571    #[tokio::test]
572    async fn not_negotiated_falls_back_to_left() {
573        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
574
575        let mut negotiate = super::builder()
576            .connect(mock_svc)
577            .inspect(|_: &&str| false)
578            .fallback(layer_fn(|s| s))
579            .upgrade(layer_fn(|s| s))
580            .build();
581
582        std::future::poll_fn(|cx| negotiate.poll_ready(cx))
583            .await
584            .unwrap();
585
586        let fut = negotiate.call(());
587        let nsvc = future::join(fut, async move {
588            assert_request_eq!(handle, ()).send_response("one");
589        })
590        .await
591        .0
592        .expect("call");
593        assert!(nsvc.is_fallback());
594    }
595
596    #[tokio::test]
597    async fn negotiated_uses_right() {
598        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
599
600        let mut negotiate = super::builder()
601            .connect(mock_svc)
602            .inspect(|_: &&str| true)
603            .fallback(layer_fn(|s| s))
604            .upgrade(layer_fn(|s| s))
605            .build();
606
607        std::future::poll_fn(|cx| negotiate.poll_ready(cx))
608            .await
609            .unwrap();
610
611        let fut = negotiate.call(());
612        let nsvc = future::join(fut, async move {
613            assert_request_eq!(handle, ()).send_response("one");
614        })
615        .await
616        .0
617        .expect("call");
618
619        assert!(nsvc.is_upgraded());
620    }
621
622    fn layer_fn<F>(f: F) -> LayerFn<F> {
623        LayerFn(f)
624    }
625
626    #[derive(Clone)]
627    struct LayerFn<F>(F);
628
629    impl<F, S, Out> tower_layer::Layer<S> for LayerFn<F>
630    where
631        F: Fn(S) -> Out,
632    {
633        type Service = Out;
634        fn layer(&self, inner: S) -> Self::Service {
635            (self.0)(inner)
636        }
637    }
638}