Skip to main content

nono_proxy/
reverse.rs

1//! Reverse proxy handler (Mode 2 — Credential Injection).
2//!
3//! Routes requests by path prefix to upstream APIs, injecting credentials
4//! from the keystore. The agent uses `http://localhost:PORT/openai/v1/chat`
5//! and the proxy rewrites to `https://api.openai.com/v1/chat` with the
6//! real credential injected.
7//!
8//! Supports multiple injection modes:
9//! - `header`: Inject into HTTP header (e.g., `Authorization: Bearer ...`)
10//! - `url_path`: Replace pattern in URL path (e.g., Telegram `/bot{}/`)
11//! - `query_param`: Add/replace query parameter (e.g., `?api_key=...`)
12//! - `basic_auth`: HTTP Basic Authentication
13//!
14//! Streaming responses (SSE, MCP Streamable HTTP, A2A JSON-RPC) are
15//! forwarded without buffering.
16
17use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::token;
23use std::net::SocketAddr;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31/// Maximum request body size (16 MiB). Prevents DoS from malicious Content-Length.
32const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
33
34/// Timeout for upstream TCP connect.
35const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37/// Handle a non-CONNECT HTTP request (reverse proxy mode).
38///
39/// Reads the full HTTP request from the client, matches path prefix to
40/// a configured route, injects credentials, and forwards to the upstream.
41/// Shared context passed from the server to the reverse proxy handler.
42pub struct ReverseProxyCtx<'a> {
43    /// Credential store for service lookups
44    pub credential_store: &'a CredentialStore,
45    /// Session token for authentication
46    pub session_token: &'a Zeroizing<String>,
47    /// Host filter for upstream validation
48    pub filter: &'a ProxyFilter,
49    /// Shared TLS connector
50    pub tls_connector: &'a TlsConnector,
51    /// Shared network audit sink for session metadata capture
52    pub audit_log: Option<&'a audit::SharedAuditLog>,
53}
54
55/// Handle a non-CONNECT HTTP request (reverse proxy mode).
56///
57/// `buffered_body` contains any bytes the BufReader read ahead beyond the
58/// headers. These are prepended to the body read from the stream to prevent
59/// data loss.
60///
61/// ## Phantom Token Pattern
62///
63/// The client (SDK) sends the session token as its "API key". The proxy:
64/// 1. Extracts the service from the path (e.g., `/openai/v1/chat` → `openai`)
65/// 2. Looks up which header that service uses (e.g., `Authorization` or `x-api-key`)
66/// 3. Validates the phantom token from that header
67/// 4. Replaces it with the real credential from keyring
68pub async fn handle_reverse_proxy(
69    first_line: &str,
70    stream: &mut TcpStream,
71    remaining_header: &[u8],
72    ctx: &ReverseProxyCtx<'_>,
73    buffered_body: &[u8],
74) -> Result<()> {
75    // Parse method, path, and HTTP version
76    let (method, path, version) = parse_request_line(first_line)?;
77    debug!("Reverse proxy: {} {}", method, path);
78
79    // Extract service prefix from path (e.g., "/openai/v1/chat" -> ("openai", "/v1/chat"))
80    let (service, upstream_path) = parse_service_prefix(&path)?;
81
82    // Look up credential for service
83    let cred = ctx
84        .credential_store
85        .get(&service)
86        .ok_or_else(|| ProxyError::UnknownService {
87            prefix: service.clone(),
88        })?;
89
90    // L7 endpoint filtering: check method+path against rules before any
91    // credential operations. Denied endpoints get 403 immediately.
92    if !cred.endpoint_rules.is_allowed(&method, &upstream_path) {
93        let reason = format!(
94            "endpoint denied: {} {} on service '{}'",
95            method, upstream_path, service
96        );
97        warn!("{}", reason);
98        audit::log_denied(
99            ctx.audit_log,
100            audit::ProxyMode::Reverse,
101            &service,
102            0,
103            &reason,
104        );
105        send_error(stream, 403, "Forbidden").await?;
106        return Ok(());
107    }
108
109    // Validate phantom token based on injection mode.
110    // For header/basic_auth modes: validate from Authorization/x-api-key header
111    // For url_path mode: validate from URL path pattern
112    // For query_param mode: validate from query parameter
113    if let Err(e) = validate_phantom_token_for_mode(
114        &cred.inject_mode,
115        remaining_header,
116        &upstream_path,
117        &cred.header_name,
118        cred.path_pattern.as_deref(),
119        cred.query_param_name.as_deref(),
120        ctx.session_token,
121    ) {
122        audit::log_denied(
123            ctx.audit_log,
124            audit::ProxyMode::Reverse,
125            &service,
126            0,
127            &e.to_string(),
128        );
129        send_error(stream, 401, "Unauthorized").await?;
130        return Ok(());
131    }
132
133    // Transform the path based on injection mode (url_path and query_param modes)
134    let transformed_path = transform_path_for_mode(
135        &cred.inject_mode,
136        &upstream_path,
137        cred.path_pattern.as_deref(),
138        cred.path_replacement.as_deref(),
139        cred.query_param_name.as_deref(),
140        &cred.raw_credential,
141    )?;
142
143    // Parse upstream URL with potentially transformed path
144    let upstream_url = format!(
145        "{}{}",
146        cred.upstream.trim_end_matches('/'),
147        transformed_path
148    );
149    debug!("Forwarding to upstream: {} {}", method, upstream_url);
150
151    let (upstream_host, upstream_port, upstream_path_full) = parse_upstream_url(&upstream_url)?;
152
153    // DNS resolve + host check via the filter
154    let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
155    if !check.result.is_allowed() {
156        let reason = check.result.reason();
157        warn!("Upstream host denied by filter: {}", reason);
158        send_error(stream, 403, "Forbidden").await?;
159        audit::log_denied(
160            ctx.audit_log,
161            audit::ProxyMode::Reverse,
162            &service,
163            0,
164            &reason,
165        );
166        return Ok(());
167    }
168
169    // Collect remaining request headers (excluding X-Nono-Token and Host)
170    let filtered_headers = filter_headers(remaining_header, &cred.header_name);
171    let content_length = extract_content_length(remaining_header);
172
173    // Read request body if present, with size limit.
174    // `buffered_body` may contain bytes the BufReader read ahead beyond
175    // headers; we prepend those to avoid data loss.
176    let body = if let Some(len) = content_length {
177        if len > MAX_REQUEST_BODY {
178            send_error(stream, 413, "Payload Too Large").await?;
179            return Ok(());
180        }
181        let mut buf = Vec::with_capacity(len);
182        let pre = buffered_body.len().min(len);
183        buf.extend_from_slice(&buffered_body[..pre]);
184        let remaining = len - pre;
185        if remaining > 0 {
186            let mut rest = vec![0u8; remaining];
187            stream.read_exact(&mut rest).await?;
188            buf.extend_from_slice(&rest);
189        }
190        buf
191    } else {
192        Vec::new()
193    };
194
195    // Connect to upstream over TLS using pre-resolved addresses
196    let upstream_result = connect_upstream_tls(
197        &upstream_host,
198        upstream_port,
199        &check.resolved_addrs,
200        ctx.tls_connector,
201    )
202    .await;
203    let mut tls_stream = match upstream_result {
204        Ok(s) => s,
205        Err(e) => {
206            warn!("Upstream connection failed: {}", e);
207            send_error(stream, 502, "Bad Gateway").await?;
208            audit::log_denied(
209                ctx.audit_log,
210                audit::ProxyMode::Reverse,
211                &service,
212                0,
213                &e.to_string(),
214            );
215            return Ok(());
216        }
217    };
218
219    // Build the upstream request into a Zeroizing buffer since it may contain
220    // credential values. This ensures credentials are zeroed from heap memory
221    // when the buffer is dropped.
222    let mut request = Zeroizing::new(format!(
223        "{} {} {}\r\nHost: {}\r\n",
224        method, upstream_path_full, version, upstream_host
225    ));
226
227    // Inject credential based on mode
228    inject_credential_for_mode(cred, &mut request);
229
230    // Forward filtered headers (excluding auth headers that we're replacing)
231    let auth_header_lower = cred.header_name.to_lowercase();
232    for (name, value) in &filtered_headers {
233        // Skip the auth header if we're using header/basic_auth mode
234        // (we already injected our own)
235        if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
236            && name.to_lowercase() == auth_header_lower
237        {
238            continue;
239        }
240        request.push_str(&format!("{}: {}\r\n", name, value));
241    }
242
243    // Content-Length for body
244    if !body.is_empty() {
245        request.push_str(&format!("Content-Length: {}\r\n", body.len()));
246    }
247    request.push_str("\r\n");
248
249    tls_stream.write_all(request.as_bytes()).await?;
250    if !body.is_empty() {
251        tls_stream.write_all(&body).await?;
252    }
253    tls_stream.flush().await?;
254
255    // Stream the response back to the client without buffering.
256    // This handles SSE (text/event-stream), chunked transfer, and regular responses.
257    let mut response_buf = [0u8; 8192];
258    let mut status_code: u16 = 502;
259    let mut first_chunk = true;
260
261    loop {
262        let n = match tls_stream.read(&mut response_buf).await {
263            Ok(0) => break,
264            Ok(n) => n,
265            Err(e) => {
266                debug!("Upstream read error: {}", e);
267                break;
268            }
269        };
270
271        // Parse status from first chunk. The HTTP status line format is:
272        // "HTTP/1.1 200 OK\r\n..." — we need the 3-digit code after the
273        // first space. We scan up to 32 bytes (enough for any valid status line).
274        if first_chunk {
275            status_code = parse_response_status(&response_buf[..n]);
276            first_chunk = false;
277        }
278
279        stream.write_all(&response_buf[..n]).await?;
280        stream.flush().await?;
281    }
282
283    audit::log_reverse_proxy(
284        ctx.audit_log,
285        &service,
286        &method,
287        &upstream_path,
288        status_code,
289    );
290    Ok(())
291}
292
293/// Parse an HTTP request line into (method, path, version).
294fn parse_request_line(line: &str) -> Result<(String, String, String)> {
295    let parts: Vec<&str> = line.split_whitespace().collect();
296    if parts.len() < 3 {
297        return Err(ProxyError::HttpParse(format!(
298            "malformed request line: {}",
299            line
300        )));
301    }
302    Ok((
303        parts[0].to_string(),
304        parts[1].to_string(),
305        parts[2].to_string(),
306    ))
307}
308
309/// Extract service prefix from path.
310///
311/// "/openai/v1/chat/completions" -> ("openai", "/v1/chat/completions")
312/// "/anthropic/v1/messages" -> ("anthropic", "/v1/messages")
313fn parse_service_prefix(path: &str) -> Result<(String, String)> {
314    let trimmed = path.strip_prefix('/').unwrap_or(path);
315    if let Some((prefix, rest)) = trimmed.split_once('/') {
316        Ok((prefix.to_string(), format!("/{}", rest)))
317    } else {
318        // No sub-path, just the prefix
319        Ok((trimmed.to_string(), "/".to_string()))
320    }
321}
322
323/// Validate the phantom token from the service's auth header.
324///
325/// The SDK sends the session token as its "API key" in the standard auth header
326/// for that service (e.g., `Authorization: Bearer <token>` for OpenAI,
327/// `x-api-key: <token>` for Anthropic). We validate the token matches the
328/// session token before swapping in the real credential.
329fn validate_phantom_token(
330    header_bytes: &[u8],
331    header_name: &str,
332    session_token: &Zeroizing<String>,
333) -> Result<()> {
334    let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
335    let header_name_lower = header_name.to_lowercase();
336
337    for line in header_str.lines() {
338        let lower = line.to_lowercase();
339        if lower.starts_with(&format!("{}:", header_name_lower)) {
340            let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
341
342            // Handle "Bearer <token>" format (strip "Bearer " prefix if present)
343            // Use case-insensitive check, then slice original value by length
344            let value_lower = value.to_lowercase();
345            let token_value = if value_lower.starts_with("bearer ") {
346                // "bearer ".len() == 7
347                value[7..].trim()
348            } else {
349                value
350            };
351
352            if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
353                return Ok(());
354            }
355            warn!("Invalid phantom token in {} header", header_name);
356            return Err(ProxyError::InvalidToken);
357        }
358    }
359
360    warn!(
361        "Missing {} header for phantom token validation",
362        header_name
363    );
364    Err(ProxyError::InvalidToken)
365}
366
367/// Filter headers, removing Host, Content-Length, and the credential header.
368///
369/// Content-Length is re-added after body is read, and Host is rewritten
370/// to the upstream. The route's configured credential header is stripped
371/// so the phantom token is not forwarded alongside the real credential.
372fn filter_headers(header_bytes: &[u8], cred_header: &str) -> Vec<(String, String)> {
373    let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
374    let cred_header_lower = format!("{}:", cred_header.to_lowercase());
375    let mut headers = Vec::new();
376
377    for line in header_str.lines() {
378        let lower = line.to_lowercase();
379        if lower.starts_with("host:")
380            || lower.starts_with("content-length:")
381            || lower.starts_with(&cred_header_lower)
382            || line.trim().is_empty()
383        {
384            continue;
385        }
386        if let Some((name, value)) = line.split_once(':') {
387            headers.push((name.trim().to_string(), value.trim().to_string()));
388        }
389    }
390
391    headers
392}
393
394/// Extract Content-Length value from raw headers.
395fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
396    let header_str = std::str::from_utf8(header_bytes).ok()?;
397    for line in header_str.lines() {
398        if line.to_lowercase().starts_with("content-length:") {
399            let value = line.split_once(':')?.1.trim();
400            return value.parse().ok();
401        }
402    }
403    None
404}
405
406/// Parse an upstream URL into (host, port, path).
407fn parse_upstream_url(url_str: &str) -> Result<(String, u16, String)> {
408    let parsed = url::Url::parse(url_str)
409        .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
410
411    let scheme = parsed.scheme();
412    if scheme != "https" && scheme != "http" {
413        return Err(ProxyError::HttpParse(format!(
414            "unsupported URL scheme: {}",
415            url_str
416        )));
417    }
418
419    let host = parsed
420        .host_str()
421        .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
422        .to_string();
423
424    let default_port = if scheme == "https" { 443 } else { 80 };
425    let port = parsed.port().unwrap_or(default_port);
426
427    let path = parsed.path().to_string();
428    let path = if path.is_empty() {
429        "/".to_string()
430    } else {
431        path
432    };
433
434    // Include query string if present
435    let path_with_query = if let Some(query) = parsed.query() {
436        format!("{}?{}", path, query)
437    } else {
438        path
439    };
440
441    Ok((host, port, path_with_query))
442}
443
444/// Connect to an upstream host over TLS using pre-resolved addresses.
445///
446/// Uses the pre-resolved `SocketAddr`s from the filter check to prevent
447/// DNS rebinding TOCTOU. Falls back to hostname resolution only if no
448/// pre-resolved addresses are available.
449///
450/// The `TlsConnector` is shared across all connections (created once at
451/// server startup with the system root certificate store).
452async fn connect_upstream_tls(
453    host: &str,
454    port: u16,
455    resolved_addrs: &[SocketAddr],
456    connector: &TlsConnector,
457) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
458    let tcp = if resolved_addrs.is_empty() {
459        // Fallback: no pre-resolved addresses (shouldn't happen in practice)
460        let addr = format!("{}:{}", host, port);
461        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
462            Ok(Ok(s)) => s,
463            Ok(Err(e)) => {
464                return Err(ProxyError::UpstreamConnect {
465                    host: host.to_string(),
466                    reason: e.to_string(),
467                });
468            }
469            Err(_) => {
470                return Err(ProxyError::UpstreamConnect {
471                    host: host.to_string(),
472                    reason: "connection timed out".to_string(),
473                });
474            }
475        }
476    } else {
477        connect_to_resolved(resolved_addrs, host).await?
478    };
479
480    let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
481        ProxyError::UpstreamConnect {
482            host: host.to_string(),
483            reason: "invalid server name for TLS".to_string(),
484        }
485    })?;
486
487    let tls_stream =
488        connector
489            .connect(server_name, tcp)
490            .await
491            .map_err(|e| ProxyError::UpstreamConnect {
492                host: host.to_string(),
493                reason: format!("TLS handshake failed: {}", e),
494            })?;
495
496    Ok(tls_stream)
497}
498
499/// Connect to one of the pre-resolved socket addresses with timeout.
500async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
501    let mut last_err = None;
502    for addr in addrs {
503        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
504            Ok(Ok(stream)) => return Ok(stream),
505            Ok(Err(e)) => {
506                debug!("Connect to {} failed: {}", addr, e);
507                last_err = Some(e.to_string());
508            }
509            Err(_) => {
510                debug!("Connect to {} timed out", addr);
511                last_err = Some("connection timed out".to_string());
512            }
513        }
514    }
515    Err(ProxyError::UpstreamConnect {
516        host: host.to_string(),
517        reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
518    })
519}
520
521/// Parse HTTP status code from the first response chunk.
522///
523/// Looks for the "HTTP/x.y NNN" pattern in the first line. Returns 502
524/// if the response doesn't contain a valid status line (upstream sent
525/// garbage or incomplete data).
526fn parse_response_status(data: &[u8]) -> u16 {
527    // Find the end of the first line (or use full data if no newline)
528    let line_end = data
529        .iter()
530        .position(|&b| b == b'\r' || b == b'\n')
531        .unwrap_or(data.len());
532    let first_line = &data[..line_end.min(64)];
533
534    if let Ok(line) = std::str::from_utf8(first_line) {
535        // Split on whitespace: ["HTTP/1.1", "200", "OK"]
536        let mut parts = line.split_whitespace();
537        if let Some(version) = parts.next() {
538            if version.starts_with("HTTP/") {
539                if let Some(code_str) = parts.next() {
540                    if code_str.len() == 3 {
541                        return code_str.parse().unwrap_or(502);
542                    }
543                }
544            }
545        }
546    }
547    502
548}
549
550/// Send an HTTP error response.
551async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
552    let body = format!("{{\"error\":\"{}\"}}", reason);
553    let response = format!(
554        "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
555        status,
556        reason,
557        body.len(),
558        body
559    );
560    stream.write_all(response.as_bytes()).await?;
561    stream.flush().await?;
562    Ok(())
563}
564
565// ============================================================================
566// Injection mode helpers
567// ============================================================================
568
569/// Validate phantom token based on injection mode.
570///
571/// Different modes extract the phantom token from different locations:
572/// - `Header`/`BasicAuth`: From the auth header (Authorization, x-api-key, etc.)
573/// - `UrlPath`: From the URL path pattern (e.g., `/bot<token>/getMe`)
574/// - `QueryParam`: From the query parameter (e.g., `?api_key=<token>`)
575fn validate_phantom_token_for_mode(
576    mode: &InjectMode,
577    header_bytes: &[u8],
578    path: &str,
579    header_name: &str,
580    path_pattern: Option<&str>,
581    query_param_name: Option<&str>,
582    session_token: &Zeroizing<String>,
583) -> Result<()> {
584    match mode {
585        InjectMode::Header | InjectMode::BasicAuth => {
586            // Validate from header (existing behavior)
587            validate_phantom_token(header_bytes, header_name, session_token)
588        }
589        InjectMode::UrlPath => {
590            // Validate from URL path
591            let pattern = path_pattern.ok_or_else(|| {
592                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
593            })?;
594            validate_phantom_token_in_path(path, pattern, session_token)
595        }
596        InjectMode::QueryParam => {
597            // Validate from query parameter
598            let param_name = query_param_name.ok_or_else(|| {
599                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
600            })?;
601            validate_phantom_token_in_query(path, param_name, session_token)
602        }
603    }
604}
605
606/// Validate phantom token embedded in URL path.
607///
608/// Extracts the token from the path using the pattern (e.g., `/bot{}/` matches
609/// `/bot<token>/getMe` and extracts `<token>`).
610fn validate_phantom_token_in_path(
611    path: &str,
612    pattern: &str,
613    session_token: &Zeroizing<String>,
614) -> Result<()> {
615    // Split pattern on {} to get prefix and suffix
616    let parts: Vec<&str> = pattern.split("{}").collect();
617    if parts.len() != 2 {
618        return Err(ProxyError::HttpParse(format!(
619            "invalid path_pattern '{}': must contain exactly one {{}}",
620            pattern
621        )));
622    }
623    let (prefix, suffix) = (parts[0], parts[1]);
624
625    // Find the token in the path
626    if let Some(start) = path.find(prefix) {
627        let after_prefix = start + prefix.len();
628
629        // Handle empty suffix case (token extends to end of path or next '/' or '?')
630        let end_offset = if suffix.is_empty() {
631            path[after_prefix..]
632                .find(['/', '?'])
633                .unwrap_or(path[after_prefix..].len())
634        } else {
635            match path[after_prefix..].find(suffix) {
636                Some(offset) => offset,
637                None => {
638                    warn!("Missing phantom token in URL path (pattern: {})", pattern);
639                    return Err(ProxyError::InvalidToken);
640                }
641            }
642        };
643
644        let token = &path[after_prefix..after_prefix + end_offset];
645        if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
646            return Ok(());
647        }
648        warn!("Invalid phantom token in URL path");
649        return Err(ProxyError::InvalidToken);
650    }
651
652    warn!("Missing phantom token in URL path (pattern: {})", pattern);
653    Err(ProxyError::InvalidToken)
654}
655
656/// Validate phantom token in query parameter.
657fn validate_phantom_token_in_query(
658    path: &str,
659    param_name: &str,
660    session_token: &Zeroizing<String>,
661) -> Result<()> {
662    // Parse query string from path
663    if let Some(query_start) = path.find('?') {
664        let query = &path[query_start + 1..];
665        for pair in query.split('&') {
666            if let Some((name, value)) = pair.split_once('=') {
667                if name == param_name {
668                    // URL-decode the value
669                    let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
670                    if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
671                        return Ok(());
672                    }
673                    warn!("Invalid phantom token in query parameter '{}'", param_name);
674                    return Err(ProxyError::InvalidToken);
675                }
676            }
677        }
678    }
679
680    warn!("Missing phantom token in query parameter '{}'", param_name);
681    Err(ProxyError::InvalidToken)
682}
683
684/// Transform URL path based on injection mode.
685///
686/// - `UrlPath`: Replace phantom token with real credential in path
687/// - `QueryParam`: Add/replace query parameter with real credential
688/// - `Header`/`BasicAuth`: No path transformation needed
689fn transform_path_for_mode(
690    mode: &InjectMode,
691    path: &str,
692    path_pattern: Option<&str>,
693    path_replacement: Option<&str>,
694    query_param_name: Option<&str>,
695    credential: &Zeroizing<String>,
696) -> Result<String> {
697    match mode {
698        InjectMode::Header | InjectMode::BasicAuth => {
699            // No path transformation needed
700            Ok(path.to_string())
701        }
702        InjectMode::UrlPath => {
703            let pattern = path_pattern.ok_or_else(|| {
704                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
705            })?;
706            let replacement = path_replacement.unwrap_or(pattern);
707            transform_url_path(path, pattern, replacement, credential)
708        }
709        InjectMode::QueryParam => {
710            let param_name = query_param_name.ok_or_else(|| {
711                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
712            })?;
713            transform_query_param(path, param_name, credential)
714        }
715    }
716}
717
718/// Transform URL path by replacing phantom token pattern with real credential.
719///
720/// Example: `/bot<phantom>/getMe` with pattern `/bot{}/` becomes `/bot<real>/getMe`
721fn transform_url_path(
722    path: &str,
723    pattern: &str,
724    replacement: &str,
725    credential: &Zeroizing<String>,
726) -> Result<String> {
727    // Split pattern on {} to get prefix and suffix
728    let parts: Vec<&str> = pattern.split("{}").collect();
729    if parts.len() != 2 {
730        return Err(ProxyError::HttpParse(format!(
731            "invalid path_pattern '{}': must contain exactly one {{}}",
732            pattern
733        )));
734    }
735    let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
736
737    // Split replacement on {}
738    let repl_parts: Vec<&str> = replacement.split("{}").collect();
739    if repl_parts.len() != 2 {
740        return Err(ProxyError::HttpParse(format!(
741            "invalid path_replacement '{}': must contain exactly one {{}}",
742            replacement
743        )));
744    }
745    let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
746
747    // Find and replace the token in the path
748    if let Some(start) = path.find(pattern_prefix) {
749        let after_prefix = start + pattern_prefix.len();
750
751        // Handle empty suffix case (token extends to end of path or next '/' or '?')
752        let end_offset = if pattern_suffix.is_empty() {
753            // Find the next path segment delimiter or end of path
754            path[after_prefix..]
755                .find(['/', '?'])
756                .unwrap_or(path[after_prefix..].len())
757        } else {
758            // Find the suffix in the remaining path
759            match path[after_prefix..].find(pattern_suffix) {
760                Some(offset) => offset,
761                None => {
762                    return Err(ProxyError::HttpParse(format!(
763                        "path '{}' does not match pattern '{}'",
764                        path, pattern
765                    )));
766                }
767            }
768        };
769
770        let before = &path[..start];
771        let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
772        return Ok(format!(
773            "{}{}{}{}{}",
774            before,
775            repl_prefix,
776            credential.as_str(),
777            repl_suffix,
778            after
779        ));
780    }
781
782    Err(ProxyError::HttpParse(format!(
783        "path '{}' does not match pattern '{}'",
784        path, pattern
785    )))
786}
787
788/// Transform query string by adding or replacing a parameter with the credential.
789fn transform_query_param(
790    path: &str,
791    param_name: &str,
792    credential: &Zeroizing<String>,
793) -> Result<String> {
794    let encoded_value = urlencoding::encode(credential.as_str());
795
796    if let Some(query_start) = path.find('?') {
797        let base_path = &path[..query_start];
798        let query = &path[query_start + 1..];
799
800        // Check if parameter already exists
801        let mut found = false;
802        let new_query: Vec<String> = query
803            .split('&')
804            .map(|pair| {
805                if let Some((name, _)) = pair.split_once('=') {
806                    if name == param_name {
807                        found = true;
808                        return format!("{}={}", param_name, encoded_value);
809                    }
810                }
811                pair.to_string()
812            })
813            .collect();
814
815        if found {
816            Ok(format!("{}?{}", base_path, new_query.join("&")))
817        } else {
818            // Append the parameter
819            Ok(format!(
820                "{}?{}&{}={}",
821                base_path, query, param_name, encoded_value
822            ))
823        }
824    } else {
825        // No query string, add one
826        Ok(format!("{}?{}={}", path, param_name, encoded_value))
827    }
828}
829
830/// Inject credential into request based on mode.
831///
832/// For header/basic_auth modes, adds the credential header.
833/// For url_path/query_param modes, the credential is already in the path.
834fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
835    match cred.inject_mode {
836        InjectMode::Header | InjectMode::BasicAuth => {
837            // Inject credential header
838            request.push_str(&format!(
839                "{}: {}\r\n",
840                cred.header_name,
841                cred.header_value.as_str()
842            ));
843        }
844        InjectMode::UrlPath | InjectMode::QueryParam => {
845            // Credential is already injected into the URL path/query
846            // No header injection needed
847        }
848    }
849}
850
851#[cfg(test)]
852#[allow(clippy::unwrap_used)]
853mod tests {
854    use super::*;
855
856    #[test]
857    fn test_parse_request_line() {
858        let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
859        assert_eq!(method, "POST");
860        assert_eq!(path, "/openai/v1/chat");
861        assert_eq!(version, "HTTP/1.1");
862    }
863
864    #[test]
865    fn test_parse_request_line_malformed() {
866        assert!(parse_request_line("GET").is_err());
867    }
868
869    #[test]
870    fn test_parse_service_prefix() {
871        let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
872        assert_eq!(service, "openai");
873        assert_eq!(path, "/v1/chat/completions");
874    }
875
876    #[test]
877    fn test_parse_service_prefix_no_subpath() {
878        let (service, path) = parse_service_prefix("/anthropic").unwrap();
879        assert_eq!(service, "anthropic");
880        assert_eq!(path, "/");
881    }
882
883    #[test]
884    fn test_validate_phantom_token_bearer_valid() {
885        let token = Zeroizing::new("secret123".to_string());
886        let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
887        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
888    }
889
890    #[test]
891    fn test_validate_phantom_token_bearer_invalid() {
892        let token = Zeroizing::new("secret123".to_string());
893        let header = b"Authorization: Bearer wrong\r\n\r\n";
894        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
895    }
896
897    #[test]
898    fn test_validate_phantom_token_x_api_key_valid() {
899        let token = Zeroizing::new("secret123".to_string());
900        let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
901        assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
902    }
903
904    #[test]
905    fn test_validate_phantom_token_x_goog_api_key_valid() {
906        let token = Zeroizing::new("secret123".to_string());
907        let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
908        assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
909    }
910
911    #[test]
912    fn test_validate_phantom_token_missing() {
913        let token = Zeroizing::new("secret123".to_string());
914        let header = b"Content-Type: application/json\r\n\r\n";
915        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
916    }
917
918    #[test]
919    fn test_validate_phantom_token_case_insensitive_header() {
920        let token = Zeroizing::new("secret123".to_string());
921        let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
922        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
923    }
924
925    #[test]
926    fn test_filter_headers_removes_host_auth() {
927        let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
928        let filtered = filter_headers(header, "Authorization");
929        assert_eq!(filtered.len(), 2);
930        assert_eq!(filtered[0].0, "Content-Type");
931        assert_eq!(filtered[1].0, "Accept");
932    }
933
934    #[test]
935    fn test_filter_headers_removes_x_api_key() {
936        let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
937        let filtered = filter_headers(header, "x-api-key");
938        assert_eq!(filtered.len(), 1);
939        assert_eq!(filtered[0].0, "Content-Type");
940    }
941
942    #[test]
943    fn test_filter_headers_removes_custom_header() {
944        let header = b"PRIVATE-TOKEN: phantom123\r\nContent-Type: application/json\r\n\r\n";
945        let filtered = filter_headers(header, "PRIVATE-TOKEN");
946        assert_eq!(filtered.len(), 1);
947        assert_eq!(filtered[0].0, "Content-Type");
948    }
949
950    #[test]
951    fn test_extract_content_length() {
952        let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
953        assert_eq!(extract_content_length(header), Some(42));
954    }
955
956    #[test]
957    fn test_extract_content_length_missing() {
958        let header = b"Content-Type: application/json\r\n\r\n";
959        assert_eq!(extract_content_length(header), None);
960    }
961
962    #[test]
963    fn test_parse_upstream_url_https() {
964        let (host, port, path) =
965            parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
966        assert_eq!(host, "api.openai.com");
967        assert_eq!(port, 443);
968        assert_eq!(path, "/v1/chat/completions");
969    }
970
971    #[test]
972    fn test_parse_upstream_url_http_with_port() {
973        let (host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
974        assert_eq!(host, "localhost");
975        assert_eq!(port, 8080);
976        assert_eq!(path, "/api");
977    }
978
979    #[test]
980    fn test_parse_upstream_url_no_path() {
981        let (host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
982        assert_eq!(host, "api.anthropic.com");
983        assert_eq!(port, 443);
984        assert_eq!(path, "/");
985    }
986
987    #[test]
988    fn test_parse_upstream_url_invalid_scheme() {
989        assert!(parse_upstream_url("ftp://example.com").is_err());
990    }
991
992    #[test]
993    fn test_parse_response_status_200() {
994        let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
995        assert_eq!(parse_response_status(data), 200);
996    }
997
998    #[test]
999    fn test_parse_response_status_404() {
1000        let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
1001        assert_eq!(parse_response_status(data), 404);
1002    }
1003
1004    #[test]
1005    fn test_parse_response_status_garbage() {
1006        let data = b"not an http response";
1007        assert_eq!(parse_response_status(data), 502);
1008    }
1009
1010    #[test]
1011    fn test_parse_response_status_empty() {
1012        assert_eq!(parse_response_status(b""), 502);
1013    }
1014
1015    #[test]
1016    fn test_parse_response_status_partial() {
1017        let data = b"HTTP/1.1 ";
1018        assert_eq!(parse_response_status(data), 502);
1019    }
1020
1021    // ============================================================================
1022    // URL Path Injection Mode Tests
1023    // ============================================================================
1024
1025    #[test]
1026    fn test_validate_phantom_token_in_path_valid() {
1027        let token = Zeroizing::new("session123".to_string());
1028        let path = "/bot/session123/getMe";
1029        let pattern = "/bot/{}/";
1030        assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
1031    }
1032
1033    #[test]
1034    fn test_validate_phantom_token_in_path_invalid() {
1035        let token = Zeroizing::new("session123".to_string());
1036        let path = "/bot/wrong_token/getMe";
1037        let pattern = "/bot/{}/";
1038        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1039    }
1040
1041    #[test]
1042    fn test_validate_phantom_token_in_path_missing() {
1043        let token = Zeroizing::new("session123".to_string());
1044        let path = "/api/getMe";
1045        let pattern = "/bot/{}/";
1046        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1047    }
1048
1049    #[test]
1050    fn test_transform_url_path_basic() {
1051        let credential = Zeroizing::new("real_token".to_string());
1052        let path = "/bot/phantom_token/getMe";
1053        let pattern = "/bot/{}/";
1054        let replacement = "/bot/{}/";
1055        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1056        assert_eq!(result, "/bot/real_token/getMe");
1057    }
1058
1059    #[test]
1060    fn test_transform_url_path_different_replacement() {
1061        let credential = Zeroizing::new("real_token".to_string());
1062        let path = "/api/v1/phantom_token/chat";
1063        let pattern = "/api/v1/{}/";
1064        let replacement = "/v2/bot/{}/";
1065        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1066        assert_eq!(result, "/v2/bot/real_token/chat");
1067    }
1068
1069    #[test]
1070    fn test_transform_url_path_no_trailing_slash() {
1071        let credential = Zeroizing::new("real_token".to_string());
1072        let path = "/bot/phantom_token";
1073        let pattern = "/bot/{}";
1074        let replacement = "/bot/{}";
1075        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1076        assert_eq!(result, "/bot/real_token");
1077    }
1078
1079    // ============================================================================
1080    // Query Param Injection Mode Tests
1081    // ============================================================================
1082
1083    #[test]
1084    fn test_validate_phantom_token_in_query_valid() {
1085        let token = Zeroizing::new("session123".to_string());
1086        let path = "/api/data?api_key=session123&other=value";
1087        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1088    }
1089
1090    #[test]
1091    fn test_validate_phantom_token_in_query_invalid() {
1092        let token = Zeroizing::new("session123".to_string());
1093        let path = "/api/data?api_key=wrong_token";
1094        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1095    }
1096
1097    #[test]
1098    fn test_validate_phantom_token_in_query_missing_param() {
1099        let token = Zeroizing::new("session123".to_string());
1100        let path = "/api/data?other=value";
1101        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1102    }
1103
1104    #[test]
1105    fn test_validate_phantom_token_in_query_no_query_string() {
1106        let token = Zeroizing::new("session123".to_string());
1107        let path = "/api/data";
1108        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1109    }
1110
1111    #[test]
1112    fn test_validate_phantom_token_in_query_url_encoded() {
1113        let token = Zeroizing::new("token with spaces".to_string());
1114        let path = "/api/data?api_key=token%20with%20spaces";
1115        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1116    }
1117
1118    #[test]
1119    fn test_transform_query_param_add_to_no_query() {
1120        let credential = Zeroizing::new("real_key".to_string());
1121        let path = "/api/data";
1122        let result = transform_query_param(path, "api_key", &credential).unwrap();
1123        assert_eq!(result, "/api/data?api_key=real_key");
1124    }
1125
1126    #[test]
1127    fn test_transform_query_param_add_to_existing_query() {
1128        let credential = Zeroizing::new("real_key".to_string());
1129        let path = "/api/data?other=value";
1130        let result = transform_query_param(path, "api_key", &credential).unwrap();
1131        assert_eq!(result, "/api/data?other=value&api_key=real_key");
1132    }
1133
1134    #[test]
1135    fn test_transform_query_param_replace_existing() {
1136        let credential = Zeroizing::new("real_key".to_string());
1137        let path = "/api/data?api_key=phantom&other=value";
1138        let result = transform_query_param(path, "api_key", &credential).unwrap();
1139        assert_eq!(result, "/api/data?api_key=real_key&other=value");
1140    }
1141
1142    #[test]
1143    fn test_transform_query_param_url_encodes_special_chars() {
1144        let credential = Zeroizing::new("key with spaces".to_string());
1145        let path = "/api/data";
1146        let result = transform_query_param(path, "api_key", &credential).unwrap();
1147        assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1148    }
1149}