Skip to main content

nono_proxy/
reverse.rs

1//! Reverse proxy handler (Mode 2 — Credential Injection).
2//!
3//! Routes requests by path prefix to upstream APIs, injecting credentials
4//! from the keystore. The agent uses `http://localhost:PORT/openai/v1/chat`
5//! and the proxy rewrites to `https://api.openai.com/v1/chat` with the
6//! real credential injected.
7//!
8//! Supports multiple injection modes:
9//! - `header`: Inject into HTTP header (e.g., `Authorization: Bearer ...`)
10//! - `url_path`: Replace pattern in URL path (e.g., Telegram `/bot{}/`)
11//! - `query_param`: Add/replace query parameter (e.g., `?api_key=...`)
12//! - `basic_auth`: HTTP Basic Authentication
13//!
14//! Streaming responses (SSE, MCP Streamable HTTP, A2A JSON-RPC) are
15//! forwarded without buffering.
16
17use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::token;
23use std::net::SocketAddr;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31/// Maximum request body size (16 MiB). Prevents DoS from malicious Content-Length.
32const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
33
34/// Timeout for upstream TCP connect.
35const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37/// Handle a non-CONNECT HTTP request (reverse proxy mode).
38///
39/// Reads the full HTTP request from the client, matches path prefix to
40/// a configured route, injects credentials, and forwards to the upstream.
41/// Shared context passed from the server to the reverse proxy handler.
42pub struct ReverseProxyCtx<'a> {
43    /// Credential store for service lookups
44    pub credential_store: &'a CredentialStore,
45    /// Session token for authentication
46    pub session_token: &'a Zeroizing<String>,
47    /// Host filter for upstream validation
48    pub filter: &'a ProxyFilter,
49    /// Shared TLS connector
50    pub tls_connector: &'a TlsConnector,
51    /// Shared network audit sink for session metadata capture
52    pub audit_log: Option<&'a audit::SharedAuditLog>,
53}
54
55/// Handle a non-CONNECT HTTP request (reverse proxy mode).
56///
57/// `buffered_body` contains any bytes the BufReader read ahead beyond the
58/// headers. These are prepended to the body read from the stream to prevent
59/// data loss.
60///
61/// ## Phantom Token Pattern
62///
63/// The client (SDK) sends the session token as its "API key". The proxy:
64/// 1. Extracts the service from the path (e.g., `/openai/v1/chat` → `openai`)
65/// 2. Looks up which header that service uses (e.g., `Authorization` or `x-api-key`)
66/// 3. Validates the phantom token from that header
67/// 4. Replaces it with the real credential from keyring
68pub async fn handle_reverse_proxy(
69    first_line: &str,
70    stream: &mut TcpStream,
71    remaining_header: &[u8],
72    ctx: &ReverseProxyCtx<'_>,
73    buffered_body: &[u8],
74) -> Result<()> {
75    // Parse method, path, and HTTP version
76    let (method, path, version) = parse_request_line(first_line)?;
77    debug!("Reverse proxy: {} {}", method, path);
78
79    // Extract service prefix from path (e.g., "/openai/v1/chat" -> ("openai", "/v1/chat"))
80    let (service, upstream_path) = parse_service_prefix(&path)?;
81
82    // Look up credential for service
83    let cred = ctx
84        .credential_store
85        .get(&service)
86        .ok_or_else(|| ProxyError::UnknownService {
87            prefix: service.clone(),
88        })?;
89
90    // Validate phantom token based on injection mode.
91    // For header/basic_auth modes: validate from Authorization/x-api-key header
92    // For url_path mode: validate from URL path pattern
93    // For query_param mode: validate from query parameter
94    if let Err(e) = validate_phantom_token_for_mode(
95        &cred.inject_mode,
96        remaining_header,
97        &upstream_path,
98        &cred.header_name,
99        cred.path_pattern.as_deref(),
100        cred.query_param_name.as_deref(),
101        ctx.session_token,
102    ) {
103        audit::log_denied(
104            ctx.audit_log,
105            audit::ProxyMode::Reverse,
106            &service,
107            0,
108            &e.to_string(),
109        );
110        send_error(stream, 401, "Unauthorized").await?;
111        return Ok(());
112    }
113
114    // Transform the path based on injection mode (url_path and query_param modes)
115    let transformed_path = transform_path_for_mode(
116        &cred.inject_mode,
117        &upstream_path,
118        cred.path_pattern.as_deref(),
119        cred.path_replacement.as_deref(),
120        cred.query_param_name.as_deref(),
121        &cred.raw_credential,
122    )?;
123
124    // Parse upstream URL with potentially transformed path
125    let upstream_url = format!(
126        "{}{}",
127        cred.upstream.trim_end_matches('/'),
128        transformed_path
129    );
130    debug!("Forwarding to upstream: {} {}", method, upstream_url);
131
132    let (upstream_host, upstream_port, upstream_path_full) = parse_upstream_url(&upstream_url)?;
133
134    // DNS resolve + host check via the filter
135    let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
136    if !check.result.is_allowed() {
137        let reason = check.result.reason();
138        warn!("Upstream host denied by filter: {}", reason);
139        send_error(stream, 403, "Forbidden").await?;
140        audit::log_denied(
141            ctx.audit_log,
142            audit::ProxyMode::Reverse,
143            &service,
144            0,
145            &reason,
146        );
147        return Ok(());
148    }
149
150    // Collect remaining request headers (excluding X-Nono-Token and Host)
151    let filtered_headers = filter_headers(remaining_header);
152    let content_length = extract_content_length(remaining_header);
153
154    // Read request body if present, with size limit.
155    // `buffered_body` may contain bytes the BufReader read ahead beyond
156    // headers; we prepend those to avoid data loss.
157    let body = if let Some(len) = content_length {
158        if len > MAX_REQUEST_BODY {
159            send_error(stream, 413, "Payload Too Large").await?;
160            return Ok(());
161        }
162        let mut buf = Vec::with_capacity(len);
163        let pre = buffered_body.len().min(len);
164        buf.extend_from_slice(&buffered_body[..pre]);
165        let remaining = len - pre;
166        if remaining > 0 {
167            let mut rest = vec![0u8; remaining];
168            stream.read_exact(&mut rest).await?;
169            buf.extend_from_slice(&rest);
170        }
171        buf
172    } else {
173        Vec::new()
174    };
175
176    // Connect to upstream over TLS using pre-resolved addresses
177    let upstream_result = connect_upstream_tls(
178        &upstream_host,
179        upstream_port,
180        &check.resolved_addrs,
181        ctx.tls_connector,
182    )
183    .await;
184    let mut tls_stream = match upstream_result {
185        Ok(s) => s,
186        Err(e) => {
187            warn!("Upstream connection failed: {}", e);
188            send_error(stream, 502, "Bad Gateway").await?;
189            audit::log_denied(
190                ctx.audit_log,
191                audit::ProxyMode::Reverse,
192                &service,
193                0,
194                &e.to_string(),
195            );
196            return Ok(());
197        }
198    };
199
200    // Build the upstream request into a Zeroizing buffer since it may contain
201    // credential values. This ensures credentials are zeroed from heap memory
202    // when the buffer is dropped.
203    let mut request = Zeroizing::new(format!(
204        "{} {} {}\r\nHost: {}\r\n",
205        method, upstream_path_full, version, upstream_host
206    ));
207
208    // Inject credential based on mode
209    inject_credential_for_mode(cred, &mut request);
210
211    // Forward filtered headers (excluding auth headers that we're replacing)
212    let auth_header_lower = cred.header_name.to_lowercase();
213    for (name, value) in &filtered_headers {
214        // Skip the auth header if we're using header/basic_auth mode
215        // (we already injected our own)
216        if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
217            && name.to_lowercase() == auth_header_lower
218        {
219            continue;
220        }
221        request.push_str(&format!("{}: {}\r\n", name, value));
222    }
223
224    // Content-Length for body
225    if !body.is_empty() {
226        request.push_str(&format!("Content-Length: {}\r\n", body.len()));
227    }
228    request.push_str("\r\n");
229
230    tls_stream.write_all(request.as_bytes()).await?;
231    if !body.is_empty() {
232        tls_stream.write_all(&body).await?;
233    }
234    tls_stream.flush().await?;
235
236    // Stream the response back to the client without buffering.
237    // This handles SSE (text/event-stream), chunked transfer, and regular responses.
238    let mut response_buf = [0u8; 8192];
239    let mut status_code: u16 = 502;
240    let mut first_chunk = true;
241
242    loop {
243        let n = match tls_stream.read(&mut response_buf).await {
244            Ok(0) => break,
245            Ok(n) => n,
246            Err(e) => {
247                debug!("Upstream read error: {}", e);
248                break;
249            }
250        };
251
252        // Parse status from first chunk. The HTTP status line format is:
253        // "HTTP/1.1 200 OK\r\n..." — we need the 3-digit code after the
254        // first space. We scan up to 32 bytes (enough for any valid status line).
255        if first_chunk {
256            status_code = parse_response_status(&response_buf[..n]);
257            first_chunk = false;
258        }
259
260        stream.write_all(&response_buf[..n]).await?;
261        stream.flush().await?;
262    }
263
264    audit::log_reverse_proxy(
265        ctx.audit_log,
266        &service,
267        &method,
268        &upstream_path,
269        status_code,
270    );
271    Ok(())
272}
273
274/// Parse an HTTP request line into (method, path, version).
275fn parse_request_line(line: &str) -> Result<(String, String, String)> {
276    let parts: Vec<&str> = line.split_whitespace().collect();
277    if parts.len() < 3 {
278        return Err(ProxyError::HttpParse(format!(
279            "malformed request line: {}",
280            line
281        )));
282    }
283    Ok((
284        parts[0].to_string(),
285        parts[1].to_string(),
286        parts[2].to_string(),
287    ))
288}
289
290/// Extract service prefix from path.
291///
292/// "/openai/v1/chat/completions" -> ("openai", "/v1/chat/completions")
293/// "/anthropic/v1/messages" -> ("anthropic", "/v1/messages")
294fn parse_service_prefix(path: &str) -> Result<(String, String)> {
295    let trimmed = path.strip_prefix('/').unwrap_or(path);
296    if let Some((prefix, rest)) = trimmed.split_once('/') {
297        Ok((prefix.to_string(), format!("/{}", rest)))
298    } else {
299        // No sub-path, just the prefix
300        Ok((trimmed.to_string(), "/".to_string()))
301    }
302}
303
304/// Validate the phantom token from the service's auth header.
305///
306/// The SDK sends the session token as its "API key" in the standard auth header
307/// for that service (e.g., `Authorization: Bearer <token>` for OpenAI,
308/// `x-api-key: <token>` for Anthropic). We validate the token matches the
309/// session token before swapping in the real credential.
310fn validate_phantom_token(
311    header_bytes: &[u8],
312    header_name: &str,
313    session_token: &Zeroizing<String>,
314) -> Result<()> {
315    let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
316    let header_name_lower = header_name.to_lowercase();
317
318    for line in header_str.lines() {
319        let lower = line.to_lowercase();
320        if lower.starts_with(&format!("{}:", header_name_lower)) {
321            let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
322
323            // Handle "Bearer <token>" format (strip "Bearer " prefix if present)
324            // Use case-insensitive check, then slice original value by length
325            let value_lower = value.to_lowercase();
326            let token_value = if value_lower.starts_with("bearer ") {
327                // "bearer ".len() == 7
328                value[7..].trim()
329            } else {
330                value
331            };
332
333            if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
334                return Ok(());
335            }
336            warn!("Invalid phantom token in {} header", header_name);
337            return Err(ProxyError::InvalidToken);
338        }
339    }
340
341    warn!(
342        "Missing {} header for phantom token validation",
343        header_name
344    );
345    Err(ProxyError::InvalidToken)
346}
347
348/// Filter headers, removing Host, Content-Length, and auth headers.
349///
350/// Content-Length is re-added after body is read, and Host is rewritten
351/// to the upstream. Authorization and x-api-key headers are stripped since
352/// we inject our own credential (the phantom token is validated but not forwarded).
353fn filter_headers(header_bytes: &[u8]) -> Vec<(String, String)> {
354    let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
355    let mut headers = Vec::new();
356
357    for line in header_str.lines() {
358        let lower = line.to_lowercase();
359        if lower.starts_with("host:")
360            || lower.starts_with("content-length:")
361            || lower.starts_with("authorization:")
362            || lower.starts_with("x-api-key:")
363            || lower.starts_with("x-goog-api-key:")
364            || line.trim().is_empty()
365        {
366            continue;
367        }
368        if let Some((name, value)) = line.split_once(':') {
369            headers.push((name.trim().to_string(), value.trim().to_string()));
370        }
371    }
372
373    headers
374}
375
376/// Extract Content-Length value from raw headers.
377fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
378    let header_str = std::str::from_utf8(header_bytes).ok()?;
379    for line in header_str.lines() {
380        if line.to_lowercase().starts_with("content-length:") {
381            let value = line.split_once(':')?.1.trim();
382            return value.parse().ok();
383        }
384    }
385    None
386}
387
388/// Parse an upstream URL into (host, port, path).
389fn parse_upstream_url(url_str: &str) -> Result<(String, u16, String)> {
390    let parsed = url::Url::parse(url_str)
391        .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
392
393    let scheme = parsed.scheme();
394    if scheme != "https" && scheme != "http" {
395        return Err(ProxyError::HttpParse(format!(
396            "unsupported URL scheme: {}",
397            url_str
398        )));
399    }
400
401    let host = parsed
402        .host_str()
403        .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
404        .to_string();
405
406    let default_port = if scheme == "https" { 443 } else { 80 };
407    let port = parsed.port().unwrap_or(default_port);
408
409    let path = parsed.path().to_string();
410    let path = if path.is_empty() {
411        "/".to_string()
412    } else {
413        path
414    };
415
416    // Include query string if present
417    let path_with_query = if let Some(query) = parsed.query() {
418        format!("{}?{}", path, query)
419    } else {
420        path
421    };
422
423    Ok((host, port, path_with_query))
424}
425
426/// Connect to an upstream host over TLS using pre-resolved addresses.
427///
428/// Uses the pre-resolved `SocketAddr`s from the filter check to prevent
429/// DNS rebinding TOCTOU. Falls back to hostname resolution only if no
430/// pre-resolved addresses are available.
431///
432/// The `TlsConnector` is shared across all connections (created once at
433/// server startup with the system root certificate store).
434async fn connect_upstream_tls(
435    host: &str,
436    port: u16,
437    resolved_addrs: &[SocketAddr],
438    connector: &TlsConnector,
439) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
440    let tcp = if resolved_addrs.is_empty() {
441        // Fallback: no pre-resolved addresses (shouldn't happen in practice)
442        let addr = format!("{}:{}", host, port);
443        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
444            Ok(Ok(s)) => s,
445            Ok(Err(e)) => {
446                return Err(ProxyError::UpstreamConnect {
447                    host: host.to_string(),
448                    reason: e.to_string(),
449                });
450            }
451            Err(_) => {
452                return Err(ProxyError::UpstreamConnect {
453                    host: host.to_string(),
454                    reason: "connection timed out".to_string(),
455                });
456            }
457        }
458    } else {
459        connect_to_resolved(resolved_addrs, host).await?
460    };
461
462    let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
463        ProxyError::UpstreamConnect {
464            host: host.to_string(),
465            reason: "invalid server name for TLS".to_string(),
466        }
467    })?;
468
469    let tls_stream =
470        connector
471            .connect(server_name, tcp)
472            .await
473            .map_err(|e| ProxyError::UpstreamConnect {
474                host: host.to_string(),
475                reason: format!("TLS handshake failed: {}", e),
476            })?;
477
478    Ok(tls_stream)
479}
480
481/// Connect to one of the pre-resolved socket addresses with timeout.
482async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
483    let mut last_err = None;
484    for addr in addrs {
485        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
486            Ok(Ok(stream)) => return Ok(stream),
487            Ok(Err(e)) => {
488                debug!("Connect to {} failed: {}", addr, e);
489                last_err = Some(e.to_string());
490            }
491            Err(_) => {
492                debug!("Connect to {} timed out", addr);
493                last_err = Some("connection timed out".to_string());
494            }
495        }
496    }
497    Err(ProxyError::UpstreamConnect {
498        host: host.to_string(),
499        reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
500    })
501}
502
503/// Parse HTTP status code from the first response chunk.
504///
505/// Looks for the "HTTP/x.y NNN" pattern in the first line. Returns 502
506/// if the response doesn't contain a valid status line (upstream sent
507/// garbage or incomplete data).
508fn parse_response_status(data: &[u8]) -> u16 {
509    // Find the end of the first line (or use full data if no newline)
510    let line_end = data
511        .iter()
512        .position(|&b| b == b'\r' || b == b'\n')
513        .unwrap_or(data.len());
514    let first_line = &data[..line_end.min(64)];
515
516    if let Ok(line) = std::str::from_utf8(first_line) {
517        // Split on whitespace: ["HTTP/1.1", "200", "OK"]
518        let mut parts = line.split_whitespace();
519        if let Some(version) = parts.next() {
520            if version.starts_with("HTTP/") {
521                if let Some(code_str) = parts.next() {
522                    if code_str.len() == 3 {
523                        return code_str.parse().unwrap_or(502);
524                    }
525                }
526            }
527        }
528    }
529    502
530}
531
532/// Send an HTTP error response.
533async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
534    let body = format!("{{\"error\":\"{}\"}}", reason);
535    let response = format!(
536        "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
537        status,
538        reason,
539        body.len(),
540        body
541    );
542    stream.write_all(response.as_bytes()).await?;
543    stream.flush().await?;
544    Ok(())
545}
546
547// ============================================================================
548// Injection mode helpers
549// ============================================================================
550
551/// Validate phantom token based on injection mode.
552///
553/// Different modes extract the phantom token from different locations:
554/// - `Header`/`BasicAuth`: From the auth header (Authorization, x-api-key, etc.)
555/// - `UrlPath`: From the URL path pattern (e.g., `/bot<token>/getMe`)
556/// - `QueryParam`: From the query parameter (e.g., `?api_key=<token>`)
557fn validate_phantom_token_for_mode(
558    mode: &InjectMode,
559    header_bytes: &[u8],
560    path: &str,
561    header_name: &str,
562    path_pattern: Option<&str>,
563    query_param_name: Option<&str>,
564    session_token: &Zeroizing<String>,
565) -> Result<()> {
566    match mode {
567        InjectMode::Header | InjectMode::BasicAuth => {
568            // Validate from header (existing behavior)
569            validate_phantom_token(header_bytes, header_name, session_token)
570        }
571        InjectMode::UrlPath => {
572            // Validate from URL path
573            let pattern = path_pattern.ok_or_else(|| {
574                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
575            })?;
576            validate_phantom_token_in_path(path, pattern, session_token)
577        }
578        InjectMode::QueryParam => {
579            // Validate from query parameter
580            let param_name = query_param_name.ok_or_else(|| {
581                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
582            })?;
583            validate_phantom_token_in_query(path, param_name, session_token)
584        }
585    }
586}
587
588/// Validate phantom token embedded in URL path.
589///
590/// Extracts the token from the path using the pattern (e.g., `/bot{}/` matches
591/// `/bot<token>/getMe` and extracts `<token>`).
592fn validate_phantom_token_in_path(
593    path: &str,
594    pattern: &str,
595    session_token: &Zeroizing<String>,
596) -> Result<()> {
597    // Split pattern on {} to get prefix and suffix
598    let parts: Vec<&str> = pattern.split("{}").collect();
599    if parts.len() != 2 {
600        return Err(ProxyError::HttpParse(format!(
601            "invalid path_pattern '{}': must contain exactly one {{}}",
602            pattern
603        )));
604    }
605    let (prefix, suffix) = (parts[0], parts[1]);
606
607    // Find the token in the path
608    if let Some(start) = path.find(prefix) {
609        let after_prefix = start + prefix.len();
610
611        // Handle empty suffix case (token extends to end of path or next '/' or '?')
612        let end_offset = if suffix.is_empty() {
613            path[after_prefix..]
614                .find(['/', '?'])
615                .unwrap_or(path[after_prefix..].len())
616        } else {
617            match path[after_prefix..].find(suffix) {
618                Some(offset) => offset,
619                None => {
620                    warn!("Missing phantom token in URL path (pattern: {})", pattern);
621                    return Err(ProxyError::InvalidToken);
622                }
623            }
624        };
625
626        let token = &path[after_prefix..after_prefix + end_offset];
627        if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
628            return Ok(());
629        }
630        warn!("Invalid phantom token in URL path");
631        return Err(ProxyError::InvalidToken);
632    }
633
634    warn!("Missing phantom token in URL path (pattern: {})", pattern);
635    Err(ProxyError::InvalidToken)
636}
637
638/// Validate phantom token in query parameter.
639fn validate_phantom_token_in_query(
640    path: &str,
641    param_name: &str,
642    session_token: &Zeroizing<String>,
643) -> Result<()> {
644    // Parse query string from path
645    if let Some(query_start) = path.find('?') {
646        let query = &path[query_start + 1..];
647        for pair in query.split('&') {
648            if let Some((name, value)) = pair.split_once('=') {
649                if name == param_name {
650                    // URL-decode the value
651                    let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
652                    if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
653                        return Ok(());
654                    }
655                    warn!("Invalid phantom token in query parameter '{}'", param_name);
656                    return Err(ProxyError::InvalidToken);
657                }
658            }
659        }
660    }
661
662    warn!("Missing phantom token in query parameter '{}'", param_name);
663    Err(ProxyError::InvalidToken)
664}
665
666/// Transform URL path based on injection mode.
667///
668/// - `UrlPath`: Replace phantom token with real credential in path
669/// - `QueryParam`: Add/replace query parameter with real credential
670/// - `Header`/`BasicAuth`: No path transformation needed
671fn transform_path_for_mode(
672    mode: &InjectMode,
673    path: &str,
674    path_pattern: Option<&str>,
675    path_replacement: Option<&str>,
676    query_param_name: Option<&str>,
677    credential: &Zeroizing<String>,
678) -> Result<String> {
679    match mode {
680        InjectMode::Header | InjectMode::BasicAuth => {
681            // No path transformation needed
682            Ok(path.to_string())
683        }
684        InjectMode::UrlPath => {
685            let pattern = path_pattern.ok_or_else(|| {
686                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
687            })?;
688            let replacement = path_replacement.unwrap_or(pattern);
689            transform_url_path(path, pattern, replacement, credential)
690        }
691        InjectMode::QueryParam => {
692            let param_name = query_param_name.ok_or_else(|| {
693                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
694            })?;
695            transform_query_param(path, param_name, credential)
696        }
697    }
698}
699
700/// Transform URL path by replacing phantom token pattern with real credential.
701///
702/// Example: `/bot<phantom>/getMe` with pattern `/bot{}/` becomes `/bot<real>/getMe`
703fn transform_url_path(
704    path: &str,
705    pattern: &str,
706    replacement: &str,
707    credential: &Zeroizing<String>,
708) -> Result<String> {
709    // Split pattern on {} to get prefix and suffix
710    let parts: Vec<&str> = pattern.split("{}").collect();
711    if parts.len() != 2 {
712        return Err(ProxyError::HttpParse(format!(
713            "invalid path_pattern '{}': must contain exactly one {{}}",
714            pattern
715        )));
716    }
717    let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
718
719    // Split replacement on {}
720    let repl_parts: Vec<&str> = replacement.split("{}").collect();
721    if repl_parts.len() != 2 {
722        return Err(ProxyError::HttpParse(format!(
723            "invalid path_replacement '{}': must contain exactly one {{}}",
724            replacement
725        )));
726    }
727    let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
728
729    // Find and replace the token in the path
730    if let Some(start) = path.find(pattern_prefix) {
731        let after_prefix = start + pattern_prefix.len();
732
733        // Handle empty suffix case (token extends to end of path or next '/' or '?')
734        let end_offset = if pattern_suffix.is_empty() {
735            // Find the next path segment delimiter or end of path
736            path[after_prefix..]
737                .find(['/', '?'])
738                .unwrap_or(path[after_prefix..].len())
739        } else {
740            // Find the suffix in the remaining path
741            match path[after_prefix..].find(pattern_suffix) {
742                Some(offset) => offset,
743                None => {
744                    return Err(ProxyError::HttpParse(format!(
745                        "path '{}' does not match pattern '{}'",
746                        path, pattern
747                    )));
748                }
749            }
750        };
751
752        let before = &path[..start];
753        let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
754        return Ok(format!(
755            "{}{}{}{}{}",
756            before,
757            repl_prefix,
758            credential.as_str(),
759            repl_suffix,
760            after
761        ));
762    }
763
764    Err(ProxyError::HttpParse(format!(
765        "path '{}' does not match pattern '{}'",
766        path, pattern
767    )))
768}
769
770/// Transform query string by adding or replacing a parameter with the credential.
771fn transform_query_param(
772    path: &str,
773    param_name: &str,
774    credential: &Zeroizing<String>,
775) -> Result<String> {
776    let encoded_value = urlencoding::encode(credential.as_str());
777
778    if let Some(query_start) = path.find('?') {
779        let base_path = &path[..query_start];
780        let query = &path[query_start + 1..];
781
782        // Check if parameter already exists
783        let mut found = false;
784        let new_query: Vec<String> = query
785            .split('&')
786            .map(|pair| {
787                if let Some((name, _)) = pair.split_once('=') {
788                    if name == param_name {
789                        found = true;
790                        return format!("{}={}", param_name, encoded_value);
791                    }
792                }
793                pair.to_string()
794            })
795            .collect();
796
797        if found {
798            Ok(format!("{}?{}", base_path, new_query.join("&")))
799        } else {
800            // Append the parameter
801            Ok(format!(
802                "{}?{}&{}={}",
803                base_path, query, param_name, encoded_value
804            ))
805        }
806    } else {
807        // No query string, add one
808        Ok(format!("{}?{}={}", path, param_name, encoded_value))
809    }
810}
811
812/// Inject credential into request based on mode.
813///
814/// For header/basic_auth modes, adds the credential header.
815/// For url_path/query_param modes, the credential is already in the path.
816fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
817    match cred.inject_mode {
818        InjectMode::Header | InjectMode::BasicAuth => {
819            // Inject credential header
820            request.push_str(&format!(
821                "{}: {}\r\n",
822                cred.header_name,
823                cred.header_value.as_str()
824            ));
825        }
826        InjectMode::UrlPath | InjectMode::QueryParam => {
827            // Credential is already injected into the URL path/query
828            // No header injection needed
829        }
830    }
831}
832
833#[cfg(test)]
834#[allow(clippy::unwrap_used)]
835mod tests {
836    use super::*;
837
838    #[test]
839    fn test_parse_request_line() {
840        let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
841        assert_eq!(method, "POST");
842        assert_eq!(path, "/openai/v1/chat");
843        assert_eq!(version, "HTTP/1.1");
844    }
845
846    #[test]
847    fn test_parse_request_line_malformed() {
848        assert!(parse_request_line("GET").is_err());
849    }
850
851    #[test]
852    fn test_parse_service_prefix() {
853        let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
854        assert_eq!(service, "openai");
855        assert_eq!(path, "/v1/chat/completions");
856    }
857
858    #[test]
859    fn test_parse_service_prefix_no_subpath() {
860        let (service, path) = parse_service_prefix("/anthropic").unwrap();
861        assert_eq!(service, "anthropic");
862        assert_eq!(path, "/");
863    }
864
865    #[test]
866    fn test_validate_phantom_token_bearer_valid() {
867        let token = Zeroizing::new("secret123".to_string());
868        let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
869        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
870    }
871
872    #[test]
873    fn test_validate_phantom_token_bearer_invalid() {
874        let token = Zeroizing::new("secret123".to_string());
875        let header = b"Authorization: Bearer wrong\r\n\r\n";
876        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
877    }
878
879    #[test]
880    fn test_validate_phantom_token_x_api_key_valid() {
881        let token = Zeroizing::new("secret123".to_string());
882        let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
883        assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
884    }
885
886    #[test]
887    fn test_validate_phantom_token_x_goog_api_key_valid() {
888        let token = Zeroizing::new("secret123".to_string());
889        let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
890        assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
891    }
892
893    #[test]
894    fn test_validate_phantom_token_missing() {
895        let token = Zeroizing::new("secret123".to_string());
896        let header = b"Content-Type: application/json\r\n\r\n";
897        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
898    }
899
900    #[test]
901    fn test_validate_phantom_token_case_insensitive_header() {
902        let token = Zeroizing::new("secret123".to_string());
903        let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
904        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
905    }
906
907    #[test]
908    fn test_filter_headers_removes_host_auth() {
909        let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
910        let filtered = filter_headers(header);
911        assert_eq!(filtered.len(), 2);
912        assert_eq!(filtered[0].0, "Content-Type");
913        assert_eq!(filtered[1].0, "Accept");
914    }
915
916    #[test]
917    fn test_filter_headers_removes_x_api_key() {
918        let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
919        let filtered = filter_headers(header);
920        assert_eq!(filtered.len(), 1);
921        assert_eq!(filtered[0].0, "Content-Type");
922    }
923
924    #[test]
925    fn test_filter_headers_removes_x_goog_api_key() {
926        let header = b"x-goog-api-key: gemini-key\r\nContent-Type: application/json\r\n\r\n";
927        let filtered = filter_headers(header);
928        assert_eq!(filtered.len(), 1);
929        assert_eq!(filtered[0].0, "Content-Type");
930    }
931
932    #[test]
933    fn test_extract_content_length() {
934        let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
935        assert_eq!(extract_content_length(header), Some(42));
936    }
937
938    #[test]
939    fn test_extract_content_length_missing() {
940        let header = b"Content-Type: application/json\r\n\r\n";
941        assert_eq!(extract_content_length(header), None);
942    }
943
944    #[test]
945    fn test_parse_upstream_url_https() {
946        let (host, port, path) =
947            parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
948        assert_eq!(host, "api.openai.com");
949        assert_eq!(port, 443);
950        assert_eq!(path, "/v1/chat/completions");
951    }
952
953    #[test]
954    fn test_parse_upstream_url_http_with_port() {
955        let (host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
956        assert_eq!(host, "localhost");
957        assert_eq!(port, 8080);
958        assert_eq!(path, "/api");
959    }
960
961    #[test]
962    fn test_parse_upstream_url_no_path() {
963        let (host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
964        assert_eq!(host, "api.anthropic.com");
965        assert_eq!(port, 443);
966        assert_eq!(path, "/");
967    }
968
969    #[test]
970    fn test_parse_upstream_url_invalid_scheme() {
971        assert!(parse_upstream_url("ftp://example.com").is_err());
972    }
973
974    #[test]
975    fn test_parse_response_status_200() {
976        let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
977        assert_eq!(parse_response_status(data), 200);
978    }
979
980    #[test]
981    fn test_parse_response_status_404() {
982        let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
983        assert_eq!(parse_response_status(data), 404);
984    }
985
986    #[test]
987    fn test_parse_response_status_garbage() {
988        let data = b"not an http response";
989        assert_eq!(parse_response_status(data), 502);
990    }
991
992    #[test]
993    fn test_parse_response_status_empty() {
994        assert_eq!(parse_response_status(b""), 502);
995    }
996
997    #[test]
998    fn test_parse_response_status_partial() {
999        let data = b"HTTP/1.1 ";
1000        assert_eq!(parse_response_status(data), 502);
1001    }
1002
1003    // ============================================================================
1004    // URL Path Injection Mode Tests
1005    // ============================================================================
1006
1007    #[test]
1008    fn test_validate_phantom_token_in_path_valid() {
1009        let token = Zeroizing::new("session123".to_string());
1010        let path = "/bot/session123/getMe";
1011        let pattern = "/bot/{}/";
1012        assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1013    }
1014
1015    #[test]
1016    fn test_validate_phantom_token_in_path_invalid() {
1017        let token = Zeroizing::new("session123".to_string());
1018        let path = "/bot/wrong_token/getMe";
1019        let pattern = "/bot/{}/";
1020        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1021    }
1022
1023    #[test]
1024    fn test_validate_phantom_token_in_path_missing() {
1025        let token = Zeroizing::new("session123".to_string());
1026        let path = "/api/getMe";
1027        let pattern = "/bot/{}/";
1028        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1029    }
1030
1031    #[test]
1032    fn test_transform_url_path_basic() {
1033        let credential = Zeroizing::new("real_token".to_string());
1034        let path = "/bot/phantom_token/getMe";
1035        let pattern = "/bot/{}/";
1036        let replacement = "/bot/{}/";
1037        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1038        assert_eq!(result, "/bot/real_token/getMe");
1039    }
1040
1041    #[test]
1042    fn test_transform_url_path_different_replacement() {
1043        let credential = Zeroizing::new("real_token".to_string());
1044        let path = "/api/v1/phantom_token/chat";
1045        let pattern = "/api/v1/{}/";
1046        let replacement = "/v2/bot/{}/";
1047        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1048        assert_eq!(result, "/v2/bot/real_token/chat");
1049    }
1050
1051    #[test]
1052    fn test_transform_url_path_no_trailing_slash() {
1053        let credential = Zeroizing::new("real_token".to_string());
1054        let path = "/bot/phantom_token";
1055        let pattern = "/bot/{}";
1056        let replacement = "/bot/{}";
1057        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1058        assert_eq!(result, "/bot/real_token");
1059    }
1060
1061    // ============================================================================
1062    // Query Param Injection Mode Tests
1063    // ============================================================================
1064
1065    #[test]
1066    fn test_validate_phantom_token_in_query_valid() {
1067        let token = Zeroizing::new("session123".to_string());
1068        let path = "/api/data?api_key=session123&other=value";
1069        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1070    }
1071
1072    #[test]
1073    fn test_validate_phantom_token_in_query_invalid() {
1074        let token = Zeroizing::new("session123".to_string());
1075        let path = "/api/data?api_key=wrong_token";
1076        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1077    }
1078
1079    #[test]
1080    fn test_validate_phantom_token_in_query_missing_param() {
1081        let token = Zeroizing::new("session123".to_string());
1082        let path = "/api/data?other=value";
1083        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1084    }
1085
1086    #[test]
1087    fn test_validate_phantom_token_in_query_no_query_string() {
1088        let token = Zeroizing::new("session123".to_string());
1089        let path = "/api/data";
1090        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1091    }
1092
1093    #[test]
1094    fn test_validate_phantom_token_in_query_url_encoded() {
1095        let token = Zeroizing::new("token with spaces".to_string());
1096        let path = "/api/data?api_key=token%20with%20spaces";
1097        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1098    }
1099
1100    #[test]
1101    fn test_transform_query_param_add_to_no_query() {
1102        let credential = Zeroizing::new("real_key".to_string());
1103        let path = "/api/data";
1104        let result = transform_query_param(path, "api_key", &credential).unwrap();
1105        assert_eq!(result, "/api/data?api_key=real_key");
1106    }
1107
1108    #[test]
1109    fn test_transform_query_param_add_to_existing_query() {
1110        let credential = Zeroizing::new("real_key".to_string());
1111        let path = "/api/data?other=value";
1112        let result = transform_query_param(path, "api_key", &credential).unwrap();
1113        assert_eq!(result, "/api/data?other=value&api_key=real_key");
1114    }
1115
1116    #[test]
1117    fn test_transform_query_param_replace_existing() {
1118        let credential = Zeroizing::new("real_key".to_string());
1119        let path = "/api/data?api_key=phantom&other=value";
1120        let result = transform_query_param(path, "api_key", &credential).unwrap();
1121        assert_eq!(result, "/api/data?api_key=real_key&other=value");
1122    }
1123
1124    #[test]
1125    fn test_transform_query_param_url_encodes_special_chars() {
1126        let credential = Zeroizing::new("key with spaces".to_string());
1127        let path = "/api/data";
1128        let result = transform_query_param(path, "api_key", &credential).unwrap();
1129        assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1130    }
1131}