hyperdriver/service/
http.rs

1use std::future::Future;
2use std::ops::Deref as _;
3use std::task::{Context, Poll};
4
5use http::{Request, Response};
6use http_body::Body as HttpBody;
7
8pub use crate::info::HttpProtocol;
9use crate::BoxError;
10use chateau::client::conn::Connection;
11use chateau::client::pool::{PoolableConnection, Pooled};
12
13/// An asynchronous function from `Request` to `Response`.
14pub trait HttpService<ReqBody> {
15    /// The `HttpBody` body of the `http::Response`.
16    type ResBody: HttpBody;
17
18    /// The error type that can occur within this `Service`.
19    ///
20    /// Note: Returning an `Error` to a hyper server will cause the connection
21    /// to be abruptly aborted. In most cases, it is better to return a `Response`
22    /// with a 4xx or 5xx status code.
23    type Error: Into<BoxError>;
24
25    /// The `Future` returned by this `Service`.
26    type Future: Future<Output = Result<Response<Self::ResBody>, Self::Error>>;
27
28    #[doc(hidden)]
29    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
30
31    #[doc(hidden)]
32    fn call(&mut self, req: Request<ReqBody>) -> Self::Future;
33}
34
35impl<T, BIn, BOut> HttpService<BIn> for T
36where
37    T: tower::Service<Request<BIn>, Response = Response<BOut>>,
38    BOut: HttpBody,
39    T::Error: Into<BoxError>,
40{
41    type ResBody = BOut;
42
43    type Error = T::Error;
44    type Future = T::Future;
45
46    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
47        tower::Service::poll_ready(self, cx)
48    }
49
50    fn call(&mut self, req: Request<BIn>) -> Self::Future {
51        tower::Service::call(self, req)
52    }
53}
54
55/// A trait to specialize connection info for HTTP connections
56pub trait HttpConnectionInfo<B>: Connection<http::Request<B>> {
57    /// Return the protocol version for this connection.
58    fn version(&self) -> HttpProtocol;
59}
60
61impl<C, B> HttpConnectionInfo<B> for Pooled<C, http::Request<B>>
62where
63    C: HttpConnectionInfo<B> + PoolableConnection<http::Request<B>>,
64    B: Send,
65{
66    fn version(&self) -> HttpProtocol {
67        self.deref().version()
68    }
69}
70
71#[cfg(feature = "client")]
72pub(super) mod http1 {
73
74    use std::fmt;
75    use std::task::{Context, Poll};
76
77    use ::http;
78    use http::uri::Scheme;
79    use http::Uri;
80    use tower::util::MapRequest;
81    use tower::ServiceExt;
82
83    use crate::service::http::HttpProtocol;
84
85    use super::HttpConnectionInfo;
86
87    type PreprocessFn<C, B> = fn((C, http::Request<B>)) -> (C, http::Request<B>);
88
89    /// A service that checks if the request is HTTP/1.1 compatible.
90    #[derive(Debug)]
91    pub struct Http1ChecksService<S, C, B>
92    where
93        S: tower::Service<(C, http::Request<B>)>,
94        C: HttpConnectionInfo<B>,
95    {
96        inner: MapRequest<S, PreprocessFn<C, B>>,
97    }
98
99    impl<S, C, B> tower::Service<(C, http::Request<B>)> for Http1ChecksService<S, C, B>
100    where
101        S: tower::Service<(C, http::Request<B>)>,
102        C: HttpConnectionInfo<B>,
103    {
104        type Response = S::Response;
105
106        type Error = S::Error;
107
108        type Future = S::Future;
109
110        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
111            self.inner.poll_ready(cx)
112        }
113
114        fn call(&mut self, req: (C, http::Request<B>)) -> Self::Future {
115            self.inner.call(req)
116        }
117    }
118
119    impl<S, C, B> Clone for Http1ChecksService<S, C, B>
120    where
121        S: tower::Service<(C, http::Request<B>)> + Clone,
122        C: HttpConnectionInfo<B>,
123    {
124        fn clone(&self) -> Self {
125            Self {
126                inner: self.inner.clone(),
127            }
128        }
129    }
130
131    impl<S, C, B> Http1ChecksService<S, C, B>
132    where
133        S: tower::Service<(C, http::Request<B>)>,
134        C: HttpConnectionInfo<B>,
135    {
136        /// Create a new `Http1ChecksService`.
137        pub fn new(service: S) -> Self {
138            Self {
139                inner: service.map_request(check_http1_request),
140            }
141        }
142    }
143
144    /// A layer that checks if the request is HTTP/1.1 compatible.
145    pub struct Http1ChecksLayer<C, B> {
146        processor: std::marker::PhantomData<fn(C, B)>,
147    }
148
149    impl<C, B> Http1ChecksLayer<C, B> {
150        /// Create a new `Http1ChecksLayer`.
151        pub fn new() -> Self {
152            Self {
153                processor: std::marker::PhantomData,
154            }
155        }
156    }
157
158    impl<C, B> Default for Http1ChecksLayer<C, B> {
159        fn default() -> Self {
160            Self::new()
161        }
162    }
163
164    impl<C, B> Clone for Http1ChecksLayer<C, B> {
165        fn clone(&self) -> Self {
166            Self::new()
167        }
168    }
169
170    impl<C, B> fmt::Debug for Http1ChecksLayer<C, B> {
171        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172            f.debug_struct("Http1ChecksLayer").finish()
173        }
174    }
175
176    impl<C, B, S> tower::layer::Layer<S> for Http1ChecksLayer<C, B>
177    where
178        S: tower::Service<(C, http::Request<B>)>,
179        C: HttpConnectionInfo<B>,
180    {
181        type Service = Http1ChecksService<S, C, B>;
182
183        fn layer(&self, service: S) -> Self::Service {
184            Http1ChecksService::new(service)
185        }
186    }
187
188    fn check_http1_request<C, B>((conn, mut req): (C, http::Request<B>)) -> (C, http::Request<B>)
189    where
190        C: HttpConnectionInfo<B>,
191    {
192        if conn.version() != HttpProtocol::Http1 {
193            return (conn, req);
194        }
195
196        if req.method() == http::Method::CONNECT {
197            authority_form(req.uri_mut());
198
199            // If the URI is to HTTPS, and the connector claimed to be a proxy,
200            // then it *should* have tunneled, and so we don't want to send
201            // absolute-form in that case.
202            if req.uri().scheme() == Some(&Scheme::HTTPS) {
203                origin_form(req.uri_mut());
204            }
205        } else if req.uri().scheme().is_none() || req.uri().authority().is_none() {
206            absolute_form(req.uri_mut());
207        } else {
208            origin_form(req.uri_mut());
209        }
210
211        (conn, req)
212    }
213
214    /// Convert the URI to authority-form, if it is not already.
215    ///
216    /// This is the form of the URI with just the authority and a default
217    /// path and scheme. This is used in HTTP/1 CONNECT requests.
218    fn authority_form(uri: &mut Uri) {
219        *uri = match uri.authority() {
220            Some(auth) => {
221                let mut parts = ::http::uri::Parts::default();
222                parts.authority = Some(auth.clone());
223                Uri::from_parts(parts).expect("authority is valid")
224            }
225            None => {
226                unreachable!("authority_form with relative uri");
227            }
228        };
229    }
230
231    fn absolute_form(uri: &mut Uri) {
232        debug_assert!(uri.scheme().is_some(), "absolute_form needs a scheme");
233        debug_assert!(
234            uri.authority().is_some(),
235            "absolute_form needs an authority"
236        );
237    }
238
239    /// Convert the URI to origin-form, if it is not already.
240    ///
241    /// This form of the URI has no scheme or authority, and contains just
242    /// the path, usually used in HTTP/1 requests.
243    fn origin_form(uri: &mut Uri) {
244        let path = match uri.path_and_query() {
245            Some(path) if path.as_str() != "/" => {
246                let mut parts = ::http::uri::Parts::default();
247                parts.path_and_query = Some(path.clone());
248                Uri::from_parts(parts).expect("path is valid uri")
249            }
250            _none_or_just_slash => {
251                debug_assert!(Uri::default() == "/");
252                Uri::default()
253            }
254        };
255        *uri = path
256    }
257
258    #[cfg(test)]
259    mod tests {
260
261        use super::*;
262
263        #[test]
264        fn test_origin_form() {
265            let mut uri = "http://example.com".parse().unwrap();
266            origin_form(&mut uri);
267            assert_eq!(uri, "/");
268
269            let mut uri = "/some/path/here".parse().unwrap();
270            origin_form(&mut uri);
271            assert_eq!(uri, "/some/path/here");
272
273            let mut uri = "http://example.com:8080/some/path?query#fragment"
274                .parse()
275                .unwrap();
276            origin_form(&mut uri);
277            assert_eq!(uri, "/some/path?query");
278
279            let mut uri = "/".parse().unwrap();
280            origin_form(&mut uri);
281            assert_eq!(uri, "/");
282        }
283
284        #[test]
285        fn test_absolute_form() {
286            let mut uri = "http://example.com".parse().unwrap();
287            absolute_form(&mut uri);
288            assert_eq!(uri, "http://example.com");
289
290            let mut uri = "http://example.com:8080".parse().unwrap();
291            absolute_form(&mut uri);
292            assert_eq!(uri, "http://example.com:8080");
293
294            let mut uri = "https://example.com/some/path?query".parse().unwrap();
295            absolute_form(&mut uri);
296            assert_eq!(uri, "https://example.com/some/path?query");
297
298            let mut uri = "https://example.com:8443".parse().unwrap();
299            absolute_form(&mut uri);
300            assert_eq!(uri, "https://example.com:8443");
301
302            let mut uri = "http://example.com:443".parse().unwrap();
303            absolute_form(&mut uri);
304            assert_eq!(uri, "http://example.com:443");
305
306            let mut uri = "https://example.com:80".parse().unwrap();
307            absolute_form(&mut uri);
308            assert_eq!(uri, "https://example.com:80");
309        }
310    }
311}
312
313#[cfg(feature = "client")]
314pub(super) mod http2 {
315    use std::fmt;
316    use std::task::{Context, Poll};
317
318    use ::http;
319
320    use crate::service::http::HttpProtocol;
321
322    use super::HttpConnectionInfo;
323
324    const CONNECTION_HEADERS: [http::HeaderName; 5] = [
325        http::header::CONNECTION,
326        http::HeaderName::from_static("proxy-connection"),
327        http::HeaderName::from_static("keep-alive"),
328        http::header::TRANSFER_ENCODING,
329        http::header::UPGRADE,
330    ];
331
332    #[derive(Debug, thiserror::Error)]
333    pub enum HttpRequestError<E> {
334        #[error("Invalid HTTP method for HTTP/2: {0}")]
335        InvalidMethod(http::Method),
336
337        #[error(transparent)]
338        Connection(E),
339    }
340
341    /// A service that checks if the request is HTTP/2 compatible.
342    #[derive(Debug, Clone)]
343    pub struct Http2ChecksService<S> {
344        inner: S,
345    }
346
347    impl<S> Http2ChecksService<S> {
348        /// Create a new `Http2ChecksService`.
349        pub fn new(inner: S) -> Self {
350            Self { inner }
351        }
352    }
353
354    fn check_http2_request<C, B, E>(
355        (conn, mut req): (C, http::Request<B>),
356    ) -> Result<(C, http::Request<B>), HttpRequestError<E>>
357    where
358        C: HttpConnectionInfo<B>,
359    {
360        if conn.version() == HttpProtocol::Http2 {
361            if req.method() == http::Method::CONNECT {
362                tracing::warn!("CONNECT method not allowed on HTTP/2");
363                return Err(HttpRequestError::InvalidMethod(http::Method::CONNECT));
364            }
365
366            *req.version_mut() = http::Version::HTTP_2;
367
368            for connection_header in &CONNECTION_HEADERS {
369                if req.headers_mut().remove(connection_header).is_some() {
370                    tracing::warn!(
371                        "removed illegal connection header {:?} from HTTP/2 request",
372                        connection_header
373                    );
374                };
375            }
376
377            if req.headers_mut().remove(http::header::HOST).is_some() {
378                tracing::warn!("removed illegal header `host` from HTTP/2 request");
379            }
380        }
381        Ok((conn, req))
382    }
383
384    impl<S, C, B> tower::Service<(C, http::Request<B>)> for Http2ChecksService<S>
385    where
386        S: tower::Service<(C, http::Request<B>)>,
387        C: HttpConnectionInfo<B>,
388    {
389        type Response = S::Response;
390
391        type Error = HttpRequestError<S::Error>;
392
393        type Future = self::future::Http2ChecksFuture<S, C, B>;
394
395        #[inline]
396        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
397            self.inner
398                .poll_ready(cx)
399                .map_err(HttpRequestError::Connection)
400        }
401
402        #[inline]
403        fn call(&mut self, req: (C, http::Request<B>)) -> Self::Future {
404            match check_http2_request(req) {
405                Ok(req) => self::future::Http2ChecksFuture::new(self.inner.call(req)),
406                Err(error) => self::future::Http2ChecksFuture::error(error),
407            }
408        }
409    }
410
411    mod future {
412        use std::{
413            fmt,
414            future::Future,
415            pin::Pin,
416            task::{ready, Context, Poll},
417        };
418
419        use super::HttpRequestError;
420        use pin_project::pin_project;
421
422        #[pin_project(project=Http2ChecksStateProject)]
423        enum Http2ChecksState<S, C, B>
424        where
425            S: tower::Service<(C, http::Request<B>)>,
426        {
427            Service(#[pin] S::Future),
428            Error(Option<HttpRequestError<S::Error>>),
429        }
430
431        /// Future returned when applying checks for HTTP/2 connections and requests.
432        #[pin_project]
433        pub struct Http2ChecksFuture<S, C, B>
434        where
435            S: tower::Service<(C, http::Request<B>)>,
436        {
437            #[pin]
438            state: Http2ChecksState<S, C, B>,
439        }
440
441        impl<S, C, B> fmt::Debug for Http2ChecksFuture<S, C, B>
442        where
443            S: tower::Service<(C, http::Request<B>)>,
444        {
445            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446                f.debug_struct("Http2ChecksFuture").finish()
447            }
448        }
449
450        impl<S, C, B> Http2ChecksFuture<S, C, B>
451        where
452            S: tower::Service<(C, http::Request<B>)>,
453        {
454            pub(super) fn new(future: S::Future) -> Self {
455                Self {
456                    state: Http2ChecksState::Service(future),
457                }
458            }
459
460            pub(super) fn error(error: HttpRequestError<S::Error>) -> Self {
461                Self {
462                    state: Http2ChecksState::Error(Some(error)),
463                }
464            }
465        }
466
467        impl<S, C, B> Future for Http2ChecksFuture<S, C, B>
468        where
469            S: tower::Service<(C, http::Request<B>)>,
470        {
471            type Output = Result<S::Response, HttpRequestError<S::Error>>;
472
473            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
474                let this = self.project();
475                match this.state.project() {
476                    Http2ChecksStateProject::Service(future) => {
477                        Poll::Ready(ready!(future.poll(cx)).map_err(HttpRequestError::Connection))
478                    }
479                    Http2ChecksStateProject::Error(error) => Poll::Ready(Err(error
480                        .take()
481                        .expect("Http2ChecksFuture Error polled after completion"))),
482                }
483            }
484        }
485    }
486
487    /// A `Layer` that applies HTTP/2 checks to requests.
488    #[derive(Default, Clone)]
489    pub struct Http2ChecksLayer {
490        _marker: (),
491    }
492
493    impl Http2ChecksLayer {
494        /// Create a new `Http2ChecksLayer`.
495        pub fn new() -> Self {
496            Self { _marker: () }
497        }
498    }
499
500    impl fmt::Debug for Http2ChecksLayer {
501        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
502            f.debug_struct("Http2ChecksLayer").finish()
503        }
504    }
505
506    impl<S> tower::layer::Layer<S> for Http2ChecksLayer {
507        type Service = Http2ChecksService<S>;
508
509        fn layer(&self, inner: S) -> Self::Service {
510            Http2ChecksService::new(inner)
511        }
512    }
513}
514
515#[cfg(test)]
516#[allow(dead_code)]
517mod tests {
518    use super::*;
519    use bytes::Bytes;
520    use http_body_util::Empty;
521    use std::{convert::Infallible, future::Ready};
522
523    struct Svc;
524
525    impl tower::Service<http::Request<Empty<Bytes>>> for Svc {
526        type Response = http::Response<Empty<Bytes>>;
527        type Error = Infallible;
528        type Future = Ready<Result<Self::Response, Self::Error>>;
529
530        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
531            Poll::Ready(Ok(()))
532        }
533
534        fn call(&mut self, req: http::Request<Empty<Bytes>>) -> Self::Future {
535            assert_eq!(req.version(), http::Version::HTTP_11);
536            std::future::ready(Ok(http::Response::new(Empty::new())))
537        }
538    }
539
540    static_assertions::assert_impl_all!(Svc: HttpService<Empty<Bytes>, ResBody=Empty<Bytes>, Error=Infallible>);
541
542    struct NotASvc;
543
544    impl tower::Service<http::Request<()>> for Svc {
545        type Response = http::Response<()>;
546        type Error = Infallible;
547        type Future = Ready<Result<Self::Response, Self::Error>>;
548
549        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
550            Poll::Ready(Ok(()))
551        }
552
553        fn call(&mut self, req: http::Request<()>) -> Self::Future {
554            assert_eq!(req.version(), http::Version::HTTP_11);
555            std::future::ready(Ok(http::Response::new(())))
556        }
557    }
558
559    static_assertions::assert_not_impl_all!(NotASvc: HttpService<(), ResBody=(), Error=Infallible>);
560}