Skip to main content

noxy/
http.rs

1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use http::{Request, Response, Uri};
8use http_body_util::BodyExt;
9use hyper::body::Incoming;
10use hyper_util::client::legacy::connect::Connection;
11use hyper_util::rt::TokioIo;
12use rustls::pki_types::ServerName;
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14use tokio::net::TcpStream;
15use tokio_rustls::TlsConnector;
16use tower::Service;
17
18pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
19pub type Body = http_body_util::combinators::BoxBody<Bytes, BoxError>;
20pub type HttpService = tower::util::BoxService<Request<Body>, Response<Body>, BoxError>;
21
22pub(crate) type UpstreamClient = hyper_util::client::legacy::Client<UpstreamConnector, Body>;
23pub(crate) type UpstreamScheme = ::http::uri::Scheme;
24
25/// Request extension that overrides the default upstream for a single request.
26/// Set by the [`Router`](crate::middleware::Router) middleware; read by
27/// [`ForwardService`] before forwarding.
28#[derive(Clone, Debug)]
29pub struct UpstreamTarget {
30    pub authority: ::http::uri::Authority,
31    pub scheme: ::http::uri::Scheme,
32}
33
34pub fn full_body(data: impl Into<Bytes>) -> Body {
35    http_body_util::Full::new(data.into())
36        .map_err(|e| match e {})
37        .boxed()
38}
39
40pub fn empty_body() -> Body {
41    http_body_util::Empty::new().map_err(|e| match e {}).boxed()
42}
43
44/// Convert a hyper `Incoming` body into our boxed body type.
45pub(crate) fn incoming_to_body(incoming: Incoming) -> Body {
46    incoming.map_err(|e| -> BoxError { Box::new(e) }).boxed()
47}
48
49/// Upstream I/O wrapper supporting both TLS and plain TCP connections.
50pub(crate) enum UpstreamIo {
51    Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
52    Plain(TcpStream),
53}
54
55impl Connection for UpstreamIo {
56    fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
57        match self {
58            UpstreamIo::Tls(tls) => {
59                let mut connected = hyper_util::client::legacy::connect::Connected::new();
60                if tls.get_ref().1.alpn_protocol() == Some(b"h2") {
61                    connected = connected.negotiated_h2();
62                }
63                connected
64            }
65            UpstreamIo::Plain(_) => hyper_util::client::legacy::connect::Connected::new(),
66        }
67    }
68}
69
70impl AsyncRead for UpstreamIo {
71    fn poll_read(
72        self: Pin<&mut Self>,
73        cx: &mut Context<'_>,
74        buf: &mut ReadBuf<'_>,
75    ) -> Poll<io::Result<()>> {
76        match self.get_mut() {
77            UpstreamIo::Tls(s) => Pin::new(s).poll_read(cx, buf),
78            UpstreamIo::Plain(s) => Pin::new(s).poll_read(cx, buf),
79        }
80    }
81}
82
83impl AsyncWrite for UpstreamIo {
84    fn poll_write(
85        self: Pin<&mut Self>,
86        cx: &mut Context<'_>,
87        buf: &[u8],
88    ) -> Poll<io::Result<usize>> {
89        match self.get_mut() {
90            UpstreamIo::Tls(s) => Pin::new(s).poll_write(cx, buf),
91            UpstreamIo::Plain(s) => Pin::new(s).poll_write(cx, buf),
92        }
93    }
94
95    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
96        match self.get_mut() {
97            UpstreamIo::Tls(s) => Pin::new(s).poll_flush(cx),
98            UpstreamIo::Plain(s) => Pin::new(s).poll_flush(cx),
99        }
100    }
101
102    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
103        match self.get_mut() {
104            UpstreamIo::Tls(s) => Pin::new(s).poll_shutdown(cx),
105            UpstreamIo::Plain(s) => Pin::new(s).poll_shutdown(cx),
106        }
107    }
108}
109
110/// Connector that establishes TLS connections to upstream hosts.
111#[derive(Clone)]
112pub(crate) struct UpstreamConnector {
113    pub tls: TlsConnector,
114}
115
116impl Service<Uri> for UpstreamConnector {
117    type Response = TokioIo<UpstreamIo>;
118    type Error = BoxError;
119    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
120
121    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122        Poll::Ready(Ok(()))
123    }
124
125    fn call(&mut self, uri: Uri) -> Self::Future {
126        let tls = self.tls.clone();
127        let is_plain = uri.scheme_str() == Some("http");
128        Box::pin(async move {
129            let host = uri.host().ok_or("missing host in URI")?;
130            let default_port = if is_plain { 80 } else { 443 };
131            let port = uri.port_u16().unwrap_or(default_port);
132            let tcp = TcpStream::connect((host, port)).await?;
133            if is_plain {
134                Ok(TokioIo::new(UpstreamIo::Plain(tcp)))
135            } else {
136                let server_name: ServerName<'static> = host.to_string().try_into()?;
137                let tls_stream = tls.connect(server_name, tcp).await?;
138                Ok(TokioIo::new(UpstreamIo::Tls(Box::new(tls_stream))))
139            }
140        })
141    }
142}
143
144/// Tower service that forwards requests to upstream via hyper-util's pooled client.
145pub(crate) struct ForwardService {
146    client: UpstreamClient,
147    authority: ::http::uri::Authority,
148    scheme: UpstreamScheme,
149}
150
151impl ForwardService {
152    pub(crate) fn new(
153        client: UpstreamClient,
154        authority: ::http::uri::Authority,
155        scheme: UpstreamScheme,
156    ) -> Self {
157        Self {
158            client,
159            authority,
160            scheme,
161        }
162    }
163}
164
165impl Service<Request<Body>> for ForwardService {
166    type Response = Response<Body>;
167    type Error = BoxError;
168    type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, BoxError>> + Send>>;
169
170    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
171        Poll::Ready(Ok(()))
172    }
173
174    fn call(&mut self, mut req: Request<Body>) -> Self::Future {
175        let (authority, scheme) = req
176            .extensions()
177            .get::<UpstreamTarget>()
178            .map(|t| (t.authority.clone(), t.scheme.clone()))
179            .unwrap_or_else(|| (self.authority.clone(), self.scheme.clone()));
180
181        let mut parts = req.uri().clone().into_parts();
182        parts.scheme = Some(scheme);
183        parts.authority = Some(authority);
184        if let Ok(uri) = ::http::Uri::from_parts(parts) {
185            *req.uri_mut() = uri;
186        }
187
188        let fut = self.client.request(req);
189        Box::pin(async move {
190            let resp = fut.await?;
191            Ok(resp.map(incoming_to_body))
192        })
193    }
194}