Skip to main content

nono_proxy/
reverse.rs

1//! Reverse proxy handler (Mode 2 — Credential Injection).
2//!
3//! Routes requests by path prefix to upstream APIs, injecting credentials
4//! from the keystore. The agent uses `http://localhost:PORT/openai/v1/chat`
5//! and the proxy rewrites to `https://api.openai.com/v1/chat` with the
6//! real credential injected.
7//!
8//! Supports multiple injection modes:
9//! - `header`: Inject into HTTP header (e.g., `Authorization: Bearer ...`)
10//! - `url_path`: Replace pattern in URL path (e.g., Telegram `/bot{}/`)
11//! - `query_param`: Add/replace query parameter (e.g., `?api_key=...`)
12//! - `basic_auth`: HTTP Basic Authentication
13//!
14//! Streaming responses (SSE, MCP Streamable HTTP, A2A JSON-RPC) are
15//! forwarded without buffering.
16
17use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::token;
23use std::net::SocketAddr;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31/// Maximum request body size (16 MiB). Prevents DoS from malicious Content-Length.
32const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
33
34/// Timeout for upstream TCP connect.
35const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37/// Handle a non-CONNECT HTTP request (reverse proxy mode).
38///
39/// Reads the full HTTP request from the client, matches path prefix to
40/// a configured route, injects credentials, and forwards to the upstream.
41/// Shared context passed from the server to the reverse proxy handler.
42pub struct ReverseProxyCtx<'a> {
43    /// Credential store for service lookups
44    pub credential_store: &'a CredentialStore,
45    /// Session token for authentication
46    pub session_token: &'a Zeroizing<String>,
47    /// Host filter for upstream validation
48    pub filter: &'a ProxyFilter,
49    /// Shared TLS connector
50    pub tls_connector: &'a TlsConnector,
51    /// Shared network audit sink for session metadata capture
52    pub audit_log: Option<&'a audit::SharedAuditLog>,
53}
54
55/// Handle a non-CONNECT HTTP request (reverse proxy mode).
56///
57/// `buffered_body` contains any bytes the BufReader read ahead beyond the
58/// headers. These are prepended to the body read from the stream to prevent
59/// data loss.
60///
61/// ## Phantom Token Pattern
62///
63/// The client (SDK) sends the session token as its "API key". The proxy:
64/// 1. Extracts the service from the path (e.g., `/openai/v1/chat` → `openai`)
65/// 2. Looks up which header that service uses (e.g., `Authorization` or `x-api-key`)
66/// 3. Validates the phantom token from that header
67/// 4. Replaces it with the real credential from keyring
68pub async fn handle_reverse_proxy(
69    first_line: &str,
70    stream: &mut TcpStream,
71    remaining_header: &[u8],
72    ctx: &ReverseProxyCtx<'_>,
73    buffered_body: &[u8],
74) -> Result<()> {
75    // Parse method, path, and HTTP version
76    let (method, path, version) = parse_request_line(first_line)?;
77    debug!("Reverse proxy: {} {}", method, path);
78
79    // Extract service prefix from path (e.g., "/openai/v1/chat" -> ("openai", "/v1/chat"))
80    let (service, upstream_path) = parse_service_prefix(&path)?;
81
82    // Look up credential for service
83    let cred = ctx
84        .credential_store
85        .get(&service)
86        .ok_or_else(|| ProxyError::UnknownService {
87            prefix: service.clone(),
88        })?;
89
90    // L7 endpoint filtering: check method+path against rules before any
91    // credential operations. Denied endpoints get 403 immediately.
92    if !cred.endpoint_rules.is_allowed(&method, &upstream_path) {
93        let reason = format!(
94            "endpoint denied: {} {} on service '{}'",
95            method, upstream_path, service
96        );
97        warn!("{}", reason);
98        audit::log_denied(
99            ctx.audit_log,
100            audit::ProxyMode::Reverse,
101            &service,
102            0,
103            &reason,
104        );
105        send_error(stream, 403, "Forbidden").await?;
106        return Ok(());
107    }
108
109    // Validate phantom token based on injection mode.
110    // For header/basic_auth modes: validate from Authorization/x-api-key header
111    // For url_path mode: validate from URL path pattern
112    // For query_param mode: validate from query parameter
113    if let Err(e) = validate_phantom_token_for_mode(
114        &cred.inject_mode,
115        remaining_header,
116        &upstream_path,
117        &cred.header_name,
118        cred.path_pattern.as_deref(),
119        cred.query_param_name.as_deref(),
120        ctx.session_token,
121    ) {
122        audit::log_denied(
123            ctx.audit_log,
124            audit::ProxyMode::Reverse,
125            &service,
126            0,
127            &e.to_string(),
128        );
129        send_error(stream, 401, "Unauthorized").await?;
130        return Ok(());
131    }
132
133    // Transform the path based on injection mode (url_path and query_param modes)
134    let transformed_path = transform_path_for_mode(
135        &cred.inject_mode,
136        &upstream_path,
137        cred.path_pattern.as_deref(),
138        cred.path_replacement.as_deref(),
139        cred.query_param_name.as_deref(),
140        &cred.raw_credential,
141    )?;
142
143    // Parse upstream URL with potentially transformed path
144    let upstream_url = format!(
145        "{}{}",
146        cred.upstream.trim_end_matches('/'),
147        transformed_path
148    );
149    debug!("Forwarding to upstream: {} {}", method, upstream_url);
150
151    let (upstream_host, upstream_port, upstream_path_full) = parse_upstream_url(&upstream_url)?;
152
153    // DNS resolve + host check via the filter
154    let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
155    if !check.result.is_allowed() {
156        let reason = check.result.reason();
157        warn!("Upstream host denied by filter: {}", reason);
158        send_error(stream, 403, "Forbidden").await?;
159        audit::log_denied(
160            ctx.audit_log,
161            audit::ProxyMode::Reverse,
162            &service,
163            0,
164            &reason,
165        );
166        return Ok(());
167    }
168
169    // Collect remaining request headers (excluding X-Nono-Token and Host)
170    let filtered_headers = filter_headers(remaining_header, &cred.header_name);
171    let content_length = extract_content_length(remaining_header);
172
173    // Read request body if present, with size limit.
174    // `buffered_body` may contain bytes the BufReader read ahead beyond
175    // headers; we prepend those to avoid data loss.
176    let body = if let Some(len) = content_length {
177        if len > MAX_REQUEST_BODY {
178            send_error(stream, 413, "Payload Too Large").await?;
179            return Ok(());
180        }
181        let mut buf = Vec::with_capacity(len);
182        let pre = buffered_body.len().min(len);
183        buf.extend_from_slice(&buffered_body[..pre]);
184        let remaining = len - pre;
185        if remaining > 0 {
186            let mut rest = vec![0u8; remaining];
187            stream.read_exact(&mut rest).await?;
188            buf.extend_from_slice(&rest);
189        }
190        buf
191    } else {
192        Vec::new()
193    };
194
195    // Connect to upstream over TLS using pre-resolved addresses.
196    // Use the per-route TLS connector (with custom CA) if configured,
197    // otherwise fall back to the shared default connector.
198    let connector = cred.tls_connector.as_ref().unwrap_or(ctx.tls_connector);
199    let upstream_result = connect_upstream_tls(
200        &upstream_host,
201        upstream_port,
202        &check.resolved_addrs,
203        connector,
204    )
205    .await;
206    let mut tls_stream = match upstream_result {
207        Ok(s) => s,
208        Err(e) => {
209            warn!("Upstream connection failed: {}", e);
210            send_error(stream, 502, "Bad Gateway").await?;
211            audit::log_denied(
212                ctx.audit_log,
213                audit::ProxyMode::Reverse,
214                &service,
215                0,
216                &e.to_string(),
217            );
218            return Ok(());
219        }
220    };
221
222    // Build the upstream request into a Zeroizing buffer since it may contain
223    // credential values. This ensures credentials are zeroed from heap memory
224    // when the buffer is dropped.
225    let mut request = Zeroizing::new(format!(
226        "{} {} {}\r\nHost: {}\r\n",
227        method, upstream_path_full, version, upstream_host
228    ));
229
230    // Inject credential based on mode
231    inject_credential_for_mode(cred, &mut request);
232
233    // Forward filtered headers (excluding auth headers that we're replacing)
234    let auth_header_lower = cred.header_name.to_lowercase();
235    for (name, value) in &filtered_headers {
236        // Skip the auth header if we're using header/basic_auth mode
237        // (we already injected our own)
238        if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
239            && name.to_lowercase() == auth_header_lower
240        {
241            continue;
242        }
243        request.push_str(&format!("{}: {}\r\n", name, value));
244    }
245
246    // Content-Length for body
247    if !body.is_empty() {
248        request.push_str(&format!("Content-Length: {}\r\n", body.len()));
249    }
250    request.push_str("\r\n");
251
252    tls_stream.write_all(request.as_bytes()).await?;
253    if !body.is_empty() {
254        tls_stream.write_all(&body).await?;
255    }
256    tls_stream.flush().await?;
257
258    // Stream the response back to the client without buffering.
259    // This handles SSE (text/event-stream), chunked transfer, and regular responses.
260    let mut response_buf = [0u8; 8192];
261    let mut status_code: u16 = 502;
262    let mut first_chunk = true;
263
264    loop {
265        let n = match tls_stream.read(&mut response_buf).await {
266            Ok(0) => break,
267            Ok(n) => n,
268            Err(e) => {
269                debug!("Upstream read error: {}", e);
270                break;
271            }
272        };
273
274        // Parse status from first chunk. The HTTP status line format is:
275        // "HTTP/1.1 200 OK\r\n..." — we need the 3-digit code after the
276        // first space. We scan up to 32 bytes (enough for any valid status line).
277        if first_chunk {
278            status_code = parse_response_status(&response_buf[..n]);
279            first_chunk = false;
280        }
281
282        stream.write_all(&response_buf[..n]).await?;
283        stream.flush().await?;
284    }
285
286    audit::log_reverse_proxy(
287        ctx.audit_log,
288        &service,
289        &method,
290        &upstream_path,
291        status_code,
292    );
293    Ok(())
294}
295
296/// Parse an HTTP request line into (method, path, version).
297fn parse_request_line(line: &str) -> Result<(String, String, String)> {
298    let parts: Vec<&str> = line.split_whitespace().collect();
299    if parts.len() < 3 {
300        return Err(ProxyError::HttpParse(format!(
301            "malformed request line: {}",
302            line
303        )));
304    }
305    Ok((
306        parts[0].to_string(),
307        parts[1].to_string(),
308        parts[2].to_string(),
309    ))
310}
311
312/// Extract service prefix from path.
313///
314/// "/openai/v1/chat/completions" -> ("openai", "/v1/chat/completions")
315/// "/anthropic/v1/messages" -> ("anthropic", "/v1/messages")
316fn parse_service_prefix(path: &str) -> Result<(String, String)> {
317    let trimmed = path.strip_prefix('/').unwrap_or(path);
318    if let Some((prefix, rest)) = trimmed.split_once('/') {
319        Ok((prefix.to_string(), format!("/{}", rest)))
320    } else {
321        // No sub-path, just the prefix
322        Ok((trimmed.to_string(), "/".to_string()))
323    }
324}
325
326/// Validate the phantom token from the service's auth header.
327///
328/// The SDK sends the session token as its "API key" in the standard auth header
329/// for that service (e.g., `Authorization: Bearer <token>` for OpenAI,
330/// `x-api-key: <token>` for Anthropic). We validate the token matches the
331/// session token before swapping in the real credential.
332fn validate_phantom_token(
333    header_bytes: &[u8],
334    header_name: &str,
335    session_token: &Zeroizing<String>,
336) -> Result<()> {
337    let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
338    let header_name_lower = header_name.to_lowercase();
339
340    for line in header_str.lines() {
341        let lower = line.to_lowercase();
342        if lower.starts_with(&format!("{}:", header_name_lower)) {
343            let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
344
345            // Handle "Bearer <token>" format (strip "Bearer " prefix if present)
346            // Use case-insensitive check, then slice original value by length
347            let value_lower = value.to_lowercase();
348            let token_value = if value_lower.starts_with("bearer ") {
349                // "bearer ".len() == 7
350                value[7..].trim()
351            } else {
352                value
353            };
354
355            if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
356                return Ok(());
357            }
358            warn!("Invalid phantom token in {} header", header_name);
359            return Err(ProxyError::InvalidToken);
360        }
361    }
362
363    warn!(
364        "Missing {} header for phantom token validation",
365        header_name
366    );
367    Err(ProxyError::InvalidToken)
368}
369
370/// Filter headers, removing Host, Content-Length, and the credential header.
371///
372/// Content-Length is re-added after body is read, and Host is rewritten
373/// to the upstream. The route's configured credential header is stripped
374/// so the phantom token is not forwarded alongside the real credential.
375fn filter_headers(header_bytes: &[u8], cred_header: &str) -> Vec<(String, String)> {
376    let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
377    let cred_header_lower = format!("{}:", cred_header.to_lowercase());
378    let mut headers = Vec::new();
379
380    for line in header_str.lines() {
381        let lower = line.to_lowercase();
382        if lower.starts_with("host:")
383            || lower.starts_with("content-length:")
384            || lower.starts_with(&cred_header_lower)
385            || line.trim().is_empty()
386        {
387            continue;
388        }
389        if let Some((name, value)) = line.split_once(':') {
390            headers.push((name.trim().to_string(), value.trim().to_string()));
391        }
392    }
393
394    headers
395}
396
397/// Extract Content-Length value from raw headers.
398fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
399    let header_str = std::str::from_utf8(header_bytes).ok()?;
400    for line in header_str.lines() {
401        if line.to_lowercase().starts_with("content-length:") {
402            let value = line.split_once(':')?.1.trim();
403            return value.parse().ok();
404        }
405    }
406    None
407}
408
409/// Parse an upstream URL into (host, port, path).
410fn parse_upstream_url(url_str: &str) -> Result<(String, u16, String)> {
411    let parsed = url::Url::parse(url_str)
412        .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
413
414    let scheme = parsed.scheme();
415    if scheme != "https" && scheme != "http" {
416        return Err(ProxyError::HttpParse(format!(
417            "unsupported URL scheme: {}",
418            url_str
419        )));
420    }
421
422    let host = parsed
423        .host_str()
424        .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
425        .to_string();
426
427    let default_port = if scheme == "https" { 443 } else { 80 };
428    let port = parsed.port().unwrap_or(default_port);
429
430    let path = parsed.path().to_string();
431    let path = if path.is_empty() {
432        "/".to_string()
433    } else {
434        path
435    };
436
437    // Include query string if present
438    let path_with_query = if let Some(query) = parsed.query() {
439        format!("{}?{}", path, query)
440    } else {
441        path
442    };
443
444    Ok((host, port, path_with_query))
445}
446
447/// Connect to an upstream host over TLS using pre-resolved addresses.
448///
449/// Uses the pre-resolved `SocketAddr`s from the filter check to prevent
450/// DNS rebinding TOCTOU. Falls back to hostname resolution only if no
451/// pre-resolved addresses are available.
452///
453/// The `TlsConnector` is shared across all connections (created once at
454/// server startup with the system root certificate store).
455async fn connect_upstream_tls(
456    host: &str,
457    port: u16,
458    resolved_addrs: &[SocketAddr],
459    connector: &TlsConnector,
460) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
461    let tcp = if resolved_addrs.is_empty() {
462        // Fallback: no pre-resolved addresses (shouldn't happen in practice)
463        let addr = format!("{}:{}", host, port);
464        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
465            Ok(Ok(s)) => s,
466            Ok(Err(e)) => {
467                return Err(ProxyError::UpstreamConnect {
468                    host: host.to_string(),
469                    reason: e.to_string(),
470                });
471            }
472            Err(_) => {
473                return Err(ProxyError::UpstreamConnect {
474                    host: host.to_string(),
475                    reason: "connection timed out".to_string(),
476                });
477            }
478        }
479    } else {
480        connect_to_resolved(resolved_addrs, host).await?
481    };
482
483    let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
484        ProxyError::UpstreamConnect {
485            host: host.to_string(),
486            reason: "invalid server name for TLS".to_string(),
487        }
488    })?;
489
490    let tls_stream =
491        connector
492            .connect(server_name, tcp)
493            .await
494            .map_err(|e| ProxyError::UpstreamConnect {
495                host: host.to_string(),
496                reason: format!("TLS handshake failed: {}", e),
497            })?;
498
499    Ok(tls_stream)
500}
501
502/// Connect to one of the pre-resolved socket addresses with timeout.
503async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
504    let mut last_err = None;
505    for addr in addrs {
506        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
507            Ok(Ok(stream)) => return Ok(stream),
508            Ok(Err(e)) => {
509                debug!("Connect to {} failed: {}", addr, e);
510                last_err = Some(e.to_string());
511            }
512            Err(_) => {
513                debug!("Connect to {} timed out", addr);
514                last_err = Some("connection timed out".to_string());
515            }
516        }
517    }
518    Err(ProxyError::UpstreamConnect {
519        host: host.to_string(),
520        reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
521    })
522}
523
524/// Parse HTTP status code from the first response chunk.
525///
526/// Looks for the "HTTP/x.y NNN" pattern in the first line. Returns 502
527/// if the response doesn't contain a valid status line (upstream sent
528/// garbage or incomplete data).
529fn parse_response_status(data: &[u8]) -> u16 {
530    // Find the end of the first line (or use full data if no newline)
531    let line_end = data
532        .iter()
533        .position(|&b| b == b'\r' || b == b'\n')
534        .unwrap_or(data.len());
535    let first_line = &data[..line_end.min(64)];
536
537    if let Ok(line) = std::str::from_utf8(first_line) {
538        // Split on whitespace: ["HTTP/1.1", "200", "OK"]
539        let mut parts = line.split_whitespace();
540        if let Some(version) = parts.next() {
541            if version.starts_with("HTTP/") {
542                if let Some(code_str) = parts.next() {
543                    if code_str.len() == 3 {
544                        return code_str.parse().unwrap_or(502);
545                    }
546                }
547            }
548        }
549    }
550    502
551}
552
553/// Send an HTTP error response.
554async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
555    let body = format!("{{\"error\":\"{}\"}}", reason);
556    let response = format!(
557        "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
558        status,
559        reason,
560        body.len(),
561        body
562    );
563    stream.write_all(response.as_bytes()).await?;
564    stream.flush().await?;
565    Ok(())
566}
567
568// ============================================================================
569// Injection mode helpers
570// ============================================================================
571
572/// Validate phantom token based on injection mode.
573///
574/// Different modes extract the phantom token from different locations:
575/// - `Header`/`BasicAuth`: From the auth header (Authorization, x-api-key, etc.)
576/// - `UrlPath`: From the URL path pattern (e.g., `/bot<token>/getMe`)
577/// - `QueryParam`: From the query parameter (e.g., `?api_key=<token>`)
578fn validate_phantom_token_for_mode(
579    mode: &InjectMode,
580    header_bytes: &[u8],
581    path: &str,
582    header_name: &str,
583    path_pattern: Option<&str>,
584    query_param_name: Option<&str>,
585    session_token: &Zeroizing<String>,
586) -> Result<()> {
587    match mode {
588        InjectMode::Header | InjectMode::BasicAuth => {
589            // Validate from header (existing behavior)
590            validate_phantom_token(header_bytes, header_name, session_token)
591        }
592        InjectMode::UrlPath => {
593            // Validate from URL path
594            let pattern = path_pattern.ok_or_else(|| {
595                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
596            })?;
597            validate_phantom_token_in_path(path, pattern, session_token)
598        }
599        InjectMode::QueryParam => {
600            // Validate from query parameter
601            let param_name = query_param_name.ok_or_else(|| {
602                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
603            })?;
604            validate_phantom_token_in_query(path, param_name, session_token)
605        }
606    }
607}
608
609/// Validate phantom token embedded in URL path.
610///
611/// Extracts the token from the path using the pattern (e.g., `/bot{}/` matches
612/// `/bot<token>/getMe` and extracts `<token>`).
613fn validate_phantom_token_in_path(
614    path: &str,
615    pattern: &str,
616    session_token: &Zeroizing<String>,
617) -> Result<()> {
618    // Split pattern on {} to get prefix and suffix
619    let parts: Vec<&str> = pattern.split("{}").collect();
620    if parts.len() != 2 {
621        return Err(ProxyError::HttpParse(format!(
622            "invalid path_pattern '{}': must contain exactly one {{}}",
623            pattern
624        )));
625    }
626    let (prefix, suffix) = (parts[0], parts[1]);
627
628    // Find the token in the path
629    if let Some(start) = path.find(prefix) {
630        let after_prefix = start + prefix.len();
631
632        // Handle empty suffix case (token extends to end of path or next '/' or '?')
633        let end_offset = if suffix.is_empty() {
634            path[after_prefix..]
635                .find(['/', '?'])
636                .unwrap_or(path[after_prefix..].len())
637        } else {
638            match path[after_prefix..].find(suffix) {
639                Some(offset) => offset,
640                None => {
641                    warn!("Missing phantom token in URL path (pattern: {})", pattern);
642                    return Err(ProxyError::InvalidToken);
643                }
644            }
645        };
646
647        let token = &path[after_prefix..after_prefix + end_offset];
648        if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
649            return Ok(());
650        }
651        warn!("Invalid phantom token in URL path");
652        return Err(ProxyError::InvalidToken);
653    }
654
655    warn!("Missing phantom token in URL path (pattern: {})", pattern);
656    Err(ProxyError::InvalidToken)
657}
658
659/// Validate phantom token in query parameter.
660fn validate_phantom_token_in_query(
661    path: &str,
662    param_name: &str,
663    session_token: &Zeroizing<String>,
664) -> Result<()> {
665    // Parse query string from path
666    if let Some(query_start) = path.find('?') {
667        let query = &path[query_start + 1..];
668        for pair in query.split('&') {
669            if let Some((name, value)) = pair.split_once('=') {
670                if name == param_name {
671                    // URL-decode the value
672                    let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
673                    if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
674                        return Ok(());
675                    }
676                    warn!("Invalid phantom token in query parameter '{}'", param_name);
677                    return Err(ProxyError::InvalidToken);
678                }
679            }
680        }
681    }
682
683    warn!("Missing phantom token in query parameter '{}'", param_name);
684    Err(ProxyError::InvalidToken)
685}
686
687/// Transform URL path based on injection mode.
688///
689/// - `UrlPath`: Replace phantom token with real credential in path
690/// - `QueryParam`: Add/replace query parameter with real credential
691/// - `Header`/`BasicAuth`: No path transformation needed
692fn transform_path_for_mode(
693    mode: &InjectMode,
694    path: &str,
695    path_pattern: Option<&str>,
696    path_replacement: Option<&str>,
697    query_param_name: Option<&str>,
698    credential: &Zeroizing<String>,
699) -> Result<String> {
700    match mode {
701        InjectMode::Header | InjectMode::BasicAuth => {
702            // No path transformation needed
703            Ok(path.to_string())
704        }
705        InjectMode::UrlPath => {
706            let pattern = path_pattern.ok_or_else(|| {
707                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
708            })?;
709            let replacement = path_replacement.unwrap_or(pattern);
710            transform_url_path(path, pattern, replacement, credential)
711        }
712        InjectMode::QueryParam => {
713            let param_name = query_param_name.ok_or_else(|| {
714                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
715            })?;
716            transform_query_param(path, param_name, credential)
717        }
718    }
719}
720
721/// Transform URL path by replacing phantom token pattern with real credential.
722///
723/// Example: `/bot<phantom>/getMe` with pattern `/bot{}/` becomes `/bot<real>/getMe`
724fn transform_url_path(
725    path: &str,
726    pattern: &str,
727    replacement: &str,
728    credential: &Zeroizing<String>,
729) -> Result<String> {
730    // Split pattern on {} to get prefix and suffix
731    let parts: Vec<&str> = pattern.split("{}").collect();
732    if parts.len() != 2 {
733        return Err(ProxyError::HttpParse(format!(
734            "invalid path_pattern '{}': must contain exactly one {{}}",
735            pattern
736        )));
737    }
738    let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
739
740    // Split replacement on {}
741    let repl_parts: Vec<&str> = replacement.split("{}").collect();
742    if repl_parts.len() != 2 {
743        return Err(ProxyError::HttpParse(format!(
744            "invalid path_replacement '{}': must contain exactly one {{}}",
745            replacement
746        )));
747    }
748    let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
749
750    // Find and replace the token in the path
751    if let Some(start) = path.find(pattern_prefix) {
752        let after_prefix = start + pattern_prefix.len();
753
754        // Handle empty suffix case (token extends to end of path or next '/' or '?')
755        let end_offset = if pattern_suffix.is_empty() {
756            // Find the next path segment delimiter or end of path
757            path[after_prefix..]
758                .find(['/', '?'])
759                .unwrap_or(path[after_prefix..].len())
760        } else {
761            // Find the suffix in the remaining path
762            match path[after_prefix..].find(pattern_suffix) {
763                Some(offset) => offset,
764                None => {
765                    return Err(ProxyError::HttpParse(format!(
766                        "path '{}' does not match pattern '{}'",
767                        path, pattern
768                    )));
769                }
770            }
771        };
772
773        let before = &path[..start];
774        let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
775        return Ok(format!(
776            "{}{}{}{}{}",
777            before,
778            repl_prefix,
779            credential.as_str(),
780            repl_suffix,
781            after
782        ));
783    }
784
785    Err(ProxyError::HttpParse(format!(
786        "path '{}' does not match pattern '{}'",
787        path, pattern
788    )))
789}
790
791/// Transform query string by adding or replacing a parameter with the credential.
792fn transform_query_param(
793    path: &str,
794    param_name: &str,
795    credential: &Zeroizing<String>,
796) -> Result<String> {
797    let encoded_value = urlencoding::encode(credential.as_str());
798
799    if let Some(query_start) = path.find('?') {
800        let base_path = &path[..query_start];
801        let query = &path[query_start + 1..];
802
803        // Check if parameter already exists
804        let mut found = false;
805        let new_query: Vec<String> = query
806            .split('&')
807            .map(|pair| {
808                if let Some((name, _)) = pair.split_once('=') {
809                    if name == param_name {
810                        found = true;
811                        return format!("{}={}", param_name, encoded_value);
812                    }
813                }
814                pair.to_string()
815            })
816            .collect();
817
818        if found {
819            Ok(format!("{}?{}", base_path, new_query.join("&")))
820        } else {
821            // Append the parameter
822            Ok(format!(
823                "{}?{}&{}={}",
824                base_path, query, param_name, encoded_value
825            ))
826        }
827    } else {
828        // No query string, add one
829        Ok(format!("{}?{}={}", path, param_name, encoded_value))
830    }
831}
832
833/// Inject credential into request based on mode.
834///
835/// For header/basic_auth modes, adds the credential header.
836/// For url_path/query_param modes, the credential is already in the path.
837fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
838    match cred.inject_mode {
839        InjectMode::Header | InjectMode::BasicAuth => {
840            // Inject credential header
841            request.push_str(&format!(
842                "{}: {}\r\n",
843                cred.header_name,
844                cred.header_value.as_str()
845            ));
846        }
847        InjectMode::UrlPath | InjectMode::QueryParam => {
848            // Credential is already injected into the URL path/query
849            // No header injection needed
850        }
851    }
852}
853
854#[cfg(test)]
855#[allow(clippy::unwrap_used)]
856mod tests {
857    use super::*;
858
859    #[test]
860    fn test_parse_request_line() {
861        let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
862        assert_eq!(method, "POST");
863        assert_eq!(path, "/openai/v1/chat");
864        assert_eq!(version, "HTTP/1.1");
865    }
866
867    #[test]
868    fn test_parse_request_line_malformed() {
869        assert!(parse_request_line("GET").is_err());
870    }
871
872    #[test]
873    fn test_parse_service_prefix() {
874        let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
875        assert_eq!(service, "openai");
876        assert_eq!(path, "/v1/chat/completions");
877    }
878
879    #[test]
880    fn test_parse_service_prefix_no_subpath() {
881        let (service, path) = parse_service_prefix("/anthropic").unwrap();
882        assert_eq!(service, "anthropic");
883        assert_eq!(path, "/");
884    }
885
886    #[test]
887    fn test_validate_phantom_token_bearer_valid() {
888        let token = Zeroizing::new("secret123".to_string());
889        let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
890        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
891    }
892
893    #[test]
894    fn test_validate_phantom_token_bearer_invalid() {
895        let token = Zeroizing::new("secret123".to_string());
896        let header = b"Authorization: Bearer wrong\r\n\r\n";
897        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
898    }
899
900    #[test]
901    fn test_validate_phantom_token_x_api_key_valid() {
902        let token = Zeroizing::new("secret123".to_string());
903        let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
904        assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
905    }
906
907    #[test]
908    fn test_validate_phantom_token_x_goog_api_key_valid() {
909        let token = Zeroizing::new("secret123".to_string());
910        let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
911        assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
912    }
913
914    #[test]
915    fn test_validate_phantom_token_missing() {
916        let token = Zeroizing::new("secret123".to_string());
917        let header = b"Content-Type: application/json\r\n\r\n";
918        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
919    }
920
921    #[test]
922    fn test_validate_phantom_token_case_insensitive_header() {
923        let token = Zeroizing::new("secret123".to_string());
924        let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
925        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
926    }
927
928    #[test]
929    fn test_filter_headers_removes_host_auth() {
930        let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
931        let filtered = filter_headers(header, "Authorization");
932        assert_eq!(filtered.len(), 2);
933        assert_eq!(filtered[0].0, "Content-Type");
934        assert_eq!(filtered[1].0, "Accept");
935    }
936
937    #[test]
938    fn test_filter_headers_removes_x_api_key() {
939        let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
940        let filtered = filter_headers(header, "x-api-key");
941        assert_eq!(filtered.len(), 1);
942        assert_eq!(filtered[0].0, "Content-Type");
943    }
944
945    #[test]
946    fn test_filter_headers_removes_custom_header() {
947        let header = b"PRIVATE-TOKEN: phantom123\r\nContent-Type: application/json\r\n\r\n";
948        let filtered = filter_headers(header, "PRIVATE-TOKEN");
949        assert_eq!(filtered.len(), 1);
950        assert_eq!(filtered[0].0, "Content-Type");
951    }
952
953    #[test]
954    fn test_extract_content_length() {
955        let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
956        assert_eq!(extract_content_length(header), Some(42));
957    }
958
959    #[test]
960    fn test_extract_content_length_missing() {
961        let header = b"Content-Type: application/json\r\n\r\n";
962        assert_eq!(extract_content_length(header), None);
963    }
964
965    #[test]
966    fn test_parse_upstream_url_https() {
967        let (host, port, path) =
968            parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
969        assert_eq!(host, "api.openai.com");
970        assert_eq!(port, 443);
971        assert_eq!(path, "/v1/chat/completions");
972    }
973
974    #[test]
975    fn test_parse_upstream_url_http_with_port() {
976        let (host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
977        assert_eq!(host, "localhost");
978        assert_eq!(port, 8080);
979        assert_eq!(path, "/api");
980    }
981
982    #[test]
983    fn test_parse_upstream_url_no_path() {
984        let (host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
985        assert_eq!(host, "api.anthropic.com");
986        assert_eq!(port, 443);
987        assert_eq!(path, "/");
988    }
989
990    #[test]
991    fn test_parse_upstream_url_invalid_scheme() {
992        assert!(parse_upstream_url("ftp://example.com").is_err());
993    }
994
995    #[test]
996    fn test_parse_response_status_200() {
997        let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
998        assert_eq!(parse_response_status(data), 200);
999    }
1000
1001    #[test]
1002    fn test_parse_response_status_404() {
1003        let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
1004        assert_eq!(parse_response_status(data), 404);
1005    }
1006
1007    #[test]
1008    fn test_parse_response_status_garbage() {
1009        let data = b"not an http response";
1010        assert_eq!(parse_response_status(data), 502);
1011    }
1012
1013    #[test]
1014    fn test_parse_response_status_empty() {
1015        assert_eq!(parse_response_status(b""), 502);
1016    }
1017
1018    #[test]
1019    fn test_parse_response_status_partial() {
1020        let data = b"HTTP/1.1 ";
1021        assert_eq!(parse_response_status(data), 502);
1022    }
1023
1024    // ============================================================================
1025    // URL Path Injection Mode Tests
1026    // ============================================================================
1027
1028    #[test]
1029    fn test_validate_phantom_token_in_path_valid() {
1030        let token = Zeroizing::new("session123".to_string());
1031        let path = "/bot/session123/getMe";
1032        let pattern = "/bot/{}/";
1033        assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1034    }
1035
1036    #[test]
1037    fn test_validate_phantom_token_in_path_invalid() {
1038        let token = Zeroizing::new("session123".to_string());
1039        let path = "/bot/wrong_token/getMe";
1040        let pattern = "/bot/{}/";
1041        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1042    }
1043
1044    #[test]
1045    fn test_validate_phantom_token_in_path_missing() {
1046        let token = Zeroizing::new("session123".to_string());
1047        let path = "/api/getMe";
1048        let pattern = "/bot/{}/";
1049        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1050    }
1051
1052    #[test]
1053    fn test_transform_url_path_basic() {
1054        let credential = Zeroizing::new("real_token".to_string());
1055        let path = "/bot/phantom_token/getMe";
1056        let pattern = "/bot/{}/";
1057        let replacement = "/bot/{}/";
1058        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1059        assert_eq!(result, "/bot/real_token/getMe");
1060    }
1061
1062    #[test]
1063    fn test_transform_url_path_different_replacement() {
1064        let credential = Zeroizing::new("real_token".to_string());
1065        let path = "/api/v1/phantom_token/chat";
1066        let pattern = "/api/v1/{}/";
1067        let replacement = "/v2/bot/{}/";
1068        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1069        assert_eq!(result, "/v2/bot/real_token/chat");
1070    }
1071
1072    #[test]
1073    fn test_transform_url_path_no_trailing_slash() {
1074        let credential = Zeroizing::new("real_token".to_string());
1075        let path = "/bot/phantom_token";
1076        let pattern = "/bot/{}";
1077        let replacement = "/bot/{}";
1078        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1079        assert_eq!(result, "/bot/real_token");
1080    }
1081
1082    // ============================================================================
1083    // Query Param Injection Mode Tests
1084    // ============================================================================
1085
1086    #[test]
1087    fn test_validate_phantom_token_in_query_valid() {
1088        let token = Zeroizing::new("session123".to_string());
1089        let path = "/api/data?api_key=session123&other=value";
1090        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1091    }
1092
1093    #[test]
1094    fn test_validate_phantom_token_in_query_invalid() {
1095        let token = Zeroizing::new("session123".to_string());
1096        let path = "/api/data?api_key=wrong_token";
1097        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1098    }
1099
1100    #[test]
1101    fn test_validate_phantom_token_in_query_missing_param() {
1102        let token = Zeroizing::new("session123".to_string());
1103        let path = "/api/data?other=value";
1104        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1105    }
1106
1107    #[test]
1108    fn test_validate_phantom_token_in_query_no_query_string() {
1109        let token = Zeroizing::new("session123".to_string());
1110        let path = "/api/data";
1111        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1112    }
1113
1114    #[test]
1115    fn test_validate_phantom_token_in_query_url_encoded() {
1116        let token = Zeroizing::new("token with spaces".to_string());
1117        let path = "/api/data?api_key=token%20with%20spaces";
1118        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1119    }
1120
1121    #[test]
1122    fn test_transform_query_param_add_to_no_query() {
1123        let credential = Zeroizing::new("real_key".to_string());
1124        let path = "/api/data";
1125        let result = transform_query_param(path, "api_key", &credential).unwrap();
1126        assert_eq!(result, "/api/data?api_key=real_key");
1127    }
1128
1129    #[test]
1130    fn test_transform_query_param_add_to_existing_query() {
1131        let credential = Zeroizing::new("real_key".to_string());
1132        let path = "/api/data?other=value";
1133        let result = transform_query_param(path, "api_key", &credential).unwrap();
1134        assert_eq!(result, "/api/data?other=value&api_key=real_key");
1135    }
1136
1137    #[test]
1138    fn test_transform_query_param_replace_existing() {
1139        let credential = Zeroizing::new("real_key".to_string());
1140        let path = "/api/data?api_key=phantom&other=value";
1141        let result = transform_query_param(path, "api_key", &credential).unwrap();
1142        assert_eq!(result, "/api/data?api_key=real_key&other=value");
1143    }
1144
1145    #[test]
1146    fn test_transform_query_param_url_encodes_special_chars() {
1147        let credential = Zeroizing::new("key with spaces".to_string());
1148        let path = "/api/data";
1149        let result = transform_query_param(path, "api_key", &credential).unwrap();
1150        assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1151    }
1152}