Skip to main content

nono_proxy/
config.rs

1//! Proxy configuration types.
2//!
3//! Defines the configuration for the proxy server, including allowed hosts,
4//! credential routes, and external proxy settings.
5
6use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9use std::path::PathBuf;
10use zeroize::Zeroizing;
11
12/// Credential injection mode determining how credentials are inserted into requests.
13#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum InjectMode {
16    /// Inject credential into an HTTP header (default)
17    #[default]
18    Header,
19    /// Replace a pattern in the URL path with the credential
20    UrlPath,
21    /// Add or replace a query parameter with the credential
22    QueryParam,
23    /// Use HTTP Basic Authentication (credential format: "username:password")
24    BasicAuth,
25}
26
27/// Configuration for the proxy server.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ProxyConfig {
30    /// Bind address (default: 127.0.0.1)
31    #[serde(default = "default_bind_addr")]
32    pub bind_addr: IpAddr,
33
34    /// Bind port (0 = OS-assigned ephemeral port)
35    #[serde(default)]
36    pub bind_port: u16,
37
38    /// Allowed hosts for CONNECT mode (exact match + wildcards).
39    /// Empty = allow all hosts (except deny list), unless `strict_filter`
40    /// is `true`.
41    #[serde(default)]
42    pub allowed_hosts: Vec<String>,
43
44    /// When `true`, an empty `allowed_hosts` denies every host instead of
45    /// falling back to allow-all.
46    #[serde(default)]
47    pub strict_filter: bool,
48
49    /// Reverse proxy credential routes.
50    #[serde(default)]
51    pub routes: Vec<RouteConfig>,
52
53    /// External (enterprise) proxy URL for passthrough mode.
54    /// When set, CONNECT requests are chained to this proxy.
55    #[serde(default)]
56    pub external_proxy: Option<ExternalProxyConfig>,
57
58    /// Outbound TCP ports that the sandbox allows direct connections on
59    /// (via Landlock ConnectTcp). Hosts whose resolved port is NOT in this
60    /// set must go through the proxy and should NOT appear in NO_PROXY.
61    #[serde(default)]
62    pub direct_connect_ports: Vec<u16>,
63
64    /// Maximum concurrent connections (0 = unlimited).
65    #[serde(default)]
66    pub max_connections: usize,
67
68    /// Directory the proxy will write the TLS-intercept trust bundle into.
69    ///
70    /// When set together with at least one route requiring L7 visibility
71    /// (`endpoint_rules`, `credential_key`, or `oauth2`), the proxy generates
72    /// an ephemeral session CA and writes a PEM bundle (system roots +
73    /// optional parent `SSL_CERT_FILE` + ephemeral CA) into this directory at
74    /// startup. The path is exposed via `ProxyHandle::intercept_ca_path()`
75    /// so the CLI can grant the sandboxed child a Landlock/Seatbelt read
76    /// capability for it.
77    ///
78    /// The directory must exist and be owner-only readable (mode `0o700`)
79    /// before `start()` is called. The CLI conventionally points this at
80    /// `~/.nono/sessions/<session_id>/`.
81    ///
82    /// `None` disables TLS interception entirely; CONNECT requests behave
83    /// as before (transparent tunnel for non-route hosts; 403 for routes
84    /// without L7 requirements).
85    #[serde(default, skip_serializing_if = "Option::is_none")]
86    pub intercept_ca_dir: Option<PathBuf>,
87
88    /// Optional contents of the parent process's `SSL_CERT_FILE`, merged
89    /// into the trust bundle so any corporate CA configured on the host
90    /// remains trusted by the sandboxed child.
91    ///
92    /// The CLI reads this from `std::env::var("SSL_CERT_FILE")` and
93    /// `std::fs::read(...)` before calling `start()`. Skipped during
94    /// (de)serialisation: it's not part of any user-authored config file.
95    #[serde(default, skip)]
96    pub intercept_parent_ca_pems: Option<Vec<u8>>,
97
98    /// Pre-generated CA material for cross-session reuse (`--trust-proxy-ca`).
99    ///
100    /// When `Some`, the proxy uses this CA instead of generating a fresh
101    /// ephemeral one. The private key was loaded from macOS Keychain by the
102    /// CLI supervisor; the cert is already trusted in the user's trust store.
103    #[serde(default, skip)]
104    pub preloaded_ca: Option<PreloadedCa>,
105
106    /// Optional CA validity override for TLS interception.
107    /// Default (`None`) uses `CA_VALIDITY_DEFAULT` (24h).
108    /// Set by CLI `--proxy-ca-validity` flag.
109    #[serde(default, skip)]
110    pub ca_validity: Option<std::time::Duration>,
111}
112
113/// Pre-generated CA key material for cross-session CA reuse.
114///
115/// Used by `--trust-proxy-ca` on macOS: the CLI persists the CA in Keychain
116/// and passes it to the proxy so all sessions within the CA's validity window
117/// share the same signing key (and the same trusted cert in the system store).
118///
119/// ## Security note
120///
121/// The Keychain item's access control depends on the binary's code-signing
122/// identity. Release-signed builds get per-app isolation; unsigned dev builds
123/// allow any local process to read the key.
124///
125/// Because the CA is trusted user-wide during its validity window, any
126/// same-user process that can read the Keychain item could mint certificates
127/// trusted by macOS trust consumers. Release-signed builds are expected to
128/// receive stronger Keychain access isolation than unsigned development builds.
129/// The configurable CA validity (`--proxy-ca-validity`) limits exposure.
130#[derive(Clone)]
131pub struct PreloadedCa {
132    /// PKCS#8 DER-encoded private key for the CA. Zeroized on drop.
133    pub key_der: Zeroizing<Vec<u8>>,
134    /// PEM-encoded CA certificate (public).
135    pub cert_pem: String,
136}
137
138impl std::fmt::Debug for PreloadedCa {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        f.debug_struct("PreloadedCa")
141            .field("key_der", &"[REDACTED]")
142            .field("cert_pem_len", &self.cert_pem.len())
143            .finish()
144    }
145}
146
147impl Default for ProxyConfig {
148    fn default() -> Self {
149        Self {
150            bind_addr: default_bind_addr(),
151            bind_port: 0,
152            allowed_hosts: Vec::new(),
153            strict_filter: false,
154            routes: Vec::new(),
155            external_proxy: None,
156            direct_connect_ports: Vec::new(),
157            max_connections: 256,
158            intercept_ca_dir: None,
159            intercept_parent_ca_pems: None,
160            preloaded_ca: None,
161            ca_validity: None,
162        }
163    }
164}
165
166fn default_bind_addr() -> IpAddr {
167    IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
168}
169
170/// Configuration for a reverse proxy credential route.
171#[derive(Debug, Clone, Default, Serialize, Deserialize)]
172pub struct RouteConfig {
173    /// Path prefix for routing (e.g., "openai").
174    /// Must NOT include leading or trailing slashes — it is a bare service name, not a URL path.
175    pub prefix: String,
176
177    /// Upstream URL to forward to (e.g., "https://api.openai.com")
178    pub upstream: String,
179
180    /// Keystore account name to load the credential from.
181    /// If `None`, no credential is injected.
182    pub credential_key: Option<String>,
183
184    /// Injection mode (default: "header")
185    #[serde(default)]
186    pub inject_mode: InjectMode,
187
188    // --- Header mode fields ---
189    /// HTTP header name for the credential (default: "Authorization")
190    /// Only used when inject_mode is "header".
191    #[serde(default = "default_inject_header")]
192    pub inject_header: String,
193
194    /// How the injected header value is built (`{}` is replaced by the secret). Only when `inject_mode` is header.
195    ///
196    /// If you set this field, that whole string is used as-is — `Authorization` or any other header.
197    ///
198    /// If you omit it: an `Authorization` header (any capitalization) defaults to `Bearer {}`; any other header defaults to `{}` (secret only, no prefix).
199    #[serde(default)]
200    pub credential_format: Option<String>,
201
202    // --- URL path mode fields ---
203    /// Pattern to match in incoming URL path. Use {} as placeholder for phantom token.
204    /// Example: "/bot{}/" matches "/bot<token>/getMe"
205    /// Only used when inject_mode is "url_path".
206    #[serde(default)]
207    pub path_pattern: Option<String>,
208
209    /// Pattern for outgoing URL path. Use {} as placeholder for real credential.
210    /// Defaults to same as path_pattern if not specified.
211    /// Only used when inject_mode is "url_path".
212    #[serde(default)]
213    pub path_replacement: Option<String>,
214
215    // --- Query param mode fields ---
216    /// Name of the query parameter to add/replace with the credential.
217    /// Only used when inject_mode is "query_param".
218    #[serde(default)]
219    pub query_param_name: Option<String>,
220
221    /// Optional overrides for proxy-side phantom token handling.
222    ///
223    /// When set, these values are used to validate the incoming phantom token
224    /// from the sandboxed client request. Outbound credential injection to the
225    /// upstream continues to use the top-level route fields.
226    #[serde(default)]
227    pub proxy: Option<ProxyInjectConfig>,
228
229    /// Explicit environment variable name for the phantom token (e.g., "OPENAI_API_KEY").
230    ///
231    /// When set, this is used as the SDK API key env var name instead of deriving
232    /// it from `credential_key.to_uppercase()`. Required when `credential_key` is
233    /// a URI manager reference (e.g., `op://`, `apple-password://`) which would
234    /// otherwise produce a nonsensical env var name.
235    #[serde(default)]
236    pub env_var: Option<String>,
237
238    /// Optional L7 endpoint rules for method+path filtering.
239    ///
240    /// When non-empty, only requests matching at least one rule are allowed
241    /// (default-deny). When empty, all method+path combinations are permitted
242    /// (backward compatible).
243    #[serde(default)]
244    pub endpoint_rules: Vec<EndpointRule>,
245
246    /// Optional path to a PEM-encoded CA certificate file for upstream TLS.
247    ///
248    /// When set, the proxy trusts this CA in addition to the system roots
249    /// when connecting to the upstream for this route. This is required for
250    /// upstreams that use self-signed or private CA certificates (e.g.,
251    /// Kubernetes API servers).
252    #[serde(default)]
253    pub tls_ca: Option<String>,
254
255    /// Optional path to a PEM-encoded client certificate for upstream mTLS.
256    ///
257    /// When set together with `tls_client_key`, the proxy presents this
258    /// certificate to the upstream during TLS handshake. Required for
259    /// upstreams that enforce mutual TLS (e.g., Kubernetes API servers
260    /// configured with client-certificate authentication).
261    #[serde(default)]
262    pub tls_client_cert: Option<String>,
263
264    /// Optional path to a PEM-encoded private key for upstream mTLS.
265    ///
266    /// Must be set together with `tls_client_cert`. The key must correspond
267    /// to the certificate in `tls_client_cert`.
268    #[serde(default)]
269    pub tls_client_key: Option<String>,
270
271    /// Optional OAuth2 client_credentials configuration.
272    /// When present, the proxy handles token exchange automatically instead
273    /// of using a static credential from the keystore.
274    /// Mutually exclusive with `credential_key` — use one or the other.
275    #[serde(default)]
276    pub oauth2: Option<OAuth2Config>,
277
278    /// Optional AWS SigV4 signing configuration.
279    ///
280    /// When present, the proxy will sign outbound requests with AWS SigV4
281    /// credentials. Mutually exclusive with `credential_key` and `oauth2`.
282    #[serde(default)]
283    pub aws_auth: Option<AwsAuthConfig>,
284}
285
286/// Optional proxy-side overrides for credential injection shape.
287///
288/// These settings apply only to how the proxy validates the phantom token from
289/// the client request. Any field omitted here falls back to the corresponding
290/// top-level route field.
291#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
292#[serde(deny_unknown_fields)]
293pub struct ProxyInjectConfig {
294    /// Optional injection mode override for proxy-side token parsing.
295    #[serde(default)]
296    pub inject_mode: Option<InjectMode>,
297
298    /// Optional header name override for header/basic_auth modes.
299    #[serde(default)]
300    pub inject_header: Option<String>,
301
302    /// Optional format override for header mode.
303    #[serde(default)]
304    pub credential_format: Option<String>,
305
306    /// Optional path pattern override for url_path mode.
307    #[serde(default)]
308    pub path_pattern: Option<String>,
309
310    /// Optional path replacement override for url_path mode.
311    #[serde(default)]
312    pub path_replacement: Option<String>,
313
314    /// Optional query parameter override for query_param mode.
315    #[serde(default)]
316    pub query_param_name: Option<String>,
317}
318
319/// An HTTP method+path access rule for reverse proxy endpoint filtering.
320///
321/// Used to restrict which API endpoints an agent can access through a
322/// credential route. Patterns use `/` separated segments with wildcards:
323/// - `*` matches exactly one path segment
324/// - `**` matches zero or more path segments
325#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
326pub struct EndpointRule {
327    /// HTTP method to match ("GET", "POST", etc.) or "*" for any method.
328    pub method: String,
329    /// URL path pattern with glob segments.
330    /// Example: "/api/v4/projects/*/merge_requests/**"
331    pub path: String,
332}
333
334/// Pre-compiled endpoint rules for the request hot path.
335///
336/// Built once at proxy startup from `EndpointRule` definitions. Holds
337/// compiled `globset::GlobMatcher`s so the hot path does a regex match,
338/// not a glob compile.
339pub struct CompiledEndpointRules {
340    rules: Vec<CompiledRule>,
341}
342
343struct CompiledRule {
344    method: String,
345    matcher: globset::GlobMatcher,
346}
347
348impl CompiledEndpointRules {
349    /// Compile endpoint rules into matchers. Invalid glob patterns are
350    /// rejected at startup with an error, not silently ignored at runtime.
351    pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
352        let mut compiled = Vec::with_capacity(rules.len());
353        for rule in rules {
354            let glob = Glob::new(&rule.path)
355                .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
356            compiled.push(CompiledRule {
357                method: rule.method.clone(),
358                matcher: glob.compile_matcher(),
359            });
360        }
361        Ok(Self { rules: compiled })
362    }
363
364    /// `true` if no endpoint rules are defined (allow-all).
365    #[must_use]
366    pub fn is_empty(&self) -> bool {
367        self.rules.is_empty()
368    }
369
370    /// `true` if method+path matches a rule, or if no rules are defined.
371    #[must_use]
372    pub fn is_allowed(&self, method: &str, path: &str) -> bool {
373        if self.rules.is_empty() {
374            return true;
375        }
376        let normalized = normalize_path(path);
377        self.rules.iter().any(|r| {
378            (r.method == "*" || r.method.eq_ignore_ascii_case(method))
379                && r.matcher.is_match(&normalized)
380        })
381    }
382}
383
384impl std::fmt::Debug for CompiledEndpointRules {
385    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        f.debug_struct("CompiledEndpointRules")
387            .field("count", &self.rules.len())
388            .finish()
389    }
390}
391
392/// Check if any endpoint rule permits the given method+path.
393/// Returns `true` if rules is empty (allow-all, backward compatible).
394///
395/// Test convenience only — compiles globs on each call. Production code
396/// should use `CompiledEndpointRules::is_allowed()` instead.
397#[cfg(test)]
398fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
399    if rules.is_empty() {
400        return true;
401    }
402    let normalized = normalize_path(path);
403    rules.iter().any(|r| {
404        (r.method == "*" || r.method.eq_ignore_ascii_case(method))
405            && Glob::new(&r.path)
406                .ok()
407                .map(|g| g.compile_matcher())
408                .is_some_and(|m| m.is_match(&normalized))
409    })
410}
411
412/// Normalize a URL path for matching: percent-decode, strip query string,
413/// collapse double slashes, strip trailing slash (but preserve root "/").
414///
415/// Percent-decoding prevents bypass via encoded characters (e.g.,
416/// `/api/%70rojects` evading a rule for `/api/projects/*`).
417fn normalize_path(path: &str) -> String {
418    // Strip query string
419    let path = path.split('?').next().unwrap_or(path);
420
421    // Percent-decode to prevent bypass via encoded segments.
422    // Use decode_binary + from_utf8_lossy so invalid UTF-8 sequences
423    // (e.g., %FF) become U+FFFD instead of falling back to the raw path.
424    let binary = urlencoding::decode_binary(path.as_bytes());
425    let decoded = String::from_utf8_lossy(&binary);
426
427    // Collapse double slashes by splitting on '/' and filtering empties,
428    // then rejoin. This also strips trailing slash.
429    let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
430    if segments.is_empty() {
431        "/".to_string()
432    } else {
433        format!("/{}", segments.join("/"))
434    }
435}
436
437fn default_inject_header() -> String {
438    "Authorization".to_string()
439}
440
441/// Template for the header value before `{}` is replaced by the secret.
442///
443/// Set in config → use that string as-is. Omitted → `Bearer {}` for an `Authorization` header (case-insensitive), `{}` for any other header.
444#[must_use]
445pub fn resolved_credential_format(inject_header: &str, credential_format: Option<&str>) -> String {
446    match credential_format {
447        Some(fmt) => fmt.to_string(),
448        None => {
449            if inject_header.eq_ignore_ascii_case("Authorization") {
450                "Bearer {}".to_string()
451            } else {
452                "{}".to_string()
453            }
454        }
455    }
456}
457
458/// Configuration for an external (enterprise) proxy.
459#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct ExternalProxyConfig {
461    /// Proxy address (e.g., "squid.corp.internal:3128")
462    pub address: String,
463
464    /// Optional authentication for the external proxy.
465    pub auth: Option<ExternalProxyAuth>,
466
467    /// Hosts to bypass the external proxy and route directly.
468    /// Supports exact hostnames and `*.` wildcard suffixes (case-insensitive).
469    /// Empty = all traffic goes through the external proxy.
470    #[serde(default)]
471    pub bypass_hosts: Vec<String>,
472}
473
474/// Authentication for an external proxy.
475#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct ExternalProxyAuth {
477    /// Keystore account name for proxy credentials.
478    pub keyring_account: String,
479
480    /// Authentication scheme (only "basic" supported).
481    #[serde(default = "default_auth_scheme")]
482    pub scheme: String,
483}
484
485fn default_auth_scheme() -> String {
486    "basic".to_string()
487}
488
489/// OAuth2 client_credentials configuration for automatic token exchange.
490///
491/// When configured on a route, the proxy handles the token lifecycle:
492/// 1. Exchanges client_id + client_secret for an access_token at startup
493/// 2. Caches the token with TTL from the `expires_in` response
494/// 3. Refreshes automatically before expiry (30s buffer)
495/// 4. Injects the access_token as `Authorization: Bearer <token>`
496///
497/// The agent never sees client_id or client_secret — only a phantom token.
498#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
499pub struct OAuth2Config {
500    /// Token endpoint URL (e.g., "https://auth.example.com/oauth/token")
501    pub token_url: String,
502    /// Client ID — plain value or credential reference (env://, file://, op://)
503    pub client_id: String,
504    /// Client secret — credential reference (env://, file://, op://)
505    pub client_secret: String,
506    /// OAuth2 scopes (space-separated). Empty = no scope parameter sent.
507    #[serde(default)]
508    pub scope: String,
509}
510
511/// AWS SigV4 signing configuration for a credential route.
512///
513/// When present on a route, the proxy will sign outbound requests using AWS
514/// SigV4. All fields are optional: an empty `aws_auth: {}` block is valid and
515/// uses the default credential chain with region and service auto-detected from
516/// the upstream URL.
517///
518/// Mutually exclusive with `credential_key` and `oauth2`.
519#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
520#[serde(deny_unknown_fields)]
521pub struct AwsAuthConfig {
522    /// AWS profile name to use for credentials.
523    /// If omitted, the default credential chain is used.
524    /// Must be non-empty with no whitespace if provided (whitespace breaks the
525    /// AWS INI config parser; profile names are case-sensitive).
526    #[serde(default)]
527    pub profile: Option<String>,
528
529    /// Explicit SigV4 signing region (e.g., `"us-east-1"`).
530    /// If omitted, auto-detected from the upstream URL.
531    /// Must be non-empty and lowercase if provided (SigV4 credential scope
532    /// requires lowercase region codes).
533    #[serde(default)]
534    pub region: Option<String>,
535
536    /// Explicit SigV4 service name (e.g., `"bedrock"`, `"s3"`, `"execute-api"`).
537    /// If omitted, auto-detected from the upstream URL.
538    /// Must be non-empty and lowercase if provided (SigV4 credential scope
539    /// requires lowercase service codes).
540    #[serde(default)]
541    pub service: Option<String>,
542}
543
544#[cfg(test)]
545#[allow(clippy::unwrap_used)]
546mod tests {
547    use super::*;
548
549    #[test]
550    fn test_default_config() {
551        let config = ProxyConfig::default();
552        assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
553        assert_eq!(config.bind_port, 0);
554        assert!(config.allowed_hosts.is_empty());
555        assert!(config.routes.is_empty());
556        assert!(config.external_proxy.is_none());
557    }
558
559    #[test]
560    fn test_config_serialization() {
561        let config = ProxyConfig {
562            allowed_hosts: vec!["api.openai.com".to_string()],
563            ..Default::default()
564        };
565        let json = serde_json::to_string(&config).unwrap();
566        let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
567        assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
568    }
569
570    #[test]
571    fn test_external_proxy_config_with_bypass_hosts() {
572        let config = ProxyConfig {
573            external_proxy: Some(ExternalProxyConfig {
574                address: "squid.corp:3128".to_string(),
575                auth: None,
576                bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
577            }),
578            ..Default::default()
579        };
580        let json = serde_json::to_string(&config).unwrap();
581        let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
582        let ext = deserialized.external_proxy.unwrap();
583        assert_eq!(ext.address, "squid.corp:3128");
584        assert_eq!(ext.bypass_hosts.len(), 2);
585        assert_eq!(ext.bypass_hosts[0], "internal.corp");
586        assert_eq!(ext.bypass_hosts[1], "*.private.net");
587    }
588
589    #[test]
590    fn test_external_proxy_config_bypass_hosts_default_empty() {
591        let json = r#"{"address": "proxy:3128", "auth": null}"#;
592        let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
593        assert!(ext.bypass_hosts.is_empty());
594    }
595
596    // ========================================================================
597    // EndpointRule + path matching tests
598    // ========================================================================
599
600    #[test]
601    fn test_endpoint_allowed_empty_rules_allows_all() {
602        assert!(endpoint_allowed(&[], "GET", "/anything"));
603        assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
604    }
605
606    /// Helper: check a single rule against method+path via endpoint_allowed.
607    fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
608        endpoint_allowed(std::slice::from_ref(rule), method, path)
609    }
610
611    #[test]
612    fn test_endpoint_rule_exact_path() {
613        let rule = EndpointRule {
614            method: "GET".to_string(),
615            path: "/v1/chat/completions".to_string(),
616        };
617        assert!(check(&rule, "GET", "/v1/chat/completions"));
618        assert!(!check(&rule, "GET", "/v1/chat"));
619        assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
620    }
621
622    #[test]
623    fn test_endpoint_rule_method_case_insensitive() {
624        let rule = EndpointRule {
625            method: "get".to_string(),
626            path: "/api".to_string(),
627        };
628        assert!(check(&rule, "GET", "/api"));
629        assert!(check(&rule, "Get", "/api"));
630    }
631
632    #[test]
633    fn test_endpoint_rule_method_wildcard() {
634        let rule = EndpointRule {
635            method: "*".to_string(),
636            path: "/api/resource".to_string(),
637        };
638        assert!(check(&rule, "GET", "/api/resource"));
639        assert!(check(&rule, "DELETE", "/api/resource"));
640        assert!(check(&rule, "POST", "/api/resource"));
641    }
642
643    #[test]
644    fn test_endpoint_rule_method_mismatch() {
645        let rule = EndpointRule {
646            method: "GET".to_string(),
647            path: "/api/resource".to_string(),
648        };
649        assert!(!check(&rule, "POST", "/api/resource"));
650        assert!(!check(&rule, "DELETE", "/api/resource"));
651    }
652
653    #[test]
654    fn test_endpoint_rule_single_wildcard() {
655        let rule = EndpointRule {
656            method: "GET".to_string(),
657            path: "/api/v4/projects/*/merge_requests".to_string(),
658        };
659        assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
660        assert!(check(
661            &rule,
662            "GET",
663            "/api/v4/projects/my-proj/merge_requests"
664        ));
665        assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
666    }
667
668    #[test]
669    fn test_endpoint_rule_double_wildcard() {
670        let rule = EndpointRule {
671            method: "GET".to_string(),
672            path: "/api/v4/projects/**".to_string(),
673        };
674        assert!(check(&rule, "GET", "/api/v4/projects/123"));
675        assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
676        assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
677        assert!(!check(&rule, "GET", "/api/v4/other"));
678    }
679
680    #[test]
681    fn test_endpoint_rule_double_wildcard_middle() {
682        let rule = EndpointRule {
683            method: "*".to_string(),
684            path: "/api/**/notes".to_string(),
685        };
686        assert!(check(&rule, "GET", "/api/notes"));
687        assert!(check(&rule, "POST", "/api/projects/123/notes"));
688        assert!(check(&rule, "GET", "/api/a/b/c/notes"));
689        assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
690    }
691
692    #[test]
693    fn test_endpoint_rule_strips_query_string() {
694        let rule = EndpointRule {
695            method: "GET".to_string(),
696            path: "/api/data".to_string(),
697        };
698        assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
699    }
700
701    #[test]
702    fn test_endpoint_rule_trailing_slash_normalized() {
703        let rule = EndpointRule {
704            method: "GET".to_string(),
705            path: "/api/data".to_string(),
706        };
707        assert!(check(&rule, "GET", "/api/data/"));
708        assert!(check(&rule, "GET", "/api/data"));
709    }
710
711    #[test]
712    fn test_endpoint_rule_double_slash_normalized() {
713        let rule = EndpointRule {
714            method: "GET".to_string(),
715            path: "/api/data".to_string(),
716        };
717        assert!(check(&rule, "GET", "/api//data"));
718    }
719
720    #[test]
721    fn test_endpoint_rule_root_path() {
722        let rule = EndpointRule {
723            method: "GET".to_string(),
724            path: "/".to_string(),
725        };
726        assert!(check(&rule, "GET", "/"));
727        assert!(!check(&rule, "GET", "/anything"));
728    }
729
730    #[test]
731    fn test_compiled_endpoint_rules_hot_path() {
732        let rules = vec![
733            EndpointRule {
734                method: "GET".to_string(),
735                path: "/repos/*/issues".to_string(),
736            },
737            EndpointRule {
738                method: "POST".to_string(),
739                path: "/repos/*/issues/*/comments".to_string(),
740            },
741        ];
742        let compiled = CompiledEndpointRules::compile(&rules).unwrap();
743        assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
744        assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
745        assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
746        assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
747    }
748
749    #[test]
750    fn test_compiled_endpoint_rules_empty_allows_all() {
751        let compiled = CompiledEndpointRules::compile(&[]).unwrap();
752        assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
753    }
754
755    #[test]
756    fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
757        let rules = vec![EndpointRule {
758            method: "GET".to_string(),
759            path: "/api/[invalid".to_string(),
760        }];
761        assert!(CompiledEndpointRules::compile(&rules).is_err());
762    }
763
764    #[test]
765    fn test_endpoint_allowed_multiple_rules() {
766        let rules = vec![
767            EndpointRule {
768                method: "GET".to_string(),
769                path: "/repos/*/issues".to_string(),
770            },
771            EndpointRule {
772                method: "POST".to_string(),
773                path: "/repos/*/issues/*/comments".to_string(),
774            },
775        ];
776        assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
777        assert!(endpoint_allowed(
778            &rules,
779            "POST",
780            "/repos/myrepo/issues/42/comments"
781        ));
782        assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
783        assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
784    }
785
786    #[test]
787    fn test_endpoint_rule_serde_default() {
788        let json = r#"{
789            "prefix": "test",
790            "upstream": "https://example.com"
791        }"#;
792        let route: RouteConfig = serde_json::from_str(json).unwrap();
793        assert!(route.endpoint_rules.is_empty());
794        assert!(route.tls_ca.is_none());
795    }
796
797    #[test]
798    fn test_tls_ca_serde_roundtrip() {
799        let json = r#"{
800            "prefix": "k8s",
801            "upstream": "https://kubernetes.local:6443",
802            "tls_ca": "/run/secrets/k8s-ca.crt"
803        }"#;
804        let route: RouteConfig = serde_json::from_str(json).unwrap();
805        assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
806
807        let serialized = serde_json::to_string(&route).unwrap();
808        let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
809        assert_eq!(
810            deserialized.tls_ca.as_deref(),
811            Some("/run/secrets/k8s-ca.crt")
812        );
813    }
814
815    #[test]
816    fn test_endpoint_rule_percent_encoded_path_decoded() {
817        // Security: percent-encoded segments must not bypass rules.
818        // e.g., /api/v4/%70rojects should match a rule for /api/v4/projects/*
819        let rule = EndpointRule {
820            method: "GET".to_string(),
821            path: "/api/v4/projects/*/issues".to_string(),
822        };
823        assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
824        assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
825    }
826
827    #[test]
828    fn test_endpoint_rule_percent_encoded_full_segment() {
829        let rule = EndpointRule {
830            method: "POST".to_string(),
831            path: "/api/data".to_string(),
832        };
833        // %64%61%74%61 = "data"
834        assert!(check(&rule, "POST", "/api/%64%61%74%61"));
835    }
836
837    #[test]
838    fn test_compiled_endpoint_rules_percent_encoded() {
839        let rules = vec![EndpointRule {
840            method: "GET".to_string(),
841            path: "/repos/*/issues".to_string(),
842        }];
843        let compiled = CompiledEndpointRules::compile(&rules).unwrap();
844        // %69ssues = "issues"
845        assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
846        assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
847    }
848
849    #[test]
850    fn test_endpoint_rule_percent_encoded_invalid_utf8() {
851        // Security: invalid UTF-8 percent sequences must not fall back to
852        // the raw path (which could bypass rules). Lossy decoding replaces
853        // invalid bytes with U+FFFD, so the path won't match real segments.
854        let rule = EndpointRule {
855            method: "GET".to_string(),
856            path: "/api/projects".to_string(),
857        };
858        // %FF is not valid UTF-8 — must not match "/api/projects"
859        assert!(!check(&rule, "GET", "/api/%FFprojects"));
860    }
861
862    #[test]
863    fn test_endpoint_rule_serde_roundtrip() {
864        let rule = EndpointRule {
865            method: "GET".to_string(),
866            path: "/api/*/data".to_string(),
867        };
868        let json = serde_json::to_string(&rule).unwrap();
869        let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
870        assert_eq!(deserialized.method, "GET");
871        assert_eq!(deserialized.path, "/api/*/data");
872    }
873
874    // ========================================================================
875    // OAuth2Config tests
876    // ========================================================================
877
878    #[test]
879    fn test_oauth2_config_deserialization() {
880        let json = r#"{
881            "token_url": "https://auth.example.com/oauth/token",
882            "client_id": "my-client",
883            "client_secret": "env://CLIENT_SECRET",
884            "scope": "read write"
885        }"#;
886        let config: OAuth2Config = serde_json::from_str(json).unwrap();
887        assert_eq!(config.token_url, "https://auth.example.com/oauth/token");
888        assert_eq!(config.client_id, "my-client");
889        assert_eq!(config.client_secret, "env://CLIENT_SECRET");
890        assert_eq!(config.scope, "read write");
891    }
892
893    #[test]
894    fn test_oauth2_config_default_scope() {
895        let json = r#"{
896            "token_url": "https://auth.example.com/oauth/token",
897            "client_id": "my-client",
898            "client_secret": "env://SECRET"
899        }"#;
900        let config: OAuth2Config = serde_json::from_str(json).unwrap();
901        assert_eq!(config.scope, "");
902    }
903
904    #[test]
905    fn test_route_config_with_oauth2() {
906        let json = r#"{
907            "prefix": "/my-api",
908            "upstream": "https://api.example.com",
909            "oauth2": {
910                "token_url": "https://auth.example.com/oauth/token",
911                "client_id": "agent-1",
912                "client_secret": "env://CLIENT_SECRET",
913                "scope": "api.read"
914            }
915        }"#;
916        let route: RouteConfig = serde_json::from_str(json).unwrap();
917        assert!(route.oauth2.is_some());
918        assert!(route.credential_key.is_none());
919        let oauth2 = route.oauth2.unwrap();
920        assert_eq!(oauth2.token_url, "https://auth.example.com/oauth/token");
921    }
922
923    #[test]
924    fn test_route_config_without_oauth2() {
925        let json = r#"{
926            "prefix": "/openai",
927            "upstream": "https://api.openai.com",
928            "credential_key": "openai"
929        }"#;
930        let route: RouteConfig = serde_json::from_str(json).unwrap();
931        assert!(route.oauth2.is_none());
932        assert!(route.credential_key.is_some());
933    }
934
935    #[test]
936    fn test_route_config_credential_format_omitted_is_none() {
937        let json = r#"{
938            "prefix": "anthropic",
939            "upstream": "https://api.anthropic.com",
940            "credential_key": "env://ANTHROPIC_API_KEY",
941            "inject_header": "x-api-key"
942        }"#;
943        let route: RouteConfig = serde_json::from_str(json).unwrap();
944        assert!(route.credential_format.is_none());
945        assert_eq!(
946            resolved_credential_format(&route.inject_header, route.credential_format.as_deref()),
947            "{}"
948        );
949    }
950
951    #[test]
952    fn test_route_config_explicit_bearer_on_custom_header_preserved() {
953        let json = r#"{
954            "prefix": "litellm",
955            "upstream": "https://litellm",
956            "credential_key": "env://LITELLM_TOKEN",
957            "inject_header": "x-litellm-api-key",
958            "credential_format": "Bearer {}"
959        }"#;
960        let route: RouteConfig = serde_json::from_str(json).unwrap();
961        assert_eq!(route.credential_format.as_deref(), Some("Bearer {}"));
962        assert_eq!(
963            resolved_credential_format(&route.inject_header, route.credential_format.as_deref()),
964            "Bearer {}"
965        );
966    }
967
968    #[test]
969    fn test_resolved_credential_format_authorization_case_insensitive() {
970        for header in ["authorization", "AUTHORIZATION", "Authorization"] {
971            assert_eq!(
972                resolved_credential_format(header, None),
973                "Bearer {}",
974                "omitted format: Authorization header name is matched case-insensitively for Bearer default"
975            );
976        }
977    }
978
979    // ========================================================================
980    // AwsAuthConfig tests
981    // ========================================================================
982
983    #[test]
984    fn test_aws_auth_config_minimal_deserializes() {
985        let json = r#"{}"#;
986        let aws: AwsAuthConfig = serde_json::from_str(json).unwrap();
987        assert!(aws.profile.is_none());
988        assert!(aws.region.is_none());
989        assert!(aws.service.is_none());
990    }
991
992    #[test]
993    fn test_aws_auth_config_all_fields_roundtrip() {
994        let original = AwsAuthConfig {
995            profile: Some("my-aws-profile".to_string()),
996            region: Some("us-east-1".to_string()),
997            service: Some("bedrock".to_string()),
998        };
999        let json = serde_json::to_string(&original).unwrap();
1000        let deserialized: AwsAuthConfig = serde_json::from_str(&json).unwrap();
1001        assert_eq!(deserialized.profile.as_deref(), Some("my-aws-profile"));
1002        assert_eq!(deserialized.region.as_deref(), Some("us-east-1"));
1003        assert_eq!(deserialized.service.as_deref(), Some("bedrock"));
1004    }
1005
1006    #[test]
1007    fn test_aws_auth_field_absent_is_none() {
1008        let json = r#"{"prefix": "bedrock", "upstream": "https://bedrock-runtime.us-east-1.amazonaws.com"}"#;
1009        let route: RouteConfig = serde_json::from_str(json).unwrap();
1010        assert!(route.aws_auth.is_none());
1011    }
1012
1013    #[test]
1014    fn test_aws_auth_config_unknown_field_rejected() {
1015        let json = r#"{"profile": "foo", "unknown_field": "bar"}"#;
1016        let result: std::result::Result<AwsAuthConfig, _> = serde_json::from_str(json);
1017        assert!(
1018            result.is_err(),
1019            "unknown fields must be rejected by deny_unknown_fields"
1020        );
1021    }
1022
1023    #[test]
1024    fn test_route_config_with_aws_auth_deserializes() {
1025        let json = r#"{
1026            "prefix": "bedrock",
1027            "upstream": "https://bedrock-runtime.us-east-1.amazonaws.com",
1028            "aws_auth": {
1029                "profile": "my-aws-profile"
1030            }
1031        }"#;
1032        let route: RouteConfig = serde_json::from_str(json).unwrap();
1033        let aws = route.aws_auth.unwrap();
1034        assert_eq!(aws.profile.as_deref(), Some("my-aws-profile"));
1035        assert!(aws.region.is_none());
1036        assert!(aws.service.is_none());
1037    }
1038
1039    #[test]
1040    fn test_route_config_with_full_aws_auth_deserializes() {
1041        let json = r#"{
1042            "prefix": "bedrock",
1043            "upstream": "https://bedrock-runtime.us-east-1.amazonaws.com",
1044            "aws_auth": {
1045                "profile": "my-aws-profile",
1046                "region": "us-west-2",
1047                "service": "bedrock"
1048            }
1049        }"#;
1050        let route: RouteConfig = serde_json::from_str(json).unwrap();
1051        let aws = route.aws_auth.unwrap();
1052        assert_eq!(aws.profile.as_deref(), Some("my-aws-profile"));
1053        assert_eq!(aws.region.as_deref(), Some("us-west-2"));
1054        assert_eq!(aws.service.as_deref(), Some("bedrock"));
1055    }
1056
1057    #[test]
1058    fn test_aws_auth_empty_object_sets_all_none() {
1059        let json = r#"{
1060            "prefix": "bedrock",
1061            "upstream": "https://bedrock-runtime.us-east-1.amazonaws.com",
1062            "aws_auth": {}
1063        }"#;
1064        let route: RouteConfig = serde_json::from_str(json).unwrap();
1065        let aws = route.aws_auth.unwrap();
1066        assert!(aws.profile.is_none());
1067        assert!(aws.region.is_none());
1068        assert!(aws.service.is_none());
1069    }
1070}