Skip to main content

nono_proxy/
config.rs

1//! Proxy configuration types.
2//!
3//! Defines the configuration for the proxy server, including allowed hosts,
4//! credential routes, and external proxy settings.
5
6use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9
10/// Credential injection mode determining how credentials are inserted into requests.
11#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum InjectMode {
14    /// Inject credential into an HTTP header (default)
15    #[default]
16    Header,
17    /// Replace a pattern in the URL path with the credential
18    UrlPath,
19    /// Add or replace a query parameter with the credential
20    QueryParam,
21    /// Use HTTP Basic Authentication (credential format: "username:password")
22    BasicAuth,
23}
24
25/// Configuration for the proxy server.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProxyConfig {
28    /// Bind address (default: 127.0.0.1)
29    #[serde(default = "default_bind_addr")]
30    pub bind_addr: IpAddr,
31
32    /// Bind port (0 = OS-assigned ephemeral port)
33    #[serde(default)]
34    pub bind_port: u16,
35
36    /// Allowed hosts for CONNECT mode (exact match + wildcards).
37    /// Empty = allow all hosts (except deny list).
38    #[serde(default)]
39    pub allowed_hosts: Vec<String>,
40
41    /// Reverse proxy credential routes.
42    #[serde(default)]
43    pub routes: Vec<RouteConfig>,
44
45    /// External (enterprise) proxy URL for passthrough mode.
46    /// When set, CONNECT requests are chained to this proxy.
47    #[serde(default)]
48    pub external_proxy: Option<ExternalProxyConfig>,
49
50    /// Maximum concurrent connections (0 = unlimited).
51    #[serde(default)]
52    pub max_connections: usize,
53}
54
55impl Default for ProxyConfig {
56    fn default() -> Self {
57        Self {
58            bind_addr: default_bind_addr(),
59            bind_port: 0,
60            allowed_hosts: Vec::new(),
61            routes: Vec::new(),
62            external_proxy: None,
63            max_connections: 256,
64        }
65    }
66}
67
68fn default_bind_addr() -> IpAddr {
69    IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
70}
71
72/// Configuration for a reverse proxy credential route.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RouteConfig {
75    /// Path prefix for routing (e.g., "openai").
76    /// Must NOT include leading or trailing slashes — it is a bare service name, not a URL path.
77    pub prefix: String,
78
79    /// Upstream URL to forward to (e.g., "https://api.openai.com")
80    pub upstream: String,
81
82    /// Keystore account name to load the credential from.
83    /// If `None`, no credential is injected.
84    pub credential_key: Option<String>,
85
86    /// Injection mode (default: "header")
87    #[serde(default)]
88    pub inject_mode: InjectMode,
89
90    // --- Header mode fields ---
91    /// HTTP header name for the credential (default: "Authorization")
92    /// Only used when inject_mode is "header".
93    #[serde(default = "default_inject_header")]
94    pub inject_header: String,
95
96    /// Format string for the credential value. `{}` is replaced with the secret.
97    /// Default: "Bearer {}"
98    /// Only used when inject_mode is "header".
99    #[serde(default = "default_credential_format")]
100    pub credential_format: String,
101
102    // --- URL path mode fields ---
103    /// Pattern to match in incoming URL path. Use {} as placeholder for phantom token.
104    /// Example: "/bot{}/" matches "/bot<token>/getMe"
105    /// Only used when inject_mode is "url_path".
106    #[serde(default)]
107    pub path_pattern: Option<String>,
108
109    /// Pattern for outgoing URL path. Use {} as placeholder for real credential.
110    /// Defaults to same as path_pattern if not specified.
111    /// Only used when inject_mode is "url_path".
112    #[serde(default)]
113    pub path_replacement: Option<String>,
114
115    // --- Query param mode fields ---
116    /// Name of the query parameter to add/replace with the credential.
117    /// Only used when inject_mode is "query_param".
118    #[serde(default)]
119    pub query_param_name: Option<String>,
120
121    /// Explicit environment variable name for the phantom token (e.g., "OPENAI_API_KEY").
122    ///
123    /// When set, this is used as the SDK API key env var name instead of deriving
124    /// it from `credential_key.to_uppercase()`. Required when `credential_key` is
125    /// a URI manager reference (e.g., `op://`, `apple-password://`) which would
126    /// otherwise produce a nonsensical env var name.
127    #[serde(default)]
128    pub env_var: Option<String>,
129
130    /// Optional L7 endpoint rules for method+path filtering.
131    ///
132    /// When non-empty, only requests matching at least one rule are allowed
133    /// (default-deny). When empty, all method+path combinations are permitted
134    /// (backward compatible).
135    #[serde(default)]
136    pub endpoint_rules: Vec<EndpointRule>,
137
138    /// Optional path to a PEM-encoded CA certificate file for upstream TLS.
139    ///
140    /// When set, the proxy trusts this CA in addition to the system roots
141    /// when connecting to the upstream for this route. This is required for
142    /// upstreams that use self-signed or private CA certificates (e.g.,
143    /// Kubernetes API servers).
144    #[serde(default)]
145    pub tls_ca: Option<String>,
146}
147
148/// An HTTP method+path access rule for reverse proxy endpoint filtering.
149///
150/// Used to restrict which API endpoints an agent can access through a
151/// credential route. Patterns use `/` separated segments with wildcards:
152/// - `*` matches exactly one path segment
153/// - `**` matches zero or more path segments
154#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
155pub struct EndpointRule {
156    /// HTTP method to match ("GET", "POST", etc.) or "*" for any method.
157    pub method: String,
158    /// URL path pattern with glob segments.
159    /// Example: "/api/v4/projects/*/merge_requests/**"
160    pub path: String,
161}
162
163/// Pre-compiled endpoint rules for the request hot path.
164///
165/// Built once at proxy startup from `EndpointRule` definitions. Holds
166/// compiled `globset::GlobMatcher`s so the hot path does a regex match,
167/// not a glob compile.
168pub struct CompiledEndpointRules {
169    rules: Vec<CompiledRule>,
170}
171
172struct CompiledRule {
173    method: String,
174    matcher: globset::GlobMatcher,
175}
176
177impl CompiledEndpointRules {
178    /// Compile endpoint rules into matchers. Invalid glob patterns are
179    /// rejected at startup with an error, not silently ignored at runtime.
180    pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
181        let mut compiled = Vec::with_capacity(rules.len());
182        for rule in rules {
183            let glob = Glob::new(&rule.path)
184                .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
185            compiled.push(CompiledRule {
186                method: rule.method.clone(),
187                matcher: glob.compile_matcher(),
188            });
189        }
190        Ok(Self { rules: compiled })
191    }
192
193    /// Check if the given method+path is allowed.
194    /// Returns `true` if no rules were compiled (allow-all, backward compatible).
195    #[must_use]
196    pub fn is_allowed(&self, method: &str, path: &str) -> bool {
197        if self.rules.is_empty() {
198            return true;
199        }
200        let normalized = normalize_path(path);
201        self.rules.iter().any(|r| {
202            (r.method == "*" || r.method.eq_ignore_ascii_case(method))
203                && r.matcher.is_match(&normalized)
204        })
205    }
206}
207
208impl std::fmt::Debug for CompiledEndpointRules {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        f.debug_struct("CompiledEndpointRules")
211            .field("count", &self.rules.len())
212            .finish()
213    }
214}
215
216/// Check if any endpoint rule permits the given method+path.
217/// Returns `true` if rules is empty (allow-all, backward compatible).
218///
219/// Test convenience only — compiles globs on each call. Production code
220/// should use `CompiledEndpointRules::is_allowed()` instead.
221#[cfg(test)]
222fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
223    if rules.is_empty() {
224        return true;
225    }
226    let normalized = normalize_path(path);
227    rules.iter().any(|r| {
228        (r.method == "*" || r.method.eq_ignore_ascii_case(method))
229            && Glob::new(&r.path)
230                .ok()
231                .map(|g| g.compile_matcher())
232                .is_some_and(|m| m.is_match(&normalized))
233    })
234}
235
236/// Normalize a URL path for matching: percent-decode, strip query string,
237/// collapse double slashes, strip trailing slash (but preserve root "/").
238///
239/// Percent-decoding prevents bypass via encoded characters (e.g.,
240/// `/api/%70rojects` evading a rule for `/api/projects/*`).
241fn normalize_path(path: &str) -> String {
242    // Strip query string
243    let path = path.split('?').next().unwrap_or(path);
244
245    // Percent-decode to prevent bypass via encoded segments.
246    // Use decode_binary + from_utf8_lossy so invalid UTF-8 sequences
247    // (e.g., %FF) become U+FFFD instead of falling back to the raw path.
248    let binary = urlencoding::decode_binary(path.as_bytes());
249    let decoded = String::from_utf8_lossy(&binary);
250
251    // Collapse double slashes by splitting on '/' and filtering empties,
252    // then rejoin. This also strips trailing slash.
253    let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
254    if segments.is_empty() {
255        "/".to_string()
256    } else {
257        format!("/{}", segments.join("/"))
258    }
259}
260
261fn default_inject_header() -> String {
262    "Authorization".to_string()
263}
264
265fn default_credential_format() -> String {
266    "Bearer {}".to_string()
267}
268
269/// Configuration for an external (enterprise) proxy.
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct ExternalProxyConfig {
272    /// Proxy address (e.g., "squid.corp.internal:3128")
273    pub address: String,
274
275    /// Optional authentication for the external proxy.
276    pub auth: Option<ExternalProxyAuth>,
277
278    /// Hosts to bypass the external proxy and route directly.
279    /// Supports exact hostnames and `*.` wildcard suffixes (case-insensitive).
280    /// Empty = all traffic goes through the external proxy.
281    #[serde(default)]
282    pub bypass_hosts: Vec<String>,
283}
284
285/// Authentication for an external proxy.
286#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct ExternalProxyAuth {
288    /// Keystore account name for proxy credentials.
289    pub keyring_account: String,
290
291    /// Authentication scheme (only "basic" supported).
292    #[serde(default = "default_auth_scheme")]
293    pub scheme: String,
294}
295
296fn default_auth_scheme() -> String {
297    "basic".to_string()
298}
299
300#[cfg(test)]
301#[allow(clippy::unwrap_used)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_default_config() {
307        let config = ProxyConfig::default();
308        assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
309        assert_eq!(config.bind_port, 0);
310        assert!(config.allowed_hosts.is_empty());
311        assert!(config.routes.is_empty());
312        assert!(config.external_proxy.is_none());
313    }
314
315    #[test]
316    fn test_config_serialization() {
317        let config = ProxyConfig {
318            allowed_hosts: vec!["api.openai.com".to_string()],
319            ..Default::default()
320        };
321        let json = serde_json::to_string(&config).unwrap();
322        let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
323        assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
324    }
325
326    #[test]
327    fn test_external_proxy_config_with_bypass_hosts() {
328        let config = ProxyConfig {
329            external_proxy: Some(ExternalProxyConfig {
330                address: "squid.corp:3128".to_string(),
331                auth: None,
332                bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
333            }),
334            ..Default::default()
335        };
336        let json = serde_json::to_string(&config).unwrap();
337        let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
338        let ext = deserialized.external_proxy.unwrap();
339        assert_eq!(ext.address, "squid.corp:3128");
340        assert_eq!(ext.bypass_hosts.len(), 2);
341        assert_eq!(ext.bypass_hosts[0], "internal.corp");
342        assert_eq!(ext.bypass_hosts[1], "*.private.net");
343    }
344
345    #[test]
346    fn test_external_proxy_config_bypass_hosts_default_empty() {
347        let json = r#"{"address": "proxy:3128", "auth": null}"#;
348        let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
349        assert!(ext.bypass_hosts.is_empty());
350    }
351
352    // ========================================================================
353    // EndpointRule + path matching tests
354    // ========================================================================
355
356    #[test]
357    fn test_endpoint_allowed_empty_rules_allows_all() {
358        assert!(endpoint_allowed(&[], "GET", "/anything"));
359        assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
360    }
361
362    /// Helper: check a single rule against method+path via endpoint_allowed.
363    fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
364        endpoint_allowed(std::slice::from_ref(rule), method, path)
365    }
366
367    #[test]
368    fn test_endpoint_rule_exact_path() {
369        let rule = EndpointRule {
370            method: "GET".to_string(),
371            path: "/v1/chat/completions".to_string(),
372        };
373        assert!(check(&rule, "GET", "/v1/chat/completions"));
374        assert!(!check(&rule, "GET", "/v1/chat"));
375        assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
376    }
377
378    #[test]
379    fn test_endpoint_rule_method_case_insensitive() {
380        let rule = EndpointRule {
381            method: "get".to_string(),
382            path: "/api".to_string(),
383        };
384        assert!(check(&rule, "GET", "/api"));
385        assert!(check(&rule, "Get", "/api"));
386    }
387
388    #[test]
389    fn test_endpoint_rule_method_wildcard() {
390        let rule = EndpointRule {
391            method: "*".to_string(),
392            path: "/api/resource".to_string(),
393        };
394        assert!(check(&rule, "GET", "/api/resource"));
395        assert!(check(&rule, "DELETE", "/api/resource"));
396        assert!(check(&rule, "POST", "/api/resource"));
397    }
398
399    #[test]
400    fn test_endpoint_rule_method_mismatch() {
401        let rule = EndpointRule {
402            method: "GET".to_string(),
403            path: "/api/resource".to_string(),
404        };
405        assert!(!check(&rule, "POST", "/api/resource"));
406        assert!(!check(&rule, "DELETE", "/api/resource"));
407    }
408
409    #[test]
410    fn test_endpoint_rule_single_wildcard() {
411        let rule = EndpointRule {
412            method: "GET".to_string(),
413            path: "/api/v4/projects/*/merge_requests".to_string(),
414        };
415        assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
416        assert!(check(
417            &rule,
418            "GET",
419            "/api/v4/projects/my-proj/merge_requests"
420        ));
421        assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
422    }
423
424    #[test]
425    fn test_endpoint_rule_double_wildcard() {
426        let rule = EndpointRule {
427            method: "GET".to_string(),
428            path: "/api/v4/projects/**".to_string(),
429        };
430        assert!(check(&rule, "GET", "/api/v4/projects/123"));
431        assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
432        assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
433        assert!(!check(&rule, "GET", "/api/v4/other"));
434    }
435
436    #[test]
437    fn test_endpoint_rule_double_wildcard_middle() {
438        let rule = EndpointRule {
439            method: "*".to_string(),
440            path: "/api/**/notes".to_string(),
441        };
442        assert!(check(&rule, "GET", "/api/notes"));
443        assert!(check(&rule, "POST", "/api/projects/123/notes"));
444        assert!(check(&rule, "GET", "/api/a/b/c/notes"));
445        assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
446    }
447
448    #[test]
449    fn test_endpoint_rule_strips_query_string() {
450        let rule = EndpointRule {
451            method: "GET".to_string(),
452            path: "/api/data".to_string(),
453        };
454        assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
455    }
456
457    #[test]
458    fn test_endpoint_rule_trailing_slash_normalized() {
459        let rule = EndpointRule {
460            method: "GET".to_string(),
461            path: "/api/data".to_string(),
462        };
463        assert!(check(&rule, "GET", "/api/data/"));
464        assert!(check(&rule, "GET", "/api/data"));
465    }
466
467    #[test]
468    fn test_endpoint_rule_double_slash_normalized() {
469        let rule = EndpointRule {
470            method: "GET".to_string(),
471            path: "/api/data".to_string(),
472        };
473        assert!(check(&rule, "GET", "/api//data"));
474    }
475
476    #[test]
477    fn test_endpoint_rule_root_path() {
478        let rule = EndpointRule {
479            method: "GET".to_string(),
480            path: "/".to_string(),
481        };
482        assert!(check(&rule, "GET", "/"));
483        assert!(!check(&rule, "GET", "/anything"));
484    }
485
486    #[test]
487    fn test_compiled_endpoint_rules_hot_path() {
488        let rules = vec![
489            EndpointRule {
490                method: "GET".to_string(),
491                path: "/repos/*/issues".to_string(),
492            },
493            EndpointRule {
494                method: "POST".to_string(),
495                path: "/repos/*/issues/*/comments".to_string(),
496            },
497        ];
498        let compiled = CompiledEndpointRules::compile(&rules).unwrap();
499        assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
500        assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
501        assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
502        assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
503    }
504
505    #[test]
506    fn test_compiled_endpoint_rules_empty_allows_all() {
507        let compiled = CompiledEndpointRules::compile(&[]).unwrap();
508        assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
509    }
510
511    #[test]
512    fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
513        let rules = vec![EndpointRule {
514            method: "GET".to_string(),
515            path: "/api/[invalid".to_string(),
516        }];
517        assert!(CompiledEndpointRules::compile(&rules).is_err());
518    }
519
520    #[test]
521    fn test_endpoint_allowed_multiple_rules() {
522        let rules = vec![
523            EndpointRule {
524                method: "GET".to_string(),
525                path: "/repos/*/issues".to_string(),
526            },
527            EndpointRule {
528                method: "POST".to_string(),
529                path: "/repos/*/issues/*/comments".to_string(),
530            },
531        ];
532        assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
533        assert!(endpoint_allowed(
534            &rules,
535            "POST",
536            "/repos/myrepo/issues/42/comments"
537        ));
538        assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
539        assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
540    }
541
542    #[test]
543    fn test_endpoint_rule_serde_default() {
544        let json = r#"{
545            "prefix": "test",
546            "upstream": "https://example.com"
547        }"#;
548        let route: RouteConfig = serde_json::from_str(json).unwrap();
549        assert!(route.endpoint_rules.is_empty());
550        assert!(route.tls_ca.is_none());
551    }
552
553    #[test]
554    fn test_tls_ca_serde_roundtrip() {
555        let json = r#"{
556            "prefix": "k8s",
557            "upstream": "https://kubernetes.local:6443",
558            "tls_ca": "/run/secrets/k8s-ca.crt"
559        }"#;
560        let route: RouteConfig = serde_json::from_str(json).unwrap();
561        assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
562
563        let serialized = serde_json::to_string(&route).unwrap();
564        let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
565        assert_eq!(
566            deserialized.tls_ca.as_deref(),
567            Some("/run/secrets/k8s-ca.crt")
568        );
569    }
570
571    #[test]
572    fn test_endpoint_rule_percent_encoded_path_decoded() {
573        // Security: percent-encoded segments must not bypass rules.
574        // e.g., /api/v4/%70rojects should match a rule for /api/v4/projects/*
575        let rule = EndpointRule {
576            method: "GET".to_string(),
577            path: "/api/v4/projects/*/issues".to_string(),
578        };
579        assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
580        assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
581    }
582
583    #[test]
584    fn test_endpoint_rule_percent_encoded_full_segment() {
585        let rule = EndpointRule {
586            method: "POST".to_string(),
587            path: "/api/data".to_string(),
588        };
589        // %64%61%74%61 = "data"
590        assert!(check(&rule, "POST", "/api/%64%61%74%61"));
591    }
592
593    #[test]
594    fn test_compiled_endpoint_rules_percent_encoded() {
595        let rules = vec![EndpointRule {
596            method: "GET".to_string(),
597            path: "/repos/*/issues".to_string(),
598        }];
599        let compiled = CompiledEndpointRules::compile(&rules).unwrap();
600        // %69ssues = "issues"
601        assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
602        assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
603    }
604
605    #[test]
606    fn test_endpoint_rule_percent_encoded_invalid_utf8() {
607        // Security: invalid UTF-8 percent sequences must not fall back to
608        // the raw path (which could bypass rules). Lossy decoding replaces
609        // invalid bytes with U+FFFD, so the path won't match real segments.
610        let rule = EndpointRule {
611            method: "GET".to_string(),
612            path: "/api/projects".to_string(),
613        };
614        // %FF is not valid UTF-8 — must not match "/api/projects"
615        assert!(!check(&rule, "GET", "/api/%FFprojects"));
616    }
617
618    #[test]
619    fn test_endpoint_rule_serde_roundtrip() {
620        let rule = EndpointRule {
621            method: "GET".to_string(),
622            path: "/api/*/data".to_string(),
623        };
624        let json = serde_json::to_string(&rule).unwrap();
625        let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
626        assert_eq!(deserialized.method, "GET");
627        assert_eq!(deserialized.path, "/api/*/data");
628    }
629}