Skip to main content

ferrotunnel_http/
proxy.rs

1use bytes::Bytes;
2use ferrotunnel_core::stream::VirtualStream;
3use http_body_util::{BodyExt, Full};
4use hyper::body::Incoming;
5use hyper::server::conn::{http1, http2};
6use hyper::{Request, Response, StatusCode};
7use hyper_util::rt::{TokioExecutor, TokioIo};
8use hyper_util::service::TowerToHyperService;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13use tower::{Layer, Service};
14
15use crate::pool::{ConnectionPool, PoolConfig};
16#[derive(Debug)]
17pub enum ProxyError {
18    Hyper(hyper::Error),
19    Custom(String),
20}
21
22impl std::fmt::Display for ProxyError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            ProxyError::Hyper(e) => write!(f, "Hyper error: {e}"),
26            ProxyError::Custom(s) => write!(f, "Proxy error: {s}"),
27        }
28    }
29}
30
31impl std::error::Error for ProxyError {}
32
33impl From<hyper::Error> for ProxyError {
34    fn from(e: hyper::Error) -> Self {
35        ProxyError::Hyper(e)
36    }
37}
38
39impl From<std::convert::Infallible> for ProxyError {
40    fn from(_: std::convert::Infallible) -> Self {
41        unreachable!()
42    }
43}
44
45use tracing::error;
46
47type BoxBody = http_body_util::combinators::BoxBody<Bytes, ProxyError>;
48
49/// Service that forwards requests to a local TCP port.
50#[derive(Clone)]
51pub struct LocalProxyService {
52    pool: Arc<ConnectionPool>,
53    use_h2: bool,
54}
55
56impl LocalProxyService {
57    pub fn new(target_addr: String) -> Self {
58        let pool = Arc::new(ConnectionPool::new(target_addr, PoolConfig::default()));
59        Self {
60            pool,
61            use_h2: false,
62        }
63    }
64
65    pub fn with_pool(pool: Arc<ConnectionPool>) -> Self {
66        Self {
67            pool,
68            use_h2: false,
69        }
70    }
71
72    /// Create a service that uses HTTP/2 for forwarding (required for gRPC).
73    pub fn with_pool_h2(pool: Arc<ConnectionPool>) -> Self {
74        Self { pool, use_h2: true }
75    }
76}
77
78use hyper::body::Body;
79
80impl<B> Service<Request<B>> for LocalProxyService
81where
82    B: Body<Data = Bytes> + Send + Sync + 'static,
83    B::Error: Into<ProxyError> + std::error::Error + Send + Sync + 'static,
84{
85    type Response = Response<BoxBody>;
86    type Error = hyper::Error;
87    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
88
89    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90        Poll::Ready(Ok(()))
91    }
92
93    #[allow(clippy::too_many_lines)]
94    fn call(&mut self, mut req: Request<B>) -> Self::Future {
95        let pool = self.pool.clone();
96        let use_h2 = self.use_h2;
97        Box::pin(async move {
98            // gRPC path: forward over HTTP/2, which preserves trailers
99            if use_h2 {
100                let req = req.map(|b| {
101                    b.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync + 'static>)
102                        .boxed()
103                });
104                let mut sender = match pool.acquire_h2().await {
105                    Ok(s) => s,
106                    Err(e) => {
107                        error!("Failed to acquire HTTP/2 connection from pool: {e}");
108                        return Ok(error_response(
109                            StatusCode::BAD_GATEWAY,
110                            &format!("Failed to connect to local service: {e}"),
111                        ));
112                    }
113                };
114                return match sender.send_request(req).await {
115                    Ok(res) => {
116                        let (parts, body) = res.into_parts();
117                        Ok(Response::from_parts(
118                            parts,
119                            body.map_err(Into::into).boxed(),
120                        ))
121                    }
122                    Err(e) => {
123                        error!("Failed to proxy gRPC request: {e}");
124                        Ok(error_response(StatusCode::BAD_GATEWAY, "Proxy error"))
125                    }
126                };
127            }
128
129            let is_upgrade = req
130                .headers()
131                .get(hyper::header::UPGRADE)
132                .and_then(|v| v.to_str().ok())
133                .is_some_and(|v| v.eq_ignore_ascii_case("websocket"));
134
135            let server_upgrade = if is_upgrade {
136                Some(hyper::upgrade::on(&mut req))
137            } else {
138                None
139            };
140
141            // Try to acquire connection from pool
142            let mut sender = match pool.acquire_h1().await {
143                Ok(s) => s,
144                Err(e) => {
145                    error!("Failed to acquire connection from pool: {e}");
146                    return Ok(error_response(
147                        StatusCode::BAD_GATEWAY,
148                        &format!("Failed to connect to local service: {e}"),
149                    ));
150                }
151            };
152
153            let req = req.map(|b| {
154                b.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync + 'static>)
155                    .boxed()
156            });
157
158            match sender.send_request(req).await {
159                Ok(res) => {
160                    if is_upgrade && res.status() == StatusCode::SWITCHING_PROTOCOLS {
161                        // Don't return upgraded connections to pool
162                        let upstream_headers = res.headers().clone();
163                        let local_upgrade = hyper::upgrade::on(res);
164
165                        if let Some(server_upgrade) = server_upgrade {
166                            tokio::spawn(async move {
167                                let (local_result, server_result) =
168                                    tokio::join!(local_upgrade, server_upgrade);
169
170                                let local_upgraded = match local_result {
171                                    Ok(u) => u,
172                                    Err(e) => {
173                                        error!("Local upgrade failed: {e}");
174                                        return;
175                                    }
176                                };
177                                let server_upgraded = match server_result {
178                                    Ok(u) => u,
179                                    Err(e) => {
180                                        error!("Server upgrade failed: {e}");
181                                        return;
182                                    }
183                                };
184
185                                let mut local_io = TokioIo::new(local_upgraded);
186                                let mut server_io = TokioIo::new(server_upgraded);
187                                let _ =
188                                    tokio::io::copy_bidirectional(&mut local_io, &mut server_io)
189                                        .await;
190                            });
191                        }
192
193                        let mut builder =
194                            Response::builder().status(StatusCode::SWITCHING_PROTOCOLS);
195                        for (key, value) in &upstream_headers {
196                            builder = builder.header(key, value);
197                        }
198                        Ok(builder
199                            .body(
200                                Full::new(Bytes::new())
201                                    .map_err(|_| ProxyError::Custom("unreachable".into()))
202                                    .boxed(),
203                            )
204                            .unwrap_or_else(|_| {
205                                error_response(
206                                    StatusCode::INTERNAL_SERVER_ERROR,
207                                    "Failed to build upgrade response",
208                                )
209                            }))
210                    } else {
211                        // Return connection to pool for reuse
212                        pool.release_h1(sender).await;
213
214                        let (parts, body) = res.into_parts();
215                        let boxed_body = body.map_err(Into::into).boxed();
216                        Ok(Response::from_parts(parts, boxed_body))
217                    }
218                }
219                Err(e) => {
220                    // Don't return broken connections to pool
221                    error!("Failed to proxy request: {e}");
222                    Ok(error_response(StatusCode::BAD_GATEWAY, "Proxy error"))
223                }
224            }
225        })
226    }
227}
228
229/// Pre-allocated bytes for common error bodies (avoids allocation in hot/error path).
230const MSG_PROXY_ERROR: &[u8] = b"Proxy error";
231const MSG_INTERNAL_ERROR: &[u8] = b"Internal error";
232
233/// Builds a plain-text error response. Shared by proxy and CLI dashboard middleware.
234/// Uses static bytes for common messages to avoid allocation.
235pub fn error_response(status: StatusCode, msg: &str) -> Response<BoxBody> {
236    let bytes = if msg == "Proxy error" {
237        Bytes::from_static(MSG_PROXY_ERROR)
238    } else {
239        Bytes::copy_from_slice(msg.as_bytes())
240    };
241    Response::builder()
242        .status(status)
243        .body(
244            Full::new(bytes)
245                .map_err(|_| ProxyError::Custom("Error construction failed".into()))
246                .boxed(),
247        )
248        .unwrap_or_else(|_| {
249            Response::new(
250                Full::new(Bytes::from_static(MSG_INTERNAL_ERROR))
251                    .map_err(|_| ProxyError::Custom("Error construction failed".into()))
252                    .boxed(),
253            )
254        })
255}
256
257#[derive(Clone)]
258pub struct HttpProxy<L> {
259    target_addr: String,
260    layer: L,
261    pool: Arc<ConnectionPool>,
262}
263
264impl HttpProxy<tower::layer::util::Identity> {
265    pub fn new(target_addr: String) -> Self {
266        let pool = Arc::new(ConnectionPool::new(
267            target_addr.clone(),
268            PoolConfig::default(),
269        ));
270        Self {
271            target_addr,
272            layer: tower::layer::util::Identity::new(),
273            pool,
274        }
275    }
276
277    pub fn with_pool_config(target_addr: String, pool_config: PoolConfig) -> Self {
278        let pool = Arc::new(ConnectionPool::new(target_addr.clone(), pool_config));
279        Self {
280            target_addr,
281            layer: tower::layer::util::Identity::new(),
282            pool,
283        }
284    }
285}
286
287impl<L> HttpProxy<L> {
288    pub fn with_layer<NewL>(self, layer: NewL) -> HttpProxy<NewL> {
289        HttpProxy {
290            target_addr: self.target_addr,
291            layer,
292            pool: self.pool,
293        }
294    }
295
296    pub fn handle_stream(&self, stream: VirtualStream)
297    where
298        L: Layer<LocalProxyService> + Clone + Send + 'static,
299        L::Service: Service<Request<Incoming>, Response = Response<BoxBody>, Error = hyper::Error>
300            + Send
301            + Clone
302            + 'static,
303        <L::Service as Service<Request<Incoming>>>::Future: Send,
304    {
305        let service = self
306            .layer
307            .clone()
308            .layer(LocalProxyService::with_pool(self.pool.clone()));
309        let hyper_service = TowerToHyperService::new(service);
310        let io = TokioIo::new(stream);
311
312        tokio::spawn(async move {
313            let _ = http1::Builder::new()
314                .serve_connection(io, hyper_service)
315                .with_upgrades()
316                .await;
317        });
318    }
319
320    /// Serve an incoming gRPC `VirtualStream` over HTTP/2.
321    ///
322    /// gRPC requires HTTP/2 end-to-end so that trailers (`grpc-status`,
323    /// `grpc-message`) are propagated correctly. This method uses a
324    /// dedicated HTTP/2 connection pool (always acquired via `acquire_h2()`)
325    /// to forward requests to the local service.
326    pub fn handle_grpc_stream(&self, stream: VirtualStream)
327    where
328        L: Layer<LocalProxyService> + Clone + Send + 'static,
329        L::Service: Service<Request<Incoming>, Response = Response<BoxBody>, Error = hyper::Error>
330            + Send
331            + Clone
332            + 'static,
333        <L::Service as Service<Request<Incoming>>>::Future: Send,
334    {
335        let grpc_pool = Arc::new(ConnectionPool::new(
336            self.target_addr.clone(),
337            PoolConfig::default(),
338        ));
339        let service = self
340            .layer
341            .clone()
342            .layer(LocalProxyService::with_pool_h2(grpc_pool));
343        let hyper_service = TowerToHyperService::new(service);
344        let io = TokioIo::new(stream);
345
346        tokio::spawn(async move {
347            let _ = http2::Builder::new(TokioExecutor::new())
348                .serve_connection(io, hyper_service)
349                .await;
350        });
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use http_body_util::BodyExt;
358    use hyper::{body::Bytes, Request};
359    use tower::Service;
360
361    #[test]
362    fn test_proxy_error_display_hyper() {
363        // We can't easily create a real hyper error, but we can test Custom
364        let err = ProxyError::Custom("test error".to_string());
365        assert!(err.to_string().contains("test error"));
366    }
367
368    #[test]
369    fn test_proxy_error_custom_display() {
370        let err = ProxyError::Custom("connection failed".to_string());
371        let display = format!("{err}");
372        assert!(display.contains("Proxy error"));
373        assert!(display.contains("connection failed"));
374    }
375
376    #[test]
377    fn test_local_proxy_service_new() {
378        let service = LocalProxyService::new("127.0.0.1:8080".to_string());
379        // Service is created successfully with pool
380        let _ = service;
381    }
382
383    #[test]
384    fn test_local_proxy_service_clone() {
385        let service = LocalProxyService::new("localhost:3000".to_string());
386        let _cloned = service.clone();
387        // Service can be cloned successfully
388    }
389
390    #[tokio::test]
391    async fn test_proxy_connection_error() {
392        // Create a service pointing to a closed port (assuming 127.0.0.1:12345 is closed)
393        let mut service = LocalProxyService::new("127.0.0.1:12345".to_string());
394
395        let req = Request::builder()
396            .uri("http://example.com")
397            .body(Full::new(Bytes::from("test")))
398            .unwrap();
399
400        // The service should return a 502 Bad Gateway response
401        let response = service.call(req).await.unwrap();
402
403        assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
404
405        let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
406        let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
407        assert!(body_str.contains("Failed to connect"));
408    }
409
410    #[test]
411    fn test_error_response_bad_gateway() {
412        let resp = error_response(StatusCode::BAD_GATEWAY, "Backend unavailable");
413        assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
414    }
415
416    #[test]
417    fn test_error_response_not_found() {
418        let resp = error_response(StatusCode::NOT_FOUND, "Route not found");
419        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
420    }
421
422    #[test]
423    fn test_error_response_internal_error() {
424        let resp = error_response(StatusCode::INTERNAL_SERVER_ERROR, "Unexpected error");
425        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
426    }
427
428    #[test]
429    fn test_http_proxy_new() {
430        let proxy = HttpProxy::new("127.0.0.1:8080".to_string());
431        assert_eq!(proxy.target_addr, "127.0.0.1:8080");
432    }
433
434    #[test]
435    fn test_http_proxy_with_layer() {
436        let proxy = HttpProxy::new("127.0.0.1:8080".to_string());
437        let _layered = proxy.with_layer(tower::layer::util::Identity::new());
438        // Just verify it compiles and runs
439    }
440}