Skip to main content

cargo_overlay_registry/
http_proxy.rs

1use std::sync::Arc;
2
3use axum::body::Body;
4use axum::extract::{Request, State};
5use axum::http::{Method, StatusCode};
6use axum::response::{IntoResponse, Response};
7use axum::Router;
8use hyper_util::rt::TokioIo;
9use hyper_util::server::conn::auto::Builder;
10use log::{debug, error};
11use rustls::ServerConfig;
12use tokio::net::TcpStream;
13use tokio_rustls::TlsAcceptor;
14use tower::Service;
15
16use crate::endpoints::{handle_internal_request, InternalResponse};
17use crate::state::{GenericProxyState, MitmCa, RegistryState};
18
19/// Shared state for the HTTP proxy functionality
20#[derive(Clone)]
21pub struct HttpProxyState<S: RegistryState + Clone = GenericProxyState> {
22    pub proxy_state: Arc<S>,
23    pub mitm_ca: Arc<MitmCa>,
24    pub upstream_hosts: Arc<Vec<String>>,
25}
26
27/// Handle incoming requests - routes CONNECT and proxy-style requests
28pub async fn handle_proxy_request<R: RegistryState + Clone + 'static>(
29    State(state): State<HttpProxyState<R>>,
30    request: Request,
31) -> Response {
32    let method = request.method().clone();
33    let uri = request.uri().clone();
34
35    if method == Method::CONNECT {
36        // Handle CONNECT request for HTTPS tunneling
37        handle_connect(state, request).await
38    } else if uri.scheme().is_some() {
39        // Proxy-style request with absolute URL (e.g., GET http://example.com/path)
40        handle_forward_request(state, request).await
41    } else {
42        // This shouldn't happen - regular requests go to axum routes
43        Response::builder()
44            .status(StatusCode::BAD_REQUEST)
45            .body(Body::from("Invalid proxy request"))
46            .unwrap()
47    }
48}
49
50/// Check if a request should be handled by the proxy layer
51pub fn is_proxy_request(request: &Request) -> bool {
52    request.method() == Method::CONNECT || request.uri().scheme().is_some()
53}
54
55/// Handle CONNECT method for HTTPS tunneling
56async fn handle_connect<R: RegistryState + Clone + 'static>(
57    state: HttpProxyState<R>,
58    request: Request,
59) -> Response {
60    let target = request.uri().to_string();
61
62    // Parse target as host:port
63    let (host, port) = if let Some(authority) = request.uri().authority() {
64        let h = authority.host().to_string();
65        let p = authority.port_u16().unwrap_or(443);
66        (h, p)
67    } else if let Some(colon_pos) = target.rfind(':') {
68        let h = target[..colon_pos].to_string();
69        let p: u16 = target[colon_pos + 1..].parse().unwrap_or(443);
70        (h, p)
71    } else {
72        (target.clone(), 443u16)
73    };
74
75    // Check if this is an upstream registry domain that we should intercept
76    let should_intercept = state.upstream_hosts.iter().any(|upstream_host| {
77        host == upstream_host.as_str() || host.ends_with(&format!(".{}", upstream_host))
78    });
79
80    if should_intercept {
81        debug!("HTTP proxy CONNECT MITM interception for {}:{}", host, port);
82    } else {
83        debug!("HTTP proxy CONNECT tunnel to {}:{}", host, port);
84    }
85
86    // Spawn task to handle the upgraded connection
87    let host_clone = host.clone();
88    tokio::spawn(async move {
89        match hyper::upgrade::on(request).await {
90            Ok(upgraded) => {
91                // TokioIo wraps the upgraded connection to implement tokio's AsyncRead/AsyncWrite
92                let stream = TokioIo::new(upgraded);
93
94                let result = if should_intercept {
95                    handle_connect_mitm(stream, &host_clone, state.proxy_state, state.mitm_ca).await
96                } else {
97                    handle_connect_passthrough(stream, &host_clone, port).await
98                };
99
100                if let Err(e) = result {
101                    debug!("CONNECT tunnel error: {}", e);
102                }
103            }
104            Err(e) => {
105                error!("Connection upgrade failed: {}", e);
106            }
107        }
108    });
109
110    // Return 200 Connection Established
111    Response::builder()
112        .status(StatusCode::OK)
113        .body(Body::empty())
114        .unwrap()
115}
116
117/// Handle CONNECT with MITM TLS interception for upstream registry domains
118async fn handle_connect_mitm<S, R>(
119    stream: S,
120    host: &str,
121    state: Arc<R>,
122    mitm_ca: Arc<MitmCa>,
123) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
124where
125    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
126    R: RegistryState + 'static,
127{
128    use tokio::io::{AsyncBufReadExt, BufReader};
129
130    // Generate a certificate for the target domain, signed by our CA
131    let (cert_pem, key_pem) = mitm_ca.sign_domain_cert(host)?;
132
133    // Parse the certificate and key
134    let certs = rustls_pemfile::certs(&mut cert_pem.as_slice()).collect::<Result<Vec<_>, _>>()?;
135    let key =
136        rustls_pemfile::private_key(&mut key_pem.as_slice())?.ok_or("No private key found")?;
137
138    // Build TLS server config
139    let server_config = ServerConfig::builder()
140        .with_no_client_auth()
141        .with_single_cert(certs, key)?;
142
143    let acceptor = TlsAcceptor::from(std::sync::Arc::new(server_config));
144
145    // Wrap the stream for TLS
146    // Note: For upgraded connections, we need to use a type-erased wrapper
147    let tls_stream = match acceptor.accept(stream).await {
148        Ok(s) => s,
149        Err(e) => {
150            error!("TLS handshake failed for {}: {}", host, e);
151            return Err(e.into());
152        }
153    };
154
155    debug!("  TLS handshake completed for {}", host);
156
157    // Split the TLS stream for reading and writing
158    let (read_half, mut write_half) = tokio::io::split(tls_stream);
159    let mut buf_reader = BufReader::new(read_half);
160
161    // Now handle HTTP requests over the TLS connection
162    loop {
163        // Read the HTTP request line
164        let mut request_line = String::new();
165
166        match buf_reader.read_line(&mut request_line).await {
167            Ok(0) => break, // Connection closed
168            Ok(_) => {}
169            Err(e) => {
170                debug!("Error reading from TLS stream: {}", e);
171                break;
172            }
173        }
174
175        if request_line.trim().is_empty() {
176            continue;
177        }
178
179        // Parse request line
180        let parts: Vec<&str> = request_line.split_whitespace().collect();
181        if parts.len() < 3 {
182            debug!("Invalid request line: {}", request_line.trim());
183            break;
184        }
185
186        let method = parts[0];
187        let path = parts[1];
188
189        // Read headers
190        let mut headers = Vec::new();
191        loop {
192            let mut line = String::new();
193            match buf_reader.read_line(&mut line).await {
194                Ok(0) => break,
195                Ok(_) => {
196                    if line.trim().is_empty() {
197                        break;
198                    }
199                    headers.push(line.trim().to_string());
200                }
201                Err(_) => break,
202            }
203        }
204
205        // Check for Expect: 100-continue header
206        let expects_continue = headers.iter().any(|h| {
207            h.to_lowercase().starts_with("expect:") && h.to_lowercase().contains("100-continue")
208        });
209
210        // Send 100 Continue response if requested before reading body
211        if expects_continue {
212            tokio::io::AsyncWriteExt::write_all(&mut write_half, b"HTTP/1.1 100 Continue\r\n\r\n")
213                .await?;
214            tokio::io::AsyncWriteExt::flush(&mut write_half).await?;
215            debug!("  Sent 100 Continue for {}", path);
216        }
217
218        // Get content length
219        let content_length: usize = headers
220            .iter()
221            .find(|h| h.to_lowercase().starts_with("content-length:"))
222            .and_then(|h| h.split(':').nth(1))
223            .and_then(|s| s.trim().parse().ok())
224            .unwrap_or(0);
225
226        // Read body if present
227        let body = if content_length > 0 {
228            let mut body = vec![0u8; content_length];
229            tokio::io::AsyncReadExt::read_exact(&mut buf_reader, &mut body).await?;
230            body
231        } else {
232            Vec::new()
233        };
234
235        debug!("  MITM {} https://{}{}", method, host, path);
236
237        // Convert headers to the format expected by handle_internal_request
238        let header_pairs: Vec<(String, String)> = headers
239            .iter()
240            .filter_map(|h| {
241                let pos = h.find(':')?;
242                Some((h[..pos].trim().to_string(), h[pos + 1..].trim().to_string()))
243            })
244            .collect();
245
246        // Handle internally
247        debug!("    -> Handling internally");
248        let response =
249            handle_internal_request(state.as_ref(), method, path, &header_pairs, &body).await;
250
251        // Write response
252        let status_line = format!("HTTP/1.1 {} OK\r\n", response.status);
253        tokio::io::AsyncWriteExt::write_all(&mut write_half, status_line.as_bytes()).await?;
254
255        for (name, value) in &response.headers {
256            let header_line = format!("{}: {}\r\n", name, value);
257            tokio::io::AsyncWriteExt::write_all(&mut write_half, header_line.as_bytes()).await?;
258        }
259
260        let content_length_header = format!("content-length: {}\r\n", response.body.len());
261        tokio::io::AsyncWriteExt::write_all(&mut write_half, content_length_header.as_bytes())
262            .await?;
263        tokio::io::AsyncWriteExt::write_all(&mut write_half, b"connection: keep-alive\r\n").await?;
264        tokio::io::AsyncWriteExt::write_all(&mut write_half, b"\r\n").await?;
265        tokio::io::AsyncWriteExt::write_all(&mut write_half, &response.body).await?;
266        tokio::io::AsyncWriteExt::flush(&mut write_half).await?;
267
268        debug!("    <- {} ({} bytes)", response.status, response.body.len());
269
270        // Check for Connection: close
271        let should_close = headers.iter().any(|h| {
272            h.to_lowercase().starts_with("connection:") && h.to_lowercase().contains("close")
273        });
274
275        if should_close {
276            break;
277        }
278    }
279
280    Ok(())
281}
282
283/// Handle CONNECT with direct passthrough (no interception)
284async fn handle_connect_passthrough<S>(
285    stream: S,
286    host: &str,
287    port: u16,
288) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
289where
290    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
291{
292    let upstream_addr = format!("{}:{}", host, port);
293    match TcpStream::connect(&upstream_addr).await {
294        Ok(upstream) => {
295            let (mut client_read, mut client_write) = tokio::io::split(stream);
296            let (mut upstream_read, mut upstream_write) = tokio::io::split(upstream);
297
298            tokio::select! {
299                result = tokio::io::copy(&mut client_read, &mut upstream_write) => {
300                    if let Err(e) = result {
301                        debug!("CONNECT tunnel client->upstream error: {}", e);
302                    }
303                }
304                result = tokio::io::copy(&mut upstream_read, &mut client_write) => {
305                    if let Err(e) = result {
306                        debug!("CONNECT tunnel upstream->client error: {}", e);
307                    }
308                }
309            }
310        }
311        Err(e) => {
312            error!("HTTP proxy: failed to connect to {}: {}", upstream_addr, e);
313            // Can't really send a response here since client expects raw tunnel
314            // Just close the connection
315        }
316    }
317
318    Ok(())
319}
320
321/// Handle regular HTTP proxy request (forward proxy with absolute URL)
322async fn handle_forward_request<R: RegistryState + Clone + 'static>(
323    state: HttpProxyState<R>,
324    request: Request,
325) -> Response {
326    let method = request.method().clone();
327    let url = request.uri().to_string();
328
329    debug!("HTTP proxy {} request to {}", method, url);
330
331    // Check if this is an upstream registry URL that should be intercepted
332    let should_intercept = url::Url::parse(&url).ok().is_some_and(|parsed| {
333        parsed.host_str().is_some_and(|url_host| {
334            state.upstream_hosts.iter().any(|upstream_host| {
335                url_host == upstream_host.as_str()
336                    || url_host.ends_with(&format!(".{}", upstream_host))
337            })
338        })
339    });
340
341    // Get the body
342    let (parts, body) = request.into_parts();
343    let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
344        Ok(b) => b.to_vec(),
345        Err(e) => {
346            error!("Failed to read request body: {}", e);
347            return Response::builder()
348                .status(StatusCode::BAD_REQUEST)
349                .body(Body::from("Failed to read body"))
350                .unwrap();
351        }
352    };
353
354    if should_intercept {
355        // Handle internally
356        let parsed = match url::Url::parse(&url) {
357            Ok(u) => u,
358            Err(_) => {
359                return Response::builder()
360                    .status(StatusCode::BAD_REQUEST)
361                    .body(Body::from("Invalid URL"))
362                    .unwrap();
363            }
364        };
365        let path = parsed.path();
366        let query = parsed
367            .query()
368            .map(|q| format!("?{}", q))
369            .unwrap_or_default();
370        let full_path = format!("{}{}", path, query);
371
372        debug!("  -> Handling internally: {}", full_path);
373
374        // Convert headers
375        let header_pairs: Vec<(String, String)> = parts
376            .headers
377            .iter()
378            .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
379            .collect();
380
381        let internal_response = handle_internal_request(
382            state.proxy_state.as_ref(),
383            method.as_str(),
384            &full_path,
385            &header_pairs,
386            &body_bytes,
387        )
388        .await;
389
390        convert_internal_response(internal_response)
391    } else {
392        // Passthrough to upstream
393        let client = state.proxy_state.client();
394
395        let request_builder = match method {
396            Method::GET => client.get(&url),
397            Method::POST => client.post(&url).body(body_bytes),
398            Method::PUT => client.put(&url).body(body_bytes),
399            Method::DELETE => client.delete(&url),
400            Method::HEAD => client.head(&url),
401            _ => {
402                return Response::builder()
403                    .status(StatusCode::METHOD_NOT_ALLOWED)
404                    .body(Body::empty())
405                    .unwrap();
406            }
407        };
408
409        // Forward relevant headers
410        let mut request_builder = request_builder;
411        for (name, value) in parts.headers.iter() {
412            let name_str = name.to_string().to_lowercase();
413            // Skip hop-by-hop headers
414            if ![
415                "host",
416                "connection",
417                "proxy-connection",
418                "proxy-authorization",
419                "te",
420                "trailer",
421                "transfer-encoding",
422                "upgrade",
423                "expect",
424            ]
425            .contains(&name_str.as_str())
426                && let Ok(val_str) = value.to_str()
427            {
428                request_builder = request_builder.header(name.clone(), val_str);
429            }
430        }
431
432        match request_builder.send().await {
433            Ok(upstream_response) => {
434                let status = upstream_response.status();
435                let mut response_builder = Response::builder().status(status);
436
437                // Copy headers
438                for (key, value) in upstream_response.headers().iter() {
439                    if key != "transfer-encoding" && key != "connection" {
440                        response_builder = response_builder.header(key.clone(), value.clone());
441                    }
442                }
443
444                match upstream_response.bytes().await {
445                    Ok(body_bytes) => response_builder.body(Body::from(body_bytes)).unwrap(),
446                    Err(e) => {
447                        error!("Failed to read upstream response: {}", e);
448                        Response::builder()
449                            .status(StatusCode::BAD_GATEWAY)
450                            .body(Body::empty())
451                            .unwrap()
452                    }
453                }
454            }
455            Err(e) => {
456                error!("HTTP proxy: upstream request failed: {}", e);
457                Response::builder()
458                    .status(StatusCode::BAD_GATEWAY)
459                    .body(Body::empty())
460                    .unwrap()
461            }
462        }
463    }
464}
465
466/// Convert InternalResponse to axum Response
467fn convert_internal_response(internal: InternalResponse) -> Response {
468    let mut builder = Response::builder()
469        .status(StatusCode::from_u16(internal.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR));
470
471    for (name, value) in internal.headers {
472        builder = builder.header(name, value);
473    }
474
475    builder.body(Body::from(internal.body)).unwrap()
476}
477
478/// Serve HTTP requests on any stream type with proxy support
479pub async fn serve_stream<S, R>(stream: S, app: Router, proxy_state: HttpProxyState<R>)
480where
481    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
482    R: RegistryState + Clone + 'static,
483{
484    use std::convert::Infallible;
485
486    use hyper::service::service_fn;
487    use hyper_util::rt::TokioExecutor;
488
489    let service = service_fn(move |request: Request<hyper::body::Incoming>| {
490        let mut app = app.clone();
491        let proxy_state = proxy_state.clone();
492
493        async move {
494            let (parts, body) = request.into_parts();
495            let body = Body::new(body);
496            let request = Request::from_parts(parts, body);
497
498            if is_proxy_request(&request) {
499                let response = handle_proxy_request(State(proxy_state), request).await;
500                Ok::<_, Infallible>(response)
501            } else {
502                let response = app.call(request).await.into_response();
503                Ok::<_, Infallible>(response)
504            }
505        }
506    });
507
508    let io = TokioIo::new(stream);
509    if let Err(e) = Builder::new(TokioExecutor::new())
510        .serve_connection_with_upgrades(io, service)
511        .await
512    {
513        debug!("Connection error: {}", e);
514    }
515}
516
517/// Handle a proxy connection, optionally with TLS
518pub async fn handle_proxy_connection<R>(
519    stream: TcpStream,
520    app: Router,
521    proxy_state: HttpProxyState<R>,
522    tls_acceptor: Option<TlsAcceptor>,
523) where
524    R: RegistryState + Clone + 'static,
525{
526    if let Some(tls_acceptor) = tls_acceptor {
527        let tls_stream = match tls_acceptor.accept(stream).await {
528            Ok(s) => s,
529            Err(e) => {
530                debug!("TLS handshake error: {}", e);
531                return;
532            }
533        };
534        serve_stream(tls_stream, app, proxy_state).await;
535    } else {
536        serve_stream(stream, app, proxy_state).await;
537    }
538}