1use std::future::Future;
2use std::ops::Deref as _;
3use std::task::{Context, Poll};
4
5use http::{Request, Response};
6use http_body::Body as HttpBody;
7
8pub use crate::info::HttpProtocol;
9use crate::BoxError;
10use chateau::client::conn::Connection;
11use chateau::client::pool::{PoolableConnection, Pooled};
12
13pub trait HttpService<ReqBody> {
15 type ResBody: HttpBody;
17
18 type Error: Into<BoxError>;
24
25 type Future: Future<Output = Result<Response<Self::ResBody>, Self::Error>>;
27
28 #[doc(hidden)]
29 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
30
31 #[doc(hidden)]
32 fn call(&mut self, req: Request<ReqBody>) -> Self::Future;
33}
34
35impl<T, BIn, BOut> HttpService<BIn> for T
36where
37 T: tower::Service<Request<BIn>, Response = Response<BOut>>,
38 BOut: HttpBody,
39 T::Error: Into<BoxError>,
40{
41 type ResBody = BOut;
42
43 type Error = T::Error;
44 type Future = T::Future;
45
46 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
47 tower::Service::poll_ready(self, cx)
48 }
49
50 fn call(&mut self, req: Request<BIn>) -> Self::Future {
51 tower::Service::call(self, req)
52 }
53}
54
55pub trait HttpConnectionInfo<B>: Connection<http::Request<B>> {
57 fn version(&self) -> HttpProtocol;
59}
60
61impl<C, B> HttpConnectionInfo<B> for Pooled<C, http::Request<B>>
62where
63 C: HttpConnectionInfo<B> + PoolableConnection<http::Request<B>>,
64 B: Send,
65{
66 fn version(&self) -> HttpProtocol {
67 self.deref().version()
68 }
69}
70
71#[cfg(feature = "client")]
72pub(super) mod http1 {
73
74 use std::fmt;
75 use std::task::{Context, Poll};
76
77 use ::http;
78 use http::uri::Scheme;
79 use http::Uri;
80 use tower::util::MapRequest;
81 use tower::ServiceExt;
82
83 use crate::service::http::HttpProtocol;
84
85 use super::HttpConnectionInfo;
86
87 type PreprocessFn<C, B> = fn((C, http::Request<B>)) -> (C, http::Request<B>);
88
89 #[derive(Debug)]
91 pub struct Http1ChecksService<S, C, B>
92 where
93 S: tower::Service<(C, http::Request<B>)>,
94 C: HttpConnectionInfo<B>,
95 {
96 inner: MapRequest<S, PreprocessFn<C, B>>,
97 }
98
99 impl<S, C, B> tower::Service<(C, http::Request<B>)> for Http1ChecksService<S, C, B>
100 where
101 S: tower::Service<(C, http::Request<B>)>,
102 C: HttpConnectionInfo<B>,
103 {
104 type Response = S::Response;
105
106 type Error = S::Error;
107
108 type Future = S::Future;
109
110 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
111 self.inner.poll_ready(cx)
112 }
113
114 fn call(&mut self, req: (C, http::Request<B>)) -> Self::Future {
115 self.inner.call(req)
116 }
117 }
118
119 impl<S, C, B> Clone for Http1ChecksService<S, C, B>
120 where
121 S: tower::Service<(C, http::Request<B>)> + Clone,
122 C: HttpConnectionInfo<B>,
123 {
124 fn clone(&self) -> Self {
125 Self {
126 inner: self.inner.clone(),
127 }
128 }
129 }
130
131 impl<S, C, B> Http1ChecksService<S, C, B>
132 where
133 S: tower::Service<(C, http::Request<B>)>,
134 C: HttpConnectionInfo<B>,
135 {
136 pub fn new(service: S) -> Self {
138 Self {
139 inner: service.map_request(check_http1_request),
140 }
141 }
142 }
143
144 pub struct Http1ChecksLayer<C, B> {
146 processor: std::marker::PhantomData<fn(C, B)>,
147 }
148
149 impl<C, B> Http1ChecksLayer<C, B> {
150 pub fn new() -> Self {
152 Self {
153 processor: std::marker::PhantomData,
154 }
155 }
156 }
157
158 impl<C, B> Default for Http1ChecksLayer<C, B> {
159 fn default() -> Self {
160 Self::new()
161 }
162 }
163
164 impl<C, B> Clone for Http1ChecksLayer<C, B> {
165 fn clone(&self) -> Self {
166 Self::new()
167 }
168 }
169
170 impl<C, B> fmt::Debug for Http1ChecksLayer<C, B> {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 f.debug_struct("Http1ChecksLayer").finish()
173 }
174 }
175
176 impl<C, B, S> tower::layer::Layer<S> for Http1ChecksLayer<C, B>
177 where
178 S: tower::Service<(C, http::Request<B>)>,
179 C: HttpConnectionInfo<B>,
180 {
181 type Service = Http1ChecksService<S, C, B>;
182
183 fn layer(&self, service: S) -> Self::Service {
184 Http1ChecksService::new(service)
185 }
186 }
187
188 fn check_http1_request<C, B>((conn, mut req): (C, http::Request<B>)) -> (C, http::Request<B>)
189 where
190 C: HttpConnectionInfo<B>,
191 {
192 if conn.version() != HttpProtocol::Http1 {
193 return (conn, req);
194 }
195
196 if req.method() == http::Method::CONNECT {
197 authority_form(req.uri_mut());
198
199 if req.uri().scheme() == Some(&Scheme::HTTPS) {
203 origin_form(req.uri_mut());
204 }
205 } else if req.uri().scheme().is_none() || req.uri().authority().is_none() {
206 absolute_form(req.uri_mut());
207 } else {
208 origin_form(req.uri_mut());
209 }
210
211 (conn, req)
212 }
213
214 fn authority_form(uri: &mut Uri) {
219 *uri = match uri.authority() {
220 Some(auth) => {
221 let mut parts = ::http::uri::Parts::default();
222 parts.authority = Some(auth.clone());
223 Uri::from_parts(parts).expect("authority is valid")
224 }
225 None => {
226 unreachable!("authority_form with relative uri");
227 }
228 };
229 }
230
231 fn absolute_form(uri: &mut Uri) {
232 debug_assert!(uri.scheme().is_some(), "absolute_form needs a scheme");
233 debug_assert!(
234 uri.authority().is_some(),
235 "absolute_form needs an authority"
236 );
237 }
238
239 fn origin_form(uri: &mut Uri) {
244 let path = match uri.path_and_query() {
245 Some(path) if path.as_str() != "/" => {
246 let mut parts = ::http::uri::Parts::default();
247 parts.path_and_query = Some(path.clone());
248 Uri::from_parts(parts).expect("path is valid uri")
249 }
250 _none_or_just_slash => {
251 debug_assert!(Uri::default() == "/");
252 Uri::default()
253 }
254 };
255 *uri = path
256 }
257
258 #[cfg(test)]
259 mod tests {
260
261 use super::*;
262
263 #[test]
264 fn test_origin_form() {
265 let mut uri = "http://example.com".parse().unwrap();
266 origin_form(&mut uri);
267 assert_eq!(uri, "/");
268
269 let mut uri = "/some/path/here".parse().unwrap();
270 origin_form(&mut uri);
271 assert_eq!(uri, "/some/path/here");
272
273 let mut uri = "http://example.com:8080/some/path?query#fragment"
274 .parse()
275 .unwrap();
276 origin_form(&mut uri);
277 assert_eq!(uri, "/some/path?query");
278
279 let mut uri = "/".parse().unwrap();
280 origin_form(&mut uri);
281 assert_eq!(uri, "/");
282 }
283
284 #[test]
285 fn test_absolute_form() {
286 let mut uri = "http://example.com".parse().unwrap();
287 absolute_form(&mut uri);
288 assert_eq!(uri, "http://example.com");
289
290 let mut uri = "http://example.com:8080".parse().unwrap();
291 absolute_form(&mut uri);
292 assert_eq!(uri, "http://example.com:8080");
293
294 let mut uri = "https://example.com/some/path?query".parse().unwrap();
295 absolute_form(&mut uri);
296 assert_eq!(uri, "https://example.com/some/path?query");
297
298 let mut uri = "https://example.com:8443".parse().unwrap();
299 absolute_form(&mut uri);
300 assert_eq!(uri, "https://example.com:8443");
301
302 let mut uri = "http://example.com:443".parse().unwrap();
303 absolute_form(&mut uri);
304 assert_eq!(uri, "http://example.com:443");
305
306 let mut uri = "https://example.com:80".parse().unwrap();
307 absolute_form(&mut uri);
308 assert_eq!(uri, "https://example.com:80");
309 }
310 }
311}
312
313#[cfg(feature = "client")]
314pub(super) mod http2 {
315 use std::fmt;
316 use std::task::{Context, Poll};
317
318 use ::http;
319
320 use crate::service::http::HttpProtocol;
321
322 use super::HttpConnectionInfo;
323
324 const CONNECTION_HEADERS: [http::HeaderName; 5] = [
325 http::header::CONNECTION,
326 http::HeaderName::from_static("proxy-connection"),
327 http::HeaderName::from_static("keep-alive"),
328 http::header::TRANSFER_ENCODING,
329 http::header::UPGRADE,
330 ];
331
332 #[derive(Debug, thiserror::Error)]
333 pub enum HttpRequestError<E> {
334 #[error("Invalid HTTP method for HTTP/2: {0}")]
335 InvalidMethod(http::Method),
336
337 #[error(transparent)]
338 Connection(E),
339 }
340
341 #[derive(Debug, Clone)]
343 pub struct Http2ChecksService<S> {
344 inner: S,
345 }
346
347 impl<S> Http2ChecksService<S> {
348 pub fn new(inner: S) -> Self {
350 Self { inner }
351 }
352 }
353
354 fn check_http2_request<C, B, E>(
355 (conn, mut req): (C, http::Request<B>),
356 ) -> Result<(C, http::Request<B>), HttpRequestError<E>>
357 where
358 C: HttpConnectionInfo<B>,
359 {
360 if conn.version() == HttpProtocol::Http2 {
361 if req.method() == http::Method::CONNECT {
362 tracing::warn!("CONNECT method not allowed on HTTP/2");
363 return Err(HttpRequestError::InvalidMethod(http::Method::CONNECT));
364 }
365
366 *req.version_mut() = http::Version::HTTP_2;
367
368 for connection_header in &CONNECTION_HEADERS {
369 if req.headers_mut().remove(connection_header).is_some() {
370 tracing::warn!(
371 "removed illegal connection header {:?} from HTTP/2 request",
372 connection_header
373 );
374 };
375 }
376
377 if req.headers_mut().remove(http::header::HOST).is_some() {
378 tracing::warn!("removed illegal header `host` from HTTP/2 request");
379 }
380 }
381 Ok((conn, req))
382 }
383
384 impl<S, C, B> tower::Service<(C, http::Request<B>)> for Http2ChecksService<S>
385 where
386 S: tower::Service<(C, http::Request<B>)>,
387 C: HttpConnectionInfo<B>,
388 {
389 type Response = S::Response;
390
391 type Error = HttpRequestError<S::Error>;
392
393 type Future = self::future::Http2ChecksFuture<S, C, B>;
394
395 #[inline]
396 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
397 self.inner
398 .poll_ready(cx)
399 .map_err(HttpRequestError::Connection)
400 }
401
402 #[inline]
403 fn call(&mut self, req: (C, http::Request<B>)) -> Self::Future {
404 match check_http2_request(req) {
405 Ok(req) => self::future::Http2ChecksFuture::new(self.inner.call(req)),
406 Err(error) => self::future::Http2ChecksFuture::error(error),
407 }
408 }
409 }
410
411 mod future {
412 use std::{
413 fmt,
414 future::Future,
415 pin::Pin,
416 task::{ready, Context, Poll},
417 };
418
419 use super::HttpRequestError;
420 use pin_project::pin_project;
421
422 #[pin_project(project=Http2ChecksStateProject)]
423 enum Http2ChecksState<S, C, B>
424 where
425 S: tower::Service<(C, http::Request<B>)>,
426 {
427 Service(#[pin] S::Future),
428 Error(Option<HttpRequestError<S::Error>>),
429 }
430
431 #[pin_project]
433 pub struct Http2ChecksFuture<S, C, B>
434 where
435 S: tower::Service<(C, http::Request<B>)>,
436 {
437 #[pin]
438 state: Http2ChecksState<S, C, B>,
439 }
440
441 impl<S, C, B> fmt::Debug for Http2ChecksFuture<S, C, B>
442 where
443 S: tower::Service<(C, http::Request<B>)>,
444 {
445 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446 f.debug_struct("Http2ChecksFuture").finish()
447 }
448 }
449
450 impl<S, C, B> Http2ChecksFuture<S, C, B>
451 where
452 S: tower::Service<(C, http::Request<B>)>,
453 {
454 pub(super) fn new(future: S::Future) -> Self {
455 Self {
456 state: Http2ChecksState::Service(future),
457 }
458 }
459
460 pub(super) fn error(error: HttpRequestError<S::Error>) -> Self {
461 Self {
462 state: Http2ChecksState::Error(Some(error)),
463 }
464 }
465 }
466
467 impl<S, C, B> Future for Http2ChecksFuture<S, C, B>
468 where
469 S: tower::Service<(C, http::Request<B>)>,
470 {
471 type Output = Result<S::Response, HttpRequestError<S::Error>>;
472
473 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
474 let this = self.project();
475 match this.state.project() {
476 Http2ChecksStateProject::Service(future) => {
477 Poll::Ready(ready!(future.poll(cx)).map_err(HttpRequestError::Connection))
478 }
479 Http2ChecksStateProject::Error(error) => Poll::Ready(Err(error
480 .take()
481 .expect("Http2ChecksFuture Error polled after completion"))),
482 }
483 }
484 }
485 }
486
487 #[derive(Default, Clone)]
489 pub struct Http2ChecksLayer {
490 _marker: (),
491 }
492
493 impl Http2ChecksLayer {
494 pub fn new() -> Self {
496 Self { _marker: () }
497 }
498 }
499
500 impl fmt::Debug for Http2ChecksLayer {
501 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
502 f.debug_struct("Http2ChecksLayer").finish()
503 }
504 }
505
506 impl<S> tower::layer::Layer<S> for Http2ChecksLayer {
507 type Service = Http2ChecksService<S>;
508
509 fn layer(&self, inner: S) -> Self::Service {
510 Http2ChecksService::new(inner)
511 }
512 }
513}
514
515#[cfg(test)]
516#[allow(dead_code)]
517mod tests {
518 use super::*;
519 use bytes::Bytes;
520 use http_body_util::Empty;
521 use std::{convert::Infallible, future::Ready};
522
523 struct Svc;
524
525 impl tower::Service<http::Request<Empty<Bytes>>> for Svc {
526 type Response = http::Response<Empty<Bytes>>;
527 type Error = Infallible;
528 type Future = Ready<Result<Self::Response, Self::Error>>;
529
530 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
531 Poll::Ready(Ok(()))
532 }
533
534 fn call(&mut self, req: http::Request<Empty<Bytes>>) -> Self::Future {
535 assert_eq!(req.version(), http::Version::HTTP_11);
536 std::future::ready(Ok(http::Response::new(Empty::new())))
537 }
538 }
539
540 static_assertions::assert_impl_all!(Svc: HttpService<Empty<Bytes>, ResBody=Empty<Bytes>, Error=Infallible>);
541
542 struct NotASvc;
543
544 impl tower::Service<http::Request<()>> for Svc {
545 type Response = http::Response<()>;
546 type Error = Infallible;
547 type Future = Ready<Result<Self::Response, Self::Error>>;
548
549 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
550 Poll::Ready(Ok(()))
551 }
552
553 fn call(&mut self, req: http::Request<()>) -> Self::Future {
554 assert_eq!(req.version(), http::Version::HTTP_11);
555 std::future::ready(Ok(http::Response::new(())))
556 }
557 }
558
559 static_assertions::assert_not_impl_all!(NotASvc: HttpService<(), ResBody=(), Error=Infallible>);
560}