1use std::sync::Arc;
2
3use axum::body::Body;
4use axum::extract::{Request, State};
5use axum::http::{Method, StatusCode};
6use axum::response::{IntoResponse, Response};
7use axum::Router;
8use hyper_util::rt::TokioIo;
9use hyper_util::server::conn::auto::Builder;
10use log::{debug, error};
11use rustls::ServerConfig;
12use tokio::net::TcpStream;
13use tokio_rustls::TlsAcceptor;
14use tower::Service;
15
16use crate::endpoints::{handle_internal_request, InternalResponse};
17use crate::state::{GenericProxyState, MitmCa, RegistryState};
18
19#[derive(Clone)]
21pub struct HttpProxyState<S: RegistryState + Clone = GenericProxyState> {
22 pub proxy_state: Arc<S>,
23 pub mitm_ca: Arc<MitmCa>,
24 pub upstream_hosts: Arc<Vec<String>>,
25}
26
27pub async fn handle_proxy_request<R: RegistryState + Clone + 'static>(
29 State(state): State<HttpProxyState<R>>,
30 request: Request,
31) -> Response {
32 let method = request.method().clone();
33 let uri = request.uri().clone();
34
35 if method == Method::CONNECT {
36 handle_connect(state, request).await
38 } else if uri.scheme().is_some() {
39 handle_forward_request(state, request).await
41 } else {
42 Response::builder()
44 .status(StatusCode::BAD_REQUEST)
45 .body(Body::from("Invalid proxy request"))
46 .unwrap()
47 }
48}
49
50pub fn is_proxy_request(request: &Request) -> bool {
52 request.method() == Method::CONNECT || request.uri().scheme().is_some()
53}
54
55async fn handle_connect<R: RegistryState + Clone + 'static>(
57 state: HttpProxyState<R>,
58 request: Request,
59) -> Response {
60 let target = request.uri().to_string();
61
62 let (host, port) = if let Some(authority) = request.uri().authority() {
64 let h = authority.host().to_string();
65 let p = authority.port_u16().unwrap_or(443);
66 (h, p)
67 } else if let Some(colon_pos) = target.rfind(':') {
68 let h = target[..colon_pos].to_string();
69 let p: u16 = target[colon_pos + 1..].parse().unwrap_or(443);
70 (h, p)
71 } else {
72 (target.clone(), 443u16)
73 };
74
75 let should_intercept = state.upstream_hosts.iter().any(|upstream_host| {
77 host == upstream_host.as_str() || host.ends_with(&format!(".{}", upstream_host))
78 });
79
80 if should_intercept {
81 debug!("HTTP proxy CONNECT MITM interception for {}:{}", host, port);
82 } else {
83 debug!("HTTP proxy CONNECT tunnel to {}:{}", host, port);
84 }
85
86 let host_clone = host.clone();
88 tokio::spawn(async move {
89 match hyper::upgrade::on(request).await {
90 Ok(upgraded) => {
91 let stream = TokioIo::new(upgraded);
93
94 let result = if should_intercept {
95 handle_connect_mitm(stream, &host_clone, state.proxy_state, state.mitm_ca).await
96 } else {
97 handle_connect_passthrough(stream, &host_clone, port).await
98 };
99
100 if let Err(e) = result {
101 debug!("CONNECT tunnel error: {}", e);
102 }
103 }
104 Err(e) => {
105 error!("Connection upgrade failed: {}", e);
106 }
107 }
108 });
109
110 Response::builder()
112 .status(StatusCode::OK)
113 .body(Body::empty())
114 .unwrap()
115}
116
117async fn handle_connect_mitm<S, R>(
119 stream: S,
120 host: &str,
121 state: Arc<R>,
122 mitm_ca: Arc<MitmCa>,
123) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
124where
125 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
126 R: RegistryState + 'static,
127{
128 use tokio::io::{AsyncBufReadExt, BufReader};
129
130 let (cert_pem, key_pem) = mitm_ca.sign_domain_cert(host)?;
132
133 let certs = rustls_pemfile::certs(&mut cert_pem.as_slice()).collect::<Result<Vec<_>, _>>()?;
135 let key =
136 rustls_pemfile::private_key(&mut key_pem.as_slice())?.ok_or("No private key found")?;
137
138 let server_config = ServerConfig::builder()
140 .with_no_client_auth()
141 .with_single_cert(certs, key)?;
142
143 let acceptor = TlsAcceptor::from(std::sync::Arc::new(server_config));
144
145 let tls_stream = match acceptor.accept(stream).await {
148 Ok(s) => s,
149 Err(e) => {
150 error!("TLS handshake failed for {}: {}", host, e);
151 return Err(e.into());
152 }
153 };
154
155 debug!(" TLS handshake completed for {}", host);
156
157 let (read_half, mut write_half) = tokio::io::split(tls_stream);
159 let mut buf_reader = BufReader::new(read_half);
160
161 loop {
163 let mut request_line = String::new();
165
166 match buf_reader.read_line(&mut request_line).await {
167 Ok(0) => break, Ok(_) => {}
169 Err(e) => {
170 debug!("Error reading from TLS stream: {}", e);
171 break;
172 }
173 }
174
175 if request_line.trim().is_empty() {
176 continue;
177 }
178
179 let parts: Vec<&str> = request_line.split_whitespace().collect();
181 if parts.len() < 3 {
182 debug!("Invalid request line: {}", request_line.trim());
183 break;
184 }
185
186 let method = parts[0];
187 let path = parts[1];
188
189 let mut headers = Vec::new();
191 loop {
192 let mut line = String::new();
193 match buf_reader.read_line(&mut line).await {
194 Ok(0) => break,
195 Ok(_) => {
196 if line.trim().is_empty() {
197 break;
198 }
199 headers.push(line.trim().to_string());
200 }
201 Err(_) => break,
202 }
203 }
204
205 let expects_continue = headers.iter().any(|h| {
207 h.to_lowercase().starts_with("expect:") && h.to_lowercase().contains("100-continue")
208 });
209
210 if expects_continue {
212 tokio::io::AsyncWriteExt::write_all(&mut write_half, b"HTTP/1.1 100 Continue\r\n\r\n")
213 .await?;
214 tokio::io::AsyncWriteExt::flush(&mut write_half).await?;
215 debug!(" Sent 100 Continue for {}", path);
216 }
217
218 let content_length: usize = headers
220 .iter()
221 .find(|h| h.to_lowercase().starts_with("content-length:"))
222 .and_then(|h| h.split(':').nth(1))
223 .and_then(|s| s.trim().parse().ok())
224 .unwrap_or(0);
225
226 let body = if content_length > 0 {
228 let mut body = vec![0u8; content_length];
229 tokio::io::AsyncReadExt::read_exact(&mut buf_reader, &mut body).await?;
230 body
231 } else {
232 Vec::new()
233 };
234
235 debug!(" MITM {} https://{}{}", method, host, path);
236
237 let header_pairs: Vec<(String, String)> = headers
239 .iter()
240 .filter_map(|h| {
241 let pos = h.find(':')?;
242 Some((h[..pos].trim().to_string(), h[pos + 1..].trim().to_string()))
243 })
244 .collect();
245
246 debug!(" -> Handling internally");
248 let response =
249 handle_internal_request(state.as_ref(), method, path, &header_pairs, &body).await;
250
251 let status_line = format!("HTTP/1.1 {} OK\r\n", response.status);
253 tokio::io::AsyncWriteExt::write_all(&mut write_half, status_line.as_bytes()).await?;
254
255 for (name, value) in &response.headers {
256 let header_line = format!("{}: {}\r\n", name, value);
257 tokio::io::AsyncWriteExt::write_all(&mut write_half, header_line.as_bytes()).await?;
258 }
259
260 let content_length_header = format!("content-length: {}\r\n", response.body.len());
261 tokio::io::AsyncWriteExt::write_all(&mut write_half, content_length_header.as_bytes())
262 .await?;
263 tokio::io::AsyncWriteExt::write_all(&mut write_half, b"connection: keep-alive\r\n").await?;
264 tokio::io::AsyncWriteExt::write_all(&mut write_half, b"\r\n").await?;
265 tokio::io::AsyncWriteExt::write_all(&mut write_half, &response.body).await?;
266 tokio::io::AsyncWriteExt::flush(&mut write_half).await?;
267
268 debug!(" <- {} ({} bytes)", response.status, response.body.len());
269
270 let should_close = headers.iter().any(|h| {
272 h.to_lowercase().starts_with("connection:") && h.to_lowercase().contains("close")
273 });
274
275 if should_close {
276 break;
277 }
278 }
279
280 Ok(())
281}
282
283async fn handle_connect_passthrough<S>(
285 stream: S,
286 host: &str,
287 port: u16,
288) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
289where
290 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
291{
292 let upstream_addr = format!("{}:{}", host, port);
293 match TcpStream::connect(&upstream_addr).await {
294 Ok(upstream) => {
295 let (mut client_read, mut client_write) = tokio::io::split(stream);
296 let (mut upstream_read, mut upstream_write) = tokio::io::split(upstream);
297
298 tokio::select! {
299 result = tokio::io::copy(&mut client_read, &mut upstream_write) => {
300 if let Err(e) = result {
301 debug!("CONNECT tunnel client->upstream error: {}", e);
302 }
303 }
304 result = tokio::io::copy(&mut upstream_read, &mut client_write) => {
305 if let Err(e) = result {
306 debug!("CONNECT tunnel upstream->client error: {}", e);
307 }
308 }
309 }
310 }
311 Err(e) => {
312 error!("HTTP proxy: failed to connect to {}: {}", upstream_addr, e);
313 }
316 }
317
318 Ok(())
319}
320
321async fn handle_forward_request<R: RegistryState + Clone + 'static>(
323 state: HttpProxyState<R>,
324 request: Request,
325) -> Response {
326 let method = request.method().clone();
327 let url = request.uri().to_string();
328
329 debug!("HTTP proxy {} request to {}", method, url);
330
331 let should_intercept = url::Url::parse(&url).ok().is_some_and(|parsed| {
333 parsed.host_str().is_some_and(|url_host| {
334 state.upstream_hosts.iter().any(|upstream_host| {
335 url_host == upstream_host.as_str()
336 || url_host.ends_with(&format!(".{}", upstream_host))
337 })
338 })
339 });
340
341 let (parts, body) = request.into_parts();
343 let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
344 Ok(b) => b.to_vec(),
345 Err(e) => {
346 error!("Failed to read request body: {}", e);
347 return Response::builder()
348 .status(StatusCode::BAD_REQUEST)
349 .body(Body::from("Failed to read body"))
350 .unwrap();
351 }
352 };
353
354 if should_intercept {
355 let parsed = match url::Url::parse(&url) {
357 Ok(u) => u,
358 Err(_) => {
359 return Response::builder()
360 .status(StatusCode::BAD_REQUEST)
361 .body(Body::from("Invalid URL"))
362 .unwrap();
363 }
364 };
365 let path = parsed.path();
366 let query = parsed
367 .query()
368 .map(|q| format!("?{}", q))
369 .unwrap_or_default();
370 let full_path = format!("{}{}", path, query);
371
372 debug!(" -> Handling internally: {}", full_path);
373
374 let header_pairs: Vec<(String, String)> = parts
376 .headers
377 .iter()
378 .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
379 .collect();
380
381 let internal_response = handle_internal_request(
382 state.proxy_state.as_ref(),
383 method.as_str(),
384 &full_path,
385 &header_pairs,
386 &body_bytes,
387 )
388 .await;
389
390 convert_internal_response(internal_response)
391 } else {
392 let client = state.proxy_state.client();
394
395 let request_builder = match method {
396 Method::GET => client.get(&url),
397 Method::POST => client.post(&url).body(body_bytes),
398 Method::PUT => client.put(&url).body(body_bytes),
399 Method::DELETE => client.delete(&url),
400 Method::HEAD => client.head(&url),
401 _ => {
402 return Response::builder()
403 .status(StatusCode::METHOD_NOT_ALLOWED)
404 .body(Body::empty())
405 .unwrap();
406 }
407 };
408
409 let mut request_builder = request_builder;
411 for (name, value) in parts.headers.iter() {
412 let name_str = name.to_string().to_lowercase();
413 if ![
415 "host",
416 "connection",
417 "proxy-connection",
418 "proxy-authorization",
419 "te",
420 "trailer",
421 "transfer-encoding",
422 "upgrade",
423 "expect",
424 ]
425 .contains(&name_str.as_str())
426 && let Ok(val_str) = value.to_str()
427 {
428 request_builder = request_builder.header(name.clone(), val_str);
429 }
430 }
431
432 match request_builder.send().await {
433 Ok(upstream_response) => {
434 let status = upstream_response.status();
435 let mut response_builder = Response::builder().status(status);
436
437 for (key, value) in upstream_response.headers().iter() {
439 if key != "transfer-encoding" && key != "connection" {
440 response_builder = response_builder.header(key.clone(), value.clone());
441 }
442 }
443
444 match upstream_response.bytes().await {
445 Ok(body_bytes) => response_builder.body(Body::from(body_bytes)).unwrap(),
446 Err(e) => {
447 error!("Failed to read upstream response: {}", e);
448 Response::builder()
449 .status(StatusCode::BAD_GATEWAY)
450 .body(Body::empty())
451 .unwrap()
452 }
453 }
454 }
455 Err(e) => {
456 error!("HTTP proxy: upstream request failed: {}", e);
457 Response::builder()
458 .status(StatusCode::BAD_GATEWAY)
459 .body(Body::empty())
460 .unwrap()
461 }
462 }
463 }
464}
465
466fn convert_internal_response(internal: InternalResponse) -> Response {
468 let mut builder = Response::builder()
469 .status(StatusCode::from_u16(internal.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR));
470
471 for (name, value) in internal.headers {
472 builder = builder.header(name, value);
473 }
474
475 builder.body(Body::from(internal.body)).unwrap()
476}
477
478pub async fn serve_stream<S, R>(stream: S, app: Router, proxy_state: HttpProxyState<R>)
480where
481 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
482 R: RegistryState + Clone + 'static,
483{
484 use std::convert::Infallible;
485
486 use hyper::service::service_fn;
487 use hyper_util::rt::TokioExecutor;
488
489 let service = service_fn(move |request: Request<hyper::body::Incoming>| {
490 let mut app = app.clone();
491 let proxy_state = proxy_state.clone();
492
493 async move {
494 let (parts, body) = request.into_parts();
495 let body = Body::new(body);
496 let request = Request::from_parts(parts, body);
497
498 if is_proxy_request(&request) {
499 let response = handle_proxy_request(State(proxy_state), request).await;
500 Ok::<_, Infallible>(response)
501 } else {
502 let response = app.call(request).await.into_response();
503 Ok::<_, Infallible>(response)
504 }
505 }
506 });
507
508 let io = TokioIo::new(stream);
509 if let Err(e) = Builder::new(TokioExecutor::new())
510 .serve_connection_with_upgrades(io, service)
511 .await
512 {
513 debug!("Connection error: {}", e);
514 }
515}
516
517pub async fn handle_proxy_connection<R>(
519 stream: TcpStream,
520 app: Router,
521 proxy_state: HttpProxyState<R>,
522 tls_acceptor: Option<TlsAcceptor>,
523) where
524 R: RegistryState + Clone + 'static,
525{
526 if let Some(tls_acceptor) = tls_acceptor {
527 let tls_stream = match tls_acceptor.accept(stream).await {
528 Ok(s) => s,
529 Err(e) => {
530 debug!("TLS handshake error: {}", e);
531 return;
532 }
533 };
534 serve_stream(tls_stream, app, proxy_state).await;
535 } else {
536 serve_stream(stream, app, proxy_state).await;
537 }
538}