Skip to main content

nono_proxy/
config.rs

1//! Proxy configuration types.
2//!
3//! Defines the configuration for the proxy server, including allowed hosts,
4//! credential routes, and external proxy settings.
5
6use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9
10/// Credential injection mode determining how credentials are inserted into requests.
11#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum InjectMode {
14    /// Inject credential into an HTTP header (default)
15    #[default]
16    Header,
17    /// Replace a pattern in the URL path with the credential
18    UrlPath,
19    /// Add or replace a query parameter with the credential
20    QueryParam,
21    /// Use HTTP Basic Authentication (credential format: "username:password")
22    BasicAuth,
23}
24
25/// Configuration for the proxy server.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProxyConfig {
28    /// Bind address (default: 127.0.0.1)
29    #[serde(default = "default_bind_addr")]
30    pub bind_addr: IpAddr,
31
32    /// Bind port (0 = OS-assigned ephemeral port)
33    #[serde(default)]
34    pub bind_port: u16,
35
36    /// Allowed hosts for CONNECT mode (exact match + wildcards).
37    /// Empty = allow all hosts (except deny list).
38    #[serde(default)]
39    pub allowed_hosts: Vec<String>,
40
41    /// Reverse proxy credential routes.
42    #[serde(default)]
43    pub routes: Vec<RouteConfig>,
44
45    /// External (enterprise) proxy URL for passthrough mode.
46    /// When set, CONNECT requests are chained to this proxy.
47    #[serde(default)]
48    pub external_proxy: Option<ExternalProxyConfig>,
49
50    /// Maximum concurrent connections (0 = unlimited).
51    #[serde(default)]
52    pub max_connections: usize,
53}
54
55impl Default for ProxyConfig {
56    fn default() -> Self {
57        Self {
58            bind_addr: default_bind_addr(),
59            bind_port: 0,
60            allowed_hosts: Vec::new(),
61            routes: Vec::new(),
62            external_proxy: None,
63            max_connections: 256,
64        }
65    }
66}
67
68fn default_bind_addr() -> IpAddr {
69    IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
70}
71
72/// Configuration for a reverse proxy credential route.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RouteConfig {
75    /// Path prefix for routing (e.g., "/openai")
76    pub prefix: String,
77
78    /// Upstream URL to forward to (e.g., "https://api.openai.com")
79    pub upstream: String,
80
81    /// Keystore account name to load the credential from.
82    /// If `None`, no credential is injected.
83    pub credential_key: Option<String>,
84
85    /// Injection mode (default: "header")
86    #[serde(default)]
87    pub inject_mode: InjectMode,
88
89    // --- Header mode fields ---
90    /// HTTP header name for the credential (default: "Authorization")
91    /// Only used when inject_mode is "header".
92    #[serde(default = "default_inject_header")]
93    pub inject_header: String,
94
95    /// Format string for the credential value. `{}` is replaced with the secret.
96    /// Default: "Bearer {}"
97    /// Only used when inject_mode is "header".
98    #[serde(default = "default_credential_format")]
99    pub credential_format: String,
100
101    // --- URL path mode fields ---
102    /// Pattern to match in incoming URL path. Use {} as placeholder for phantom token.
103    /// Example: "/bot{}/" matches "/bot<token>/getMe"
104    /// Only used when inject_mode is "url_path".
105    #[serde(default)]
106    pub path_pattern: Option<String>,
107
108    /// Pattern for outgoing URL path. Use {} as placeholder for real credential.
109    /// Defaults to same as path_pattern if not specified.
110    /// Only used when inject_mode is "url_path".
111    #[serde(default)]
112    pub path_replacement: Option<String>,
113
114    // --- Query param mode fields ---
115    /// Name of the query parameter to add/replace with the credential.
116    /// Only used when inject_mode is "query_param".
117    #[serde(default)]
118    pub query_param_name: Option<String>,
119
120    /// Explicit environment variable name for the phantom token (e.g., "OPENAI_API_KEY").
121    ///
122    /// When set, this is used as the SDK API key env var name instead of deriving
123    /// it from `credential_key.to_uppercase()`. Required when `credential_key` is
124    /// a URI manager reference (e.g., `op://`, `apple-password://`) which would
125    /// otherwise produce a nonsensical env var name.
126    #[serde(default)]
127    pub env_var: Option<String>,
128
129    /// Optional L7 endpoint rules for method+path filtering.
130    ///
131    /// When non-empty, only requests matching at least one rule are allowed
132    /// (default-deny). When empty, all method+path combinations are permitted
133    /// (backward compatible).
134    #[serde(default)]
135    pub endpoint_rules: Vec<EndpointRule>,
136}
137
138/// An HTTP method+path access rule for reverse proxy endpoint filtering.
139///
140/// Used to restrict which API endpoints an agent can access through a
141/// credential route. Patterns use `/` separated segments with wildcards:
142/// - `*` matches exactly one path segment
143/// - `**` matches zero or more path segments
144#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
145pub struct EndpointRule {
146    /// HTTP method to match ("GET", "POST", etc.) or "*" for any method.
147    pub method: String,
148    /// URL path pattern with glob segments.
149    /// Example: "/api/v4/projects/*/merge_requests/**"
150    pub path: String,
151}
152
153/// Pre-compiled endpoint rules for the request hot path.
154///
155/// Built once at proxy startup from `EndpointRule` definitions. Holds
156/// compiled `globset::GlobMatcher`s so the hot path does a regex match,
157/// not a glob compile.
158pub struct CompiledEndpointRules {
159    rules: Vec<CompiledRule>,
160}
161
162struct CompiledRule {
163    method: String,
164    matcher: globset::GlobMatcher,
165}
166
167impl CompiledEndpointRules {
168    /// Compile endpoint rules into matchers. Invalid glob patterns are
169    /// rejected at startup with an error, not silently ignored at runtime.
170    pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
171        let mut compiled = Vec::with_capacity(rules.len());
172        for rule in rules {
173            let glob = Glob::new(&rule.path)
174                .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
175            compiled.push(CompiledRule {
176                method: rule.method.clone(),
177                matcher: glob.compile_matcher(),
178            });
179        }
180        Ok(Self { rules: compiled })
181    }
182
183    /// Check if the given method+path is allowed.
184    /// Returns `true` if no rules were compiled (allow-all, backward compatible).
185    #[must_use]
186    pub fn is_allowed(&self, method: &str, path: &str) -> bool {
187        if self.rules.is_empty() {
188            return true;
189        }
190        let normalized = normalize_path(path);
191        self.rules.iter().any(|r| {
192            (r.method == "*" || r.method.eq_ignore_ascii_case(method))
193                && r.matcher.is_match(&normalized)
194        })
195    }
196}
197
198impl std::fmt::Debug for CompiledEndpointRules {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        f.debug_struct("CompiledEndpointRules")
201            .field("count", &self.rules.len())
202            .finish()
203    }
204}
205
206/// Check if any endpoint rule permits the given method+path.
207/// Returns `true` if rules is empty (allow-all, backward compatible).
208///
209/// Test convenience only — compiles globs on each call. Production code
210/// should use `CompiledEndpointRules::is_allowed()` instead.
211#[cfg(test)]
212fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
213    if rules.is_empty() {
214        return true;
215    }
216    let normalized = normalize_path(path);
217    rules.iter().any(|r| {
218        (r.method == "*" || r.method.eq_ignore_ascii_case(method))
219            && Glob::new(&r.path)
220                .ok()
221                .map(|g| g.compile_matcher())
222                .is_some_and(|m| m.is_match(&normalized))
223    })
224}
225
226/// Normalize a URL path for matching: percent-decode, strip query string,
227/// collapse double slashes, strip trailing slash (but preserve root "/").
228///
229/// Percent-decoding prevents bypass via encoded characters (e.g.,
230/// `/api/%70rojects` evading a rule for `/api/projects/*`).
231fn normalize_path(path: &str) -> String {
232    // Strip query string
233    let path = path.split('?').next().unwrap_or(path);
234
235    // Percent-decode to prevent bypass via encoded segments.
236    // Use decode_binary + from_utf8_lossy so invalid UTF-8 sequences
237    // (e.g., %FF) become U+FFFD instead of falling back to the raw path.
238    let binary = urlencoding::decode_binary(path.as_bytes());
239    let decoded = String::from_utf8_lossy(&binary);
240
241    // Collapse double slashes by splitting on '/' and filtering empties,
242    // then rejoin. This also strips trailing slash.
243    let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
244    if segments.is_empty() {
245        "/".to_string()
246    } else {
247        format!("/{}", segments.join("/"))
248    }
249}
250
251fn default_inject_header() -> String {
252    "Authorization".to_string()
253}
254
255fn default_credential_format() -> String {
256    "Bearer {}".to_string()
257}
258
259/// Configuration for an external (enterprise) proxy.
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct ExternalProxyConfig {
262    /// Proxy address (e.g., "squid.corp.internal:3128")
263    pub address: String,
264
265    /// Optional authentication for the external proxy.
266    pub auth: Option<ExternalProxyAuth>,
267
268    /// Hosts to bypass the external proxy and route directly.
269    /// Supports exact hostnames and `*.` wildcard suffixes (case-insensitive).
270    /// Empty = all traffic goes through the external proxy.
271    #[serde(default)]
272    pub bypass_hosts: Vec<String>,
273}
274
275/// Authentication for an external proxy.
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct ExternalProxyAuth {
278    /// Keystore account name for proxy credentials.
279    pub keyring_account: String,
280
281    /// Authentication scheme (only "basic" supported).
282    #[serde(default = "default_auth_scheme")]
283    pub scheme: String,
284}
285
286fn default_auth_scheme() -> String {
287    "basic".to_string()
288}
289
290#[cfg(test)]
291#[allow(clippy::unwrap_used)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_default_config() {
297        let config = ProxyConfig::default();
298        assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
299        assert_eq!(config.bind_port, 0);
300        assert!(config.allowed_hosts.is_empty());
301        assert!(config.routes.is_empty());
302        assert!(config.external_proxy.is_none());
303    }
304
305    #[test]
306    fn test_config_serialization() {
307        let config = ProxyConfig {
308            allowed_hosts: vec!["api.openai.com".to_string()],
309            ..Default::default()
310        };
311        let json = serde_json::to_string(&config).unwrap();
312        let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
313        assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
314    }
315
316    #[test]
317    fn test_external_proxy_config_with_bypass_hosts() {
318        let config = ProxyConfig {
319            external_proxy: Some(ExternalProxyConfig {
320                address: "squid.corp:3128".to_string(),
321                auth: None,
322                bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
323            }),
324            ..Default::default()
325        };
326        let json = serde_json::to_string(&config).unwrap();
327        let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
328        let ext = deserialized.external_proxy.unwrap();
329        assert_eq!(ext.address, "squid.corp:3128");
330        assert_eq!(ext.bypass_hosts.len(), 2);
331        assert_eq!(ext.bypass_hosts[0], "internal.corp");
332        assert_eq!(ext.bypass_hosts[1], "*.private.net");
333    }
334
335    #[test]
336    fn test_external_proxy_config_bypass_hosts_default_empty() {
337        let json = r#"{"address": "proxy:3128", "auth": null}"#;
338        let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
339        assert!(ext.bypass_hosts.is_empty());
340    }
341
342    // ========================================================================
343    // EndpointRule + path matching tests
344    // ========================================================================
345
346    #[test]
347    fn test_endpoint_allowed_empty_rules_allows_all() {
348        assert!(endpoint_allowed(&[], "GET", "/anything"));
349        assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
350    }
351
352    /// Helper: check a single rule against method+path via endpoint_allowed.
353    fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
354        endpoint_allowed(std::slice::from_ref(rule), method, path)
355    }
356
357    #[test]
358    fn test_endpoint_rule_exact_path() {
359        let rule = EndpointRule {
360            method: "GET".to_string(),
361            path: "/v1/chat/completions".to_string(),
362        };
363        assert!(check(&rule, "GET", "/v1/chat/completions"));
364        assert!(!check(&rule, "GET", "/v1/chat"));
365        assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
366    }
367
368    #[test]
369    fn test_endpoint_rule_method_case_insensitive() {
370        let rule = EndpointRule {
371            method: "get".to_string(),
372            path: "/api".to_string(),
373        };
374        assert!(check(&rule, "GET", "/api"));
375        assert!(check(&rule, "Get", "/api"));
376    }
377
378    #[test]
379    fn test_endpoint_rule_method_wildcard() {
380        let rule = EndpointRule {
381            method: "*".to_string(),
382            path: "/api/resource".to_string(),
383        };
384        assert!(check(&rule, "GET", "/api/resource"));
385        assert!(check(&rule, "DELETE", "/api/resource"));
386        assert!(check(&rule, "POST", "/api/resource"));
387    }
388
389    #[test]
390    fn test_endpoint_rule_method_mismatch() {
391        let rule = EndpointRule {
392            method: "GET".to_string(),
393            path: "/api/resource".to_string(),
394        };
395        assert!(!check(&rule, "POST", "/api/resource"));
396        assert!(!check(&rule, "DELETE", "/api/resource"));
397    }
398
399    #[test]
400    fn test_endpoint_rule_single_wildcard() {
401        let rule = EndpointRule {
402            method: "GET".to_string(),
403            path: "/api/v4/projects/*/merge_requests".to_string(),
404        };
405        assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
406        assert!(check(
407            &rule,
408            "GET",
409            "/api/v4/projects/my-proj/merge_requests"
410        ));
411        assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
412    }
413
414    #[test]
415    fn test_endpoint_rule_double_wildcard() {
416        let rule = EndpointRule {
417            method: "GET".to_string(),
418            path: "/api/v4/projects/**".to_string(),
419        };
420        assert!(check(&rule, "GET", "/api/v4/projects/123"));
421        assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
422        assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
423        assert!(!check(&rule, "GET", "/api/v4/other"));
424    }
425
426    #[test]
427    fn test_endpoint_rule_double_wildcard_middle() {
428        let rule = EndpointRule {
429            method: "*".to_string(),
430            path: "/api/**/notes".to_string(),
431        };
432        assert!(check(&rule, "GET", "/api/notes"));
433        assert!(check(&rule, "POST", "/api/projects/123/notes"));
434        assert!(check(&rule, "GET", "/api/a/b/c/notes"));
435        assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
436    }
437
438    #[test]
439    fn test_endpoint_rule_strips_query_string() {
440        let rule = EndpointRule {
441            method: "GET".to_string(),
442            path: "/api/data".to_string(),
443        };
444        assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
445    }
446
447    #[test]
448    fn test_endpoint_rule_trailing_slash_normalized() {
449        let rule = EndpointRule {
450            method: "GET".to_string(),
451            path: "/api/data".to_string(),
452        };
453        assert!(check(&rule, "GET", "/api/data/"));
454        assert!(check(&rule, "GET", "/api/data"));
455    }
456
457    #[test]
458    fn test_endpoint_rule_double_slash_normalized() {
459        let rule = EndpointRule {
460            method: "GET".to_string(),
461            path: "/api/data".to_string(),
462        };
463        assert!(check(&rule, "GET", "/api//data"));
464    }
465
466    #[test]
467    fn test_endpoint_rule_root_path() {
468        let rule = EndpointRule {
469            method: "GET".to_string(),
470            path: "/".to_string(),
471        };
472        assert!(check(&rule, "GET", "/"));
473        assert!(!check(&rule, "GET", "/anything"));
474    }
475
476    #[test]
477    fn test_compiled_endpoint_rules_hot_path() {
478        let rules = vec![
479            EndpointRule {
480                method: "GET".to_string(),
481                path: "/repos/*/issues".to_string(),
482            },
483            EndpointRule {
484                method: "POST".to_string(),
485                path: "/repos/*/issues/*/comments".to_string(),
486            },
487        ];
488        let compiled = CompiledEndpointRules::compile(&rules).unwrap();
489        assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
490        assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
491        assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
492        assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
493    }
494
495    #[test]
496    fn test_compiled_endpoint_rules_empty_allows_all() {
497        let compiled = CompiledEndpointRules::compile(&[]).unwrap();
498        assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
499    }
500
501    #[test]
502    fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
503        let rules = vec![EndpointRule {
504            method: "GET".to_string(),
505            path: "/api/[invalid".to_string(),
506        }];
507        assert!(CompiledEndpointRules::compile(&rules).is_err());
508    }
509
510    #[test]
511    fn test_endpoint_allowed_multiple_rules() {
512        let rules = vec![
513            EndpointRule {
514                method: "GET".to_string(),
515                path: "/repos/*/issues".to_string(),
516            },
517            EndpointRule {
518                method: "POST".to_string(),
519                path: "/repos/*/issues/*/comments".to_string(),
520            },
521        ];
522        assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
523        assert!(endpoint_allowed(
524            &rules,
525            "POST",
526            "/repos/myrepo/issues/42/comments"
527        ));
528        assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
529        assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
530    }
531
532    #[test]
533    fn test_endpoint_rule_serde_default() {
534        let json = r#"{
535            "prefix": "/test",
536            "upstream": "https://example.com"
537        }"#;
538        let route: RouteConfig = serde_json::from_str(json).unwrap();
539        assert!(route.endpoint_rules.is_empty());
540    }
541
542    #[test]
543    fn test_endpoint_rule_percent_encoded_path_decoded() {
544        // Security: percent-encoded segments must not bypass rules.
545        // e.g., /api/v4/%70rojects should match a rule for /api/v4/projects/*
546        let rule = EndpointRule {
547            method: "GET".to_string(),
548            path: "/api/v4/projects/*/issues".to_string(),
549        };
550        assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
551        assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
552    }
553
554    #[test]
555    fn test_endpoint_rule_percent_encoded_full_segment() {
556        let rule = EndpointRule {
557            method: "POST".to_string(),
558            path: "/api/data".to_string(),
559        };
560        // %64%61%74%61 = "data"
561        assert!(check(&rule, "POST", "/api/%64%61%74%61"));
562    }
563
564    #[test]
565    fn test_compiled_endpoint_rules_percent_encoded() {
566        let rules = vec![EndpointRule {
567            method: "GET".to_string(),
568            path: "/repos/*/issues".to_string(),
569        }];
570        let compiled = CompiledEndpointRules::compile(&rules).unwrap();
571        // %69ssues = "issues"
572        assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
573        assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
574    }
575
576    #[test]
577    fn test_endpoint_rule_percent_encoded_invalid_utf8() {
578        // Security: invalid UTF-8 percent sequences must not fall back to
579        // the raw path (which could bypass rules). Lossy decoding replaces
580        // invalid bytes with U+FFFD, so the path won't match real segments.
581        let rule = EndpointRule {
582            method: "GET".to_string(),
583            path: "/api/projects".to_string(),
584        };
585        // %FF is not valid UTF-8 — must not match "/api/projects"
586        assert!(!check(&rule, "GET", "/api/%FFprojects"));
587    }
588
589    #[test]
590    fn test_endpoint_rule_serde_roundtrip() {
591        let rule = EndpointRule {
592            method: "GET".to_string(),
593            path: "/api/*/data".to_string(),
594        };
595        let json = serde_json::to_string(&rule).unwrap();
596        let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
597        assert_eq!(deserialized.method, "GET");
598        assert_eq!(deserialized.path, "/api/*/data");
599    }
600}