Skip to main content

nono_proxy/
reverse.rs

1//! Reverse proxy handler (Mode 2 — Credential Injection).
2//!
3//! Routes requests by path prefix to upstream APIs, injecting credentials
4//! from the keystore. The agent uses `http://localhost:PORT/openai/v1/chat`
5//! and the proxy rewrites to `https://api.openai.com/v1/chat` with the
6//! real credential injected.
7//!
8//! Supports multiple injection modes:
9//! - `header`: Inject into HTTP header (e.g., `Authorization: Bearer ...`)
10//! - `url_path`: Replace pattern in URL path (e.g., Telegram `/bot{}/`)
11//! - `query_param`: Add/replace query parameter (e.g., `?api_key=...`)
12//! - `basic_auth`: HTTP Basic Authentication
13//!
14//! Streaming responses (SSE, MCP Streamable HTTP, A2A JSON-RPC) are
15//! forwarded without buffering.
16
17use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::route::RouteStore;
23use crate::token;
24use std::net::SocketAddr;
25use std::time::Duration;
26use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
27use tokio::net::TcpStream;
28use tokio_rustls::TlsConnector;
29use tracing::{debug, warn};
30use zeroize::Zeroizing;
31
32/// Maximum request body size (16 MiB). Prevents DoS from malicious Content-Length.
33const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
34
35/// Timeout for upstream TCP connect.
36const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
37
38/// Handle a non-CONNECT HTTP request (reverse proxy mode).
39///
40/// Reads the full HTTP request from the client, matches path prefix to
41/// a configured route, injects credentials, and forwards to the upstream.
42/// Shared context passed from the server to the reverse proxy handler.
43pub struct ReverseProxyCtx<'a> {
44    /// Route store for upstream URL, L7 filtering, and per-route TLS
45    pub route_store: &'a RouteStore,
46    /// Credential store for service lookups (optional injection)
47    pub credential_store: &'a CredentialStore,
48    /// Session token for authentication
49    pub session_token: &'a Zeroizing<String>,
50    /// Host filter for upstream validation
51    pub filter: &'a ProxyFilter,
52    /// Shared TLS connector
53    pub tls_connector: &'a TlsConnector,
54    /// Shared network audit sink for session metadata capture
55    pub audit_log: Option<&'a audit::SharedAuditLog>,
56}
57
58/// Handle a non-CONNECT HTTP request (reverse proxy mode).
59///
60/// `buffered_body` contains any bytes the BufReader read ahead beyond the
61/// headers. These are prepended to the body read from the stream to prevent
62/// data loss.
63///
64/// ## Phantom Token Pattern
65///
66/// The client (SDK) sends the session token as its "API key". The proxy:
67/// 1. Extracts the service from the path (e.g., `/openai/v1/chat` → `openai`)
68/// 2. Looks up which header that service uses (e.g., `Authorization` or `x-api-key`)
69/// 3. Validates the phantom token from that header
70/// 4. Replaces it with the real credential from keyring
71pub async fn handle_reverse_proxy(
72    first_line: &str,
73    stream: &mut TcpStream,
74    remaining_header: &[u8],
75    ctx: &ReverseProxyCtx<'_>,
76    buffered_body: &[u8],
77) -> Result<()> {
78    // Parse method, path, and HTTP version
79    let (method, path, version) = parse_request_line(first_line)?;
80    debug!("Reverse proxy: {} {}", method, path);
81
82    // Extract service prefix from path (e.g., "/openai/v1/chat" -> ("openai", "/v1/chat"))
83    let (service, upstream_path) = parse_service_prefix(&path)?;
84    let route = ctx
85        .route_store
86        .get(&service)
87        .ok_or_else(|| ProxyError::UnknownService {
88            prefix: service.clone(),
89        })?;
90    let static_cred = ctx.credential_store.get(&service);
91    let oauth2_route = ctx.credential_store.get_oauth2(&service);
92
93    // L7 endpoint filtering runs for all reverse-proxy routes, whether or not
94    // they inject a credential.
95    if !route.endpoint_rules.is_allowed(&method, &upstream_path) {
96        let reason = format!(
97            "endpoint denied: {} {} on service '{}'",
98            method, upstream_path, service
99        );
100        warn!("{}", reason);
101        audit::log_denied(
102            ctx.audit_log,
103            audit::ProxyMode::Reverse,
104            &service,
105            0,
106            &reason,
107        );
108        send_error(stream, 403, "Forbidden").await?;
109        return Ok(());
110    }
111
112    if let Some(oauth2_route) = oauth2_route {
113        return handle_oauth2_credential(
114            oauth2_route,
115            route,
116            &service,
117            &upstream_path,
118            &method,
119            &version,
120            stream,
121            remaining_header,
122            buffered_body,
123            ctx,
124        )
125        .await;
126    }
127
128    let cred = static_cred;
129
130    // Authenticate the request. Every reverse proxy request must prove
131    // possession of the session token, regardless of whether a credential
132    // is configured — this is the localhost auth boundary.
133    if let Some(cred) = cred {
134        if let Err(e) = validate_phantom_token_for_mode(
135            &cred.proxy_inject_mode,
136            remaining_header,
137            &upstream_path,
138            &cred.proxy_header_name,
139            cred.proxy_path_pattern.as_deref(),
140            cred.proxy_query_param_name.as_deref(),
141            ctx.session_token,
142        ) {
143            audit::log_denied(
144                ctx.audit_log,
145                audit::ProxyMode::Reverse,
146                &service,
147                0,
148                &e.to_string(),
149            );
150            send_error(stream, 401, "Unauthorized").await?;
151            return Ok(());
152        }
153    } else if let Err(e) = token::validate_proxy_auth(remaining_header, ctx.session_token) {
154        audit::log_denied(
155            ctx.audit_log,
156            audit::ProxyMode::Reverse,
157            &service,
158            0,
159            &e.to_string(),
160        );
161        send_error(stream, 407, "Proxy Authentication Required").await?;
162        return Ok(());
163    }
164
165    let transformed_path = if let Some(cred) = cred {
166        let cleaned_path = strip_proxy_artifacts(
167            &upstream_path,
168            &cred.proxy_inject_mode,
169            &cred.inject_mode,
170            cred.proxy_path_pattern.as_deref(),
171            cred.proxy_query_param_name.as_deref(),
172        );
173        transform_path_for_mode(
174            &cred.inject_mode,
175            &cleaned_path,
176            cred.path_pattern.as_deref(),
177            cred.path_replacement.as_deref(),
178            cred.query_param_name.as_deref(),
179            &cred.raw_credential,
180        )?
181    } else {
182        upstream_path.clone()
183    };
184
185    let upstream_url = format!(
186        "{}{}",
187        route.upstream.trim_end_matches('/'),
188        transformed_path
189    );
190    debug!("Forwarding to upstream: {} {}", method, upstream_url);
191
192    let (upstream_scheme, upstream_host, upstream_port, upstream_path_full) =
193        parse_upstream_url(&upstream_url)?;
194    let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
195    if !check.result.is_allowed() {
196        let reason = check.result.reason();
197        warn!("Upstream host denied by filter: {}", reason);
198        send_error(stream, 403, "Forbidden").await?;
199        audit::log_denied(
200            ctx.audit_log,
201            audit::ProxyMode::Reverse,
202            &service,
203            0,
204            &reason,
205        );
206        return Ok(());
207    }
208    if let Err(reason) =
209        validate_http_upstream_target(upstream_scheme, &upstream_host, &check.resolved_addrs)
210    {
211        warn!("{}", reason);
212        send_error(stream, 502, "Bad Gateway").await?;
213        audit::log_denied(
214            ctx.audit_log,
215            audit::ProxyMode::Reverse,
216            &service,
217            0,
218            &reason,
219        );
220        return Ok(());
221    }
222
223    let strip_header = cred.map(|c| c.proxy_header_name.as_str()).unwrap_or("");
224    let filtered_headers = filter_headers(remaining_header, strip_header);
225    let content_length = extract_content_length(remaining_header);
226    let body = match read_request_body(stream, content_length, buffered_body).await? {
227        Some(body) => body,
228        None => return Ok(()),
229    };
230
231    let upstream_authority = format_host_header(upstream_scheme, &upstream_host, upstream_port);
232    let mut request = Zeroizing::new(format!(
233        "{} {} {}\r\nHost: {}\r\n",
234        method, upstream_path_full, version, upstream_authority
235    ));
236
237    if let Some(cred) = cred {
238        inject_credential_for_mode(cred, &mut request);
239    }
240
241    let auth_header_lower = cred.map(|c| c.header_name.to_lowercase());
242    for (name, value) in &filtered_headers {
243        if let (Some(cred), Some(header_lower)) = (cred, auth_header_lower.as_ref()) {
244            if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
245                && name.to_lowercase() == *header_lower
246            {
247                continue;
248            }
249        }
250        request.push_str(&format!("{}: {}\r\n", name, value));
251    }
252
253    request.push_str("Connection: close\r\n");
254    if !body.is_empty() {
255        request.push_str(&format!("Content-Length: {}\r\n", body.len()));
256    }
257    request.push_str("\r\n");
258
259    let status_code = match upstream_scheme {
260        UpstreamScheme::Https => {
261            let connector = route.tls_connector.as_ref().unwrap_or(ctx.tls_connector);
262            let mut tls_stream = match connect_upstream_tls(
263                &upstream_host,
264                upstream_port,
265                &check.resolved_addrs,
266                connector,
267            )
268            .await
269            {
270                Ok(s) => s,
271                Err(e) => {
272                    warn!("Upstream connection failed: {}", e);
273                    send_error(stream, 502, "Bad Gateway").await?;
274                    audit::log_denied(
275                        ctx.audit_log,
276                        audit::ProxyMode::Reverse,
277                        &service,
278                        0,
279                        &e.to_string(),
280                    );
281                    return Ok(());
282                }
283            };
284
285            write_upstream_request(&mut tls_stream, &request, &body).await?;
286            stream_response(&mut tls_stream, stream).await?
287        }
288        UpstreamScheme::Http => {
289            let mut upstream_stream =
290                match connect_upstream_tcp(&upstream_host, upstream_port, &check.resolved_addrs)
291                    .await
292                {
293                    Ok(s) => s,
294                    Err(e) => {
295                        warn!("Upstream connection failed: {}", e);
296                        send_error(stream, 502, "Bad Gateway").await?;
297                        audit::log_denied(
298                            ctx.audit_log,
299                            audit::ProxyMode::Reverse,
300                            &service,
301                            0,
302                            &e.to_string(),
303                        );
304                        return Ok(());
305                    }
306                };
307
308            write_upstream_request(&mut upstream_stream, &request, &body).await?;
309            stream_response(&mut upstream_stream, stream).await?
310        }
311    };
312    audit::log_reverse_proxy(
313        ctx.audit_log,
314        &service,
315        &method,
316        &upstream_path,
317        status_code,
318    );
319    Ok(())
320}
321
322/// Handle a reverse proxy request using an OAuth2 token cache.
323///
324/// Retrieves a (possibly refreshed) access token from the cache and injects
325/// it as `Authorization: Bearer <token>`. The agent authenticates with the
326/// session token via the `Authorization: Bearer <phantom>` header, which is
327/// validated and then replaced with the real OAuth2 access token.
328#[allow(clippy::too_many_arguments)]
329async fn handle_oauth2_credential(
330    oauth2_route: &crate::credential::OAuth2Route,
331    route: &crate::route::LoadedRoute,
332    service: &str,
333    upstream_path: &str,
334    method: &str,
335    version: &str,
336    stream: &mut TcpStream,
337    remaining_header: &[u8],
338    buffered_body: &[u8],
339    ctx: &ReverseProxyCtx<'_>,
340) -> Result<()> {
341    // Get (possibly refreshed) OAuth2 access token
342    let access_token = oauth2_route.cache.get_or_refresh().await;
343
344    // Validate session token from Authorization header (phantom token pattern).
345    // OAuth2 routes still require the agent to authenticate with the session
346    // token — this prevents unauthorized access to the token-exchanged credential.
347    if let Err(e) = validate_phantom_token(remaining_header, "Authorization", ctx.session_token) {
348        audit::log_denied(
349            ctx.audit_log,
350            audit::ProxyMode::Reverse,
351            service,
352            0,
353            &e.to_string(),
354        );
355        send_error(stream, 401, "Unauthorized").await?;
356        return Ok(());
357    }
358
359    let upstream_url = format!(
360        "{}{}",
361        oauth2_route.upstream.trim_end_matches('/'),
362        upstream_path
363    );
364    debug!("OAuth2 forwarding to upstream: {} {}", method, upstream_url);
365
366    let (upstream_scheme, upstream_host, upstream_port, upstream_path_full) =
367        parse_upstream_url(&upstream_url)?;
368    // DNS resolve + host check via the filter
369    let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
370    if !check.result.is_allowed() {
371        let reason = check.result.reason();
372        warn!("Upstream host denied by filter: {}", reason);
373        send_error(stream, 403, "Forbidden").await?;
374        audit::log_denied(
375            ctx.audit_log,
376            audit::ProxyMode::Reverse,
377            service,
378            0,
379            &reason,
380        );
381        return Ok(());
382    }
383    if let Err(reason) =
384        validate_http_upstream_target(upstream_scheme, &upstream_host, &check.resolved_addrs)
385    {
386        warn!("{}", reason);
387        send_error(stream, 502, "Bad Gateway").await?;
388        audit::log_denied(
389            ctx.audit_log,
390            audit::ProxyMode::Reverse,
391            service,
392            0,
393            &reason,
394        );
395        return Ok(());
396    }
397
398    // Collect remaining request headers, stripping the client-supplied
399    // Authorization header that carries the phantom token.
400    let filtered_headers = filter_headers(remaining_header, "Authorization");
401    let content_length = extract_content_length(remaining_header);
402
403    // Read request body
404    let body = match read_request_body(stream, content_length, buffered_body).await? {
405        Some(body) => body,
406        None => return Ok(()),
407    };
408
409    // Build upstream request with Bearer token injection
410    let upstream_authority = format_host_header(upstream_scheme, &upstream_host, upstream_port);
411    let mut request = Zeroizing::new(format!(
412        "{} {} {}\r\nHost: {}\r\n",
413        method, upstream_path_full, version, upstream_authority
414    ));
415
416    // Inject OAuth2 access token as Authorization: Bearer
417    request.push_str(&format!(
418        "Authorization: Bearer {}\r\n",
419        access_token.as_str()
420    ));
421
422    // Forward filtered headers (auth headers already stripped by filter_headers)
423    for (name, value) in &filtered_headers {
424        request.push_str(&format!("{}: {}\r\n", name, value));
425    }
426
427    if !body.is_empty() {
428        request.push_str(&format!("Content-Length: {}\r\n", body.len()));
429    }
430    request.push_str("\r\n");
431
432    let status_code = match upstream_scheme {
433        UpstreamScheme::Https => {
434            let connector = route.tls_connector.as_ref().unwrap_or(ctx.tls_connector);
435            let mut tls_stream = match connect_upstream_tls(
436                &upstream_host,
437                upstream_port,
438                &check.resolved_addrs,
439                connector,
440            )
441            .await
442            {
443                Ok(s) => s,
444                Err(e) => {
445                    warn!("Upstream connection failed: {}", e);
446                    send_error(stream, 502, "Bad Gateway").await?;
447                    audit::log_denied(
448                        ctx.audit_log,
449                        audit::ProxyMode::Reverse,
450                        service,
451                        0,
452                        &e.to_string(),
453                    );
454                    return Ok(());
455                }
456            };
457
458            write_upstream_request(&mut tls_stream, &request, &body).await?;
459            stream_response(&mut tls_stream, stream).await?
460        }
461        UpstreamScheme::Http => {
462            let mut upstream_stream =
463                match connect_upstream_tcp(&upstream_host, upstream_port, &check.resolved_addrs)
464                    .await
465                {
466                    Ok(s) => s,
467                    Err(e) => {
468                        warn!("Upstream connection failed: {}", e);
469                        send_error(stream, 502, "Bad Gateway").await?;
470                        audit::log_denied(
471                            ctx.audit_log,
472                            audit::ProxyMode::Reverse,
473                            service,
474                            0,
475                            &e.to_string(),
476                        );
477                        return Ok(());
478                    }
479                };
480
481            write_upstream_request(&mut upstream_stream, &request, &body).await?;
482            stream_response(&mut upstream_stream, stream).await?
483        }
484    };
485
486    audit::log_reverse_proxy(ctx.audit_log, service, method, upstream_path, status_code);
487    Ok(())
488}
489
490async fn write_upstream_request<S>(stream: &mut S, request: &str, body: &[u8]) -> Result<()>
491where
492    S: AsyncWrite + Unpin,
493{
494    stream.write_all(request.as_bytes()).await?;
495    if !body.is_empty() {
496        stream.write_all(body).await?;
497    }
498    stream.flush().await?;
499    Ok(())
500}
501
502/// Read request body from the client stream with size limit.
503///
504/// `buffered_body` contains bytes the BufReader read ahead beyond headers.
505async fn read_request_body(
506    stream: &mut TcpStream,
507    content_length: Option<usize>,
508    buffered_body: &[u8],
509) -> Result<Option<Vec<u8>>> {
510    if let Some(len) = content_length {
511        if len > MAX_REQUEST_BODY {
512            send_error(stream, 413, "Payload Too Large").await?;
513            return Ok(None);
514        }
515        let mut buf = Vec::with_capacity(len);
516        let pre = buffered_body.len().min(len);
517        buf.extend_from_slice(&buffered_body[..pre]);
518        let remaining = len - pre;
519        if remaining > 0 {
520            let mut rest = vec![0u8; remaining];
521            stream.read_exact(&mut rest).await?;
522            buf.extend_from_slice(&rest);
523        }
524        Ok(Some(buf))
525    } else {
526        Ok(Some(Vec::new()))
527    }
528}
529
530/// Stream the upstream TLS response back to the client.
531///
532/// Returns the HTTP status code parsed from the first chunk.
533async fn stream_response<S>(tls_stream: &mut S, stream: &mut TcpStream) -> Result<u16>
534where
535    S: AsyncRead + AsyncWrite + Unpin,
536{
537    let mut response_buf = [0u8; 8192];
538    let mut status_code: u16 = 502;
539    let mut first_chunk = true;
540
541    loop {
542        let n = match tls_stream.read(&mut response_buf).await {
543            Ok(0) => break,
544            Ok(n) => n,
545            Err(e) => {
546                debug!("Upstream read error: {}", e);
547                break;
548            }
549        };
550
551        if first_chunk {
552            status_code = parse_response_status(&response_buf[..n]);
553            first_chunk = false;
554        }
555
556        stream.write_all(&response_buf[..n]).await?;
557        stream.flush().await?;
558    }
559
560    Ok(status_code)
561}
562
563/// Parse an HTTP request line into (method, path, version).
564fn parse_request_line(line: &str) -> Result<(String, String, String)> {
565    let parts: Vec<&str> = line.split_whitespace().collect();
566    if parts.len() < 3 {
567        return Err(ProxyError::HttpParse(format!(
568            "malformed request line: {}",
569            line
570        )));
571    }
572    Ok((
573        parts[0].to_string(),
574        parts[1].to_string(),
575        parts[2].to_string(),
576    ))
577}
578
579/// Extract service prefix from path.
580///
581/// "/openai/v1/chat/completions" -> ("openai", "/v1/chat/completions")
582/// "/anthropic/v1/messages" -> ("anthropic", "/v1/messages")
583fn parse_service_prefix(path: &str) -> Result<(String, String)> {
584    let trimmed = path.strip_prefix('/').unwrap_or(path);
585    if let Some((prefix, rest)) = trimmed.split_once('/') {
586        Ok((prefix.to_string(), format!("/{}", rest)))
587    } else {
588        // No sub-path, just the prefix
589        Ok((trimmed.to_string(), "/".to_string()))
590    }
591}
592
593/// Validate the phantom token from the service's auth header.
594///
595/// The SDK sends the session token as its "API key" in the standard auth header
596/// for that service (e.g., `Authorization: Bearer <token>` for OpenAI,
597/// `x-api-key: <token>` for Anthropic). We validate the token matches the
598/// session token before swapping in the real credential.
599fn validate_phantom_token(
600    header_bytes: &[u8],
601    header_name: &str,
602    session_token: &Zeroizing<String>,
603) -> Result<()> {
604    let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
605    let header_name_lower = header_name.to_lowercase();
606
607    for line in header_str.lines() {
608        let lower = line.to_lowercase();
609        if lower.starts_with(&format!("{}:", header_name_lower)) {
610            let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
611
612            // Handle "Bearer <token>" format (strip "Bearer " prefix if present)
613            // Use case-insensitive check, then slice original value by length
614            let value_lower = value.to_lowercase();
615            let token_value = if value_lower.starts_with("bearer ") {
616                // "bearer ".len() == 7
617                value[7..].trim()
618            } else {
619                value
620            };
621
622            if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
623                return Ok(());
624            }
625            warn!("Invalid phantom token in {} header", header_name);
626            return Err(ProxyError::InvalidToken);
627        }
628    }
629
630    warn!(
631        "Missing {} header for phantom token validation",
632        header_name
633    );
634    Err(ProxyError::InvalidToken)
635}
636
637/// Filter headers, removing hop-by-hop and proxy-internal headers.
638///
639/// Always strips:
640/// - `Host` (rewritten to upstream)
641/// - `Content-Length` (re-added after body is read)
642/// - `Proxy-Authorization` (hop-by-hop, contains session token)
643///
644/// When `cred_header` is non-empty, also strips that header (it contains
645/// the phantom token that must not be forwarded alongside the real credential).
646/// When `cred_header` is empty (no-credential route), all other headers
647/// including `Authorization` are passed through to the upstream.
648fn filter_headers(header_bytes: &[u8], cred_header: &str) -> Vec<(String, String)> {
649    let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
650    let cred_header_lower = if cred_header.is_empty() {
651        String::new()
652    } else {
653        format!("{}:", cred_header.to_lowercase())
654    };
655    let mut headers = Vec::new();
656
657    for line in header_str.lines() {
658        let lower = line.to_lowercase();
659        if lower.starts_with("host:")
660            || lower.starts_with("content-length:")
661            || lower.starts_with("connection:")
662            || lower.starts_with("proxy-authorization:")
663            || (!cred_header_lower.is_empty() && lower.starts_with(&cred_header_lower))
664            || line.trim().is_empty()
665        {
666            continue;
667        }
668        if let Some((name, value)) = line.split_once(':') {
669            headers.push((name.trim().to_string(), value.trim().to_string()));
670        }
671    }
672
673    headers
674}
675
676/// Extract Content-Length value from raw headers.
677fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
678    let header_str = std::str::from_utf8(header_bytes).ok()?;
679    for line in header_str.lines() {
680        if line.to_lowercase().starts_with("content-length:") {
681            let value = line.split_once(':')?.1.trim();
682            return value.parse().ok();
683        }
684    }
685    None
686}
687
688/// Parse an upstream URL into (host, port, path).
689#[derive(Debug, Clone, Copy, PartialEq, Eq)]
690enum UpstreamScheme {
691    Http,
692    Https,
693}
694
695fn validate_http_upstream_target(
696    scheme: UpstreamScheme,
697    host: &str,
698    resolved_addrs: &[SocketAddr],
699) -> std::result::Result<(), String> {
700    if matches!(scheme, UpstreamScheme::Https) {
701        return Ok(());
702    }
703
704    if is_local_only_target(host, resolved_addrs) {
705        Ok(())
706    } else {
707        Err(format!(
708            "refusing insecure http upstream for non-local host '{}'; http is only allowed for loopback addresses",
709            host
710        ))
711    }
712}
713
714fn is_local_only_target(host: &str, resolved_addrs: &[SocketAddr]) -> bool {
715    if !resolved_addrs.is_empty() {
716        return resolved_addrs.iter().all(|addr| addr.ip().is_loopback());
717    }
718
719    match host.parse::<std::net::IpAddr>() {
720        Ok(std::net::IpAddr::V4(ip)) => ip.is_loopback(),
721        Ok(std::net::IpAddr::V6(ip)) => ip.is_loopback(),
722        Err(_) => false,
723    }
724}
725
726fn format_host_header(scheme: UpstreamScheme, host: &str, port: u16) -> String {
727    let default_port = match scheme {
728        UpstreamScheme::Http => 80,
729        UpstreamScheme::Https => 443,
730    };
731    let bracketed_host = if host.contains(':') && !host.starts_with('[') {
732        format!("[{}]", host)
733    } else {
734        host.to_string()
735    };
736
737    if port == default_port {
738        bracketed_host
739    } else {
740        format!("{}:{}", bracketed_host, port)
741    }
742}
743
744fn parse_upstream_url(url_str: &str) -> Result<(UpstreamScheme, String, u16, String)> {
745    let parsed = url::Url::parse(url_str)
746        .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
747
748    let scheme = match parsed.scheme() {
749        "https" => UpstreamScheme::Https,
750        "http" => UpstreamScheme::Http,
751        _ => {
752            return Err(ProxyError::HttpParse(format!(
753                "unsupported URL scheme: {}",
754                url_str
755            )));
756        }
757    };
758
759    let host = parsed
760        .host_str()
761        .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
762        .to_string();
763
764    let default_port = if matches!(scheme, UpstreamScheme::Https) {
765        443
766    } else {
767        80
768    };
769    let port = parsed.port().unwrap_or(default_port);
770
771    let path = parsed.path().to_string();
772    let path = if path.is_empty() {
773        "/".to_string()
774    } else {
775        path
776    };
777
778    // Include query string if present
779    let path_with_query = if let Some(query) = parsed.query() {
780        format!("{}?{}", path, query)
781    } else {
782        path
783    };
784
785    Ok((scheme, host, port, path_with_query))
786}
787
788/// Connect to an upstream host over TLS using pre-resolved addresses.
789///
790/// Uses the pre-resolved `SocketAddr`s from the filter check to prevent
791/// DNS rebinding TOCTOU. Falls back to hostname resolution only if no
792/// pre-resolved addresses are available.
793///
794/// The `TlsConnector` is shared across all connections (created once at
795/// server startup with the system root certificate store).
796async fn connect_upstream_tls(
797    host: &str,
798    port: u16,
799    resolved_addrs: &[SocketAddr],
800    connector: &TlsConnector,
801) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
802    let tcp = if resolved_addrs.is_empty() {
803        // Fallback: no pre-resolved addresses (shouldn't happen in practice)
804        let addr = format!("{}:{}", host, port);
805        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
806            Ok(Ok(s)) => s,
807            Ok(Err(e)) => {
808                return Err(ProxyError::UpstreamConnect {
809                    host: host.to_string(),
810                    reason: e.to_string(),
811                });
812            }
813            Err(_) => {
814                return Err(ProxyError::UpstreamConnect {
815                    host: host.to_string(),
816                    reason: "connection timed out".to_string(),
817                });
818            }
819        }
820    } else {
821        connect_to_resolved(resolved_addrs, host).await?
822    };
823
824    let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
825        ProxyError::UpstreamConnect {
826            host: host.to_string(),
827            reason: "invalid server name for TLS".to_string(),
828        }
829    })?;
830
831    let tls_stream =
832        connector
833            .connect(server_name, tcp)
834            .await
835            .map_err(|e| ProxyError::UpstreamConnect {
836                host: host.to_string(),
837                reason: format!("TLS handshake failed: {}", e),
838            })?;
839
840    Ok(tls_stream)
841}
842
843async fn connect_upstream_tcp(
844    host: &str,
845    port: u16,
846    resolved_addrs: &[SocketAddr],
847) -> Result<TcpStream> {
848    if resolved_addrs.is_empty() {
849        let addr = format!("{}:{}", host, port);
850        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
851            Ok(Ok(s)) => Ok(s),
852            Ok(Err(e)) => Err(ProxyError::UpstreamConnect {
853                host: host.to_string(),
854                reason: e.to_string(),
855            }),
856            Err(_) => Err(ProxyError::UpstreamConnect {
857                host: host.to_string(),
858                reason: "connection timed out".to_string(),
859            }),
860        }
861    } else {
862        connect_to_resolved(resolved_addrs, host).await
863    }
864}
865
866/// Connect to one of the pre-resolved socket addresses with timeout.
867async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
868    let mut last_err = None;
869    for addr in addrs {
870        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
871            Ok(Ok(stream)) => return Ok(stream),
872            Ok(Err(e)) => {
873                debug!("Connect to {} failed: {}", addr, e);
874                last_err = Some(e.to_string());
875            }
876            Err(_) => {
877                debug!("Connect to {} timed out", addr);
878                last_err = Some("connection timed out".to_string());
879            }
880        }
881    }
882    Err(ProxyError::UpstreamConnect {
883        host: host.to_string(),
884        reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
885    })
886}
887
888/// Parse HTTP status code from the first response chunk.
889///
890/// Looks for the "HTTP/x.y NNN" pattern in the first line. Returns 502
891/// if the response doesn't contain a valid status line (upstream sent
892/// garbage or incomplete data).
893fn parse_response_status(data: &[u8]) -> u16 {
894    // Find the end of the first line (or use full data if no newline)
895    let line_end = data
896        .iter()
897        .position(|&b| b == b'\r' || b == b'\n')
898        .unwrap_or(data.len());
899    let first_line = &data[..line_end.min(64)];
900
901    if let Ok(line) = std::str::from_utf8(first_line) {
902        // Split on whitespace: ["HTTP/1.1", "200", "OK"]
903        let mut parts = line.split_whitespace();
904        if let Some(version) = parts.next() {
905            if version.starts_with("HTTP/") {
906                if let Some(code_str) = parts.next() {
907                    if code_str.len() == 3 {
908                        return code_str.parse().unwrap_or(502);
909                    }
910                }
911            }
912        }
913    }
914    502
915}
916
917/// Send an HTTP error response.
918async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
919    let body = format!("{{\"error\":\"{}\"}}", reason);
920    let response = format!(
921        "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
922        status,
923        reason,
924        body.len(),
925        body
926    );
927    stream.write_all(response.as_bytes()).await?;
928    stream.flush().await?;
929    Ok(())
930}
931
932// ============================================================================
933// Injection mode helpers
934// ============================================================================
935
936/// Validate phantom token based on injection mode.
937///
938/// Different modes extract the phantom token from different locations:
939/// - `Header`/`BasicAuth`: From the auth header (Authorization, x-api-key, etc.)
940/// - `UrlPath`: From the URL path pattern (e.g., `/bot<token>/getMe`)
941/// - `QueryParam`: From the query parameter (e.g., `?api_key=<token>`)
942fn validate_phantom_token_for_mode(
943    mode: &InjectMode,
944    header_bytes: &[u8],
945    path: &str,
946    header_name: &str,
947    path_pattern: Option<&str>,
948    query_param_name: Option<&str>,
949    session_token: &Zeroizing<String>,
950) -> Result<()> {
951    match mode {
952        InjectMode::Header | InjectMode::BasicAuth => {
953            // Validate from header (existing behavior)
954            validate_phantom_token(header_bytes, header_name, session_token)
955        }
956        InjectMode::UrlPath => {
957            // Validate from URL path
958            let pattern = path_pattern.ok_or_else(|| {
959                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
960            })?;
961            validate_phantom_token_in_path(path, pattern, session_token)
962        }
963        InjectMode::QueryParam => {
964            // Validate from query parameter
965            let param_name = query_param_name.ok_or_else(|| {
966                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
967            })?;
968            validate_phantom_token_in_query(path, param_name, session_token)
969        }
970    }
971}
972
973/// Validate phantom token embedded in URL path.
974///
975/// Extracts the token from the path using the pattern (e.g., `/bot{}/` matches
976/// `/bot<token>/getMe` and extracts `<token>`).
977fn validate_phantom_token_in_path(
978    path: &str,
979    pattern: &str,
980    session_token: &Zeroizing<String>,
981) -> Result<()> {
982    // Split pattern on {} to get prefix and suffix
983    let parts: Vec<&str> = pattern.split("{}").collect();
984    if parts.len() != 2 {
985        return Err(ProxyError::HttpParse(format!(
986            "invalid path_pattern '{}': must contain exactly one {{}}",
987            pattern
988        )));
989    }
990    let (prefix, suffix) = (parts[0], parts[1]);
991
992    // Find the token in the path
993    if let Some(start) = path.find(prefix) {
994        let after_prefix = start + prefix.len();
995
996        // Handle empty suffix case (token extends to end of path or next '/' or '?')
997        let end_offset = if suffix.is_empty() {
998            path[after_prefix..]
999                .find(['/', '?'])
1000                .unwrap_or(path[after_prefix..].len())
1001        } else {
1002            match path[after_prefix..].find(suffix) {
1003                Some(offset) => offset,
1004                None => {
1005                    warn!("Missing phantom token in URL path (pattern: {})", pattern);
1006                    return Err(ProxyError::InvalidToken);
1007                }
1008            }
1009        };
1010
1011        let token = &path[after_prefix..after_prefix + end_offset];
1012        if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
1013            return Ok(());
1014        }
1015        warn!("Invalid phantom token in URL path");
1016        return Err(ProxyError::InvalidToken);
1017    }
1018
1019    warn!("Missing phantom token in URL path (pattern: {})", pattern);
1020    Err(ProxyError::InvalidToken)
1021}
1022
1023/// Validate phantom token in query parameter.
1024fn validate_phantom_token_in_query(
1025    path: &str,
1026    param_name: &str,
1027    session_token: &Zeroizing<String>,
1028) -> Result<()> {
1029    // Parse query string from path
1030    if let Some(query_start) = path.find('?') {
1031        let query = &path[query_start + 1..];
1032        for pair in query.split('&') {
1033            if let Some((name, value)) = pair.split_once('=') {
1034                if name == param_name {
1035                    // URL-decode the value
1036                    let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
1037                    if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
1038                        return Ok(());
1039                    }
1040                    warn!("Invalid phantom token in query parameter '{}'", param_name);
1041                    return Err(ProxyError::InvalidToken);
1042                }
1043            }
1044        }
1045    }
1046
1047    warn!("Missing phantom token in query parameter '{}'", param_name);
1048    Err(ProxyError::InvalidToken)
1049}
1050
1051/// Transform URL path based on injection mode.
1052///
1053/// - `UrlPath`: Replace phantom token with real credential in path
1054/// - `QueryParam`: Add/replace query parameter with real credential
1055/// - `Header`/`BasicAuth`: No path transformation needed
1056fn transform_path_for_mode(
1057    mode: &InjectMode,
1058    path: &str,
1059    path_pattern: Option<&str>,
1060    path_replacement: Option<&str>,
1061    query_param_name: Option<&str>,
1062    credential: &Zeroizing<String>,
1063) -> Result<String> {
1064    match mode {
1065        InjectMode::Header | InjectMode::BasicAuth => {
1066            // No path transformation needed
1067            Ok(path.to_string())
1068        }
1069        InjectMode::UrlPath => {
1070            let pattern = path_pattern.ok_or_else(|| {
1071                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
1072            })?;
1073            let replacement = path_replacement.unwrap_or(pattern);
1074            transform_url_path(path, pattern, replacement, credential)
1075        }
1076        InjectMode::QueryParam => {
1077            let param_name = query_param_name.ok_or_else(|| {
1078                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
1079            })?;
1080            transform_query_param(path, param_name, credential)
1081        }
1082    }
1083}
1084
1085/// Transform URL path by replacing phantom token pattern with real credential.
1086///
1087/// Example: `/bot<phantom>/getMe` with pattern `/bot{}/` becomes `/bot<real>/getMe`
1088fn transform_url_path(
1089    path: &str,
1090    pattern: &str,
1091    replacement: &str,
1092    credential: &Zeroizing<String>,
1093) -> Result<String> {
1094    // Split pattern on {} to get prefix and suffix
1095    let parts: Vec<&str> = pattern.split("{}").collect();
1096    if parts.len() != 2 {
1097        return Err(ProxyError::HttpParse(format!(
1098            "invalid path_pattern '{}': must contain exactly one {{}}",
1099            pattern
1100        )));
1101    }
1102    let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
1103
1104    // Split replacement on {}
1105    let repl_parts: Vec<&str> = replacement.split("{}").collect();
1106    if repl_parts.len() != 2 {
1107        return Err(ProxyError::HttpParse(format!(
1108            "invalid path_replacement '{}': must contain exactly one {{}}",
1109            replacement
1110        )));
1111    }
1112    let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
1113
1114    // Find and replace the token in the path
1115    if let Some(start) = path.find(pattern_prefix) {
1116        let after_prefix = start + pattern_prefix.len();
1117
1118        // Handle empty suffix case (token extends to end of path or next '/' or '?')
1119        let end_offset = if pattern_suffix.is_empty() {
1120            // Find the next path segment delimiter or end of path
1121            path[after_prefix..]
1122                .find(['/', '?'])
1123                .unwrap_or(path[after_prefix..].len())
1124        } else {
1125            // Find the suffix in the remaining path
1126            match path[after_prefix..].find(pattern_suffix) {
1127                Some(offset) => offset,
1128                None => {
1129                    return Err(ProxyError::HttpParse(format!(
1130                        "path '{}' does not match pattern '{}'",
1131                        path, pattern
1132                    )));
1133                }
1134            }
1135        };
1136
1137        let before = &path[..start];
1138        let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
1139        return Ok(format!(
1140            "{}{}{}{}{}",
1141            before,
1142            repl_prefix,
1143            credential.as_str(),
1144            repl_suffix,
1145            after
1146        ));
1147    }
1148
1149    Err(ProxyError::HttpParse(format!(
1150        "path '{}' does not match pattern '{}'",
1151        path, pattern
1152    )))
1153}
1154
1155/// Transform query string by adding or replacing a parameter with the credential.
1156fn transform_query_param(
1157    path: &str,
1158    param_name: &str,
1159    credential: &Zeroizing<String>,
1160) -> Result<String> {
1161    let encoded_value = urlencoding::encode(credential.as_str());
1162
1163    if let Some(query_start) = path.find('?') {
1164        let base_path = &path[..query_start];
1165        let query = &path[query_start + 1..];
1166
1167        // Check if parameter already exists
1168        let mut found = false;
1169        let new_query: Vec<String> = query
1170            .split('&')
1171            .map(|pair| {
1172                if let Some((name, _)) = pair.split_once('=') {
1173                    if name == param_name {
1174                        found = true;
1175                        return format!("{}={}", param_name, encoded_value);
1176                    }
1177                }
1178                pair.to_string()
1179            })
1180            .collect();
1181
1182        if found {
1183            Ok(format!("{}?{}", base_path, new_query.join("&")))
1184        } else {
1185            // Append the parameter
1186            Ok(format!(
1187                "{}?{}&{}={}",
1188                base_path, query, param_name, encoded_value
1189            ))
1190        }
1191    } else {
1192        // No query string, add one
1193        Ok(format!("{}?{}={}", path, param_name, encoded_value))
1194    }
1195}
1196
1197/// Strip proxy-side artifacts from the path when proxy and upstream modes differ.
1198///
1199/// When the proxy validates the phantom token using a different injection mode
1200/// than the upstream (e.g., proxy uses `url_path` or `query_param` while upstream
1201/// uses `header`), the proxy-side token is embedded in the URL. This function
1202/// removes it before the path is forwarded to the upstream, preventing phantom
1203/// token leakage.
1204///
1205/// When both modes are the same, the upstream transform handles replacement
1206/// (phantom → real credential), so no stripping is needed.
1207fn strip_proxy_artifacts(
1208    path: &str,
1209    proxy_mode: &InjectMode,
1210    upstream_mode: &InjectMode,
1211    proxy_path_pattern: Option<&str>,
1212    proxy_query_param_name: Option<&str>,
1213) -> String {
1214    // Only strip when modes differ — same-mode cases are handled by the
1215    // upstream transform which replaces the phantom token with the real one.
1216    if proxy_mode == upstream_mode {
1217        return path.to_string();
1218    }
1219
1220    match proxy_mode {
1221        InjectMode::UrlPath => {
1222            if let Some(pattern) = proxy_path_pattern {
1223                strip_proxy_path_token(path, pattern)
1224            } else {
1225                path.to_string()
1226            }
1227        }
1228        InjectMode::QueryParam => {
1229            if let Some(param_name) = proxy_query_param_name {
1230                strip_proxy_query_param(path, param_name)
1231            } else {
1232                path.to_string()
1233            }
1234        }
1235        // Header and BasicAuth modes don't embed artifacts in the URL path.
1236        InjectMode::Header | InjectMode::BasicAuth => path.to_string(),
1237    }
1238}
1239
1240/// Remove a phantom token path segment matched by the given pattern.
1241///
1242/// Example: path `/TOKEN123/api/v1/pods` with pattern `/{}/` → `/api/v1/pods`
1243fn strip_proxy_path_token(path: &str, pattern: &str) -> String {
1244    let parts: Vec<&str> = pattern.split("{}").collect();
1245    if parts.len() != 2 {
1246        return path.to_string();
1247    }
1248    let (prefix, suffix) = (parts[0], parts[1]);
1249
1250    // Prefer matching at the start of the path to avoid false hits on
1251    // common prefixes like "/" that would otherwise match at position 0
1252    // even if the intended token is in a later segment.
1253    let start = if path.starts_with(prefix) {
1254        Some(0)
1255    } else {
1256        path.find(prefix)
1257    };
1258
1259    if let Some(start) = start {
1260        let after_prefix = start + prefix.len();
1261        let end_offset = if suffix.is_empty() {
1262            path[after_prefix..]
1263                .find(['/', '?'])
1264                .unwrap_or(path[after_prefix..].len())
1265        } else {
1266            match path[after_prefix..].find(suffix) {
1267                Some(offset) => offset,
1268                None => return path.to_string(),
1269            }
1270        };
1271
1272        let before = &path[..start];
1273        let after = &path[after_prefix + end_offset + suffix.len()..];
1274
1275        // Join before and after with exactly one separator to avoid
1276        // malformed paths: "/prefixapi" (missing slash) or "/api//v1"
1277        // (double slash) when the stripped segment was mid-path.
1278        let joined = match (before.ends_with('/'), after.starts_with('/')) {
1279            (true, true) => format!("{}{}", before, &after[1..]),
1280            (false, false) if !before.is_empty() && !after.is_empty() => {
1281                format!("{}/{}", before, after)
1282            }
1283            _ => format!("{}{}", before, after),
1284        };
1285
1286        if joined.is_empty() || !joined.starts_with('/') {
1287            format!("/{}", joined)
1288        } else {
1289            joined
1290        }
1291    } else {
1292        path.to_string()
1293    }
1294}
1295
1296/// Remove a phantom token query parameter from the URL.
1297///
1298/// Example: path `/api/v1/pods?token=XXX&limit=10` → `/api/v1/pods?limit=10`
1299fn strip_proxy_query_param(path: &str, param_name: &str) -> String {
1300    if let Some(query_start) = path.find('?') {
1301        let base_path = &path[..query_start];
1302        let query = &path[query_start + 1..];
1303
1304        let remaining: Vec<&str> = query
1305            .split('&')
1306            .filter(|pair| {
1307                pair.split_once('=')
1308                    .map(|(name, _)| name != param_name)
1309                    .unwrap_or(true)
1310            })
1311            .collect();
1312
1313        if remaining.is_empty() {
1314            base_path.to_string()
1315        } else {
1316            format!("{}?{}", base_path, remaining.join("&"))
1317        }
1318    } else {
1319        path.to_string()
1320    }
1321}
1322
1323/// Inject credential into request based on mode.
1324///
1325/// For header/basic_auth modes, adds the credential header.
1326/// For url_path/query_param modes, the credential is already in the path.
1327fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
1328    match cred.inject_mode {
1329        InjectMode::Header | InjectMode::BasicAuth => {
1330            // Inject credential header
1331            request.push_str(&format!(
1332                "{}: {}\r\n",
1333                cred.header_name,
1334                cred.header_value.as_str()
1335            ));
1336        }
1337        InjectMode::UrlPath | InjectMode::QueryParam => {
1338            // Credential is already injected into the URL path/query
1339            // No header injection needed
1340        }
1341    }
1342}
1343
1344#[cfg(test)]
1345#[allow(clippy::unwrap_used)]
1346mod tests {
1347    use super::*;
1348
1349    #[test]
1350    fn test_parse_request_line() {
1351        let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
1352        assert_eq!(method, "POST");
1353        assert_eq!(path, "/openai/v1/chat");
1354        assert_eq!(version, "HTTP/1.1");
1355    }
1356
1357    #[test]
1358    fn test_parse_request_line_malformed() {
1359        assert!(parse_request_line("GET").is_err());
1360    }
1361
1362    #[test]
1363    fn test_parse_service_prefix() {
1364        let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
1365        assert_eq!(service, "openai");
1366        assert_eq!(path, "/v1/chat/completions");
1367    }
1368
1369    #[test]
1370    fn test_parse_service_prefix_no_subpath() {
1371        let (service, path) = parse_service_prefix("/anthropic").unwrap();
1372        assert_eq!(service, "anthropic");
1373        assert_eq!(path, "/");
1374    }
1375
1376    #[test]
1377    fn test_validate_phantom_token_bearer_valid() {
1378        let token = Zeroizing::new("secret123".to_string());
1379        let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
1380        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
1381    }
1382
1383    #[test]
1384    fn test_validate_phantom_token_bearer_invalid() {
1385        let token = Zeroizing::new("secret123".to_string());
1386        let header = b"Authorization: Bearer wrong\r\n\r\n";
1387        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
1388    }
1389
1390    #[test]
1391    fn test_validate_phantom_token_x_api_key_valid() {
1392        let token = Zeroizing::new("secret123".to_string());
1393        let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
1394        assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
1395    }
1396
1397    #[test]
1398    fn test_validate_phantom_token_x_goog_api_key_valid() {
1399        let token = Zeroizing::new("secret123".to_string());
1400        let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
1401        assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
1402    }
1403
1404    #[test]
1405    fn test_validate_phantom_token_missing() {
1406        let token = Zeroizing::new("secret123".to_string());
1407        let header = b"Content-Type: application/json\r\n\r\n";
1408        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
1409    }
1410
1411    #[test]
1412    fn test_validate_phantom_token_case_insensitive_header() {
1413        let token = Zeroizing::new("secret123".to_string());
1414        let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
1415        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
1416    }
1417
1418    #[test]
1419    fn test_filter_headers_removes_host_auth() {
1420        let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
1421        let filtered = filter_headers(header, "Authorization");
1422        assert_eq!(filtered.len(), 2);
1423        assert_eq!(filtered[0].0, "Content-Type");
1424        assert_eq!(filtered[1].0, "Accept");
1425    }
1426
1427    #[test]
1428    fn test_filter_headers_removes_x_api_key() {
1429        let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
1430        let filtered = filter_headers(header, "x-api-key");
1431        assert_eq!(filtered.len(), 1);
1432        assert_eq!(filtered[0].0, "Content-Type");
1433    }
1434
1435    #[test]
1436    fn test_filter_headers_removes_custom_header() {
1437        let header = b"PRIVATE-TOKEN: phantom123\r\nContent-Type: application/json\r\n\r\n";
1438        let filtered = filter_headers(header, "PRIVATE-TOKEN");
1439        assert_eq!(filtered.len(), 1);
1440        assert_eq!(filtered[0].0, "Content-Type");
1441    }
1442
1443    #[test]
1444    fn test_extract_content_length() {
1445        let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
1446        assert_eq!(extract_content_length(header), Some(42));
1447    }
1448
1449    #[test]
1450    fn test_extract_content_length_missing() {
1451        let header = b"Content-Type: application/json\r\n\r\n";
1452        assert_eq!(extract_content_length(header), None);
1453    }
1454
1455    #[test]
1456    fn test_parse_upstream_url_https() {
1457        let (scheme, host, port, path) =
1458            parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
1459        assert_eq!(scheme, UpstreamScheme::Https);
1460        assert_eq!(host, "api.openai.com");
1461        assert_eq!(port, 443);
1462        assert_eq!(path, "/v1/chat/completions");
1463    }
1464
1465    #[test]
1466    fn test_parse_upstream_url_http_with_port() {
1467        let (scheme, host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
1468        assert_eq!(scheme, UpstreamScheme::Http);
1469        assert_eq!(host, "localhost");
1470        assert_eq!(port, 8080);
1471        assert_eq!(path, "/api");
1472    }
1473
1474    #[test]
1475    fn test_parse_upstream_url_no_path() {
1476        let (scheme, host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
1477        assert_eq!(scheme, UpstreamScheme::Https);
1478        assert_eq!(host, "api.anthropic.com");
1479        assert_eq!(port, 443);
1480        assert_eq!(path, "/");
1481    }
1482
1483    #[test]
1484    fn test_parse_upstream_url_invalid_scheme() {
1485        assert!(parse_upstream_url("ftp://example.com").is_err());
1486    }
1487
1488    #[test]
1489    fn test_validate_http_upstream_target_rejects_non_local_host() {
1490        let err = validate_http_upstream_target(UpstreamScheme::Http, "api.example.com", &[])
1491            .expect_err("non-local http upstream should be rejected");
1492        assert!(err.contains("refusing insecure http upstream"));
1493    }
1494
1495    #[test]
1496    fn test_validate_http_upstream_target_allows_loopback() {
1497        let loopback = [SocketAddr::from(([127, 0, 0, 1], 8080))];
1498        assert!(validate_http_upstream_target(UpstreamScheme::Http, "127.0.0.1", &[]).is_ok());
1499        assert!(validate_http_upstream_target(UpstreamScheme::Http, "::1", &[]).is_ok());
1500        assert!(
1501            validate_http_upstream_target(UpstreamScheme::Http, "localhost", &loopback).is_ok()
1502        );
1503    }
1504
1505    #[test]
1506    fn test_validate_http_upstream_target_rejects_unspecified_addresses() {
1507        let unspecified = [SocketAddr::from(([0, 0, 0, 0], 8080))];
1508        let err = validate_http_upstream_target(UpstreamScheme::Http, "0.0.0.0", &[])
1509            .expect_err("unspecified http upstream should be rejected");
1510        assert!(err.contains("loopback addresses"));
1511
1512        let err = validate_http_upstream_target(UpstreamScheme::Http, "localhost", &unspecified)
1513            .expect_err("localhost resolving to unspecified should be rejected");
1514        assert!(err.contains("loopback addresses"));
1515    }
1516
1517    #[test]
1518    fn test_validate_http_upstream_target_rejects_localhost_resolving_non_loopback() {
1519        let poisoned = [SocketAddr::from(([203, 0, 113, 10], 8080))];
1520        let err = validate_http_upstream_target(UpstreamScheme::Http, "localhost", &poisoned)
1521            .expect_err("localhost resolving off-host should be rejected");
1522        assert!(err.contains("refusing insecure http upstream"));
1523    }
1524
1525    #[test]
1526    fn test_format_host_header_uses_port_for_non_default_http() {
1527        assert_eq!(
1528            format_host_header(UpstreamScheme::Http, "localhost", 8080),
1529            "localhost:8080"
1530        );
1531    }
1532
1533    #[test]
1534    fn test_format_host_header_omits_default_https_port() {
1535        assert_eq!(
1536            format_host_header(UpstreamScheme::Https, "api.openai.com", 443),
1537            "api.openai.com"
1538        );
1539    }
1540
1541    #[test]
1542    fn test_format_host_header_brackets_ipv6() {
1543        assert_eq!(
1544            format_host_header(UpstreamScheme::Http, "::1", 8080),
1545            "[::1]:8080"
1546        );
1547    }
1548
1549    #[test]
1550    fn test_parse_response_status_200() {
1551        let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
1552        assert_eq!(parse_response_status(data), 200);
1553    }
1554
1555    #[test]
1556    fn test_parse_response_status_404() {
1557        let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
1558        assert_eq!(parse_response_status(data), 404);
1559    }
1560
1561    #[test]
1562    fn test_parse_response_status_garbage() {
1563        let data = b"not an http response";
1564        assert_eq!(parse_response_status(data), 502);
1565    }
1566
1567    #[test]
1568    fn test_parse_response_status_empty() {
1569        assert_eq!(parse_response_status(b""), 502);
1570    }
1571
1572    #[test]
1573    fn test_parse_response_status_partial() {
1574        let data = b"HTTP/1.1 ";
1575        assert_eq!(parse_response_status(data), 502);
1576    }
1577
1578    // ============================================================================
1579    // URL Path Injection Mode Tests
1580    // ============================================================================
1581
1582    #[test]
1583    fn test_validate_phantom_token_in_path_valid() {
1584        let token = Zeroizing::new("session123".to_string());
1585        let path = "/bot/session123/getMe";
1586        let pattern = "/bot/{}/";
1587        assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1588    }
1589
1590    #[test]
1591    fn test_validate_phantom_token_in_path_invalid() {
1592        let token = Zeroizing::new("session123".to_string());
1593        let path = "/bot/wrong_token/getMe";
1594        let pattern = "/bot/{}/";
1595        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1596    }
1597
1598    #[test]
1599    fn test_validate_phantom_token_in_path_missing() {
1600        let token = Zeroizing::new("session123".to_string());
1601        let path = "/api/getMe";
1602        let pattern = "/bot/{}/";
1603        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1604    }
1605
1606    #[test]
1607    fn test_transform_url_path_basic() {
1608        let credential = Zeroizing::new("real_token".to_string());
1609        let path = "/bot/phantom_token/getMe";
1610        let pattern = "/bot/{}/";
1611        let replacement = "/bot/{}/";
1612        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1613        assert_eq!(result, "/bot/real_token/getMe");
1614    }
1615
1616    #[test]
1617    fn test_transform_url_path_different_replacement() {
1618        let credential = Zeroizing::new("real_token".to_string());
1619        let path = "/api/v1/phantom_token/chat";
1620        let pattern = "/api/v1/{}/";
1621        let replacement = "/v2/bot/{}/";
1622        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1623        assert_eq!(result, "/v2/bot/real_token/chat");
1624    }
1625
1626    #[test]
1627    fn test_transform_url_path_no_trailing_slash() {
1628        let credential = Zeroizing::new("real_token".to_string());
1629        let path = "/bot/phantom_token";
1630        let pattern = "/bot/{}";
1631        let replacement = "/bot/{}";
1632        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1633        assert_eq!(result, "/bot/real_token");
1634    }
1635
1636    // ============================================================================
1637    // Query Param Injection Mode Tests
1638    // ============================================================================
1639
1640    #[test]
1641    fn test_validate_phantom_token_in_query_valid() {
1642        let token = Zeroizing::new("session123".to_string());
1643        let path = "/api/data?api_key=session123&other=value";
1644        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1645    }
1646
1647    #[test]
1648    fn test_validate_phantom_token_in_query_invalid() {
1649        let token = Zeroizing::new("session123".to_string());
1650        let path = "/api/data?api_key=wrong_token";
1651        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1652    }
1653
1654    #[test]
1655    fn test_validate_phantom_token_in_query_missing_param() {
1656        let token = Zeroizing::new("session123".to_string());
1657        let path = "/api/data?other=value";
1658        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1659    }
1660
1661    #[test]
1662    fn test_validate_phantom_token_in_query_no_query_string() {
1663        let token = Zeroizing::new("session123".to_string());
1664        let path = "/api/data";
1665        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1666    }
1667
1668    #[test]
1669    fn test_validate_phantom_token_in_query_url_encoded() {
1670        let token = Zeroizing::new("token with spaces".to_string());
1671        let path = "/api/data?api_key=token%20with%20spaces";
1672        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1673    }
1674
1675    #[test]
1676    fn test_transform_query_param_add_to_no_query() {
1677        let credential = Zeroizing::new("real_key".to_string());
1678        let path = "/api/data";
1679        let result = transform_query_param(path, "api_key", &credential).unwrap();
1680        assert_eq!(result, "/api/data?api_key=real_key");
1681    }
1682
1683    #[test]
1684    fn test_transform_query_param_add_to_existing_query() {
1685        let credential = Zeroizing::new("real_key".to_string());
1686        let path = "/api/data?other=value";
1687        let result = transform_query_param(path, "api_key", &credential).unwrap();
1688        assert_eq!(result, "/api/data?other=value&api_key=real_key");
1689    }
1690
1691    #[test]
1692    fn test_transform_query_param_replace_existing() {
1693        let credential = Zeroizing::new("real_key".to_string());
1694        let path = "/api/data?api_key=phantom&other=value";
1695        let result = transform_query_param(path, "api_key", &credential).unwrap();
1696        assert_eq!(result, "/api/data?api_key=real_key&other=value");
1697    }
1698
1699    #[test]
1700    fn test_transform_query_param_url_encodes_special_chars() {
1701        let credential = Zeroizing::new("key with spaces".to_string());
1702        let path = "/api/data";
1703        let result = transform_query_param(path, "api_key", &credential).unwrap();
1704        assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1705    }
1706
1707    #[test]
1708    fn test_validate_phantom_token_uses_proxy_mode_over_upstream_mode() {
1709        let token = Zeroizing::new("session123".to_string());
1710        let header = b"Authorization: Bearer session123\r\n\r\n";
1711        let path = "/api/data?api_key=wrong";
1712
1713        // Simulate split config where proxy-side mode is header while upstream
1714        // mode might be query_param.
1715        let result = validate_phantom_token_for_mode(
1716            &InjectMode::Header,
1717            header,
1718            path,
1719            "Authorization",
1720            None,
1721            Some("api_key"),
1722            &token,
1723        );
1724
1725        assert!(result.is_ok());
1726    }
1727
1728    #[test]
1729    fn test_transform_path_uses_upstream_mode_independently() {
1730        let credential = Zeroizing::new("real_key".to_string());
1731        let path = "/api/data?api_key=phantom";
1732
1733        // Simulate split config where upstream mode is query_param.
1734        let transformed = transform_path_for_mode(
1735            &InjectMode::QueryParam,
1736            path,
1737            None,
1738            None,
1739            Some("api_key"),
1740            &credential,
1741        )
1742        .expect("query-param transform should succeed");
1743
1744        assert_eq!(transformed, "/api/data?api_key=real_key");
1745    }
1746
1747    // ========================================================================
1748    // Proxy artifact stripping tests
1749    // ========================================================================
1750
1751    #[test]
1752    fn test_strip_proxy_path_token_basic() {
1753        // Pattern: /{}/  — token is the first path segment
1754        let result = strip_proxy_path_token("/PHANTOM123/api/v1/pods", "/{}/");
1755        assert_eq!(result, "/api/v1/pods");
1756    }
1757
1758    #[test]
1759    fn test_strip_proxy_path_token_nested_pattern() {
1760        // Pattern: /auth/{}/  — token is in a nested segment
1761        let result = strip_proxy_path_token("/auth/PHANTOM123/api/v1/pods", "/auth/{}/");
1762        assert_eq!(result, "/api/v1/pods");
1763    }
1764
1765    #[test]
1766    fn test_strip_proxy_path_token_no_trailing_slash() {
1767        // Pattern: /{}  — token at end of path with no trailing content
1768        let result = strip_proxy_path_token("/PHANTOM123", "/{}");
1769        assert_eq!(result, "/");
1770    }
1771
1772    #[test]
1773    fn test_strip_proxy_path_token_preserves_query() {
1774        // Pattern: /{}/  — should preserve query string after stripping
1775        let result = strip_proxy_path_token("/PHANTOM123/api?limit=10", "/{}/");
1776        assert_eq!(result, "/api?limit=10");
1777    }
1778
1779    #[test]
1780    fn test_strip_proxy_path_token_no_match() {
1781        // Pattern doesn't match — return path unchanged
1782        let result = strip_proxy_path_token("/api/v1/pods", "/auth/{}/");
1783        assert_eq!(result, "/api/v1/pods");
1784    }
1785
1786    #[test]
1787    fn test_strip_proxy_path_token_mid_path_slash_join() {
1788        // Token in the middle: before="/api" after="data" must join with "/"
1789        let result = strip_proxy_path_token("/api/k8s/PHANTOM/data", "/k8s/{}/");
1790        assert_eq!(result, "/api/data");
1791    }
1792
1793    #[test]
1794    fn test_strip_proxy_path_token_no_double_slash() {
1795        // Before ends with "/" and after starts with "/" — collapse to one
1796        let result = strip_proxy_path_token("/prefix/PHANTOM//suffix", "/prefix/{}/");
1797        assert_eq!(result, "/suffix");
1798    }
1799
1800    #[test]
1801    fn test_strip_proxy_query_param_only_param() {
1802        let result = strip_proxy_query_param("/api/v1/pods?token=PHANTOM123", "token");
1803        assert_eq!(result, "/api/v1/pods");
1804    }
1805
1806    #[test]
1807    fn test_strip_proxy_query_param_with_other_params() {
1808        let result = strip_proxy_query_param("/api/v1/pods?token=PHANTOM123&limit=10", "token");
1809        assert_eq!(result, "/api/v1/pods?limit=10");
1810    }
1811
1812    #[test]
1813    fn test_strip_proxy_query_param_middle() {
1814        let result =
1815            strip_proxy_query_param("/api/v1/pods?limit=10&token=PHANTOM123&watch=true", "token");
1816        assert_eq!(result, "/api/v1/pods?limit=10&watch=true");
1817    }
1818
1819    #[test]
1820    fn test_strip_proxy_query_param_no_match() {
1821        let result = strip_proxy_query_param("/api/v1/pods?limit=10", "token");
1822        assert_eq!(result, "/api/v1/pods?limit=10");
1823    }
1824
1825    #[test]
1826    fn test_strip_proxy_query_param_no_query_string() {
1827        let result = strip_proxy_query_param("/api/v1/pods", "token");
1828        assert_eq!(result, "/api/v1/pods");
1829    }
1830
1831    #[test]
1832    fn test_strip_proxy_artifacts_same_mode_noop() {
1833        // When proxy and upstream use the same mode, no stripping (upstream transform handles it)
1834        let path = "/PHANTOM123/api/v1/pods";
1835        let result = strip_proxy_artifacts(
1836            path,
1837            &InjectMode::UrlPath,
1838            &InjectMode::UrlPath,
1839            Some("/{}/"),
1840            None,
1841        );
1842        assert_eq!(result, path);
1843    }
1844
1845    #[test]
1846    fn test_strip_proxy_artifacts_url_path_to_header() {
1847        // Proxy uses url_path, upstream uses header — must strip path token
1848        let result = strip_proxy_artifacts(
1849            "/PHANTOM123/api/v1/pods",
1850            &InjectMode::UrlPath,
1851            &InjectMode::Header,
1852            Some("/{}/"),
1853            None,
1854        );
1855        assert_eq!(result, "/api/v1/pods");
1856    }
1857
1858    #[test]
1859    fn test_strip_proxy_artifacts_query_param_to_header() {
1860        // Proxy uses query_param, upstream uses header — must strip query param
1861        let result = strip_proxy_artifacts(
1862            "/api/v1/pods?token=PHANTOM123",
1863            &InjectMode::QueryParam,
1864            &InjectMode::Header,
1865            None,
1866            Some("token"),
1867        );
1868        assert_eq!(result, "/api/v1/pods");
1869    }
1870
1871    #[test]
1872    fn test_strip_proxy_artifacts_header_to_query_param() {
1873        // Proxy uses header, upstream uses query_param — no URL artifacts to strip
1874        let path = "/api/v1/pods";
1875        let result = strip_proxy_artifacts(
1876            path,
1877            &InjectMode::Header,
1878            &InjectMode::QueryParam,
1879            None,
1880            None,
1881        );
1882        assert_eq!(result, path);
1883    }
1884
1885    #[test]
1886    fn test_end_to_end_url_path_proxy_header_upstream() {
1887        // Full flow: proxy validates via url_path, upstream injects via header.
1888        // The path token must be stripped before forwarding.
1889        let token = Zeroizing::new("session456".to_string());
1890        let credential = Zeroizing::new("real_bearer_token".to_string());
1891        let path = "/session456/api/v1/namespaces";
1892
1893        // 1. Proxy-side validation succeeds
1894        assert!(validate_phantom_token_for_mode(
1895            &InjectMode::UrlPath,
1896            b"\r\n\r\n", // no auth header needed for url_path mode
1897            path,
1898            "Authorization",
1899            Some("/{}/"),
1900            None,
1901            &token,
1902        )
1903        .is_ok());
1904
1905        // 2. Strip proxy artifacts
1906        let cleaned = strip_proxy_artifacts(
1907            path,
1908            &InjectMode::UrlPath,
1909            &InjectMode::Header,
1910            Some("/{}/"),
1911            None,
1912        );
1913        assert_eq!(cleaned, "/api/v1/namespaces");
1914
1915        // 3. Upstream transform (header mode = no path change)
1916        let transformed =
1917            transform_path_for_mode(&InjectMode::Header, &cleaned, None, None, None, &credential)
1918                .unwrap();
1919        assert_eq!(transformed, "/api/v1/namespaces");
1920    }
1921
1922    #[test]
1923    fn test_end_to_end_query_param_proxy_header_upstream() {
1924        // Full flow: proxy validates via query_param, upstream injects via header.
1925        let token = Zeroizing::new("session789".to_string());
1926        let credential = Zeroizing::new("real_bearer_token".to_string());
1927        let path = "/api/v1/pods?token=session789&limit=100";
1928
1929        // 1. Proxy-side validation succeeds
1930        assert!(validate_phantom_token_for_mode(
1931            &InjectMode::QueryParam,
1932            b"\r\n\r\n",
1933            path,
1934            "Authorization",
1935            None,
1936            Some("token"),
1937            &token,
1938        )
1939        .is_ok());
1940
1941        // 2. Strip proxy artifacts
1942        let cleaned = strip_proxy_artifacts(
1943            path,
1944            &InjectMode::QueryParam,
1945            &InjectMode::Header,
1946            None,
1947            Some("token"),
1948        );
1949        assert_eq!(cleaned, "/api/v1/pods?limit=100");
1950
1951        // 3. Upstream transform (header mode = no path change)
1952        let transformed =
1953            transform_path_for_mode(&InjectMode::Header, &cleaned, None, None, None, &credential)
1954                .unwrap();
1955        assert_eq!(transformed, "/api/v1/pods?limit=100");
1956    }
1957}