Skip to main content

nono_proxy/
reverse.rs

1//! Reverse proxy handler (Mode 2 — Credential Injection).
2//!
3//! Routes requests by path prefix to upstream APIs, injecting credentials
4//! from the keystore. The agent uses `http://localhost:PORT/openai/v1/chat`
5//! and the proxy rewrites to `https://api.openai.com/v1/chat` with the
6//! real credential injected.
7//!
8//! Supports multiple injection modes:
9//! - `header`: Inject into HTTP header (e.g., `Authorization: Bearer ...`)
10//! - `url_path`: Replace pattern in URL path (e.g., Telegram `/bot{}/`)
11//! - `query_param`: Add/replace query parameter (e.g., `?api_key=...`)
12//! - `basic_auth`: HTTP Basic Authentication
13//!
14//! Streaming responses (SSE, MCP Streamable HTTP, A2A JSON-RPC) are
15//! forwarded without buffering.
16
17use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::token;
23use std::net::SocketAddr;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31/// Maximum request body size (16 MiB). Prevents DoS from malicious Content-Length.
32const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
33
34/// Timeout for upstream TCP connect.
35const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37/// Handle a non-CONNECT HTTP request (reverse proxy mode).
38///
39/// Reads the full HTTP request from the client, matches path prefix to
40/// a configured route, injects credentials, and forwards to the upstream.
41/// Shared context passed from the server to the reverse proxy handler.
42pub struct ReverseProxyCtx<'a> {
43    /// Credential store for service lookups
44    pub credential_store: &'a CredentialStore,
45    /// Session token for authentication
46    pub session_token: &'a Zeroizing<String>,
47    /// Host filter for upstream validation
48    pub filter: &'a ProxyFilter,
49    /// Shared TLS connector
50    pub tls_connector: &'a TlsConnector,
51    /// Shared network audit sink for session metadata capture
52    pub audit_log: Option<&'a audit::SharedAuditLog>,
53}
54
55/// Handle a non-CONNECT HTTP request (reverse proxy mode).
56///
57/// `buffered_body` contains any bytes the BufReader read ahead beyond the
58/// headers. These are prepended to the body read from the stream to prevent
59/// data loss.
60///
61/// ## Phantom Token Pattern
62///
63/// The client (SDK) sends the session token as its "API key". The proxy:
64/// 1. Extracts the service from the path (e.g., `/openai/v1/chat` → `openai`)
65/// 2. Looks up which header that service uses (e.g., `Authorization` or `x-api-key`)
66/// 3. Validates the phantom token from that header
67/// 4. Replaces it with the real credential from keyring
68pub async fn handle_reverse_proxy(
69    first_line: &str,
70    stream: &mut TcpStream,
71    remaining_header: &[u8],
72    ctx: &ReverseProxyCtx<'_>,
73    buffered_body: &[u8],
74) -> Result<()> {
75    // Parse method, path, and HTTP version
76    let (method, path, version) = parse_request_line(first_line)?;
77    debug!("Reverse proxy: {} {}", method, path);
78
79    // Extract service prefix from path (e.g., "/openai/v1/chat" -> ("openai", "/v1/chat"))
80    let (service, upstream_path) = parse_service_prefix(&path)?;
81
82    // Look up credential for service
83    let cred = ctx
84        .credential_store
85        .get(&service)
86        .ok_or_else(|| ProxyError::UnknownService {
87            prefix: service.clone(),
88        })?;
89
90    // L7 endpoint filtering: check method+path against rules before any
91    // credential operations. Denied endpoints get 403 immediately.
92    if !cred.endpoint_rules.is_allowed(&method, &upstream_path) {
93        let reason = format!(
94            "endpoint denied: {} {} on service '{}'",
95            method, upstream_path, service
96        );
97        warn!("{}", reason);
98        audit::log_denied(
99            ctx.audit_log,
100            audit::ProxyMode::Reverse,
101            &service,
102            0,
103            &reason,
104        );
105        send_error(stream, 403, "Forbidden").await?;
106        return Ok(());
107    }
108
109    // Validate phantom token based on injection mode.
110    // For header/basic_auth modes: validate from Authorization/x-api-key header
111    // For url_path mode: validate from URL path pattern
112    // For query_param mode: validate from query parameter
113    if let Err(e) = validate_phantom_token_for_mode(
114        &cred.inject_mode,
115        remaining_header,
116        &upstream_path,
117        &cred.header_name,
118        cred.path_pattern.as_deref(),
119        cred.query_param_name.as_deref(),
120        ctx.session_token,
121    ) {
122        audit::log_denied(
123            ctx.audit_log,
124            audit::ProxyMode::Reverse,
125            &service,
126            0,
127            &e.to_string(),
128        );
129        send_error(stream, 401, "Unauthorized").await?;
130        return Ok(());
131    }
132
133    // Transform the path based on injection mode (url_path and query_param modes)
134    let transformed_path = transform_path_for_mode(
135        &cred.inject_mode,
136        &upstream_path,
137        cred.path_pattern.as_deref(),
138        cred.path_replacement.as_deref(),
139        cred.query_param_name.as_deref(),
140        &cred.raw_credential,
141    )?;
142
143    // Parse upstream URL with potentially transformed path
144    let upstream_url = format!(
145        "{}{}",
146        cred.upstream.trim_end_matches('/'),
147        transformed_path
148    );
149    debug!("Forwarding to upstream: {} {}", method, upstream_url);
150
151    let (upstream_host, upstream_port, upstream_path_full) = parse_upstream_url(&upstream_url)?;
152
153    // DNS resolve + host check via the filter
154    let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
155    if !check.result.is_allowed() {
156        let reason = check.result.reason();
157        warn!("Upstream host denied by filter: {}", reason);
158        send_error(stream, 403, "Forbidden").await?;
159        audit::log_denied(
160            ctx.audit_log,
161            audit::ProxyMode::Reverse,
162            &service,
163            0,
164            &reason,
165        );
166        return Ok(());
167    }
168
169    // Collect remaining request headers (excluding X-Nono-Token and Host)
170    let filtered_headers = filter_headers(remaining_header);
171    let content_length = extract_content_length(remaining_header);
172
173    // Read request body if present, with size limit.
174    // `buffered_body` may contain bytes the BufReader read ahead beyond
175    // headers; we prepend those to avoid data loss.
176    let body = if let Some(len) = content_length {
177        if len > MAX_REQUEST_BODY {
178            send_error(stream, 413, "Payload Too Large").await?;
179            return Ok(());
180        }
181        let mut buf = Vec::with_capacity(len);
182        let pre = buffered_body.len().min(len);
183        buf.extend_from_slice(&buffered_body[..pre]);
184        let remaining = len - pre;
185        if remaining > 0 {
186            let mut rest = vec![0u8; remaining];
187            stream.read_exact(&mut rest).await?;
188            buf.extend_from_slice(&rest);
189        }
190        buf
191    } else {
192        Vec::new()
193    };
194
195    // Connect to upstream over TLS using pre-resolved addresses
196    let upstream_result = connect_upstream_tls(
197        &upstream_host,
198        upstream_port,
199        &check.resolved_addrs,
200        ctx.tls_connector,
201    )
202    .await;
203    let mut tls_stream = match upstream_result {
204        Ok(s) => s,
205        Err(e) => {
206            warn!("Upstream connection failed: {}", e);
207            send_error(stream, 502, "Bad Gateway").await?;
208            audit::log_denied(
209                ctx.audit_log,
210                audit::ProxyMode::Reverse,
211                &service,
212                0,
213                &e.to_string(),
214            );
215            return Ok(());
216        }
217    };
218
219    // Build the upstream request into a Zeroizing buffer since it may contain
220    // credential values. This ensures credentials are zeroed from heap memory
221    // when the buffer is dropped.
222    let mut request = Zeroizing::new(format!(
223        "{} {} {}\r\nHost: {}\r\n",
224        method, upstream_path_full, version, upstream_host
225    ));
226
227    // Inject credential based on mode
228    inject_credential_for_mode(cred, &mut request);
229
230    // Forward filtered headers (excluding auth headers that we're replacing)
231    let auth_header_lower = cred.header_name.to_lowercase();
232    for (name, value) in &filtered_headers {
233        // Skip the auth header if we're using header/basic_auth mode
234        // (we already injected our own)
235        if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
236            && name.to_lowercase() == auth_header_lower
237        {
238            continue;
239        }
240        request.push_str(&format!("{}: {}\r\n", name, value));
241    }
242
243    // Content-Length for body
244    if !body.is_empty() {
245        request.push_str(&format!("Content-Length: {}\r\n", body.len()));
246    }
247    request.push_str("\r\n");
248
249    tls_stream.write_all(request.as_bytes()).await?;
250    if !body.is_empty() {
251        tls_stream.write_all(&body).await?;
252    }
253    tls_stream.flush().await?;
254
255    // Stream the response back to the client without buffering.
256    // This handles SSE (text/event-stream), chunked transfer, and regular responses.
257    let mut response_buf = [0u8; 8192];
258    let mut status_code: u16 = 502;
259    let mut first_chunk = true;
260
261    loop {
262        let n = match tls_stream.read(&mut response_buf).await {
263            Ok(0) => break,
264            Ok(n) => n,
265            Err(e) => {
266                debug!("Upstream read error: {}", e);
267                break;
268            }
269        };
270
271        // Parse status from first chunk. The HTTP status line format is:
272        // "HTTP/1.1 200 OK\r\n..." — we need the 3-digit code after the
273        // first space. We scan up to 32 bytes (enough for any valid status line).
274        if first_chunk {
275            status_code = parse_response_status(&response_buf[..n]);
276            first_chunk = false;
277        }
278
279        stream.write_all(&response_buf[..n]).await?;
280        stream.flush().await?;
281    }
282
283    audit::log_reverse_proxy(
284        ctx.audit_log,
285        &service,
286        &method,
287        &upstream_path,
288        status_code,
289    );
290    Ok(())
291}
292
293/// Parse an HTTP request line into (method, path, version).
294fn parse_request_line(line: &str) -> Result<(String, String, String)> {
295    let parts: Vec<&str> = line.split_whitespace().collect();
296    if parts.len() < 3 {
297        return Err(ProxyError::HttpParse(format!(
298            "malformed request line: {}",
299            line
300        )));
301    }
302    Ok((
303        parts[0].to_string(),
304        parts[1].to_string(),
305        parts[2].to_string(),
306    ))
307}
308
309/// Extract service prefix from path.
310///
311/// "/openai/v1/chat/completions" -> ("openai", "/v1/chat/completions")
312/// "/anthropic/v1/messages" -> ("anthropic", "/v1/messages")
313fn parse_service_prefix(path: &str) -> Result<(String, String)> {
314    let trimmed = path.strip_prefix('/').unwrap_or(path);
315    if let Some((prefix, rest)) = trimmed.split_once('/') {
316        Ok((prefix.to_string(), format!("/{}", rest)))
317    } else {
318        // No sub-path, just the prefix
319        Ok((trimmed.to_string(), "/".to_string()))
320    }
321}
322
323/// Validate the phantom token from the service's auth header.
324///
325/// The SDK sends the session token as its "API key" in the standard auth header
326/// for that service (e.g., `Authorization: Bearer <token>` for OpenAI,
327/// `x-api-key: <token>` for Anthropic). We validate the token matches the
328/// session token before swapping in the real credential.
329fn validate_phantom_token(
330    header_bytes: &[u8],
331    header_name: &str,
332    session_token: &Zeroizing<String>,
333) -> Result<()> {
334    let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
335    let header_name_lower = header_name.to_lowercase();
336
337    for line in header_str.lines() {
338        let lower = line.to_lowercase();
339        if lower.starts_with(&format!("{}:", header_name_lower)) {
340            let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
341
342            // Handle "Bearer <token>" format (strip "Bearer " prefix if present)
343            // Use case-insensitive check, then slice original value by length
344            let value_lower = value.to_lowercase();
345            let token_value = if value_lower.starts_with("bearer ") {
346                // "bearer ".len() == 7
347                value[7..].trim()
348            } else {
349                value
350            };
351
352            if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
353                return Ok(());
354            }
355            warn!("Invalid phantom token in {} header", header_name);
356            return Err(ProxyError::InvalidToken);
357        }
358    }
359
360    warn!(
361        "Missing {} header for phantom token validation",
362        header_name
363    );
364    Err(ProxyError::InvalidToken)
365}
366
367/// Filter headers, removing Host, Content-Length, and auth headers.
368///
369/// Content-Length is re-added after body is read, and Host is rewritten
370/// to the upstream. Authorization and x-api-key headers are stripped since
371/// we inject our own credential (the phantom token is validated but not forwarded).
372fn filter_headers(header_bytes: &[u8]) -> Vec<(String, String)> {
373    let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
374    let mut headers = Vec::new();
375
376    for line in header_str.lines() {
377        let lower = line.to_lowercase();
378        if lower.starts_with("host:")
379            || lower.starts_with("content-length:")
380            || lower.starts_with("authorization:")
381            || lower.starts_with("x-api-key:")
382            || lower.starts_with("x-goog-api-key:")
383            || line.trim().is_empty()
384        {
385            continue;
386        }
387        if let Some((name, value)) = line.split_once(':') {
388            headers.push((name.trim().to_string(), value.trim().to_string()));
389        }
390    }
391
392    headers
393}
394
395/// Extract Content-Length value from raw headers.
396fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
397    let header_str = std::str::from_utf8(header_bytes).ok()?;
398    for line in header_str.lines() {
399        if line.to_lowercase().starts_with("content-length:") {
400            let value = line.split_once(':')?.1.trim();
401            return value.parse().ok();
402        }
403    }
404    None
405}
406
407/// Parse an upstream URL into (host, port, path).
408fn parse_upstream_url(url_str: &str) -> Result<(String, u16, String)> {
409    let parsed = url::Url::parse(url_str)
410        .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
411
412    let scheme = parsed.scheme();
413    if scheme != "https" && scheme != "http" {
414        return Err(ProxyError::HttpParse(format!(
415            "unsupported URL scheme: {}",
416            url_str
417        )));
418    }
419
420    let host = parsed
421        .host_str()
422        .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
423        .to_string();
424
425    let default_port = if scheme == "https" { 443 } else { 80 };
426    let port = parsed.port().unwrap_or(default_port);
427
428    let path = parsed.path().to_string();
429    let path = if path.is_empty() {
430        "/".to_string()
431    } else {
432        path
433    };
434
435    // Include query string if present
436    let path_with_query = if let Some(query) = parsed.query() {
437        format!("{}?{}", path, query)
438    } else {
439        path
440    };
441
442    Ok((host, port, path_with_query))
443}
444
445/// Connect to an upstream host over TLS using pre-resolved addresses.
446///
447/// Uses the pre-resolved `SocketAddr`s from the filter check to prevent
448/// DNS rebinding TOCTOU. Falls back to hostname resolution only if no
449/// pre-resolved addresses are available.
450///
451/// The `TlsConnector` is shared across all connections (created once at
452/// server startup with the system root certificate store).
453async fn connect_upstream_tls(
454    host: &str,
455    port: u16,
456    resolved_addrs: &[SocketAddr],
457    connector: &TlsConnector,
458) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
459    let tcp = if resolved_addrs.is_empty() {
460        // Fallback: no pre-resolved addresses (shouldn't happen in practice)
461        let addr = format!("{}:{}", host, port);
462        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
463            Ok(Ok(s)) => s,
464            Ok(Err(e)) => {
465                return Err(ProxyError::UpstreamConnect {
466                    host: host.to_string(),
467                    reason: e.to_string(),
468                });
469            }
470            Err(_) => {
471                return Err(ProxyError::UpstreamConnect {
472                    host: host.to_string(),
473                    reason: "connection timed out".to_string(),
474                });
475            }
476        }
477    } else {
478        connect_to_resolved(resolved_addrs, host).await?
479    };
480
481    let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
482        ProxyError::UpstreamConnect {
483            host: host.to_string(),
484            reason: "invalid server name for TLS".to_string(),
485        }
486    })?;
487
488    let tls_stream =
489        connector
490            .connect(server_name, tcp)
491            .await
492            .map_err(|e| ProxyError::UpstreamConnect {
493                host: host.to_string(),
494                reason: format!("TLS handshake failed: {}", e),
495            })?;
496
497    Ok(tls_stream)
498}
499
500/// Connect to one of the pre-resolved socket addresses with timeout.
501async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
502    let mut last_err = None;
503    for addr in addrs {
504        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
505            Ok(Ok(stream)) => return Ok(stream),
506            Ok(Err(e)) => {
507                debug!("Connect to {} failed: {}", addr, e);
508                last_err = Some(e.to_string());
509            }
510            Err(_) => {
511                debug!("Connect to {} timed out", addr);
512                last_err = Some("connection timed out".to_string());
513            }
514        }
515    }
516    Err(ProxyError::UpstreamConnect {
517        host: host.to_string(),
518        reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
519    })
520}
521
522/// Parse HTTP status code from the first response chunk.
523///
524/// Looks for the "HTTP/x.y NNN" pattern in the first line. Returns 502
525/// if the response doesn't contain a valid status line (upstream sent
526/// garbage or incomplete data).
527fn parse_response_status(data: &[u8]) -> u16 {
528    // Find the end of the first line (or use full data if no newline)
529    let line_end = data
530        .iter()
531        .position(|&b| b == b'\r' || b == b'\n')
532        .unwrap_or(data.len());
533    let first_line = &data[..line_end.min(64)];
534
535    if let Ok(line) = std::str::from_utf8(first_line) {
536        // Split on whitespace: ["HTTP/1.1", "200", "OK"]
537        let mut parts = line.split_whitespace();
538        if let Some(version) = parts.next() {
539            if version.starts_with("HTTP/") {
540                if let Some(code_str) = parts.next() {
541                    if code_str.len() == 3 {
542                        return code_str.parse().unwrap_or(502);
543                    }
544                }
545            }
546        }
547    }
548    502
549}
550
551/// Send an HTTP error response.
552async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
553    let body = format!("{{\"error\":\"{}\"}}", reason);
554    let response = format!(
555        "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
556        status,
557        reason,
558        body.len(),
559        body
560    );
561    stream.write_all(response.as_bytes()).await?;
562    stream.flush().await?;
563    Ok(())
564}
565
566// ============================================================================
567// Injection mode helpers
568// ============================================================================
569
570/// Validate phantom token based on injection mode.
571///
572/// Different modes extract the phantom token from different locations:
573/// - `Header`/`BasicAuth`: From the auth header (Authorization, x-api-key, etc.)
574/// - `UrlPath`: From the URL path pattern (e.g., `/bot<token>/getMe`)
575/// - `QueryParam`: From the query parameter (e.g., `?api_key=<token>`)
576fn validate_phantom_token_for_mode(
577    mode: &InjectMode,
578    header_bytes: &[u8],
579    path: &str,
580    header_name: &str,
581    path_pattern: Option<&str>,
582    query_param_name: Option<&str>,
583    session_token: &Zeroizing<String>,
584) -> Result<()> {
585    match mode {
586        InjectMode::Header | InjectMode::BasicAuth => {
587            // Validate from header (existing behavior)
588            validate_phantom_token(header_bytes, header_name, session_token)
589        }
590        InjectMode::UrlPath => {
591            // Validate from URL path
592            let pattern = path_pattern.ok_or_else(|| {
593                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
594            })?;
595            validate_phantom_token_in_path(path, pattern, session_token)
596        }
597        InjectMode::QueryParam => {
598            // Validate from query parameter
599            let param_name = query_param_name.ok_or_else(|| {
600                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
601            })?;
602            validate_phantom_token_in_query(path, param_name, session_token)
603        }
604    }
605}
606
607/// Validate phantom token embedded in URL path.
608///
609/// Extracts the token from the path using the pattern (e.g., `/bot{}/` matches
610/// `/bot<token>/getMe` and extracts `<token>`).
611fn validate_phantom_token_in_path(
612    path: &str,
613    pattern: &str,
614    session_token: &Zeroizing<String>,
615) -> Result<()> {
616    // Split pattern on {} to get prefix and suffix
617    let parts: Vec<&str> = pattern.split("{}").collect();
618    if parts.len() != 2 {
619        return Err(ProxyError::HttpParse(format!(
620            "invalid path_pattern '{}': must contain exactly one {{}}",
621            pattern
622        )));
623    }
624    let (prefix, suffix) = (parts[0], parts[1]);
625
626    // Find the token in the path
627    if let Some(start) = path.find(prefix) {
628        let after_prefix = start + prefix.len();
629
630        // Handle empty suffix case (token extends to end of path or next '/' or '?')
631        let end_offset = if suffix.is_empty() {
632            path[after_prefix..]
633                .find(['/', '?'])
634                .unwrap_or(path[after_prefix..].len())
635        } else {
636            match path[after_prefix..].find(suffix) {
637                Some(offset) => offset,
638                None => {
639                    warn!("Missing phantom token in URL path (pattern: {})", pattern);
640                    return Err(ProxyError::InvalidToken);
641                }
642            }
643        };
644
645        let token = &path[after_prefix..after_prefix + end_offset];
646        if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
647            return Ok(());
648        }
649        warn!("Invalid phantom token in URL path");
650        return Err(ProxyError::InvalidToken);
651    }
652
653    warn!("Missing phantom token in URL path (pattern: {})", pattern);
654    Err(ProxyError::InvalidToken)
655}
656
657/// Validate phantom token in query parameter.
658fn validate_phantom_token_in_query(
659    path: &str,
660    param_name: &str,
661    session_token: &Zeroizing<String>,
662) -> Result<()> {
663    // Parse query string from path
664    if let Some(query_start) = path.find('?') {
665        let query = &path[query_start + 1..];
666        for pair in query.split('&') {
667            if let Some((name, value)) = pair.split_once('=') {
668                if name == param_name {
669                    // URL-decode the value
670                    let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
671                    if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
672                        return Ok(());
673                    }
674                    warn!("Invalid phantom token in query parameter '{}'", param_name);
675                    return Err(ProxyError::InvalidToken);
676                }
677            }
678        }
679    }
680
681    warn!("Missing phantom token in query parameter '{}'", param_name);
682    Err(ProxyError::InvalidToken)
683}
684
685/// Transform URL path based on injection mode.
686///
687/// - `UrlPath`: Replace phantom token with real credential in path
688/// - `QueryParam`: Add/replace query parameter with real credential
689/// - `Header`/`BasicAuth`: No path transformation needed
690fn transform_path_for_mode(
691    mode: &InjectMode,
692    path: &str,
693    path_pattern: Option<&str>,
694    path_replacement: Option<&str>,
695    query_param_name: Option<&str>,
696    credential: &Zeroizing<String>,
697) -> Result<String> {
698    match mode {
699        InjectMode::Header | InjectMode::BasicAuth => {
700            // No path transformation needed
701            Ok(path.to_string())
702        }
703        InjectMode::UrlPath => {
704            let pattern = path_pattern.ok_or_else(|| {
705                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
706            })?;
707            let replacement = path_replacement.unwrap_or(pattern);
708            transform_url_path(path, pattern, replacement, credential)
709        }
710        InjectMode::QueryParam => {
711            let param_name = query_param_name.ok_or_else(|| {
712                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
713            })?;
714            transform_query_param(path, param_name, credential)
715        }
716    }
717}
718
719/// Transform URL path by replacing phantom token pattern with real credential.
720///
721/// Example: `/bot<phantom>/getMe` with pattern `/bot{}/` becomes `/bot<real>/getMe`
722fn transform_url_path(
723    path: &str,
724    pattern: &str,
725    replacement: &str,
726    credential: &Zeroizing<String>,
727) -> Result<String> {
728    // Split pattern on {} to get prefix and suffix
729    let parts: Vec<&str> = pattern.split("{}").collect();
730    if parts.len() != 2 {
731        return Err(ProxyError::HttpParse(format!(
732            "invalid path_pattern '{}': must contain exactly one {{}}",
733            pattern
734        )));
735    }
736    let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
737
738    // Split replacement on {}
739    let repl_parts: Vec<&str> = replacement.split("{}").collect();
740    if repl_parts.len() != 2 {
741        return Err(ProxyError::HttpParse(format!(
742            "invalid path_replacement '{}': must contain exactly one {{}}",
743            replacement
744        )));
745    }
746    let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
747
748    // Find and replace the token in the path
749    if let Some(start) = path.find(pattern_prefix) {
750        let after_prefix = start + pattern_prefix.len();
751
752        // Handle empty suffix case (token extends to end of path or next '/' or '?')
753        let end_offset = if pattern_suffix.is_empty() {
754            // Find the next path segment delimiter or end of path
755            path[after_prefix..]
756                .find(['/', '?'])
757                .unwrap_or(path[after_prefix..].len())
758        } else {
759            // Find the suffix in the remaining path
760            match path[after_prefix..].find(pattern_suffix) {
761                Some(offset) => offset,
762                None => {
763                    return Err(ProxyError::HttpParse(format!(
764                        "path '{}' does not match pattern '{}'",
765                        path, pattern
766                    )));
767                }
768            }
769        };
770
771        let before = &path[..start];
772        let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
773        return Ok(format!(
774            "{}{}{}{}{}",
775            before,
776            repl_prefix,
777            credential.as_str(),
778            repl_suffix,
779            after
780        ));
781    }
782
783    Err(ProxyError::HttpParse(format!(
784        "path '{}' does not match pattern '{}'",
785        path, pattern
786    )))
787}
788
789/// Transform query string by adding or replacing a parameter with the credential.
790fn transform_query_param(
791    path: &str,
792    param_name: &str,
793    credential: &Zeroizing<String>,
794) -> Result<String> {
795    let encoded_value = urlencoding::encode(credential.as_str());
796
797    if let Some(query_start) = path.find('?') {
798        let base_path = &path[..query_start];
799        let query = &path[query_start + 1..];
800
801        // Check if parameter already exists
802        let mut found = false;
803        let new_query: Vec<String> = query
804            .split('&')
805            .map(|pair| {
806                if let Some((name, _)) = pair.split_once('=') {
807                    if name == param_name {
808                        found = true;
809                        return format!("{}={}", param_name, encoded_value);
810                    }
811                }
812                pair.to_string()
813            })
814            .collect();
815
816        if found {
817            Ok(format!("{}?{}", base_path, new_query.join("&")))
818        } else {
819            // Append the parameter
820            Ok(format!(
821                "{}?{}&{}={}",
822                base_path, query, param_name, encoded_value
823            ))
824        }
825    } else {
826        // No query string, add one
827        Ok(format!("{}?{}={}", path, param_name, encoded_value))
828    }
829}
830
831/// Inject credential into request based on mode.
832///
833/// For header/basic_auth modes, adds the credential header.
834/// For url_path/query_param modes, the credential is already in the path.
835fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
836    match cred.inject_mode {
837        InjectMode::Header | InjectMode::BasicAuth => {
838            // Inject credential header
839            request.push_str(&format!(
840                "{}: {}\r\n",
841                cred.header_name,
842                cred.header_value.as_str()
843            ));
844        }
845        InjectMode::UrlPath | InjectMode::QueryParam => {
846            // Credential is already injected into the URL path/query
847            // No header injection needed
848        }
849    }
850}
851
852#[cfg(test)]
853#[allow(clippy::unwrap_used)]
854mod tests {
855    use super::*;
856
857    #[test]
858    fn test_parse_request_line() {
859        let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
860        assert_eq!(method, "POST");
861        assert_eq!(path, "/openai/v1/chat");
862        assert_eq!(version, "HTTP/1.1");
863    }
864
865    #[test]
866    fn test_parse_request_line_malformed() {
867        assert!(parse_request_line("GET").is_err());
868    }
869
870    #[test]
871    fn test_parse_service_prefix() {
872        let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
873        assert_eq!(service, "openai");
874        assert_eq!(path, "/v1/chat/completions");
875    }
876
877    #[test]
878    fn test_parse_service_prefix_no_subpath() {
879        let (service, path) = parse_service_prefix("/anthropic").unwrap();
880        assert_eq!(service, "anthropic");
881        assert_eq!(path, "/");
882    }
883
884    #[test]
885    fn test_validate_phantom_token_bearer_valid() {
886        let token = Zeroizing::new("secret123".to_string());
887        let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
888        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
889    }
890
891    #[test]
892    fn test_validate_phantom_token_bearer_invalid() {
893        let token = Zeroizing::new("secret123".to_string());
894        let header = b"Authorization: Bearer wrong\r\n\r\n";
895        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
896    }
897
898    #[test]
899    fn test_validate_phantom_token_x_api_key_valid() {
900        let token = Zeroizing::new("secret123".to_string());
901        let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
902        assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
903    }
904
905    #[test]
906    fn test_validate_phantom_token_x_goog_api_key_valid() {
907        let token = Zeroizing::new("secret123".to_string());
908        let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
909        assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
910    }
911
912    #[test]
913    fn test_validate_phantom_token_missing() {
914        let token = Zeroizing::new("secret123".to_string());
915        let header = b"Content-Type: application/json\r\n\r\n";
916        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
917    }
918
919    #[test]
920    fn test_validate_phantom_token_case_insensitive_header() {
921        let token = Zeroizing::new("secret123".to_string());
922        let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
923        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
924    }
925
926    #[test]
927    fn test_filter_headers_removes_host_auth() {
928        let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
929        let filtered = filter_headers(header);
930        assert_eq!(filtered.len(), 2);
931        assert_eq!(filtered[0].0, "Content-Type");
932        assert_eq!(filtered[1].0, "Accept");
933    }
934
935    #[test]
936    fn test_filter_headers_removes_x_api_key() {
937        let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
938        let filtered = filter_headers(header);
939        assert_eq!(filtered.len(), 1);
940        assert_eq!(filtered[0].0, "Content-Type");
941    }
942
943    #[test]
944    fn test_filter_headers_removes_x_goog_api_key() {
945        let header = b"x-goog-api-key: gemini-key\r\nContent-Type: application/json\r\n\r\n";
946        let filtered = filter_headers(header);
947        assert_eq!(filtered.len(), 1);
948        assert_eq!(filtered[0].0, "Content-Type");
949    }
950
951    #[test]
952    fn test_extract_content_length() {
953        let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
954        assert_eq!(extract_content_length(header), Some(42));
955    }
956
957    #[test]
958    fn test_extract_content_length_missing() {
959        let header = b"Content-Type: application/json\r\n\r\n";
960        assert_eq!(extract_content_length(header), None);
961    }
962
963    #[test]
964    fn test_parse_upstream_url_https() {
965        let (host, port, path) =
966            parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
967        assert_eq!(host, "api.openai.com");
968        assert_eq!(port, 443);
969        assert_eq!(path, "/v1/chat/completions");
970    }
971
972    #[test]
973    fn test_parse_upstream_url_http_with_port() {
974        let (host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
975        assert_eq!(host, "localhost");
976        assert_eq!(port, 8080);
977        assert_eq!(path, "/api");
978    }
979
980    #[test]
981    fn test_parse_upstream_url_no_path() {
982        let (host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
983        assert_eq!(host, "api.anthropic.com");
984        assert_eq!(port, 443);
985        assert_eq!(path, "/");
986    }
987
988    #[test]
989    fn test_parse_upstream_url_invalid_scheme() {
990        assert!(parse_upstream_url("ftp://example.com").is_err());
991    }
992
993    #[test]
994    fn test_parse_response_status_200() {
995        let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
996        assert_eq!(parse_response_status(data), 200);
997    }
998
999    #[test]
1000    fn test_parse_response_status_404() {
1001        let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
1002        assert_eq!(parse_response_status(data), 404);
1003    }
1004
1005    #[test]
1006    fn test_parse_response_status_garbage() {
1007        let data = b"not an http response";
1008        assert_eq!(parse_response_status(data), 502);
1009    }
1010
1011    #[test]
1012    fn test_parse_response_status_empty() {
1013        assert_eq!(parse_response_status(b""), 502);
1014    }
1015
1016    #[test]
1017    fn test_parse_response_status_partial() {
1018        let data = b"HTTP/1.1 ";
1019        assert_eq!(parse_response_status(data), 502);
1020    }
1021
1022    // ============================================================================
1023    // URL Path Injection Mode Tests
1024    // ============================================================================
1025
1026    #[test]
1027    fn test_validate_phantom_token_in_path_valid() {
1028        let token = Zeroizing::new("session123".to_string());
1029        let path = "/bot/session123/getMe";
1030        let pattern = "/bot/{}/";
1031        assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1032    }
1033
1034    #[test]
1035    fn test_validate_phantom_token_in_path_invalid() {
1036        let token = Zeroizing::new("session123".to_string());
1037        let path = "/bot/wrong_token/getMe";
1038        let pattern = "/bot/{}/";
1039        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1040    }
1041
1042    #[test]
1043    fn test_validate_phantom_token_in_path_missing() {
1044        let token = Zeroizing::new("session123".to_string());
1045        let path = "/api/getMe";
1046        let pattern = "/bot/{}/";
1047        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1048    }
1049
1050    #[test]
1051    fn test_transform_url_path_basic() {
1052        let credential = Zeroizing::new("real_token".to_string());
1053        let path = "/bot/phantom_token/getMe";
1054        let pattern = "/bot/{}/";
1055        let replacement = "/bot/{}/";
1056        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1057        assert_eq!(result, "/bot/real_token/getMe");
1058    }
1059
1060    #[test]
1061    fn test_transform_url_path_different_replacement() {
1062        let credential = Zeroizing::new("real_token".to_string());
1063        let path = "/api/v1/phantom_token/chat";
1064        let pattern = "/api/v1/{}/";
1065        let replacement = "/v2/bot/{}/";
1066        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1067        assert_eq!(result, "/v2/bot/real_token/chat");
1068    }
1069
1070    #[test]
1071    fn test_transform_url_path_no_trailing_slash() {
1072        let credential = Zeroizing::new("real_token".to_string());
1073        let path = "/bot/phantom_token";
1074        let pattern = "/bot/{}";
1075        let replacement = "/bot/{}";
1076        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1077        assert_eq!(result, "/bot/real_token");
1078    }
1079
1080    // ============================================================================
1081    // Query Param Injection Mode Tests
1082    // ============================================================================
1083
1084    #[test]
1085    fn test_validate_phantom_token_in_query_valid() {
1086        let token = Zeroizing::new("session123".to_string());
1087        let path = "/api/data?api_key=session123&other=value";
1088        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1089    }
1090
1091    #[test]
1092    fn test_validate_phantom_token_in_query_invalid() {
1093        let token = Zeroizing::new("session123".to_string());
1094        let path = "/api/data?api_key=wrong_token";
1095        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1096    }
1097
1098    #[test]
1099    fn test_validate_phantom_token_in_query_missing_param() {
1100        let token = Zeroizing::new("session123".to_string());
1101        let path = "/api/data?other=value";
1102        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1103    }
1104
1105    #[test]
1106    fn test_validate_phantom_token_in_query_no_query_string() {
1107        let token = Zeroizing::new("session123".to_string());
1108        let path = "/api/data";
1109        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1110    }
1111
1112    #[test]
1113    fn test_validate_phantom_token_in_query_url_encoded() {
1114        let token = Zeroizing::new("token with spaces".to_string());
1115        let path = "/api/data?api_key=token%20with%20spaces";
1116        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1117    }
1118
1119    #[test]
1120    fn test_transform_query_param_add_to_no_query() {
1121        let credential = Zeroizing::new("real_key".to_string());
1122        let path = "/api/data";
1123        let result = transform_query_param(path, "api_key", &credential).unwrap();
1124        assert_eq!(result, "/api/data?api_key=real_key");
1125    }
1126
1127    #[test]
1128    fn test_transform_query_param_add_to_existing_query() {
1129        let credential = Zeroizing::new("real_key".to_string());
1130        let path = "/api/data?other=value";
1131        let result = transform_query_param(path, "api_key", &credential).unwrap();
1132        assert_eq!(result, "/api/data?other=value&api_key=real_key");
1133    }
1134
1135    #[test]
1136    fn test_transform_query_param_replace_existing() {
1137        let credential = Zeroizing::new("real_key".to_string());
1138        let path = "/api/data?api_key=phantom&other=value";
1139        let result = transform_query_param(path, "api_key", &credential).unwrap();
1140        assert_eq!(result, "/api/data?api_key=real_key&other=value");
1141    }
1142
1143    #[test]
1144    fn test_transform_query_param_url_encodes_special_chars() {
1145        let credential = Zeroizing::new("key with spaces".to_string());
1146        let path = "/api/data";
1147        let result = transform_query_param(path, "api_key", &credential).unwrap();
1148        assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1149    }
1150}