Skip to main content

nono_proxy/
reverse.rs

1//! Reverse proxy handler (Mode 2 — Credential Injection).
2//!
3//! Routes requests by path prefix to upstream APIs, injecting credentials
4//! from the keystore. The agent uses `http://localhost:PORT/openai/v1/chat`
5//! and the proxy rewrites to `https://api.openai.com/v1/chat` with the
6//! real credential injected.
7//!
8//! Supports multiple injection modes:
9//! - `header`: Inject into HTTP header (e.g., `Authorization: Bearer ...`)
10//! - `url_path`: Replace pattern in URL path (e.g., Telegram `/bot{}/`)
11//! - `query_param`: Add/replace query parameter (e.g., `?api_key=...`)
12//! - `basic_auth`: HTTP Basic Authentication
13//!
14//! Streaming responses (SSE, MCP Streamable HTTP, A2A JSON-RPC) are
15//! forwarded without buffering.
16
17use crate::audit;
18use crate::config::InjectMode;
19use crate::credential::{CredentialStore, LoadedCredential};
20use crate::error::{ProxyError, Result};
21use crate::filter::ProxyFilter;
22use crate::token;
23use std::net::SocketAddr;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31/// Maximum request body size (16 MiB). Prevents DoS from malicious Content-Length.
32const MAX_REQUEST_BODY: usize = 16 * 1024 * 1024;
33
34/// Timeout for upstream TCP connect.
35const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
36
37/// Handle a non-CONNECT HTTP request (reverse proxy mode).
38///
39/// Reads the full HTTP request from the client, matches path prefix to
40/// a configured route, injects credentials, and forwards to the upstream.
41/// Shared context passed from the server to the reverse proxy handler.
42pub struct ReverseProxyCtx<'a> {
43    /// Credential store for service lookups
44    pub credential_store: &'a CredentialStore,
45    /// Session token for authentication
46    pub session_token: &'a Zeroizing<String>,
47    /// Host filter for upstream validation
48    pub filter: &'a ProxyFilter,
49    /// Shared TLS connector
50    pub tls_connector: &'a TlsConnector,
51}
52
53/// Handle a non-CONNECT HTTP request (reverse proxy mode).
54///
55/// `buffered_body` contains any bytes the BufReader read ahead beyond the
56/// headers. These are prepended to the body read from the stream to prevent
57/// data loss.
58///
59/// ## Phantom Token Pattern
60///
61/// The client (SDK) sends the session token as its "API key". The proxy:
62/// 1. Extracts the service from the path (e.g., `/openai/v1/chat` → `openai`)
63/// 2. Looks up which header that service uses (e.g., `Authorization` or `x-api-key`)
64/// 3. Validates the phantom token from that header
65/// 4. Replaces it with the real credential from keyring
66pub async fn handle_reverse_proxy(
67    first_line: &str,
68    stream: &mut TcpStream,
69    remaining_header: &[u8],
70    ctx: &ReverseProxyCtx<'_>,
71    buffered_body: &[u8],
72) -> Result<()> {
73    // Parse method, path, and HTTP version
74    let (method, path, version) = parse_request_line(first_line)?;
75    debug!("Reverse proxy: {} {}", method, path);
76
77    // Extract service prefix from path (e.g., "/openai/v1/chat" -> ("openai", "/v1/chat"))
78    let (service, upstream_path) = parse_service_prefix(&path)?;
79
80    // Look up credential for service
81    let cred = ctx
82        .credential_store
83        .get(&service)
84        .ok_or_else(|| ProxyError::UnknownService {
85            prefix: service.clone(),
86        })?;
87
88    // Validate phantom token based on injection mode.
89    // For header/basic_auth modes: validate from Authorization/x-api-key header
90    // For url_path mode: validate from URL path pattern
91    // For query_param mode: validate from query parameter
92    if let Err(e) = validate_phantom_token_for_mode(
93        &cred.inject_mode,
94        remaining_header,
95        &upstream_path,
96        &cred.header_name,
97        cred.path_pattern.as_deref(),
98        cred.query_param_name.as_deref(),
99        ctx.session_token,
100    ) {
101        audit::log_denied(audit::ProxyMode::Reverse, &service, 0, &e.to_string());
102        send_error(stream, 401, "Unauthorized").await?;
103        return Ok(());
104    }
105
106    // Transform the path based on injection mode (url_path and query_param modes)
107    let transformed_path = transform_path_for_mode(
108        &cred.inject_mode,
109        &upstream_path,
110        cred.path_pattern.as_deref(),
111        cred.path_replacement.as_deref(),
112        cred.query_param_name.as_deref(),
113        &cred.raw_credential,
114    )?;
115
116    // Parse upstream URL with potentially transformed path
117    let upstream_url = format!(
118        "{}{}",
119        cred.upstream.trim_end_matches('/'),
120        transformed_path
121    );
122    debug!("Forwarding to upstream: {} {}", method, upstream_url);
123
124    let (upstream_host, upstream_port, upstream_path_full) = parse_upstream_url(&upstream_url)?;
125
126    // DNS resolve + CIDR check via the filter (prevents rebinding TOCTOU)
127    let check = ctx.filter.check_host(&upstream_host, upstream_port).await?;
128    if !check.result.is_allowed() {
129        let reason = check.result.reason();
130        warn!("Upstream host denied by filter: {}", reason);
131        send_error(stream, 403, "Forbidden").await?;
132        audit::log_denied(audit::ProxyMode::Reverse, &service, 0, &reason);
133        return Ok(());
134    }
135
136    // Collect remaining request headers (excluding X-Nono-Token and Host)
137    let filtered_headers = filter_headers(remaining_header);
138    let content_length = extract_content_length(remaining_header);
139
140    // Read request body if present, with size limit.
141    // `buffered_body` may contain bytes the BufReader read ahead beyond
142    // headers; we prepend those to avoid data loss.
143    let body = if let Some(len) = content_length {
144        if len > MAX_REQUEST_BODY {
145            send_error(stream, 413, "Payload Too Large").await?;
146            return Ok(());
147        }
148        let mut buf = Vec::with_capacity(len);
149        let pre = buffered_body.len().min(len);
150        buf.extend_from_slice(&buffered_body[..pre]);
151        let remaining = len - pre;
152        if remaining > 0 {
153            let mut rest = vec![0u8; remaining];
154            stream.read_exact(&mut rest).await?;
155            buf.extend_from_slice(&rest);
156        }
157        buf
158    } else {
159        Vec::new()
160    };
161
162    // Connect to upstream over TLS using pre-resolved addresses
163    let upstream_result = connect_upstream_tls(
164        &upstream_host,
165        upstream_port,
166        &check.resolved_addrs,
167        ctx.tls_connector,
168    )
169    .await;
170    let mut tls_stream = match upstream_result {
171        Ok(s) => s,
172        Err(e) => {
173            warn!("Upstream connection failed: {}", e);
174            send_error(stream, 502, "Bad Gateway").await?;
175            audit::log_denied(audit::ProxyMode::Reverse, &service, 0, &e.to_string());
176            return Ok(());
177        }
178    };
179
180    // Build the upstream request into a Zeroizing buffer since it may contain
181    // credential values. This ensures credentials are zeroed from heap memory
182    // when the buffer is dropped.
183    let mut request = Zeroizing::new(format!(
184        "{} {} {}\r\nHost: {}\r\n",
185        method, upstream_path_full, version, upstream_host
186    ));
187
188    // Inject credential based on mode
189    inject_credential_for_mode(cred, &mut request);
190
191    // Forward filtered headers (excluding auth headers that we're replacing)
192    let auth_header_lower = cred.header_name.to_lowercase();
193    for (name, value) in &filtered_headers {
194        // Skip the auth header if we're using header/basic_auth mode
195        // (we already injected our own)
196        if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
197            && name.to_lowercase() == auth_header_lower
198        {
199            continue;
200        }
201        request.push_str(&format!("{}: {}\r\n", name, value));
202    }
203
204    // Content-Length for body
205    if !body.is_empty() {
206        request.push_str(&format!("Content-Length: {}\r\n", body.len()));
207    }
208    request.push_str("\r\n");
209
210    tls_stream.write_all(request.as_bytes()).await?;
211    if !body.is_empty() {
212        tls_stream.write_all(&body).await?;
213    }
214    tls_stream.flush().await?;
215
216    // Stream the response back to the client without buffering.
217    // This handles SSE (text/event-stream), chunked transfer, and regular responses.
218    let mut response_buf = [0u8; 8192];
219    let mut status_code: u16 = 502;
220    let mut first_chunk = true;
221
222    loop {
223        let n = match tls_stream.read(&mut response_buf).await {
224            Ok(0) => break,
225            Ok(n) => n,
226            Err(e) => {
227                debug!("Upstream read error: {}", e);
228                break;
229            }
230        };
231
232        // Parse status from first chunk. The HTTP status line format is:
233        // "HTTP/1.1 200 OK\r\n..." — we need the 3-digit code after the
234        // first space. We scan up to 32 bytes (enough for any valid status line).
235        if first_chunk {
236            status_code = parse_response_status(&response_buf[..n]);
237            first_chunk = false;
238        }
239
240        stream.write_all(&response_buf[..n]).await?;
241        stream.flush().await?;
242    }
243
244    audit::log_reverse_proxy(&service, &method, &upstream_path, status_code);
245    Ok(())
246}
247
248/// Parse an HTTP request line into (method, path, version).
249fn parse_request_line(line: &str) -> Result<(String, String, String)> {
250    let parts: Vec<&str> = line.split_whitespace().collect();
251    if parts.len() < 3 {
252        return Err(ProxyError::HttpParse(format!(
253            "malformed request line: {}",
254            line
255        )));
256    }
257    Ok((
258        parts[0].to_string(),
259        parts[1].to_string(),
260        parts[2].to_string(),
261    ))
262}
263
264/// Extract service prefix from path.
265///
266/// "/openai/v1/chat/completions" -> ("openai", "/v1/chat/completions")
267/// "/anthropic/v1/messages" -> ("anthropic", "/v1/messages")
268fn parse_service_prefix(path: &str) -> Result<(String, String)> {
269    let trimmed = path.strip_prefix('/').unwrap_or(path);
270    if let Some((prefix, rest)) = trimmed.split_once('/') {
271        Ok((prefix.to_string(), format!("/{}", rest)))
272    } else {
273        // No sub-path, just the prefix
274        Ok((trimmed.to_string(), "/".to_string()))
275    }
276}
277
278/// Validate the phantom token from the service's auth header.
279///
280/// The SDK sends the session token as its "API key" in the standard auth header
281/// for that service (e.g., `Authorization: Bearer <token>` for OpenAI,
282/// `x-api-key: <token>` for Anthropic). We validate the token matches the
283/// session token before swapping in the real credential.
284fn validate_phantom_token(
285    header_bytes: &[u8],
286    header_name: &str,
287    session_token: &Zeroizing<String>,
288) -> Result<()> {
289    let header_str = std::str::from_utf8(header_bytes).map_err(|_| ProxyError::InvalidToken)?;
290    let header_name_lower = header_name.to_lowercase();
291
292    for line in header_str.lines() {
293        let lower = line.to_lowercase();
294        if lower.starts_with(&format!("{}:", header_name_lower)) {
295            let value = line.split_once(':').map(|(_, v)| v.trim()).unwrap_or("");
296
297            // Handle "Bearer <token>" format (strip "Bearer " prefix if present)
298            // Use case-insensitive check, then slice original value by length
299            let value_lower = value.to_lowercase();
300            let token_value = if value_lower.starts_with("bearer ") {
301                // "bearer ".len() == 7
302                value[7..].trim()
303            } else {
304                value
305            };
306
307            if token::constant_time_eq(token_value.as_bytes(), session_token.as_bytes()) {
308                return Ok(());
309            }
310            warn!("Invalid phantom token in {} header", header_name);
311            return Err(ProxyError::InvalidToken);
312        }
313    }
314
315    warn!(
316        "Missing {} header for phantom token validation",
317        header_name
318    );
319    Err(ProxyError::InvalidToken)
320}
321
322/// Filter headers, removing Host, Content-Length, and auth headers.
323///
324/// Content-Length is re-added after body is read, and Host is rewritten
325/// to the upstream. Authorization and x-api-key headers are stripped since
326/// we inject our own credential (the phantom token is validated but not forwarded).
327fn filter_headers(header_bytes: &[u8]) -> Vec<(String, String)> {
328    let header_str = std::str::from_utf8(header_bytes).unwrap_or("");
329    let mut headers = Vec::new();
330
331    for line in header_str.lines() {
332        let lower = line.to_lowercase();
333        if lower.starts_with("host:")
334            || lower.starts_with("content-length:")
335            || lower.starts_with("authorization:")
336            || lower.starts_with("x-api-key:")
337            || lower.starts_with("x-goog-api-key:")
338            || line.trim().is_empty()
339        {
340            continue;
341        }
342        if let Some((name, value)) = line.split_once(':') {
343            headers.push((name.trim().to_string(), value.trim().to_string()));
344        }
345    }
346
347    headers
348}
349
350/// Extract Content-Length value from raw headers.
351fn extract_content_length(header_bytes: &[u8]) -> Option<usize> {
352    let header_str = std::str::from_utf8(header_bytes).ok()?;
353    for line in header_str.lines() {
354        if line.to_lowercase().starts_with("content-length:") {
355            let value = line.split_once(':')?.1.trim();
356            return value.parse().ok();
357        }
358    }
359    None
360}
361
362/// Parse an upstream URL into (host, port, path).
363fn parse_upstream_url(url_str: &str) -> Result<(String, u16, String)> {
364    let parsed = url::Url::parse(url_str)
365        .map_err(|e| ProxyError::HttpParse(format!("invalid upstream URL '{}': {}", url_str, e)))?;
366
367    let scheme = parsed.scheme();
368    if scheme != "https" && scheme != "http" {
369        return Err(ProxyError::HttpParse(format!(
370            "unsupported URL scheme: {}",
371            url_str
372        )));
373    }
374
375    let host = parsed
376        .host_str()
377        .ok_or_else(|| ProxyError::HttpParse(format!("missing host in URL: {}", url_str)))?
378        .to_string();
379
380    let default_port = if scheme == "https" { 443 } else { 80 };
381    let port = parsed.port().unwrap_or(default_port);
382
383    let path = parsed.path().to_string();
384    let path = if path.is_empty() {
385        "/".to_string()
386    } else {
387        path
388    };
389
390    // Include query string if present
391    let path_with_query = if let Some(query) = parsed.query() {
392        format!("{}?{}", path, query)
393    } else {
394        path
395    };
396
397    Ok((host, port, path_with_query))
398}
399
400/// Connect to an upstream host over TLS using pre-resolved addresses.
401///
402/// Uses the pre-resolved `SocketAddr`s from the filter check to prevent
403/// DNS rebinding TOCTOU. Falls back to hostname resolution only if no
404/// pre-resolved addresses are available.
405///
406/// The `TlsConnector` is shared across all connections (created once at
407/// server startup with the system root certificate store).
408async fn connect_upstream_tls(
409    host: &str,
410    port: u16,
411    resolved_addrs: &[SocketAddr],
412    connector: &TlsConnector,
413) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
414    let tcp = if resolved_addrs.is_empty() {
415        // Fallback: no pre-resolved addresses (shouldn't happen in practice)
416        let addr = format!("{}:{}", host, port);
417        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr)).await {
418            Ok(Ok(s)) => s,
419            Ok(Err(e)) => {
420                return Err(ProxyError::UpstreamConnect {
421                    host: host.to_string(),
422                    reason: e.to_string(),
423                });
424            }
425            Err(_) => {
426                return Err(ProxyError::UpstreamConnect {
427                    host: host.to_string(),
428                    reason: "connection timed out".to_string(),
429                });
430            }
431        }
432    } else {
433        connect_to_resolved(resolved_addrs, host).await?
434    };
435
436    let server_name = rustls::pki_types::ServerName::try_from(host.to_string()).map_err(|_| {
437        ProxyError::UpstreamConnect {
438            host: host.to_string(),
439            reason: "invalid server name for TLS".to_string(),
440        }
441    })?;
442
443    let tls_stream =
444        connector
445            .connect(server_name, tcp)
446            .await
447            .map_err(|e| ProxyError::UpstreamConnect {
448                host: host.to_string(),
449                reason: format!("TLS handshake failed: {}", e),
450            })?;
451
452    Ok(tls_stream)
453}
454
455/// Connect to one of the pre-resolved socket addresses with timeout.
456async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
457    let mut last_err = None;
458    for addr in addrs {
459        match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
460            Ok(Ok(stream)) => return Ok(stream),
461            Ok(Err(e)) => {
462                debug!("Connect to {} failed: {}", addr, e);
463                last_err = Some(e.to_string());
464            }
465            Err(_) => {
466                debug!("Connect to {} timed out", addr);
467                last_err = Some("connection timed out".to_string());
468            }
469        }
470    }
471    Err(ProxyError::UpstreamConnect {
472        host: host.to_string(),
473        reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
474    })
475}
476
477/// Parse HTTP status code from the first response chunk.
478///
479/// Looks for the "HTTP/x.y NNN" pattern in the first line. Returns 502
480/// if the response doesn't contain a valid status line (upstream sent
481/// garbage or incomplete data).
482fn parse_response_status(data: &[u8]) -> u16 {
483    // Find the end of the first line (or use full data if no newline)
484    let line_end = data
485        .iter()
486        .position(|&b| b == b'\r' || b == b'\n')
487        .unwrap_or(data.len());
488    let first_line = &data[..line_end.min(64)];
489
490    if let Ok(line) = std::str::from_utf8(first_line) {
491        // Split on whitespace: ["HTTP/1.1", "200", "OK"]
492        let mut parts = line.split_whitespace();
493        if let Some(version) = parts.next() {
494            if version.starts_with("HTTP/") {
495                if let Some(code_str) = parts.next() {
496                    if code_str.len() == 3 {
497                        return code_str.parse().unwrap_or(502);
498                    }
499                }
500            }
501        }
502    }
503    502
504}
505
506/// Send an HTTP error response.
507async fn send_error(stream: &mut TcpStream, status: u16, reason: &str) -> Result<()> {
508    let body = format!("{{\"error\":\"{}\"}}", reason);
509    let response = format!(
510        "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
511        status,
512        reason,
513        body.len(),
514        body
515    );
516    stream.write_all(response.as_bytes()).await?;
517    stream.flush().await?;
518    Ok(())
519}
520
521// ============================================================================
522// Injection mode helpers
523// ============================================================================
524
525/// Validate phantom token based on injection mode.
526///
527/// Different modes extract the phantom token from different locations:
528/// - `Header`/`BasicAuth`: From the auth header (Authorization, x-api-key, etc.)
529/// - `UrlPath`: From the URL path pattern (e.g., `/bot<token>/getMe`)
530/// - `QueryParam`: From the query parameter (e.g., `?api_key=<token>`)
531fn validate_phantom_token_for_mode(
532    mode: &InjectMode,
533    header_bytes: &[u8],
534    path: &str,
535    header_name: &str,
536    path_pattern: Option<&str>,
537    query_param_name: Option<&str>,
538    session_token: &Zeroizing<String>,
539) -> Result<()> {
540    match mode {
541        InjectMode::Header | InjectMode::BasicAuth => {
542            // Validate from header (existing behavior)
543            validate_phantom_token(header_bytes, header_name, session_token)
544        }
545        InjectMode::UrlPath => {
546            // Validate from URL path
547            let pattern = path_pattern.ok_or_else(|| {
548                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
549            })?;
550            validate_phantom_token_in_path(path, pattern, session_token)
551        }
552        InjectMode::QueryParam => {
553            // Validate from query parameter
554            let param_name = query_param_name.ok_or_else(|| {
555                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
556            })?;
557            validate_phantom_token_in_query(path, param_name, session_token)
558        }
559    }
560}
561
562/// Validate phantom token embedded in URL path.
563///
564/// Extracts the token from the path using the pattern (e.g., `/bot{}/` matches
565/// `/bot<token>/getMe` and extracts `<token>`).
566fn validate_phantom_token_in_path(
567    path: &str,
568    pattern: &str,
569    session_token: &Zeroizing<String>,
570) -> Result<()> {
571    // Split pattern on {} to get prefix and suffix
572    let parts: Vec<&str> = pattern.split("{}").collect();
573    if parts.len() != 2 {
574        return Err(ProxyError::HttpParse(format!(
575            "invalid path_pattern '{}': must contain exactly one {{}}",
576            pattern
577        )));
578    }
579    let (prefix, suffix) = (parts[0], parts[1]);
580
581    // Find the token in the path
582    if let Some(start) = path.find(prefix) {
583        let after_prefix = start + prefix.len();
584
585        // Handle empty suffix case (token extends to end of path or next '/' or '?')
586        let end_offset = if suffix.is_empty() {
587            path[after_prefix..]
588                .find(['/', '?'])
589                .unwrap_or(path[after_prefix..].len())
590        } else {
591            match path[after_prefix..].find(suffix) {
592                Some(offset) => offset,
593                None => {
594                    warn!("Missing phantom token in URL path (pattern: {})", pattern);
595                    return Err(ProxyError::InvalidToken);
596                }
597            }
598        };
599
600        let token = &path[after_prefix..after_prefix + end_offset];
601        if token::constant_time_eq(token.as_bytes(), session_token.as_bytes()) {
602            return Ok(());
603        }
604        warn!("Invalid phantom token in URL path");
605        return Err(ProxyError::InvalidToken);
606    }
607
608    warn!("Missing phantom token in URL path (pattern: {})", pattern);
609    Err(ProxyError::InvalidToken)
610}
611
612/// Validate phantom token in query parameter.
613fn validate_phantom_token_in_query(
614    path: &str,
615    param_name: &str,
616    session_token: &Zeroizing<String>,
617) -> Result<()> {
618    // Parse query string from path
619    if let Some(query_start) = path.find('?') {
620        let query = &path[query_start + 1..];
621        for pair in query.split('&') {
622            if let Some((name, value)) = pair.split_once('=') {
623                if name == param_name {
624                    // URL-decode the value
625                    let decoded = urlencoding::decode(value).unwrap_or_else(|_| value.into());
626                    if token::constant_time_eq(decoded.as_bytes(), session_token.as_bytes()) {
627                        return Ok(());
628                    }
629                    warn!("Invalid phantom token in query parameter '{}'", param_name);
630                    return Err(ProxyError::InvalidToken);
631                }
632            }
633        }
634    }
635
636    warn!("Missing phantom token in query parameter '{}'", param_name);
637    Err(ProxyError::InvalidToken)
638}
639
640/// Transform URL path based on injection mode.
641///
642/// - `UrlPath`: Replace phantom token with real credential in path
643/// - `QueryParam`: Add/replace query parameter with real credential
644/// - `Header`/`BasicAuth`: No path transformation needed
645fn transform_path_for_mode(
646    mode: &InjectMode,
647    path: &str,
648    path_pattern: Option<&str>,
649    path_replacement: Option<&str>,
650    query_param_name: Option<&str>,
651    credential: &Zeroizing<String>,
652) -> Result<String> {
653    match mode {
654        InjectMode::Header | InjectMode::BasicAuth => {
655            // No path transformation needed
656            Ok(path.to_string())
657        }
658        InjectMode::UrlPath => {
659            let pattern = path_pattern.ok_or_else(|| {
660                ProxyError::HttpParse("url_path mode requires path_pattern".to_string())
661            })?;
662            let replacement = path_replacement.unwrap_or(pattern);
663            transform_url_path(path, pattern, replacement, credential)
664        }
665        InjectMode::QueryParam => {
666            let param_name = query_param_name.ok_or_else(|| {
667                ProxyError::HttpParse("query_param mode requires query_param_name".to_string())
668            })?;
669            transform_query_param(path, param_name, credential)
670        }
671    }
672}
673
674/// Transform URL path by replacing phantom token pattern with real credential.
675///
676/// Example: `/bot<phantom>/getMe` with pattern `/bot{}/` becomes `/bot<real>/getMe`
677fn transform_url_path(
678    path: &str,
679    pattern: &str,
680    replacement: &str,
681    credential: &Zeroizing<String>,
682) -> Result<String> {
683    // Split pattern on {} to get prefix and suffix
684    let parts: Vec<&str> = pattern.split("{}").collect();
685    if parts.len() != 2 {
686        return Err(ProxyError::HttpParse(format!(
687            "invalid path_pattern '{}': must contain exactly one {{}}",
688            pattern
689        )));
690    }
691    let (pattern_prefix, pattern_suffix) = (parts[0], parts[1]);
692
693    // Split replacement on {}
694    let repl_parts: Vec<&str> = replacement.split("{}").collect();
695    if repl_parts.len() != 2 {
696        return Err(ProxyError::HttpParse(format!(
697            "invalid path_replacement '{}': must contain exactly one {{}}",
698            replacement
699        )));
700    }
701    let (repl_prefix, repl_suffix) = (repl_parts[0], repl_parts[1]);
702
703    // Find and replace the token in the path
704    if let Some(start) = path.find(pattern_prefix) {
705        let after_prefix = start + pattern_prefix.len();
706
707        // Handle empty suffix case (token extends to end of path or next '/' or '?')
708        let end_offset = if pattern_suffix.is_empty() {
709            // Find the next path segment delimiter or end of path
710            path[after_prefix..]
711                .find(['/', '?'])
712                .unwrap_or(path[after_prefix..].len())
713        } else {
714            // Find the suffix in the remaining path
715            match path[after_prefix..].find(pattern_suffix) {
716                Some(offset) => offset,
717                None => {
718                    return Err(ProxyError::HttpParse(format!(
719                        "path '{}' does not match pattern '{}'",
720                        path, pattern
721                    )));
722                }
723            }
724        };
725
726        let before = &path[..start];
727        let after = &path[after_prefix + end_offset + pattern_suffix.len()..];
728        return Ok(format!(
729            "{}{}{}{}{}",
730            before,
731            repl_prefix,
732            credential.as_str(),
733            repl_suffix,
734            after
735        ));
736    }
737
738    Err(ProxyError::HttpParse(format!(
739        "path '{}' does not match pattern '{}'",
740        path, pattern
741    )))
742}
743
744/// Transform query string by adding or replacing a parameter with the credential.
745fn transform_query_param(
746    path: &str,
747    param_name: &str,
748    credential: &Zeroizing<String>,
749) -> Result<String> {
750    let encoded_value = urlencoding::encode(credential.as_str());
751
752    if let Some(query_start) = path.find('?') {
753        let base_path = &path[..query_start];
754        let query = &path[query_start + 1..];
755
756        // Check if parameter already exists
757        let mut found = false;
758        let new_query: Vec<String> = query
759            .split('&')
760            .map(|pair| {
761                if let Some((name, _)) = pair.split_once('=') {
762                    if name == param_name {
763                        found = true;
764                        return format!("{}={}", param_name, encoded_value);
765                    }
766                }
767                pair.to_string()
768            })
769            .collect();
770
771        if found {
772            Ok(format!("{}?{}", base_path, new_query.join("&")))
773        } else {
774            // Append the parameter
775            Ok(format!(
776                "{}?{}&{}={}",
777                base_path, query, param_name, encoded_value
778            ))
779        }
780    } else {
781        // No query string, add one
782        Ok(format!("{}?{}={}", path, param_name, encoded_value))
783    }
784}
785
786/// Inject credential into request based on mode.
787///
788/// For header/basic_auth modes, adds the credential header.
789/// For url_path/query_param modes, the credential is already in the path.
790fn inject_credential_for_mode(cred: &LoadedCredential, request: &mut Zeroizing<String>) {
791    match cred.inject_mode {
792        InjectMode::Header | InjectMode::BasicAuth => {
793            // Inject credential header
794            request.push_str(&format!(
795                "{}: {}\r\n",
796                cred.header_name,
797                cred.header_value.as_str()
798            ));
799        }
800        InjectMode::UrlPath | InjectMode::QueryParam => {
801            // Credential is already injected into the URL path/query
802            // No header injection needed
803        }
804    }
805}
806
807#[cfg(test)]
808#[allow(clippy::unwrap_used)]
809mod tests {
810    use super::*;
811
812    #[test]
813    fn test_parse_request_line() {
814        let (method, path, version) = parse_request_line("POST /openai/v1/chat HTTP/1.1").unwrap();
815        assert_eq!(method, "POST");
816        assert_eq!(path, "/openai/v1/chat");
817        assert_eq!(version, "HTTP/1.1");
818    }
819
820    #[test]
821    fn test_parse_request_line_malformed() {
822        assert!(parse_request_line("GET").is_err());
823    }
824
825    #[test]
826    fn test_parse_service_prefix() {
827        let (service, path) = parse_service_prefix("/openai/v1/chat/completions").unwrap();
828        assert_eq!(service, "openai");
829        assert_eq!(path, "/v1/chat/completions");
830    }
831
832    #[test]
833    fn test_parse_service_prefix_no_subpath() {
834        let (service, path) = parse_service_prefix("/anthropic").unwrap();
835        assert_eq!(service, "anthropic");
836        assert_eq!(path, "/");
837    }
838
839    #[test]
840    fn test_validate_phantom_token_bearer_valid() {
841        let token = Zeroizing::new("secret123".to_string());
842        let header = b"Authorization: Bearer secret123\r\nContent-Type: application/json\r\n\r\n";
843        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
844    }
845
846    #[test]
847    fn test_validate_phantom_token_bearer_invalid() {
848        let token = Zeroizing::new("secret123".to_string());
849        let header = b"Authorization: Bearer wrong\r\n\r\n";
850        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
851    }
852
853    #[test]
854    fn test_validate_phantom_token_x_api_key_valid() {
855        let token = Zeroizing::new("secret123".to_string());
856        let header = b"x-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
857        assert!(validate_phantom_token(header, "x-api-key", &token).is_ok());
858    }
859
860    #[test]
861    fn test_validate_phantom_token_x_goog_api_key_valid() {
862        let token = Zeroizing::new("secret123".to_string());
863        let header = b"x-goog-api-key: secret123\r\nContent-Type: application/json\r\n\r\n";
864        assert!(validate_phantom_token(header, "x-goog-api-key", &token).is_ok());
865    }
866
867    #[test]
868    fn test_validate_phantom_token_missing() {
869        let token = Zeroizing::new("secret123".to_string());
870        let header = b"Content-Type: application/json\r\n\r\n";
871        assert!(validate_phantom_token(header, "Authorization", &token).is_err());
872    }
873
874    #[test]
875    fn test_validate_phantom_token_case_insensitive_header() {
876        let token = Zeroizing::new("secret123".to_string());
877        let header = b"AUTHORIZATION: Bearer secret123\r\n\r\n";
878        assert!(validate_phantom_token(header, "Authorization", &token).is_ok());
879    }
880
881    #[test]
882    fn test_filter_headers_removes_host_auth() {
883        let header = b"Host: localhost:8080\r\nAuthorization: Bearer old\r\nContent-Type: application/json\r\nAccept: */*\r\n\r\n";
884        let filtered = filter_headers(header);
885        assert_eq!(filtered.len(), 2);
886        assert_eq!(filtered[0].0, "Content-Type");
887        assert_eq!(filtered[1].0, "Accept");
888    }
889
890    #[test]
891    fn test_filter_headers_removes_x_api_key() {
892        let header = b"x-api-key: sk-old\r\nContent-Type: application/json\r\n\r\n";
893        let filtered = filter_headers(header);
894        assert_eq!(filtered.len(), 1);
895        assert_eq!(filtered[0].0, "Content-Type");
896    }
897
898    #[test]
899    fn test_filter_headers_removes_x_goog_api_key() {
900        let header = b"x-goog-api-key: gemini-key\r\nContent-Type: application/json\r\n\r\n";
901        let filtered = filter_headers(header);
902        assert_eq!(filtered.len(), 1);
903        assert_eq!(filtered[0].0, "Content-Type");
904    }
905
906    #[test]
907    fn test_extract_content_length() {
908        let header = b"Content-Type: application/json\r\nContent-Length: 42\r\n\r\n";
909        assert_eq!(extract_content_length(header), Some(42));
910    }
911
912    #[test]
913    fn test_extract_content_length_missing() {
914        let header = b"Content-Type: application/json\r\n\r\n";
915        assert_eq!(extract_content_length(header), None);
916    }
917
918    #[test]
919    fn test_parse_upstream_url_https() {
920        let (host, port, path) =
921            parse_upstream_url("https://api.openai.com/v1/chat/completions").unwrap();
922        assert_eq!(host, "api.openai.com");
923        assert_eq!(port, 443);
924        assert_eq!(path, "/v1/chat/completions");
925    }
926
927    #[test]
928    fn test_parse_upstream_url_http_with_port() {
929        let (host, port, path) = parse_upstream_url("http://localhost:8080/api").unwrap();
930        assert_eq!(host, "localhost");
931        assert_eq!(port, 8080);
932        assert_eq!(path, "/api");
933    }
934
935    #[test]
936    fn test_parse_upstream_url_no_path() {
937        let (host, port, path) = parse_upstream_url("https://api.anthropic.com").unwrap();
938        assert_eq!(host, "api.anthropic.com");
939        assert_eq!(port, 443);
940        assert_eq!(path, "/");
941    }
942
943    #[test]
944    fn test_parse_upstream_url_invalid_scheme() {
945        assert!(parse_upstream_url("ftp://example.com").is_err());
946    }
947
948    #[test]
949    fn test_parse_response_status_200() {
950        let data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n";
951        assert_eq!(parse_response_status(data), 200);
952    }
953
954    #[test]
955    fn test_parse_response_status_404() {
956        let data = b"HTTP/1.1 404 Not Found\r\n\r\n";
957        assert_eq!(parse_response_status(data), 404);
958    }
959
960    #[test]
961    fn test_parse_response_status_garbage() {
962        let data = b"not an http response";
963        assert_eq!(parse_response_status(data), 502);
964    }
965
966    #[test]
967    fn test_parse_response_status_empty() {
968        assert_eq!(parse_response_status(b""), 502);
969    }
970
971    #[test]
972    fn test_parse_response_status_partial() {
973        let data = b"HTTP/1.1 ";
974        assert_eq!(parse_response_status(data), 502);
975    }
976
977    // ============================================================================
978    // URL Path Injection Mode Tests
979    // ============================================================================
980
981    #[test]
982    fn test_validate_phantom_token_in_path_valid() {
983        let token = Zeroizing::new("session123".to_string());
984        let path = "/bot/session123/getMe";
985        let pattern = "/bot/{}/";
986        assert!(validate_phantom_token_in_path(path, pattern, &token).is_ok());
987    }
988
989    #[test]
990    fn test_validate_phantom_token_in_path_invalid() {
991        let token = Zeroizing::new("session123".to_string());
992        let path = "/bot/wrong_token/getMe";
993        let pattern = "/bot/{}/";
994        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
995    }
996
997    #[test]
998    fn test_validate_phantom_token_in_path_missing() {
999        let token = Zeroizing::new("session123".to_string());
1000        let path = "/api/getMe";
1001        let pattern = "/bot/{}/";
1002        assert!(validate_phantom_token_in_path(path, pattern, &token).is_err());
1003    }
1004
1005    #[test]
1006    fn test_transform_url_path_basic() {
1007        let credential = Zeroizing::new("real_token".to_string());
1008        let path = "/bot/phantom_token/getMe";
1009        let pattern = "/bot/{}/";
1010        let replacement = "/bot/{}/";
1011        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1012        assert_eq!(result, "/bot/real_token/getMe");
1013    }
1014
1015    #[test]
1016    fn test_transform_url_path_different_replacement() {
1017        let credential = Zeroizing::new("real_token".to_string());
1018        let path = "/api/v1/phantom_token/chat";
1019        let pattern = "/api/v1/{}/";
1020        let replacement = "/v2/bot/{}/";
1021        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1022        assert_eq!(result, "/v2/bot/real_token/chat");
1023    }
1024
1025    #[test]
1026    fn test_transform_url_path_no_trailing_slash() {
1027        let credential = Zeroizing::new("real_token".to_string());
1028        let path = "/bot/phantom_token";
1029        let pattern = "/bot/{}";
1030        let replacement = "/bot/{}";
1031        let result = transform_url_path(path, pattern, replacement, &credential).unwrap();
1032        assert_eq!(result, "/bot/real_token");
1033    }
1034
1035    // ============================================================================
1036    // Query Param Injection Mode Tests
1037    // ============================================================================
1038
1039    #[test]
1040    fn test_validate_phantom_token_in_query_valid() {
1041        let token = Zeroizing::new("session123".to_string());
1042        let path = "/api/data?api_key=session123&other=value";
1043        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1044    }
1045
1046    #[test]
1047    fn test_validate_phantom_token_in_query_invalid() {
1048        let token = Zeroizing::new("session123".to_string());
1049        let path = "/api/data?api_key=wrong_token";
1050        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1051    }
1052
1053    #[test]
1054    fn test_validate_phantom_token_in_query_missing_param() {
1055        let token = Zeroizing::new("session123".to_string());
1056        let path = "/api/data?other=value";
1057        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1058    }
1059
1060    #[test]
1061    fn test_validate_phantom_token_in_query_no_query_string() {
1062        let token = Zeroizing::new("session123".to_string());
1063        let path = "/api/data";
1064        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_err());
1065    }
1066
1067    #[test]
1068    fn test_validate_phantom_token_in_query_url_encoded() {
1069        let token = Zeroizing::new("token with spaces".to_string());
1070        let path = "/api/data?api_key=token%20with%20spaces";
1071        assert!(validate_phantom_token_in_query(path, "api_key", &token).is_ok());
1072    }
1073
1074    #[test]
1075    fn test_transform_query_param_add_to_no_query() {
1076        let credential = Zeroizing::new("real_key".to_string());
1077        let path = "/api/data";
1078        let result = transform_query_param(path, "api_key", &credential).unwrap();
1079        assert_eq!(result, "/api/data?api_key=real_key");
1080    }
1081
1082    #[test]
1083    fn test_transform_query_param_add_to_existing_query() {
1084        let credential = Zeroizing::new("real_key".to_string());
1085        let path = "/api/data?other=value";
1086        let result = transform_query_param(path, "api_key", &credential).unwrap();
1087        assert_eq!(result, "/api/data?other=value&api_key=real_key");
1088    }
1089
1090    #[test]
1091    fn test_transform_query_param_replace_existing() {
1092        let credential = Zeroizing::new("real_key".to_string());
1093        let path = "/api/data?api_key=phantom&other=value";
1094        let result = transform_query_param(path, "api_key", &credential).unwrap();
1095        assert_eq!(result, "/api/data?api_key=real_key&other=value");
1096    }
1097
1098    #[test]
1099    fn test_transform_query_param_url_encodes_special_chars() {
1100        let credential = Zeroizing::new("key with spaces".to_string());
1101        let path = "/api/data";
1102        let result = transform_query_param(path, "api_key", &credential).unwrap();
1103        assert_eq!(result, "/api/data?api_key=key%20with%20spaces");
1104    }
1105}