1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use http::{Request, Response, Uri};
8use http_body_util::BodyExt;
9use hyper::body::Incoming;
10use hyper_util::client::legacy::connect::Connection;
11use hyper_util::rt::TokioIo;
12use rustls::pki_types::ServerName;
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14use tokio::net::TcpStream;
15use tokio_rustls::TlsConnector;
16use tower::Service;
17
18pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
19pub type Body = http_body_util::combinators::BoxBody<Bytes, BoxError>;
20pub type HttpService = tower::util::BoxService<Request<Body>, Response<Body>, BoxError>;
21
22pub(crate) type UpstreamClient = hyper_util::client::legacy::Client<UpstreamConnector, Body>;
23pub(crate) type UpstreamScheme = ::http::uri::Scheme;
24
25#[derive(Clone, Debug)]
29pub struct UpstreamTarget {
30 pub authority: ::http::uri::Authority,
31 pub scheme: ::http::uri::Scheme,
32}
33
34pub fn full_body(data: impl Into<Bytes>) -> Body {
35 http_body_util::Full::new(data.into())
36 .map_err(|e| match e {})
37 .boxed()
38}
39
40pub fn empty_body() -> Body {
41 http_body_util::Empty::new().map_err(|e| match e {}).boxed()
42}
43
44pub(crate) fn incoming_to_body(incoming: Incoming) -> Body {
46 incoming.map_err(|e| -> BoxError { Box::new(e) }).boxed()
47}
48
49pub(crate) enum UpstreamIo {
51 Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
52 Plain(TcpStream),
53}
54
55impl Connection for UpstreamIo {
56 fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
57 match self {
58 UpstreamIo::Tls(tls) => {
59 let mut connected = hyper_util::client::legacy::connect::Connected::new();
60 if tls.get_ref().1.alpn_protocol() == Some(b"h2") {
61 connected = connected.negotiated_h2();
62 }
63 connected
64 }
65 UpstreamIo::Plain(_) => hyper_util::client::legacy::connect::Connected::new(),
66 }
67 }
68}
69
70impl AsyncRead for UpstreamIo {
71 fn poll_read(
72 self: Pin<&mut Self>,
73 cx: &mut Context<'_>,
74 buf: &mut ReadBuf<'_>,
75 ) -> Poll<io::Result<()>> {
76 match self.get_mut() {
77 UpstreamIo::Tls(s) => Pin::new(s).poll_read(cx, buf),
78 UpstreamIo::Plain(s) => Pin::new(s).poll_read(cx, buf),
79 }
80 }
81}
82
83impl AsyncWrite for UpstreamIo {
84 fn poll_write(
85 self: Pin<&mut Self>,
86 cx: &mut Context<'_>,
87 buf: &[u8],
88 ) -> Poll<io::Result<usize>> {
89 match self.get_mut() {
90 UpstreamIo::Tls(s) => Pin::new(s).poll_write(cx, buf),
91 UpstreamIo::Plain(s) => Pin::new(s).poll_write(cx, buf),
92 }
93 }
94
95 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
96 match self.get_mut() {
97 UpstreamIo::Tls(s) => Pin::new(s).poll_flush(cx),
98 UpstreamIo::Plain(s) => Pin::new(s).poll_flush(cx),
99 }
100 }
101
102 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
103 match self.get_mut() {
104 UpstreamIo::Tls(s) => Pin::new(s).poll_shutdown(cx),
105 UpstreamIo::Plain(s) => Pin::new(s).poll_shutdown(cx),
106 }
107 }
108}
109
110#[derive(Clone)]
112pub(crate) struct UpstreamConnector {
113 pub tls: TlsConnector,
114}
115
116impl Service<Uri> for UpstreamConnector {
117 type Response = TokioIo<UpstreamIo>;
118 type Error = BoxError;
119 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
120
121 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122 Poll::Ready(Ok(()))
123 }
124
125 fn call(&mut self, uri: Uri) -> Self::Future {
126 let tls = self.tls.clone();
127 let is_plain = uri.scheme_str() == Some("http");
128 Box::pin(async move {
129 let host = uri.host().ok_or("missing host in URI")?;
130 let default_port = if is_plain { 80 } else { 443 };
131 let port = uri.port_u16().unwrap_or(default_port);
132 let tcp = TcpStream::connect((host, port)).await?;
133 if is_plain {
134 Ok(TokioIo::new(UpstreamIo::Plain(tcp)))
135 } else {
136 let server_name: ServerName<'static> = host.to_string().try_into()?;
137 let tls_stream = tls.connect(server_name, tcp).await?;
138 Ok(TokioIo::new(UpstreamIo::Tls(Box::new(tls_stream))))
139 }
140 })
141 }
142}
143
144pub(crate) struct ForwardService {
146 client: UpstreamClient,
147 authority: ::http::uri::Authority,
148 scheme: UpstreamScheme,
149}
150
151impl ForwardService {
152 pub(crate) fn new(
153 client: UpstreamClient,
154 authority: ::http::uri::Authority,
155 scheme: UpstreamScheme,
156 ) -> Self {
157 Self {
158 client,
159 authority,
160 scheme,
161 }
162 }
163}
164
165impl Service<Request<Body>> for ForwardService {
166 type Response = Response<Body>;
167 type Error = BoxError;
168 type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, BoxError>> + Send>>;
169
170 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
171 Poll::Ready(Ok(()))
172 }
173
174 fn call(&mut self, mut req: Request<Body>) -> Self::Future {
175 let (authority, scheme) = req
176 .extensions()
177 .get::<UpstreamTarget>()
178 .map(|t| (t.authority.clone(), t.scheme.clone()))
179 .unwrap_or_else(|| (self.authority.clone(), self.scheme.clone()));
180
181 let mut parts = req.uri().clone().into_parts();
182 parts.scheme = Some(scheme);
183 parts.authority = Some(authority);
184 if let Ok(uri) = ::http::Uri::from_parts(parts) {
185 *req.uri_mut() = uri;
186 }
187
188 let fut = self.client.request(req);
189 Box::pin(async move {
190 let resp = fut.await?;
191 Ok(resp.map(incoming_to_body))
192 })
193 }
194}