Skip to main content

nono_proxy/
config.rs

1//! Proxy configuration types.
2//!
3//! Defines the configuration for the proxy server, including allowed hosts,
4//! credential routes, and external proxy settings.
5
6use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9
10/// Credential injection mode determining how credentials are inserted into requests.
11#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum InjectMode {
14    /// Inject credential into an HTTP header (default)
15    #[default]
16    Header,
17    /// Replace a pattern in the URL path with the credential
18    UrlPath,
19    /// Add or replace a query parameter with the credential
20    QueryParam,
21    /// Use HTTP Basic Authentication (credential format: "username:password")
22    BasicAuth,
23}
24
25/// Configuration for the proxy server.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProxyConfig {
28    /// Bind address (default: 127.0.0.1)
29    #[serde(default = "default_bind_addr")]
30    pub bind_addr: IpAddr,
31
32    /// Bind port (0 = OS-assigned ephemeral port)
33    #[serde(default)]
34    pub bind_port: u16,
35
36    /// Allowed hosts for CONNECT mode (exact match + wildcards).
37    /// Empty = allow all hosts (except deny list).
38    #[serde(default)]
39    pub allowed_hosts: Vec<String>,
40
41    /// Reverse proxy credential routes.
42    #[serde(default)]
43    pub routes: Vec<RouteConfig>,
44
45    /// External (enterprise) proxy URL for passthrough mode.
46    /// When set, CONNECT requests are chained to this proxy.
47    #[serde(default)]
48    pub external_proxy: Option<ExternalProxyConfig>,
49
50    /// Maximum concurrent connections (0 = unlimited).
51    #[serde(default)]
52    pub max_connections: usize,
53}
54
55impl Default for ProxyConfig {
56    fn default() -> Self {
57        Self {
58            bind_addr: default_bind_addr(),
59            bind_port: 0,
60            allowed_hosts: Vec::new(),
61            routes: Vec::new(),
62            external_proxy: None,
63            max_connections: 256,
64        }
65    }
66}
67
68fn default_bind_addr() -> IpAddr {
69    IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
70}
71
72/// Configuration for a reverse proxy credential route.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RouteConfig {
75    /// Path prefix for routing (e.g., "openai").
76    /// Must NOT include leading or trailing slashes — it is a bare service name, not a URL path.
77    pub prefix: String,
78
79    /// Upstream URL to forward to (e.g., "https://api.openai.com")
80    pub upstream: String,
81
82    /// Keystore account name to load the credential from.
83    /// If `None`, no credential is injected.
84    pub credential_key: Option<String>,
85
86    /// Injection mode (default: "header")
87    #[serde(default)]
88    pub inject_mode: InjectMode,
89
90    // --- Header mode fields ---
91    /// HTTP header name for the credential (default: "Authorization")
92    /// Only used when inject_mode is "header".
93    #[serde(default = "default_inject_header")]
94    pub inject_header: String,
95
96    /// Format string for the credential value. `{}` is replaced with the secret.
97    /// Default: "Bearer {}"
98    /// Only used when inject_mode is "header".
99    #[serde(default = "default_credential_format")]
100    pub credential_format: String,
101
102    // --- URL path mode fields ---
103    /// Pattern to match in incoming URL path. Use {} as placeholder for phantom token.
104    /// Example: "/bot{}/" matches "/bot<token>/getMe"
105    /// Only used when inject_mode is "url_path".
106    #[serde(default)]
107    pub path_pattern: Option<String>,
108
109    /// Pattern for outgoing URL path. Use {} as placeholder for real credential.
110    /// Defaults to same as path_pattern if not specified.
111    /// Only used when inject_mode is "url_path".
112    #[serde(default)]
113    pub path_replacement: Option<String>,
114
115    // --- Query param mode fields ---
116    /// Name of the query parameter to add/replace with the credential.
117    /// Only used when inject_mode is "query_param".
118    #[serde(default)]
119    pub query_param_name: Option<String>,
120
121    /// Explicit environment variable name for the phantom token (e.g., "OPENAI_API_KEY").
122    ///
123    /// When set, this is used as the SDK API key env var name instead of deriving
124    /// it from `credential_key.to_uppercase()`. Required when `credential_key` is
125    /// a URI manager reference (e.g., `op://`, `apple-password://`) which would
126    /// otherwise produce a nonsensical env var name.
127    #[serde(default)]
128    pub env_var: Option<String>,
129
130    /// Optional L7 endpoint rules for method+path filtering.
131    ///
132    /// When non-empty, only requests matching at least one rule are allowed
133    /// (default-deny). When empty, all method+path combinations are permitted
134    /// (backward compatible).
135    #[serde(default)]
136    pub endpoint_rules: Vec<EndpointRule>,
137
138    /// Optional path to a PEM-encoded CA certificate file for upstream TLS.
139    ///
140    /// When set, the proxy trusts this CA in addition to the system roots
141    /// when connecting to the upstream for this route. This is required for
142    /// upstreams that use self-signed or private CA certificates (e.g.,
143    /// Kubernetes API servers).
144    #[serde(default)]
145    pub tls_ca: Option<String>,
146
147    /// Optional path to a PEM-encoded client certificate for upstream mTLS.
148    ///
149    /// When set together with `tls_client_key`, the proxy presents this
150    /// certificate to the upstream during TLS handshake. Required for
151    /// upstreams that enforce mutual TLS (e.g., Kubernetes API servers
152    /// configured with client-certificate authentication).
153    #[serde(default)]
154    pub tls_client_cert: Option<String>,
155
156    /// Optional path to a PEM-encoded private key for upstream mTLS.
157    ///
158    /// Must be set together with `tls_client_cert`. The key must correspond
159    /// to the certificate in `tls_client_cert`.
160    #[serde(default)]
161    pub tls_client_key: Option<String>,
162}
163
164/// An HTTP method+path access rule for reverse proxy endpoint filtering.
165///
166/// Used to restrict which API endpoints an agent can access through a
167/// credential route. Patterns use `/` separated segments with wildcards:
168/// - `*` matches exactly one path segment
169/// - `**` matches zero or more path segments
170#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct EndpointRule {
172    /// HTTP method to match ("GET", "POST", etc.) or "*" for any method.
173    pub method: String,
174    /// URL path pattern with glob segments.
175    /// Example: "/api/v4/projects/*/merge_requests/**"
176    pub path: String,
177}
178
179/// Pre-compiled endpoint rules for the request hot path.
180///
181/// Built once at proxy startup from `EndpointRule` definitions. Holds
182/// compiled `globset::GlobMatcher`s so the hot path does a regex match,
183/// not a glob compile.
184pub struct CompiledEndpointRules {
185    rules: Vec<CompiledRule>,
186}
187
188struct CompiledRule {
189    method: String,
190    matcher: globset::GlobMatcher,
191}
192
193impl CompiledEndpointRules {
194    /// Compile endpoint rules into matchers. Invalid glob patterns are
195    /// rejected at startup with an error, not silently ignored at runtime.
196    pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
197        let mut compiled = Vec::with_capacity(rules.len());
198        for rule in rules {
199            let glob = Glob::new(&rule.path)
200                .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
201            compiled.push(CompiledRule {
202                method: rule.method.clone(),
203                matcher: glob.compile_matcher(),
204            });
205        }
206        Ok(Self { rules: compiled })
207    }
208
209    /// Check if the given method+path is allowed.
210    /// Returns `true` if no rules were compiled (allow-all, backward compatible).
211    #[must_use]
212    pub fn is_allowed(&self, method: &str, path: &str) -> bool {
213        if self.rules.is_empty() {
214            return true;
215        }
216        let normalized = normalize_path(path);
217        self.rules.iter().any(|r| {
218            (r.method == "*" || r.method.eq_ignore_ascii_case(method))
219                && r.matcher.is_match(&normalized)
220        })
221    }
222}
223
224impl std::fmt::Debug for CompiledEndpointRules {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        f.debug_struct("CompiledEndpointRules")
227            .field("count", &self.rules.len())
228            .finish()
229    }
230}
231
232/// Check if any endpoint rule permits the given method+path.
233/// Returns `true` if rules is empty (allow-all, backward compatible).
234///
235/// Test convenience only — compiles globs on each call. Production code
236/// should use `CompiledEndpointRules::is_allowed()` instead.
237#[cfg(test)]
238fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
239    if rules.is_empty() {
240        return true;
241    }
242    let normalized = normalize_path(path);
243    rules.iter().any(|r| {
244        (r.method == "*" || r.method.eq_ignore_ascii_case(method))
245            && Glob::new(&r.path)
246                .ok()
247                .map(|g| g.compile_matcher())
248                .is_some_and(|m| m.is_match(&normalized))
249    })
250}
251
252/// Normalize a URL path for matching: percent-decode, strip query string,
253/// collapse double slashes, strip trailing slash (but preserve root "/").
254///
255/// Percent-decoding prevents bypass via encoded characters (e.g.,
256/// `/api/%70rojects` evading a rule for `/api/projects/*`).
257fn normalize_path(path: &str) -> String {
258    // Strip query string
259    let path = path.split('?').next().unwrap_or(path);
260
261    // Percent-decode to prevent bypass via encoded segments.
262    // Use decode_binary + from_utf8_lossy so invalid UTF-8 sequences
263    // (e.g., %FF) become U+FFFD instead of falling back to the raw path.
264    let binary = urlencoding::decode_binary(path.as_bytes());
265    let decoded = String::from_utf8_lossy(&binary);
266
267    // Collapse double slashes by splitting on '/' and filtering empties,
268    // then rejoin. This also strips trailing slash.
269    let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
270    if segments.is_empty() {
271        "/".to_string()
272    } else {
273        format!("/{}", segments.join("/"))
274    }
275}
276
277fn default_inject_header() -> String {
278    "Authorization".to_string()
279}
280
281fn default_credential_format() -> String {
282    "Bearer {}".to_string()
283}
284
285/// Configuration for an external (enterprise) proxy.
286#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct ExternalProxyConfig {
288    /// Proxy address (e.g., "squid.corp.internal:3128")
289    pub address: String,
290
291    /// Optional authentication for the external proxy.
292    pub auth: Option<ExternalProxyAuth>,
293
294    /// Hosts to bypass the external proxy and route directly.
295    /// Supports exact hostnames and `*.` wildcard suffixes (case-insensitive).
296    /// Empty = all traffic goes through the external proxy.
297    #[serde(default)]
298    pub bypass_hosts: Vec<String>,
299}
300
301/// Authentication for an external proxy.
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct ExternalProxyAuth {
304    /// Keystore account name for proxy credentials.
305    pub keyring_account: String,
306
307    /// Authentication scheme (only "basic" supported).
308    #[serde(default = "default_auth_scheme")]
309    pub scheme: String,
310}
311
312fn default_auth_scheme() -> String {
313    "basic".to_string()
314}
315
316#[cfg(test)]
317#[allow(clippy::unwrap_used)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_default_config() {
323        let config = ProxyConfig::default();
324        assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
325        assert_eq!(config.bind_port, 0);
326        assert!(config.allowed_hosts.is_empty());
327        assert!(config.routes.is_empty());
328        assert!(config.external_proxy.is_none());
329    }
330
331    #[test]
332    fn test_config_serialization() {
333        let config = ProxyConfig {
334            allowed_hosts: vec!["api.openai.com".to_string()],
335            ..Default::default()
336        };
337        let json = serde_json::to_string(&config).unwrap();
338        let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
339        assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
340    }
341
342    #[test]
343    fn test_external_proxy_config_with_bypass_hosts() {
344        let config = ProxyConfig {
345            external_proxy: Some(ExternalProxyConfig {
346                address: "squid.corp:3128".to_string(),
347                auth: None,
348                bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
349            }),
350            ..Default::default()
351        };
352        let json = serde_json::to_string(&config).unwrap();
353        let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
354        let ext = deserialized.external_proxy.unwrap();
355        assert_eq!(ext.address, "squid.corp:3128");
356        assert_eq!(ext.bypass_hosts.len(), 2);
357        assert_eq!(ext.bypass_hosts[0], "internal.corp");
358        assert_eq!(ext.bypass_hosts[1], "*.private.net");
359    }
360
361    #[test]
362    fn test_external_proxy_config_bypass_hosts_default_empty() {
363        let json = r#"{"address": "proxy:3128", "auth": null}"#;
364        let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
365        assert!(ext.bypass_hosts.is_empty());
366    }
367
368    // ========================================================================
369    // EndpointRule + path matching tests
370    // ========================================================================
371
372    #[test]
373    fn test_endpoint_allowed_empty_rules_allows_all() {
374        assert!(endpoint_allowed(&[], "GET", "/anything"));
375        assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
376    }
377
378    /// Helper: check a single rule against method+path via endpoint_allowed.
379    fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
380        endpoint_allowed(std::slice::from_ref(rule), method, path)
381    }
382
383    #[test]
384    fn test_endpoint_rule_exact_path() {
385        let rule = EndpointRule {
386            method: "GET".to_string(),
387            path: "/v1/chat/completions".to_string(),
388        };
389        assert!(check(&rule, "GET", "/v1/chat/completions"));
390        assert!(!check(&rule, "GET", "/v1/chat"));
391        assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
392    }
393
394    #[test]
395    fn test_endpoint_rule_method_case_insensitive() {
396        let rule = EndpointRule {
397            method: "get".to_string(),
398            path: "/api".to_string(),
399        };
400        assert!(check(&rule, "GET", "/api"));
401        assert!(check(&rule, "Get", "/api"));
402    }
403
404    #[test]
405    fn test_endpoint_rule_method_wildcard() {
406        let rule = EndpointRule {
407            method: "*".to_string(),
408            path: "/api/resource".to_string(),
409        };
410        assert!(check(&rule, "GET", "/api/resource"));
411        assert!(check(&rule, "DELETE", "/api/resource"));
412        assert!(check(&rule, "POST", "/api/resource"));
413    }
414
415    #[test]
416    fn test_endpoint_rule_method_mismatch() {
417        let rule = EndpointRule {
418            method: "GET".to_string(),
419            path: "/api/resource".to_string(),
420        };
421        assert!(!check(&rule, "POST", "/api/resource"));
422        assert!(!check(&rule, "DELETE", "/api/resource"));
423    }
424
425    #[test]
426    fn test_endpoint_rule_single_wildcard() {
427        let rule = EndpointRule {
428            method: "GET".to_string(),
429            path: "/api/v4/projects/*/merge_requests".to_string(),
430        };
431        assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
432        assert!(check(
433            &rule,
434            "GET",
435            "/api/v4/projects/my-proj/merge_requests"
436        ));
437        assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
438    }
439
440    #[test]
441    fn test_endpoint_rule_double_wildcard() {
442        let rule = EndpointRule {
443            method: "GET".to_string(),
444            path: "/api/v4/projects/**".to_string(),
445        };
446        assert!(check(&rule, "GET", "/api/v4/projects/123"));
447        assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
448        assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
449        assert!(!check(&rule, "GET", "/api/v4/other"));
450    }
451
452    #[test]
453    fn test_endpoint_rule_double_wildcard_middle() {
454        let rule = EndpointRule {
455            method: "*".to_string(),
456            path: "/api/**/notes".to_string(),
457        };
458        assert!(check(&rule, "GET", "/api/notes"));
459        assert!(check(&rule, "POST", "/api/projects/123/notes"));
460        assert!(check(&rule, "GET", "/api/a/b/c/notes"));
461        assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
462    }
463
464    #[test]
465    fn test_endpoint_rule_strips_query_string() {
466        let rule = EndpointRule {
467            method: "GET".to_string(),
468            path: "/api/data".to_string(),
469        };
470        assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
471    }
472
473    #[test]
474    fn test_endpoint_rule_trailing_slash_normalized() {
475        let rule = EndpointRule {
476            method: "GET".to_string(),
477            path: "/api/data".to_string(),
478        };
479        assert!(check(&rule, "GET", "/api/data/"));
480        assert!(check(&rule, "GET", "/api/data"));
481    }
482
483    #[test]
484    fn test_endpoint_rule_double_slash_normalized() {
485        let rule = EndpointRule {
486            method: "GET".to_string(),
487            path: "/api/data".to_string(),
488        };
489        assert!(check(&rule, "GET", "/api//data"));
490    }
491
492    #[test]
493    fn test_endpoint_rule_root_path() {
494        let rule = EndpointRule {
495            method: "GET".to_string(),
496            path: "/".to_string(),
497        };
498        assert!(check(&rule, "GET", "/"));
499        assert!(!check(&rule, "GET", "/anything"));
500    }
501
502    #[test]
503    fn test_compiled_endpoint_rules_hot_path() {
504        let rules = vec![
505            EndpointRule {
506                method: "GET".to_string(),
507                path: "/repos/*/issues".to_string(),
508            },
509            EndpointRule {
510                method: "POST".to_string(),
511                path: "/repos/*/issues/*/comments".to_string(),
512            },
513        ];
514        let compiled = CompiledEndpointRules::compile(&rules).unwrap();
515        assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
516        assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
517        assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
518        assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
519    }
520
521    #[test]
522    fn test_compiled_endpoint_rules_empty_allows_all() {
523        let compiled = CompiledEndpointRules::compile(&[]).unwrap();
524        assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
525    }
526
527    #[test]
528    fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
529        let rules = vec![EndpointRule {
530            method: "GET".to_string(),
531            path: "/api/[invalid".to_string(),
532        }];
533        assert!(CompiledEndpointRules::compile(&rules).is_err());
534    }
535
536    #[test]
537    fn test_endpoint_allowed_multiple_rules() {
538        let rules = vec![
539            EndpointRule {
540                method: "GET".to_string(),
541                path: "/repos/*/issues".to_string(),
542            },
543            EndpointRule {
544                method: "POST".to_string(),
545                path: "/repos/*/issues/*/comments".to_string(),
546            },
547        ];
548        assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
549        assert!(endpoint_allowed(
550            &rules,
551            "POST",
552            "/repos/myrepo/issues/42/comments"
553        ));
554        assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
555        assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
556    }
557
558    #[test]
559    fn test_endpoint_rule_serde_default() {
560        let json = r#"{
561            "prefix": "test",
562            "upstream": "https://example.com"
563        }"#;
564        let route: RouteConfig = serde_json::from_str(json).unwrap();
565        assert!(route.endpoint_rules.is_empty());
566        assert!(route.tls_ca.is_none());
567    }
568
569    #[test]
570    fn test_tls_ca_serde_roundtrip() {
571        let json = r#"{
572            "prefix": "k8s",
573            "upstream": "https://kubernetes.local:6443",
574            "tls_ca": "/run/secrets/k8s-ca.crt"
575        }"#;
576        let route: RouteConfig = serde_json::from_str(json).unwrap();
577        assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
578
579        let serialized = serde_json::to_string(&route).unwrap();
580        let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
581        assert_eq!(
582            deserialized.tls_ca.as_deref(),
583            Some("/run/secrets/k8s-ca.crt")
584        );
585    }
586
587    #[test]
588    fn test_endpoint_rule_percent_encoded_path_decoded() {
589        // Security: percent-encoded segments must not bypass rules.
590        // e.g., /api/v4/%70rojects should match a rule for /api/v4/projects/*
591        let rule = EndpointRule {
592            method: "GET".to_string(),
593            path: "/api/v4/projects/*/issues".to_string(),
594        };
595        assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
596        assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
597    }
598
599    #[test]
600    fn test_endpoint_rule_percent_encoded_full_segment() {
601        let rule = EndpointRule {
602            method: "POST".to_string(),
603            path: "/api/data".to_string(),
604        };
605        // %64%61%74%61 = "data"
606        assert!(check(&rule, "POST", "/api/%64%61%74%61"));
607    }
608
609    #[test]
610    fn test_compiled_endpoint_rules_percent_encoded() {
611        let rules = vec![EndpointRule {
612            method: "GET".to_string(),
613            path: "/repos/*/issues".to_string(),
614        }];
615        let compiled = CompiledEndpointRules::compile(&rules).unwrap();
616        // %69ssues = "issues"
617        assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
618        assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
619    }
620
621    #[test]
622    fn test_endpoint_rule_percent_encoded_invalid_utf8() {
623        // Security: invalid UTF-8 percent sequences must not fall back to
624        // the raw path (which could bypass rules). Lossy decoding replaces
625        // invalid bytes with U+FFFD, so the path won't match real segments.
626        let rule = EndpointRule {
627            method: "GET".to_string(),
628            path: "/api/projects".to_string(),
629        };
630        // %FF is not valid UTF-8 — must not match "/api/projects"
631        assert!(!check(&rule, "GET", "/api/%FFprojects"));
632    }
633
634    #[test]
635    fn test_endpoint_rule_serde_roundtrip() {
636        let rule = EndpointRule {
637            method: "GET".to_string(),
638            path: "/api/*/data".to_string(),
639        };
640        let json = serde_json::to_string(&rule).unwrap();
641        let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
642        assert_eq!(deserialized.method, "GET");
643        assert_eq!(deserialized.path, "/api/*/data");
644    }
645}