1use bytes::Bytes;
2use ferrotunnel_core::stream::VirtualStream;
3use http_body_util::{BodyExt, Full};
4use hyper::body::Incoming;
5use hyper::server::conn::{http1, http2};
6use hyper::{Request, Response, StatusCode};
7use hyper_util::rt::{TokioExecutor, TokioIo};
8use hyper_util::service::TowerToHyperService;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13use tower::{Layer, Service};
14
15use crate::pool::{ConnectionPool, PoolConfig};
16#[derive(Debug)]
17pub enum ProxyError {
18 Hyper(hyper::Error),
19 Custom(String),
20}
21
22impl std::fmt::Display for ProxyError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 ProxyError::Hyper(e) => write!(f, "Hyper error: {e}"),
26 ProxyError::Custom(s) => write!(f, "Proxy error: {s}"),
27 }
28 }
29}
30
31impl std::error::Error for ProxyError {}
32
33impl From<hyper::Error> for ProxyError {
34 fn from(e: hyper::Error) -> Self {
35 ProxyError::Hyper(e)
36 }
37}
38
39impl From<std::convert::Infallible> for ProxyError {
40 fn from(_: std::convert::Infallible) -> Self {
41 unreachable!()
42 }
43}
44
45use tracing::error;
46
47type BoxBody = http_body_util::combinators::BoxBody<Bytes, ProxyError>;
48
49#[derive(Clone)]
51pub struct LocalProxyService {
52 pool: Arc<ConnectionPool>,
53 use_h2: bool,
54}
55
56impl LocalProxyService {
57 pub fn new(target_addr: String) -> Self {
58 let pool = Arc::new(ConnectionPool::new(target_addr, PoolConfig::default()));
59 Self {
60 pool,
61 use_h2: false,
62 }
63 }
64
65 pub fn with_pool(pool: Arc<ConnectionPool>) -> Self {
66 Self {
67 pool,
68 use_h2: false,
69 }
70 }
71
72 pub fn with_pool_h2(pool: Arc<ConnectionPool>) -> Self {
74 Self { pool, use_h2: true }
75 }
76}
77
78use hyper::body::Body;
79
80impl<B> Service<Request<B>> for LocalProxyService
81where
82 B: Body<Data = Bytes> + Send + Sync + 'static,
83 B::Error: Into<ProxyError> + std::error::Error + Send + Sync + 'static,
84{
85 type Response = Response<BoxBody>;
86 type Error = hyper::Error;
87 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
88
89 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90 Poll::Ready(Ok(()))
91 }
92
93 #[allow(clippy::too_many_lines)]
94 fn call(&mut self, mut req: Request<B>) -> Self::Future {
95 let pool = self.pool.clone();
96 let use_h2 = self.use_h2;
97 Box::pin(async move {
98 if use_h2 {
100 let req = req.map(|b| {
101 b.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync + 'static>)
102 .boxed()
103 });
104 let mut sender = match pool.acquire_h2().await {
105 Ok(s) => s,
106 Err(e) => {
107 error!("Failed to acquire HTTP/2 connection from pool: {e}");
108 return Ok(error_response(
109 StatusCode::BAD_GATEWAY,
110 &format!("Failed to connect to local service: {e}"),
111 ));
112 }
113 };
114 return match sender.send_request(req).await {
115 Ok(res) => {
116 let (parts, body) = res.into_parts();
117 Ok(Response::from_parts(
118 parts,
119 body.map_err(Into::into).boxed(),
120 ))
121 }
122 Err(e) => {
123 error!("Failed to proxy gRPC request: {e}");
124 Ok(error_response(StatusCode::BAD_GATEWAY, "Proxy error"))
125 }
126 };
127 }
128
129 let is_upgrade = req
130 .headers()
131 .get(hyper::header::UPGRADE)
132 .and_then(|v| v.to_str().ok())
133 .is_some_and(|v| v.eq_ignore_ascii_case("websocket"));
134
135 let server_upgrade = if is_upgrade {
136 Some(hyper::upgrade::on(&mut req))
137 } else {
138 None
139 };
140
141 let mut sender = match pool.acquire_h1().await {
143 Ok(s) => s,
144 Err(e) => {
145 error!("Failed to acquire connection from pool: {e}");
146 return Ok(error_response(
147 StatusCode::BAD_GATEWAY,
148 &format!("Failed to connect to local service: {e}"),
149 ));
150 }
151 };
152
153 let req = req.map(|b| {
154 b.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync + 'static>)
155 .boxed()
156 });
157
158 match sender.send_request(req).await {
159 Ok(res) => {
160 if is_upgrade && res.status() == StatusCode::SWITCHING_PROTOCOLS {
161 let upstream_headers = res.headers().clone();
163 let local_upgrade = hyper::upgrade::on(res);
164
165 if let Some(server_upgrade) = server_upgrade {
166 tokio::spawn(async move {
167 let (local_result, server_result) =
168 tokio::join!(local_upgrade, server_upgrade);
169
170 let local_upgraded = match local_result {
171 Ok(u) => u,
172 Err(e) => {
173 error!("Local upgrade failed: {e}");
174 return;
175 }
176 };
177 let server_upgraded = match server_result {
178 Ok(u) => u,
179 Err(e) => {
180 error!("Server upgrade failed: {e}");
181 return;
182 }
183 };
184
185 let mut local_io = TokioIo::new(local_upgraded);
186 let mut server_io = TokioIo::new(server_upgraded);
187 let _ =
188 tokio::io::copy_bidirectional(&mut local_io, &mut server_io)
189 .await;
190 });
191 }
192
193 let mut builder =
194 Response::builder().status(StatusCode::SWITCHING_PROTOCOLS);
195 for (key, value) in &upstream_headers {
196 builder = builder.header(key, value);
197 }
198 Ok(builder
199 .body(
200 Full::new(Bytes::new())
201 .map_err(|_| ProxyError::Custom("unreachable".into()))
202 .boxed(),
203 )
204 .unwrap_or_else(|_| {
205 error_response(
206 StatusCode::INTERNAL_SERVER_ERROR,
207 "Failed to build upgrade response",
208 )
209 }))
210 } else {
211 pool.release_h1(sender).await;
213
214 let (parts, body) = res.into_parts();
215 let boxed_body = body.map_err(Into::into).boxed();
216 Ok(Response::from_parts(parts, boxed_body))
217 }
218 }
219 Err(e) => {
220 error!("Failed to proxy request: {e}");
222 Ok(error_response(StatusCode::BAD_GATEWAY, "Proxy error"))
223 }
224 }
225 })
226 }
227}
228
229const MSG_PROXY_ERROR: &[u8] = b"Proxy error";
231const MSG_INTERNAL_ERROR: &[u8] = b"Internal error";
232
233pub fn error_response(status: StatusCode, msg: &str) -> Response<BoxBody> {
236 let bytes = if msg == "Proxy error" {
237 Bytes::from_static(MSG_PROXY_ERROR)
238 } else {
239 Bytes::copy_from_slice(msg.as_bytes())
240 };
241 Response::builder()
242 .status(status)
243 .body(
244 Full::new(bytes)
245 .map_err(|_| ProxyError::Custom("Error construction failed".into()))
246 .boxed(),
247 )
248 .unwrap_or_else(|_| {
249 Response::new(
250 Full::new(Bytes::from_static(MSG_INTERNAL_ERROR))
251 .map_err(|_| ProxyError::Custom("Error construction failed".into()))
252 .boxed(),
253 )
254 })
255}
256
257#[derive(Clone)]
258pub struct HttpProxy<L> {
259 target_addr: String,
260 layer: L,
261 pool: Arc<ConnectionPool>,
262}
263
264impl HttpProxy<tower::layer::util::Identity> {
265 pub fn new(target_addr: String) -> Self {
266 let pool = Arc::new(ConnectionPool::new(
267 target_addr.clone(),
268 PoolConfig::default(),
269 ));
270 Self {
271 target_addr,
272 layer: tower::layer::util::Identity::new(),
273 pool,
274 }
275 }
276
277 pub fn with_pool_config(target_addr: String, pool_config: PoolConfig) -> Self {
278 let pool = Arc::new(ConnectionPool::new(target_addr.clone(), pool_config));
279 Self {
280 target_addr,
281 layer: tower::layer::util::Identity::new(),
282 pool,
283 }
284 }
285}
286
287impl<L> HttpProxy<L> {
288 pub fn with_layer<NewL>(self, layer: NewL) -> HttpProxy<NewL> {
289 HttpProxy {
290 target_addr: self.target_addr,
291 layer,
292 pool: self.pool,
293 }
294 }
295
296 pub fn handle_stream(&self, stream: VirtualStream)
297 where
298 L: Layer<LocalProxyService> + Clone + Send + 'static,
299 L::Service: Service<Request<Incoming>, Response = Response<BoxBody>, Error = hyper::Error>
300 + Send
301 + Clone
302 + 'static,
303 <L::Service as Service<Request<Incoming>>>::Future: Send,
304 {
305 let service = self
306 .layer
307 .clone()
308 .layer(LocalProxyService::with_pool(self.pool.clone()));
309 let hyper_service = TowerToHyperService::new(service);
310 let io = TokioIo::new(stream);
311
312 tokio::spawn(async move {
313 let _ = http1::Builder::new()
314 .serve_connection(io, hyper_service)
315 .with_upgrades()
316 .await;
317 });
318 }
319
320 pub fn handle_grpc_stream(&self, stream: VirtualStream)
327 where
328 L: Layer<LocalProxyService> + Clone + Send + 'static,
329 L::Service: Service<Request<Incoming>, Response = Response<BoxBody>, Error = hyper::Error>
330 + Send
331 + Clone
332 + 'static,
333 <L::Service as Service<Request<Incoming>>>::Future: Send,
334 {
335 let grpc_pool = Arc::new(ConnectionPool::new(
336 self.target_addr.clone(),
337 PoolConfig::default(),
338 ));
339 let service = self
340 .layer
341 .clone()
342 .layer(LocalProxyService::with_pool_h2(grpc_pool));
343 let hyper_service = TowerToHyperService::new(service);
344 let io = TokioIo::new(stream);
345
346 tokio::spawn(async move {
347 let _ = http2::Builder::new(TokioExecutor::new())
348 .serve_connection(io, hyper_service)
349 .await;
350 });
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use http_body_util::BodyExt;
358 use hyper::{body::Bytes, Request};
359 use tower::Service;
360
361 #[test]
362 fn test_proxy_error_display_hyper() {
363 let err = ProxyError::Custom("test error".to_string());
365 assert!(err.to_string().contains("test error"));
366 }
367
368 #[test]
369 fn test_proxy_error_custom_display() {
370 let err = ProxyError::Custom("connection failed".to_string());
371 let display = format!("{err}");
372 assert!(display.contains("Proxy error"));
373 assert!(display.contains("connection failed"));
374 }
375
376 #[test]
377 fn test_local_proxy_service_new() {
378 let service = LocalProxyService::new("127.0.0.1:8080".to_string());
379 let _ = service;
381 }
382
383 #[test]
384 fn test_local_proxy_service_clone() {
385 let service = LocalProxyService::new("localhost:3000".to_string());
386 let _cloned = service.clone();
387 }
389
390 #[tokio::test]
391 async fn test_proxy_connection_error() {
392 let mut service = LocalProxyService::new("127.0.0.1:12345".to_string());
394
395 let req = Request::builder()
396 .uri("http://example.com")
397 .body(Full::new(Bytes::from("test")))
398 .unwrap();
399
400 let response = service.call(req).await.unwrap();
402
403 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
404
405 let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
406 let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
407 assert!(body_str.contains("Failed to connect"));
408 }
409
410 #[test]
411 fn test_error_response_bad_gateway() {
412 let resp = error_response(StatusCode::BAD_GATEWAY, "Backend unavailable");
413 assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
414 }
415
416 #[test]
417 fn test_error_response_not_found() {
418 let resp = error_response(StatusCode::NOT_FOUND, "Route not found");
419 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
420 }
421
422 #[test]
423 fn test_error_response_internal_error() {
424 let resp = error_response(StatusCode::INTERNAL_SERVER_ERROR, "Unexpected error");
425 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
426 }
427
428 #[test]
429 fn test_http_proxy_new() {
430 let proxy = HttpProxy::new("127.0.0.1:8080".to_string());
431 assert_eq!(proxy.target_addr, "127.0.0.1:8080");
432 }
433
434 #[test]
435 fn test_http_proxy_with_layer() {
436 let proxy = HttpProxy::new("127.0.0.1:8080".to_string());
437 let _layered = proxy.with_layer(tower::layer::util::Identity::new());
438 }
440}