monolake_services/http/handlers/
upstream.rs1use std::{
50 convert::Infallible,
51 net::{SocketAddr, ToSocketAddrs},
52 time::Duration,
53};
54
55use bytes::Bytes;
56use http::{header, HeaderMap, HeaderValue, Request, StatusCode};
57use monoio::net::TcpStream;
58use monoio_http::common::{
59 body::{Body, HttpBody},
60 error::HttpError,
61};
62#[cfg(feature = "tls")]
63use monoio_transports::connectors::{TlsConnector, TlsStream};
64use monoio_transports::{
65 connectors::{Connector, TcpConnector},
66 http::{HttpConnection, HttpConnector},
67};
68use monolake_core::{
69 context::{PeerAddr, RemoteAddr},
70 http::ResponseWithContinue,
71 listener::AcceptedAddr,
72};
73use service_async::{AsyncMakeService, MakeService, ParamMaybeRef, ParamRef, Service};
74use tracing::{debug, info};
75
76use crate::http::{generate_response, HttpVersion};
77
78type PooledHttpConnector = HttpConnector<TcpConnector, SocketAddr, TcpStream>;
79#[cfg(feature = "tls")]
80type PooledHttpsConnector = HttpConnector<
81 TlsConnector<TcpConnector>,
82 monoio_transports::connectors::TcpTlsAddr,
83 TlsStream<TcpStream>,
84>;
85
86#[derive(Default)]
95pub struct UpstreamHandler {
96 http_connector: PooledHttpConnector,
97 #[cfg(feature = "tls")]
98 https_connector: PooledHttpsConnector,
99 pub http_upstream_timeout: HttpUpstreamTimeout,
100}
101
102impl UpstreamHandler {
103 #[cfg(not(feature = "tls"))]
104 pub fn new(http_connector: PooledHttpConnector, http_upstream_timeout: HttpUpstreamTimeout) -> Self {
105 UpstreamHandler {
106 http_connector,
107 http_upstream_timeout,
108 }
109 }
110
111 #[cfg(feature = "tls")]
112 pub fn new(
113 connector: PooledHttpConnector,
114 tls_connector: PooledHttpsConnector,
115 http_upstream_timeout: HttpUpstreamTimeout,
116 ) -> Self {
117 UpstreamHandler {
118 http_connector: connector,
119 https_connector: tls_connector,
120 http_upstream_timeout,
121 }
122 }
123
124 pub const fn factory(
125 http_upstream_timeout: HttpUpstreamTimeout,
126 version: HttpVersion,
127 ) -> UpstreamHandlerFactory {
128 UpstreamHandlerFactory {
129 http_upstream_timeout,
130 version,
131 }
132 }
133}
134
135impl<CX, B> Service<(Request<B>, CX)> for UpstreamHandler
136where
137 CX: ParamRef<PeerAddr> + ParamMaybeRef<Option<RemoteAddr>>,
138 B: Body<Data = Bytes, Error = HttpError>,
140 HttpError: From<B::Error>,
141{
142 type Response = ResponseWithContinue<HttpBody>;
143 type Error = Infallible;
144
145 async fn call(&self, (mut req, ctx): (Request<B>, CX)) -> Result<Self::Response, Self::Error> {
146 add_xff_header(req.headers_mut(), &ctx);
147 #[cfg(feature = "tls")]
148 if req.uri().scheme() == Some(&http::uri::Scheme::HTTPS) {
149 return self.send_https_request(req).await;
150 }
151 self.send_http_request(req).await
152 }
153}
154
155impl UpstreamHandler {
156 async fn send_http_request<B>(
157 &self,
158 mut req: Request<B>,
159 ) -> Result<ResponseWithContinue<HttpBody>, Infallible>
160 where
161 B: Body<Data = Bytes, Error = HttpError>,
162 HttpError: From<B::Error>,
163 {
164 let Some(host) = req.uri().host() else {
165 info!("invalid uri which does not contain host: {:?}", req.uri());
166 return Ok((generate_response(StatusCode::BAD_REQUEST, true), true));
167 };
168 let port = req.uri().port_u16().unwrap_or(80);
169 let mut iter = match (host, port).to_socket_addrs() {
170 Ok(iter) => iter,
171 Err(e) => {
172 info!("convert invalid uri: {:?} with error: {:?}", req.uri(), e);
173 return Ok((generate_response(StatusCode::BAD_REQUEST, true), true));
174 }
175 };
176 let Some(key) = iter.next() else {
177 info!("unable to resolve host: {host}");
178 return Ok((generate_response(StatusCode::BAD_REQUEST, true), true));
179 };
180 debug!("key: {:?}", key);
181 let mut conn = match self.http_connector.connect(key).await {
182 Ok(conn) => {
183 match &conn {
184 HttpConnection::Http1(_) => {
185 *req.version_mut() = http::Version::HTTP_11;
186 }
187 HttpConnection::Http2(_) => {
188 *req.version_mut() = http::Version::HTTP_2;
189 req.headers_mut().remove(http::header::HOST);
190 }
191 }
192 conn
193 }
194 Err(e) => {
195 info!("connect upstream error: {:?}", e);
196 return Ok((generate_response(StatusCode::BAD_GATEWAY, true), true));
197 }
198 };
199
200 match conn.send_request(req).await {
201 (Ok(resp), _) => Ok((resp, true)),
202 (Err(_e), _) => Ok((generate_response(StatusCode::BAD_GATEWAY, false), true)),
205 }
206 }
207
208 #[cfg(feature = "tls")]
209 async fn send_https_request<B>(
210 &self,
211 req: Request<B>,
212 ) -> Result<ResponseWithContinue<HttpBody>, Infallible>
213 where
214 B: Body<Data = Bytes, Error = HttpError>,
215 HttpError: From<B::Error>,
216 {
217 let key = match req.uri().try_into() {
218 Ok(key) => key,
219 Err(e) => {
220 info!("convert invalid uri: {:?} with error: {:?}", req.uri(), e);
221 return Ok((generate_response(StatusCode::BAD_REQUEST, true), true));
222 }
223 };
224 debug!("key: {:?}", key);
225 let connect = match self.http_upstream_timeout.connect_timeout {
226 Some(connect_timeout) => {
227 match monoio::time::timeout(connect_timeout, self.https_connector.connect(key))
228 .await
229 {
230 Ok(x) => x,
231 Err(_) => {
232 info!("connect upstream timeout");
233 return Ok((generate_response(StatusCode::BAD_GATEWAY, true), true));
234 }
235 }
236 }
237 None => self.https_connector.connect(key).await,
238 };
239
240 let mut conn = match connect {
241 Ok(conn) => conn,
242 Err(e) => {
243 info!("connect upstream error: {:?}", e);
244 return Ok((generate_response(StatusCode::BAD_GATEWAY, true), true));
245 }
246 };
247
248 match conn.send_request(req).await {
249 (Ok(resp), _) => Ok((resp, true)),
250 (Err(_e), _) => Ok((generate_response(StatusCode::BAD_GATEWAY, false), true)),
253 }
254 }
255}
256
257pub struct UpstreamHandlerFactory {
258 http_upstream_timeout: HttpUpstreamTimeout,
259 version: HttpVersion,
260}
261
262impl UpstreamHandlerFactory {
263 pub fn new(
264 http_upstream_timeout: HttpUpstreamTimeout,
265 version: HttpVersion,
266 ) -> UpstreamHandlerFactory {
267 UpstreamHandlerFactory {
268 http_upstream_timeout,
269 version,
270 }
271 }
272}
273
274macro_rules! create_connectors {
275 ($self:ident, $http_connector:ident, $https_connector:ident, $old_service:ident) => {
276 let mut $http_connector = match $self.version {
277 HttpVersion::Http2 => PooledHttpConnector::build_tcp_http2_only(),
278 HttpVersion::Http11 => {
279 PooledHttpConnector::build_tcp_http1_only()
281 }
282 HttpVersion::Auto => {
283 PooledHttpConnector::default()
285 }
286 };
287 $http_connector.set_read_timeout($self.http_upstream_timeout.read_timeout);
288
289 #[cfg(feature = "tls")]
290 let mut $https_connector = match $self.version {
291 HttpVersion::Http2 => {
292 PooledHttpsConnector::build_tls_http2_only()
294 }
295 HttpVersion::Http11 => {
296 PooledHttpsConnector::build_tls_http1_only()
298 }
299 HttpVersion::Auto => {
300 PooledHttpsConnector::default()
302 }
303 };
304 #[cfg(feature = "tls")]
305 $https_connector.set_read_timeout($self.http_upstream_timeout.read_timeout);
306
307 if let Some($old_service) = $old_service {
310 match PooledHttpConnector::transfer_pool(
312 &$old_service.http_connector,
313 &mut $http_connector,
314 ) {
315 Ok(_) => tracing::trace!("Transferred HTTP pool from old service to new service"),
316 Err(e) => {
317 tracing::error!("Failed to transfer pool: {:?}", e);
318 }
319 }
320 #[cfg(feature = "tls")]
321 match PooledHttpsConnector::transfer_pool(
322 &$old_service.https_connector,
323 &mut $https_connector,
324 ) {
325 Ok(_) => tracing::trace!("Transferred HTTPS pool from old service to new service"),
326 Err(e) => {
327 tracing::error!("Failed to transfer pool: {:?}", e);
328 }
329 }
330 }
331 };
332}
333impl MakeService for UpstreamHandlerFactory {
335 type Service = UpstreamHandler;
336 type Error = Infallible;
337 fn make_via_ref(&self, old: Option<&Self::Service>) -> Result<Self::Service, Self::Error> {
338 create_connectors!(self, http_connector, https_connector, old);
339 Ok(UpstreamHandler {
340 http_connector,
341 #[cfg(feature = "tls")]
342 https_connector,
343 http_upstream_timeout: self.http_upstream_timeout,
344 })
345 }
346}
347
348impl AsyncMakeService for UpstreamHandlerFactory {
349 type Service = UpstreamHandler;
350 type Error = Infallible;
351
352 async fn make_via_ref(
353 &self,
354 old: Option<&Self::Service>,
355 ) -> Result<Self::Service, Self::Error> {
356 create_connectors!(self, http_connector, https_connector, old);
357 Ok(UpstreamHandler {
358 http_connector,
359 #[cfg(feature = "tls")]
360 https_connector,
361 http_upstream_timeout: self.http_upstream_timeout,
362 })
363 }
364}
365
366#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
367pub struct HttpUpstreamTimeout {
368 pub connect_timeout: Option<Duration>,
371 pub read_timeout: Option<Duration>,
373}
374
375fn add_xff_header<CX>(headers: &mut HeaderMap, ctx: &CX)
376where
377 CX: ParamRef<PeerAddr> + ParamMaybeRef<Option<RemoteAddr>>,
378{
379 let peer_addr = ParamRef::<PeerAddr>::param_ref(ctx);
380 let remote_addr = ParamMaybeRef::<Option<RemoteAddr>>::param_maybe_ref(ctx);
381 let addr = remote_addr
382 .and_then(|addr| addr.as_ref().map(|x| &x.0))
383 .unwrap_or(&peer_addr.0);
384
385 match addr {
386 AcceptedAddr::Tcp(addr) => {
387 if let Ok(value) = HeaderValue::from_maybe_shared(Bytes::from(addr.ip().to_string())) {
388 headers.insert(header::FORWARDED, value);
389 }
390 }
391 AcceptedAddr::Unix(addr) => {
392 if let Some(path) = addr.as_pathname().and_then(|s| s.to_str()) {
393 if let Ok(value) = HeaderValue::from_str(path) {
394 headers.insert(header::FORWARDED, value);
395 }
396 }
397 }
398 }
399}