monolake_services/http/handlers/
upstream.rs

1//! Upstream proxy handling and request forwarding module.
2//!
3//! This module provides components for proxying HTTP and HTTPS requests to upstream servers,
4//! leveraging high-performance HTTP client implementations optimized for use with monoio's
5//! asynchronous runtime and io_uring.
6//!
7//! # Key Components
8//!
9//! - [`UpstreamHandler`]: The main service component responsible for proxying requests. It utilizes
10//!   the `HttpConnector` for efficient connection management and request handling.
11//! - [`UpstreamHandlerFactory`]: A factory for creating and updating `UpstreamHandler` instances.
12//! - [`HttpUpstreamTimeout`]: Configuration for various timeout settings in upstream communication.
13//!
14//! # Features
15//!
16//! - HTTP and HTTPS request proxying using optimized connectors
17//! - Connection pooling for efficient resource usage, provided by `HttpConnector`
18//! - Support for both HTTP/1.1 and HTTP/2 protocols
19//! - Configurable timeout settings
20//! - TLS support (enabled with the `tls` feature flag)
21//! - X-Forwarded-For header management
22//! - Leverages monoio's native IO traits built on top of io_uring for high performance
23//!
24//! # HTTP Connector Usage
25//!
26//! The `UpstreamHandler` utilizes `HttpConnector`, which provides:
27//!
28//! - Unified interface for HTTP/1.1 and HTTP/2 connections
29//! - Built-in connection pooling for efficient reuse of established connections
30//! - Optimized for monoio's asynchronous runtime and io_uring
31//! - TLS support for secure HTTPS connections
32//!
33//! # Error Handling
34//!
35//! - Connection errors result in 502 Bad Gateway responses
36//! - Invalid URIs or unresolvable hosts result in 400 Bad Request responses
37//! - Timeouts are handled gracefully, returning appropriate error responses
38//!
39//! # Performance Considerations
40//!
41//! - Utilizes `HttpConnector`'s connection pooling to reduce the overhead of creating new
42//!   connections
43//! - Employs efficient async I/O operations leveraging io_uring for improved performance
44//! - Supports both HTTP/1.1 and HTTP/2, allowing for protocol-specific optimizations
45//!
46//! # Feature Flags
47//!
48//! - `tls`: Enables TLS support for HTTPS connections to upstream servers
49use std::{
50    convert::Infallible,
51    net::{SocketAddr, ToSocketAddrs},
52    time::Duration,
53};
54
55use bytes::Bytes;
56use http::{header, HeaderMap, HeaderValue, Request, StatusCode};
57use monoio::net::TcpStream;
58use monoio_http::common::{
59    body::{Body, HttpBody},
60    error::HttpError,
61};
62#[cfg(feature = "tls")]
63use monoio_transports::connectors::{TlsConnector, TlsStream};
64use monoio_transports::{
65    connectors::{Connector, TcpConnector},
66    http::{HttpConnection, HttpConnector},
67};
68use monolake_core::{
69    context::{PeerAddr, RemoteAddr},
70    http::ResponseWithContinue,
71    listener::AcceptedAddr,
72};
73use service_async::{AsyncMakeService, MakeService, ParamMaybeRef, ParamRef, Service};
74use tracing::{debug, info};
75
76use crate::http::{generate_response, HttpVersion};
77
78type PooledHttpConnector = HttpConnector<TcpConnector, SocketAddr, TcpStream>;
79#[cfg(feature = "tls")]
80type PooledHttpsConnector = HttpConnector<
81    TlsConnector<TcpConnector>,
82    monoio_transports::connectors::TcpTlsAddr,
83    TlsStream<TcpStream>,
84>;
85
86/// Handles proxying of HTTP and HTTPS requests to upstream servers.
87///
88/// `UpstreamHandler` is responsible for forwarding incoming requests to appropriate
89/// upstream servers, handling both HTTP and HTTPS protocols. It manages connection
90/// pooling, timeout settings, and error handling.
91///
92/// For implementation details and example usage, see the
93/// [module level documentation](crate::http::handlers::upstream).
94#[derive(Default)]
95pub struct UpstreamHandler {
96    http_connector: PooledHttpConnector,
97    #[cfg(feature = "tls")]
98    https_connector: PooledHttpsConnector,
99    pub http_upstream_timeout: HttpUpstreamTimeout,
100}
101
102impl UpstreamHandler {
103    #[cfg(not(feature = "tls"))]
104    pub fn new(http_connector: PooledHttpConnector, http_upstream_timeout: HttpUpstreamTimeout) -> Self {
105        UpstreamHandler {
106            http_connector,
107            http_upstream_timeout,
108        }
109    }
110
111    #[cfg(feature = "tls")]
112    pub fn new(
113        connector: PooledHttpConnector,
114        tls_connector: PooledHttpsConnector,
115        http_upstream_timeout: HttpUpstreamTimeout,
116    ) -> Self {
117        UpstreamHandler {
118            http_connector: connector,
119            https_connector: tls_connector,
120            http_upstream_timeout,
121        }
122    }
123
124    pub const fn factory(
125        http_upstream_timeout: HttpUpstreamTimeout,
126        version: HttpVersion,
127    ) -> UpstreamHandlerFactory {
128        UpstreamHandlerFactory {
129            http_upstream_timeout,
130            version,
131        }
132    }
133}
134
135impl<CX, B> Service<(Request<B>, CX)> for UpstreamHandler
136where
137    CX: ParamRef<PeerAddr> + ParamMaybeRef<Option<RemoteAddr>>,
138    // B: Body,
139    B: Body<Data = Bytes, Error = HttpError>,
140    HttpError: From<B::Error>,
141{
142    type Response = ResponseWithContinue<HttpBody>;
143    type Error = Infallible;
144
145    async fn call(&self, (mut req, ctx): (Request<B>, CX)) -> Result<Self::Response, Self::Error> {
146        add_xff_header(req.headers_mut(), &ctx);
147        #[cfg(feature = "tls")]
148        if req.uri().scheme() == Some(&http::uri::Scheme::HTTPS) {
149            return self.send_https_request(req).await;
150        }
151        self.send_http_request(req).await
152    }
153}
154
155impl UpstreamHandler {
156    async fn send_http_request<B>(
157        &self,
158        mut req: Request<B>,
159    ) -> Result<ResponseWithContinue<HttpBody>, Infallible>
160    where
161        B: Body<Data = Bytes, Error = HttpError>,
162        HttpError: From<B::Error>,
163    {
164        let Some(host) = req.uri().host() else {
165            info!("invalid uri which does not contain host: {:?}", req.uri());
166            return Ok((generate_response(StatusCode::BAD_REQUEST, true), true));
167        };
168        let port = req.uri().port_u16().unwrap_or(80);
169        let mut iter = match (host, port).to_socket_addrs() {
170            Ok(iter) => iter,
171            Err(e) => {
172                info!("convert invalid uri: {:?} with error: {:?}", req.uri(), e);
173                return Ok((generate_response(StatusCode::BAD_REQUEST, true), true));
174            }
175        };
176        let Some(key) = iter.next() else {
177            info!("unable to resolve host: {host}");
178            return Ok((generate_response(StatusCode::BAD_REQUEST, true), true));
179        };
180        debug!("key: {:?}", key);
181        let mut conn = match self.http_connector.connect(key).await {
182            Ok(conn) => {
183                match &conn {
184                    HttpConnection::Http1(_) => {
185                        *req.version_mut() = http::Version::HTTP_11;
186                    }
187                    HttpConnection::Http2(_) => {
188                        *req.version_mut() = http::Version::HTTP_2;
189                        req.headers_mut().remove(http::header::HOST);
190                    }
191                }
192                conn
193            }
194            Err(e) => {
195                info!("connect upstream error: {:?}", e);
196                return Ok((generate_response(StatusCode::BAD_GATEWAY, true), true));
197            }
198        };
199
200        match conn.send_request(req).await {
201            (Ok(resp), _) => Ok((resp, true)),
202            // Bad gateway should not affect inbound connection.
203            // It should still be keepalive.
204            (Err(_e), _) => Ok((generate_response(StatusCode::BAD_GATEWAY, false), true)),
205        }
206    }
207
208    #[cfg(feature = "tls")]
209    async fn send_https_request<B>(
210        &self,
211        req: Request<B>,
212    ) -> Result<ResponseWithContinue<HttpBody>, Infallible>
213    where
214        B: Body<Data = Bytes, Error = HttpError>,
215        HttpError: From<B::Error>,
216    {
217        let key = match req.uri().try_into() {
218            Ok(key) => key,
219            Err(e) => {
220                info!("convert invalid uri: {:?} with error: {:?}", req.uri(), e);
221                return Ok((generate_response(StatusCode::BAD_REQUEST, true), true));
222            }
223        };
224        debug!("key: {:?}", key);
225        let connect = match self.http_upstream_timeout.connect_timeout {
226            Some(connect_timeout) => {
227                match monoio::time::timeout(connect_timeout, self.https_connector.connect(key))
228                    .await
229                {
230                    Ok(x) => x,
231                    Err(_) => {
232                        info!("connect upstream timeout");
233                        return Ok((generate_response(StatusCode::BAD_GATEWAY, true), true));
234                    }
235                }
236            }
237            None => self.https_connector.connect(key).await,
238        };
239
240        let mut conn = match connect {
241            Ok(conn) => conn,
242            Err(e) => {
243                info!("connect upstream error: {:?}", e);
244                return Ok((generate_response(StatusCode::BAD_GATEWAY, true), true));
245            }
246        };
247
248        match conn.send_request(req).await {
249            (Ok(resp), _) => Ok((resp, true)),
250            // Bad gateway should not affect inbound connection.
251            // It should still be keepalive.
252            (Err(_e), _) => Ok((generate_response(StatusCode::BAD_GATEWAY, false), true)),
253        }
254    }
255}
256
257pub struct UpstreamHandlerFactory {
258    http_upstream_timeout: HttpUpstreamTimeout,
259    version: HttpVersion,
260}
261
262impl UpstreamHandlerFactory {
263    pub fn new(
264        http_upstream_timeout: HttpUpstreamTimeout,
265        version: HttpVersion,
266    ) -> UpstreamHandlerFactory {
267        UpstreamHandlerFactory {
268            http_upstream_timeout,
269            version,
270        }
271    }
272}
273
274macro_rules! create_connectors {
275    ($self:ident, $http_connector:ident, $https_connector:ident, $old_service:ident) => {
276        let mut $http_connector = match $self.version {
277            HttpVersion::Http2 => PooledHttpConnector::build_tcp_http2_only(),
278            HttpVersion::Http11 => {
279                // No support for upgrades to HTTP/2
280                PooledHttpConnector::build_tcp_http1_only()
281            }
282            HttpVersion::Auto => {
283                // Default to HTTP/1.1
284                PooledHttpConnector::default()
285            }
286        };
287        $http_connector.set_read_timeout($self.http_upstream_timeout.read_timeout);
288
289        #[cfg(feature = "tls")]
290        let mut $https_connector = match $self.version {
291            HttpVersion::Http2 => {
292                // ALPN advertised with h2
293                PooledHttpsConnector::build_tls_http2_only()
294            }
295            HttpVersion::Http11 => {
296                // ALPN advertised with http1.1
297                PooledHttpsConnector::build_tls_http1_only()
298            }
299            HttpVersion::Auto => {
300                // ALPN advertised with h2/http1.1
301                PooledHttpsConnector::default()
302            }
303        };
304        #[cfg(feature = "tls")]
305        $https_connector.set_read_timeout($self.http_upstream_timeout.read_timeout);
306
307        // If there is an old service, transfer the pool from the old service to the new one
308        // to avoid creating new connections.
309        if let Some($old_service) = $old_service {
310            // Pool transfer is only supported when the protocol and timeout settings are the same.
311            match PooledHttpConnector::transfer_pool(
312                &$old_service.http_connector,
313                &mut $http_connector,
314            ) {
315                Ok(_) => tracing::trace!("Transferred HTTP pool from old service to new service"),
316                Err(e) => {
317                    tracing::error!("Failed to transfer pool: {:?}", e);
318                }
319            }
320            #[cfg(feature = "tls")]
321            match PooledHttpsConnector::transfer_pool(
322                &$old_service.https_connector,
323                &mut $https_connector,
324            ) {
325                Ok(_) => tracing::trace!("Transferred HTTPS pool from old service to new service"),
326                Err(e) => {
327                    tracing::error!("Failed to transfer pool: {:?}", e);
328                }
329            }
330        }
331    };
332}
333// HttpCoreService is a Service and a MakeService.
334impl MakeService for UpstreamHandlerFactory {
335    type Service = UpstreamHandler;
336    type Error = Infallible;
337    fn make_via_ref(&self, old: Option<&Self::Service>) -> Result<Self::Service, Self::Error> {
338        create_connectors!(self, http_connector, https_connector, old);
339        Ok(UpstreamHandler {
340            http_connector,
341            #[cfg(feature = "tls")]
342            https_connector,
343            http_upstream_timeout: self.http_upstream_timeout,
344        })
345    }
346}
347
348impl AsyncMakeService for UpstreamHandlerFactory {
349    type Service = UpstreamHandler;
350    type Error = Infallible;
351
352    async fn make_via_ref(
353        &self,
354        old: Option<&Self::Service>,
355    ) -> Result<Self::Service, Self::Error> {
356        create_connectors!(self, http_connector, https_connector, old);
357        Ok(UpstreamHandler {
358            http_connector,
359            #[cfg(feature = "tls")]
360            https_connector,
361            http_upstream_timeout: self.http_upstream_timeout,
362        })
363    }
364}
365
366#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
367pub struct HttpUpstreamTimeout {
368    // Connect timeout
369    // Link Nginx `proxy_connect_timeout`
370    pub connect_timeout: Option<Duration>,
371    // Response read timeout
372    pub read_timeout: Option<Duration>,
373}
374
375fn add_xff_header<CX>(headers: &mut HeaderMap, ctx: &CX)
376where
377    CX: ParamRef<PeerAddr> + ParamMaybeRef<Option<RemoteAddr>>,
378{
379    let peer_addr = ParamRef::<PeerAddr>::param_ref(ctx);
380    let remote_addr = ParamMaybeRef::<Option<RemoteAddr>>::param_maybe_ref(ctx);
381    let addr = remote_addr
382        .and_then(|addr| addr.as_ref().map(|x| &x.0))
383        .unwrap_or(&peer_addr.0);
384
385    match addr {
386        AcceptedAddr::Tcp(addr) => {
387            if let Ok(value) = HeaderValue::from_maybe_shared(Bytes::from(addr.ip().to_string())) {
388                headers.insert(header::FORWARDED, value);
389            }
390        }
391        AcceptedAddr::Unix(addr) => {
392            if let Some(path) = addr.as_pathname().and_then(|s| s.to_str()) {
393                if let Ok(value) = HeaderValue::from_str(path) {
394                    headers.insert(header::FORWARDED, value);
395                }
396            }
397        }
398    }
399}