Skip to main content

nono_proxy/
config.rs

1//! Proxy configuration types.
2//!
3//! Defines the configuration for the proxy server, including allowed hosts,
4//! credential routes, and external proxy settings.
5
6use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9
10/// Credential injection mode determining how credentials are inserted into requests.
11#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum InjectMode {
14    /// Inject credential into an HTTP header (default)
15    #[default]
16    Header,
17    /// Replace a pattern in the URL path with the credential
18    UrlPath,
19    /// Add or replace a query parameter with the credential
20    QueryParam,
21    /// Use HTTP Basic Authentication (credential format: "username:password")
22    BasicAuth,
23}
24
25/// Configuration for the proxy server.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProxyConfig {
28    /// Bind address (default: 127.0.0.1)
29    #[serde(default = "default_bind_addr")]
30    pub bind_addr: IpAddr,
31
32    /// Bind port (0 = OS-assigned ephemeral port)
33    #[serde(default)]
34    pub bind_port: u16,
35
36    /// Allowed hosts for CONNECT mode (exact match + wildcards).
37    /// Empty = allow all hosts (except deny list).
38    #[serde(default)]
39    pub allowed_hosts: Vec<String>,
40
41    /// Reverse proxy credential routes.
42    #[serde(default)]
43    pub routes: Vec<RouteConfig>,
44
45    /// External (enterprise) proxy URL for passthrough mode.
46    /// When set, CONNECT requests are chained to this proxy.
47    #[serde(default)]
48    pub external_proxy: Option<ExternalProxyConfig>,
49
50    /// Maximum concurrent connections (0 = unlimited).
51    #[serde(default)]
52    pub max_connections: usize,
53}
54
55impl Default for ProxyConfig {
56    fn default() -> Self {
57        Self {
58            bind_addr: default_bind_addr(),
59            bind_port: 0,
60            allowed_hosts: Vec::new(),
61            routes: Vec::new(),
62            external_proxy: None,
63            max_connections: 256,
64        }
65    }
66}
67
68fn default_bind_addr() -> IpAddr {
69    IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
70}
71
72/// Configuration for a reverse proxy credential route.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RouteConfig {
75    /// Path prefix for routing (e.g., "/openai")
76    pub prefix: String,
77
78    /// Upstream URL to forward to (e.g., "https://api.openai.com")
79    pub upstream: String,
80
81    /// Keystore account name to load the credential from.
82    /// If `None`, no credential is injected.
83    pub credential_key: Option<String>,
84
85    /// Injection mode (default: "header")
86    #[serde(default)]
87    pub inject_mode: InjectMode,
88
89    // --- Header mode fields ---
90    /// HTTP header name for the credential (default: "Authorization")
91    /// Only used when inject_mode is "header".
92    #[serde(default = "default_inject_header")]
93    pub inject_header: String,
94
95    /// Format string for the credential value. `{}` is replaced with the secret.
96    /// Default: "Bearer {}"
97    /// Only used when inject_mode is "header".
98    #[serde(default = "default_credential_format")]
99    pub credential_format: String,
100
101    // --- URL path mode fields ---
102    /// Pattern to match in incoming URL path. Use {} as placeholder for phantom token.
103    /// Example: "/bot{}/" matches "/bot<token>/getMe"
104    /// Only used when inject_mode is "url_path".
105    #[serde(default)]
106    pub path_pattern: Option<String>,
107
108    /// Pattern for outgoing URL path. Use {} as placeholder for real credential.
109    /// Defaults to same as path_pattern if not specified.
110    /// Only used when inject_mode is "url_path".
111    #[serde(default)]
112    pub path_replacement: Option<String>,
113
114    // --- Query param mode fields ---
115    /// Name of the query parameter to add/replace with the credential.
116    /// Only used when inject_mode is "query_param".
117    #[serde(default)]
118    pub query_param_name: Option<String>,
119
120    /// Explicit environment variable name for the phantom token (e.g., "OPENAI_API_KEY").
121    ///
122    /// When set, this is used as the SDK API key env var name instead of deriving
123    /// it from `credential_key.to_uppercase()`. Required when `credential_key` is
124    /// a URI manager reference (e.g., `op://`, `apple-password://`) which would
125    /// otherwise produce a nonsensical env var name.
126    #[serde(default)]
127    pub env_var: Option<String>,
128
129    /// Optional L7 endpoint rules for method+path filtering.
130    ///
131    /// When non-empty, only requests matching at least one rule are allowed
132    /// (default-deny). When empty, all method+path combinations are permitted
133    /// (backward compatible).
134    #[serde(default)]
135    pub endpoint_rules: Vec<EndpointRule>,
136
137    /// Optional path to a PEM-encoded CA certificate file for upstream TLS.
138    ///
139    /// When set, the proxy trusts this CA in addition to the system roots
140    /// when connecting to the upstream for this route. This is required for
141    /// upstreams that use self-signed or private CA certificates (e.g.,
142    /// Kubernetes API servers).
143    #[serde(default)]
144    pub tls_ca: Option<String>,
145}
146
147/// An HTTP method+path access rule for reverse proxy endpoint filtering.
148///
149/// Used to restrict which API endpoints an agent can access through a
150/// credential route. Patterns use `/` separated segments with wildcards:
151/// - `*` matches exactly one path segment
152/// - `**` matches zero or more path segments
153#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
154pub struct EndpointRule {
155    /// HTTP method to match ("GET", "POST", etc.) or "*" for any method.
156    pub method: String,
157    /// URL path pattern with glob segments.
158    /// Example: "/api/v4/projects/*/merge_requests/**"
159    pub path: String,
160}
161
162/// Pre-compiled endpoint rules for the request hot path.
163///
164/// Built once at proxy startup from `EndpointRule` definitions. Holds
165/// compiled `globset::GlobMatcher`s so the hot path does a regex match,
166/// not a glob compile.
167pub struct CompiledEndpointRules {
168    rules: Vec<CompiledRule>,
169}
170
171struct CompiledRule {
172    method: String,
173    matcher: globset::GlobMatcher,
174}
175
176impl CompiledEndpointRules {
177    /// Compile endpoint rules into matchers. Invalid glob patterns are
178    /// rejected at startup with an error, not silently ignored at runtime.
179    pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
180        let mut compiled = Vec::with_capacity(rules.len());
181        for rule in rules {
182            let glob = Glob::new(&rule.path)
183                .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
184            compiled.push(CompiledRule {
185                method: rule.method.clone(),
186                matcher: glob.compile_matcher(),
187            });
188        }
189        Ok(Self { rules: compiled })
190    }
191
192    /// Check if the given method+path is allowed.
193    /// Returns `true` if no rules were compiled (allow-all, backward compatible).
194    #[must_use]
195    pub fn is_allowed(&self, method: &str, path: &str) -> bool {
196        if self.rules.is_empty() {
197            return true;
198        }
199        let normalized = normalize_path(path);
200        self.rules.iter().any(|r| {
201            (r.method == "*" || r.method.eq_ignore_ascii_case(method))
202                && r.matcher.is_match(&normalized)
203        })
204    }
205}
206
207impl std::fmt::Debug for CompiledEndpointRules {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        f.debug_struct("CompiledEndpointRules")
210            .field("count", &self.rules.len())
211            .finish()
212    }
213}
214
215/// Check if any endpoint rule permits the given method+path.
216/// Returns `true` if rules is empty (allow-all, backward compatible).
217///
218/// Test convenience only — compiles globs on each call. Production code
219/// should use `CompiledEndpointRules::is_allowed()` instead.
220#[cfg(test)]
221fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
222    if rules.is_empty() {
223        return true;
224    }
225    let normalized = normalize_path(path);
226    rules.iter().any(|r| {
227        (r.method == "*" || r.method.eq_ignore_ascii_case(method))
228            && Glob::new(&r.path)
229                .ok()
230                .map(|g| g.compile_matcher())
231                .is_some_and(|m| m.is_match(&normalized))
232    })
233}
234
235/// Normalize a URL path for matching: percent-decode, strip query string,
236/// collapse double slashes, strip trailing slash (but preserve root "/").
237///
238/// Percent-decoding prevents bypass via encoded characters (e.g.,
239/// `/api/%70rojects` evading a rule for `/api/projects/*`).
240fn normalize_path(path: &str) -> String {
241    // Strip query string
242    let path = path.split('?').next().unwrap_or(path);
243
244    // Percent-decode to prevent bypass via encoded segments.
245    // Use decode_binary + from_utf8_lossy so invalid UTF-8 sequences
246    // (e.g., %FF) become U+FFFD instead of falling back to the raw path.
247    let binary = urlencoding::decode_binary(path.as_bytes());
248    let decoded = String::from_utf8_lossy(&binary);
249
250    // Collapse double slashes by splitting on '/' and filtering empties,
251    // then rejoin. This also strips trailing slash.
252    let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
253    if segments.is_empty() {
254        "/".to_string()
255    } else {
256        format!("/{}", segments.join("/"))
257    }
258}
259
260fn default_inject_header() -> String {
261    "Authorization".to_string()
262}
263
264fn default_credential_format() -> String {
265    "Bearer {}".to_string()
266}
267
268/// Configuration for an external (enterprise) proxy.
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ExternalProxyConfig {
271    /// Proxy address (e.g., "squid.corp.internal:3128")
272    pub address: String,
273
274    /// Optional authentication for the external proxy.
275    pub auth: Option<ExternalProxyAuth>,
276
277    /// Hosts to bypass the external proxy and route directly.
278    /// Supports exact hostnames and `*.` wildcard suffixes (case-insensitive).
279    /// Empty = all traffic goes through the external proxy.
280    #[serde(default)]
281    pub bypass_hosts: Vec<String>,
282}
283
284/// Authentication for an external proxy.
285#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct ExternalProxyAuth {
287    /// Keystore account name for proxy credentials.
288    pub keyring_account: String,
289
290    /// Authentication scheme (only "basic" supported).
291    #[serde(default = "default_auth_scheme")]
292    pub scheme: String,
293}
294
295fn default_auth_scheme() -> String {
296    "basic".to_string()
297}
298
299#[cfg(test)]
300#[allow(clippy::unwrap_used)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_default_config() {
306        let config = ProxyConfig::default();
307        assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
308        assert_eq!(config.bind_port, 0);
309        assert!(config.allowed_hosts.is_empty());
310        assert!(config.routes.is_empty());
311        assert!(config.external_proxy.is_none());
312    }
313
314    #[test]
315    fn test_config_serialization() {
316        let config = ProxyConfig {
317            allowed_hosts: vec!["api.openai.com".to_string()],
318            ..Default::default()
319        };
320        let json = serde_json::to_string(&config).unwrap();
321        let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
322        assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
323    }
324
325    #[test]
326    fn test_external_proxy_config_with_bypass_hosts() {
327        let config = ProxyConfig {
328            external_proxy: Some(ExternalProxyConfig {
329                address: "squid.corp:3128".to_string(),
330                auth: None,
331                bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
332            }),
333            ..Default::default()
334        };
335        let json = serde_json::to_string(&config).unwrap();
336        let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
337        let ext = deserialized.external_proxy.unwrap();
338        assert_eq!(ext.address, "squid.corp:3128");
339        assert_eq!(ext.bypass_hosts.len(), 2);
340        assert_eq!(ext.bypass_hosts[0], "internal.corp");
341        assert_eq!(ext.bypass_hosts[1], "*.private.net");
342    }
343
344    #[test]
345    fn test_external_proxy_config_bypass_hosts_default_empty() {
346        let json = r#"{"address": "proxy:3128", "auth": null}"#;
347        let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
348        assert!(ext.bypass_hosts.is_empty());
349    }
350
351    // ========================================================================
352    // EndpointRule + path matching tests
353    // ========================================================================
354
355    #[test]
356    fn test_endpoint_allowed_empty_rules_allows_all() {
357        assert!(endpoint_allowed(&[], "GET", "/anything"));
358        assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
359    }
360
361    /// Helper: check a single rule against method+path via endpoint_allowed.
362    fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
363        endpoint_allowed(std::slice::from_ref(rule), method, path)
364    }
365
366    #[test]
367    fn test_endpoint_rule_exact_path() {
368        let rule = EndpointRule {
369            method: "GET".to_string(),
370            path: "/v1/chat/completions".to_string(),
371        };
372        assert!(check(&rule, "GET", "/v1/chat/completions"));
373        assert!(!check(&rule, "GET", "/v1/chat"));
374        assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
375    }
376
377    #[test]
378    fn test_endpoint_rule_method_case_insensitive() {
379        let rule = EndpointRule {
380            method: "get".to_string(),
381            path: "/api".to_string(),
382        };
383        assert!(check(&rule, "GET", "/api"));
384        assert!(check(&rule, "Get", "/api"));
385    }
386
387    #[test]
388    fn test_endpoint_rule_method_wildcard() {
389        let rule = EndpointRule {
390            method: "*".to_string(),
391            path: "/api/resource".to_string(),
392        };
393        assert!(check(&rule, "GET", "/api/resource"));
394        assert!(check(&rule, "DELETE", "/api/resource"));
395        assert!(check(&rule, "POST", "/api/resource"));
396    }
397
398    #[test]
399    fn test_endpoint_rule_method_mismatch() {
400        let rule = EndpointRule {
401            method: "GET".to_string(),
402            path: "/api/resource".to_string(),
403        };
404        assert!(!check(&rule, "POST", "/api/resource"));
405        assert!(!check(&rule, "DELETE", "/api/resource"));
406    }
407
408    #[test]
409    fn test_endpoint_rule_single_wildcard() {
410        let rule = EndpointRule {
411            method: "GET".to_string(),
412            path: "/api/v4/projects/*/merge_requests".to_string(),
413        };
414        assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
415        assert!(check(
416            &rule,
417            "GET",
418            "/api/v4/projects/my-proj/merge_requests"
419        ));
420        assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
421    }
422
423    #[test]
424    fn test_endpoint_rule_double_wildcard() {
425        let rule = EndpointRule {
426            method: "GET".to_string(),
427            path: "/api/v4/projects/**".to_string(),
428        };
429        assert!(check(&rule, "GET", "/api/v4/projects/123"));
430        assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
431        assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
432        assert!(!check(&rule, "GET", "/api/v4/other"));
433    }
434
435    #[test]
436    fn test_endpoint_rule_double_wildcard_middle() {
437        let rule = EndpointRule {
438            method: "*".to_string(),
439            path: "/api/**/notes".to_string(),
440        };
441        assert!(check(&rule, "GET", "/api/notes"));
442        assert!(check(&rule, "POST", "/api/projects/123/notes"));
443        assert!(check(&rule, "GET", "/api/a/b/c/notes"));
444        assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
445    }
446
447    #[test]
448    fn test_endpoint_rule_strips_query_string() {
449        let rule = EndpointRule {
450            method: "GET".to_string(),
451            path: "/api/data".to_string(),
452        };
453        assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
454    }
455
456    #[test]
457    fn test_endpoint_rule_trailing_slash_normalized() {
458        let rule = EndpointRule {
459            method: "GET".to_string(),
460            path: "/api/data".to_string(),
461        };
462        assert!(check(&rule, "GET", "/api/data/"));
463        assert!(check(&rule, "GET", "/api/data"));
464    }
465
466    #[test]
467    fn test_endpoint_rule_double_slash_normalized() {
468        let rule = EndpointRule {
469            method: "GET".to_string(),
470            path: "/api/data".to_string(),
471        };
472        assert!(check(&rule, "GET", "/api//data"));
473    }
474
475    #[test]
476    fn test_endpoint_rule_root_path() {
477        let rule = EndpointRule {
478            method: "GET".to_string(),
479            path: "/".to_string(),
480        };
481        assert!(check(&rule, "GET", "/"));
482        assert!(!check(&rule, "GET", "/anything"));
483    }
484
485    #[test]
486    fn test_compiled_endpoint_rules_hot_path() {
487        let rules = vec![
488            EndpointRule {
489                method: "GET".to_string(),
490                path: "/repos/*/issues".to_string(),
491            },
492            EndpointRule {
493                method: "POST".to_string(),
494                path: "/repos/*/issues/*/comments".to_string(),
495            },
496        ];
497        let compiled = CompiledEndpointRules::compile(&rules).unwrap();
498        assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
499        assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
500        assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
501        assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
502    }
503
504    #[test]
505    fn test_compiled_endpoint_rules_empty_allows_all() {
506        let compiled = CompiledEndpointRules::compile(&[]).unwrap();
507        assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
508    }
509
510    #[test]
511    fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
512        let rules = vec![EndpointRule {
513            method: "GET".to_string(),
514            path: "/api/[invalid".to_string(),
515        }];
516        assert!(CompiledEndpointRules::compile(&rules).is_err());
517    }
518
519    #[test]
520    fn test_endpoint_allowed_multiple_rules() {
521        let rules = vec![
522            EndpointRule {
523                method: "GET".to_string(),
524                path: "/repos/*/issues".to_string(),
525            },
526            EndpointRule {
527                method: "POST".to_string(),
528                path: "/repos/*/issues/*/comments".to_string(),
529            },
530        ];
531        assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
532        assert!(endpoint_allowed(
533            &rules,
534            "POST",
535            "/repos/myrepo/issues/42/comments"
536        ));
537        assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
538        assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
539    }
540
541    #[test]
542    fn test_endpoint_rule_serde_default() {
543        let json = r#"{
544            "prefix": "/test",
545            "upstream": "https://example.com"
546        }"#;
547        let route: RouteConfig = serde_json::from_str(json).unwrap();
548        assert!(route.endpoint_rules.is_empty());
549        assert!(route.tls_ca.is_none());
550    }
551
552    #[test]
553    fn test_tls_ca_serde_roundtrip() {
554        let json = r#"{
555            "prefix": "/k8s",
556            "upstream": "https://kubernetes.local:6443",
557            "tls_ca": "/run/secrets/k8s-ca.crt"
558        }"#;
559        let route: RouteConfig = serde_json::from_str(json).unwrap();
560        assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
561
562        let serialized = serde_json::to_string(&route).unwrap();
563        let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
564        assert_eq!(
565            deserialized.tls_ca.as_deref(),
566            Some("/run/secrets/k8s-ca.crt")
567        );
568    }
569
570    #[test]
571    fn test_endpoint_rule_percent_encoded_path_decoded() {
572        // Security: percent-encoded segments must not bypass rules.
573        // e.g., /api/v4/%70rojects should match a rule for /api/v4/projects/*
574        let rule = EndpointRule {
575            method: "GET".to_string(),
576            path: "/api/v4/projects/*/issues".to_string(),
577        };
578        assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
579        assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
580    }
581
582    #[test]
583    fn test_endpoint_rule_percent_encoded_full_segment() {
584        let rule = EndpointRule {
585            method: "POST".to_string(),
586            path: "/api/data".to_string(),
587        };
588        // %64%61%74%61 = "data"
589        assert!(check(&rule, "POST", "/api/%64%61%74%61"));
590    }
591
592    #[test]
593    fn test_compiled_endpoint_rules_percent_encoded() {
594        let rules = vec![EndpointRule {
595            method: "GET".to_string(),
596            path: "/repos/*/issues".to_string(),
597        }];
598        let compiled = CompiledEndpointRules::compile(&rules).unwrap();
599        // %69ssues = "issues"
600        assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
601        assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
602    }
603
604    #[test]
605    fn test_endpoint_rule_percent_encoded_invalid_utf8() {
606        // Security: invalid UTF-8 percent sequences must not fall back to
607        // the raw path (which could bypass rules). Lossy decoding replaces
608        // invalid bytes with U+FFFD, so the path won't match real segments.
609        let rule = EndpointRule {
610            method: "GET".to_string(),
611            path: "/api/projects".to_string(),
612        };
613        // %FF is not valid UTF-8 — must not match "/api/projects"
614        assert!(!check(&rule, "GET", "/api/%FFprojects"));
615    }
616
617    #[test]
618    fn test_endpoint_rule_serde_roundtrip() {
619        let rule = EndpointRule {
620            method: "GET".to_string(),
621            path: "/api/*/data".to_string(),
622        };
623        let json = serde_json::to_string(&rule).unwrap();
624        let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
625        assert_eq!(deserialized.method, "GET");
626        assert_eq!(deserialized.path, "/api/*/data");
627    }
628}