1use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9use std::path::PathBuf;
10
11#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum InjectMode {
15 #[default]
17 Header,
18 UrlPath,
20 QueryParam,
22 BasicAuth,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ProxyConfig {
29 #[serde(default = "default_bind_addr")]
31 pub bind_addr: IpAddr,
32
33 #[serde(default)]
35 pub bind_port: u16,
36
37 #[serde(default)]
40 pub allowed_hosts: Vec<String>,
41
42 #[serde(default)]
44 pub routes: Vec<RouteConfig>,
45
46 #[serde(default)]
49 pub external_proxy: Option<ExternalProxyConfig>,
50
51 #[serde(default)]
55 pub direct_connect_ports: Vec<u16>,
56
57 #[serde(default)]
59 pub max_connections: usize,
60
61 #[serde(default, skip_serializing_if = "Option::is_none")]
79 pub intercept_ca_dir: Option<PathBuf>,
80
81 #[serde(default, skip)]
89 pub intercept_parent_ca_pems: Option<Vec<u8>>,
90}
91
92impl Default for ProxyConfig {
93 fn default() -> Self {
94 Self {
95 bind_addr: default_bind_addr(),
96 bind_port: 0,
97 allowed_hosts: Vec::new(),
98 routes: Vec::new(),
99 external_proxy: None,
100 direct_connect_ports: Vec::new(),
101 max_connections: 256,
102 intercept_ca_dir: None,
103 intercept_parent_ca_pems: None,
104 }
105 }
106}
107
108fn default_bind_addr() -> IpAddr {
109 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct RouteConfig {
115 pub prefix: String,
118
119 pub upstream: String,
121
122 pub credential_key: Option<String>,
125
126 #[serde(default)]
128 pub inject_mode: InjectMode,
129
130 #[serde(default = "default_inject_header")]
134 pub inject_header: String,
135
136 #[serde(default = "default_credential_format")]
140 pub credential_format: String,
141
142 #[serde(default)]
147 pub path_pattern: Option<String>,
148
149 #[serde(default)]
153 pub path_replacement: Option<String>,
154
155 #[serde(default)]
159 pub query_param_name: Option<String>,
160
161 #[serde(default)]
167 pub proxy: Option<ProxyInjectConfig>,
168
169 #[serde(default)]
176 pub env_var: Option<String>,
177
178 #[serde(default)]
184 pub endpoint_rules: Vec<EndpointRule>,
185
186 #[serde(default)]
193 pub tls_ca: Option<String>,
194
195 #[serde(default)]
202 pub tls_client_cert: Option<String>,
203
204 #[serde(default)]
209 pub tls_client_key: Option<String>,
210
211 #[serde(default)]
216 pub oauth2: Option<OAuth2Config>,
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
225#[serde(deny_unknown_fields)]
226pub struct ProxyInjectConfig {
227 #[serde(default)]
229 pub inject_mode: Option<InjectMode>,
230
231 #[serde(default)]
233 pub inject_header: Option<String>,
234
235 #[serde(default)]
237 pub credential_format: Option<String>,
238
239 #[serde(default)]
241 pub path_pattern: Option<String>,
242
243 #[serde(default)]
245 pub path_replacement: Option<String>,
246
247 #[serde(default)]
249 pub query_param_name: Option<String>,
250}
251
252#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
259pub struct EndpointRule {
260 pub method: String,
262 pub path: String,
265}
266
267pub struct CompiledEndpointRules {
273 rules: Vec<CompiledRule>,
274}
275
276struct CompiledRule {
277 method: String,
278 matcher: globset::GlobMatcher,
279}
280
281impl CompiledEndpointRules {
282 pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
285 let mut compiled = Vec::with_capacity(rules.len());
286 for rule in rules {
287 let glob = Glob::new(&rule.path)
288 .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
289 compiled.push(CompiledRule {
290 method: rule.method.clone(),
291 matcher: glob.compile_matcher(),
292 });
293 }
294 Ok(Self { rules: compiled })
295 }
296
297 #[must_use]
299 pub fn is_empty(&self) -> bool {
300 self.rules.is_empty()
301 }
302
303 #[must_use]
305 pub fn is_allowed(&self, method: &str, path: &str) -> bool {
306 if self.rules.is_empty() {
307 return true;
308 }
309 let normalized = normalize_path(path);
310 self.rules.iter().any(|r| {
311 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
312 && r.matcher.is_match(&normalized)
313 })
314 }
315}
316
317impl std::fmt::Debug for CompiledEndpointRules {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 f.debug_struct("CompiledEndpointRules")
320 .field("count", &self.rules.len())
321 .finish()
322 }
323}
324
325#[cfg(test)]
331fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
332 if rules.is_empty() {
333 return true;
334 }
335 let normalized = normalize_path(path);
336 rules.iter().any(|r| {
337 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
338 && Glob::new(&r.path)
339 .ok()
340 .map(|g| g.compile_matcher())
341 .is_some_and(|m| m.is_match(&normalized))
342 })
343}
344
345fn normalize_path(path: &str) -> String {
351 let path = path.split('?').next().unwrap_or(path);
353
354 let binary = urlencoding::decode_binary(path.as_bytes());
358 let decoded = String::from_utf8_lossy(&binary);
359
360 let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
363 if segments.is_empty() {
364 "/".to_string()
365 } else {
366 format!("/{}", segments.join("/"))
367 }
368}
369
370fn default_inject_header() -> String {
371 "Authorization".to_string()
372}
373
374fn default_credential_format() -> String {
375 "Bearer {}".to_string()
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct ExternalProxyConfig {
381 pub address: String,
383
384 pub auth: Option<ExternalProxyAuth>,
386
387 #[serde(default)]
391 pub bypass_hosts: Vec<String>,
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct ExternalProxyAuth {
397 pub keyring_account: String,
399
400 #[serde(default = "default_auth_scheme")]
402 pub scheme: String,
403}
404
405fn default_auth_scheme() -> String {
406 "basic".to_string()
407}
408
409#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
419pub struct OAuth2Config {
420 pub token_url: String,
422 pub client_id: String,
424 pub client_secret: String,
426 #[serde(default)]
428 pub scope: String,
429}
430
431#[cfg(test)]
432#[allow(clippy::unwrap_used)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_default_config() {
438 let config = ProxyConfig::default();
439 assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
440 assert_eq!(config.bind_port, 0);
441 assert!(config.allowed_hosts.is_empty());
442 assert!(config.routes.is_empty());
443 assert!(config.external_proxy.is_none());
444 }
445
446 #[test]
447 fn test_config_serialization() {
448 let config = ProxyConfig {
449 allowed_hosts: vec!["api.openai.com".to_string()],
450 ..Default::default()
451 };
452 let json = serde_json::to_string(&config).unwrap();
453 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
454 assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
455 }
456
457 #[test]
458 fn test_external_proxy_config_with_bypass_hosts() {
459 let config = ProxyConfig {
460 external_proxy: Some(ExternalProxyConfig {
461 address: "squid.corp:3128".to_string(),
462 auth: None,
463 bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
464 }),
465 ..Default::default()
466 };
467 let json = serde_json::to_string(&config).unwrap();
468 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
469 let ext = deserialized.external_proxy.unwrap();
470 assert_eq!(ext.address, "squid.corp:3128");
471 assert_eq!(ext.bypass_hosts.len(), 2);
472 assert_eq!(ext.bypass_hosts[0], "internal.corp");
473 assert_eq!(ext.bypass_hosts[1], "*.private.net");
474 }
475
476 #[test]
477 fn test_external_proxy_config_bypass_hosts_default_empty() {
478 let json = r#"{"address": "proxy:3128", "auth": null}"#;
479 let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
480 assert!(ext.bypass_hosts.is_empty());
481 }
482
483 #[test]
488 fn test_endpoint_allowed_empty_rules_allows_all() {
489 assert!(endpoint_allowed(&[], "GET", "/anything"));
490 assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
491 }
492
493 fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
495 endpoint_allowed(std::slice::from_ref(rule), method, path)
496 }
497
498 #[test]
499 fn test_endpoint_rule_exact_path() {
500 let rule = EndpointRule {
501 method: "GET".to_string(),
502 path: "/v1/chat/completions".to_string(),
503 };
504 assert!(check(&rule, "GET", "/v1/chat/completions"));
505 assert!(!check(&rule, "GET", "/v1/chat"));
506 assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
507 }
508
509 #[test]
510 fn test_endpoint_rule_method_case_insensitive() {
511 let rule = EndpointRule {
512 method: "get".to_string(),
513 path: "/api".to_string(),
514 };
515 assert!(check(&rule, "GET", "/api"));
516 assert!(check(&rule, "Get", "/api"));
517 }
518
519 #[test]
520 fn test_endpoint_rule_method_wildcard() {
521 let rule = EndpointRule {
522 method: "*".to_string(),
523 path: "/api/resource".to_string(),
524 };
525 assert!(check(&rule, "GET", "/api/resource"));
526 assert!(check(&rule, "DELETE", "/api/resource"));
527 assert!(check(&rule, "POST", "/api/resource"));
528 }
529
530 #[test]
531 fn test_endpoint_rule_method_mismatch() {
532 let rule = EndpointRule {
533 method: "GET".to_string(),
534 path: "/api/resource".to_string(),
535 };
536 assert!(!check(&rule, "POST", "/api/resource"));
537 assert!(!check(&rule, "DELETE", "/api/resource"));
538 }
539
540 #[test]
541 fn test_endpoint_rule_single_wildcard() {
542 let rule = EndpointRule {
543 method: "GET".to_string(),
544 path: "/api/v4/projects/*/merge_requests".to_string(),
545 };
546 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
547 assert!(check(
548 &rule,
549 "GET",
550 "/api/v4/projects/my-proj/merge_requests"
551 ));
552 assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
553 }
554
555 #[test]
556 fn test_endpoint_rule_double_wildcard() {
557 let rule = EndpointRule {
558 method: "GET".to_string(),
559 path: "/api/v4/projects/**".to_string(),
560 };
561 assert!(check(&rule, "GET", "/api/v4/projects/123"));
562 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
563 assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
564 assert!(!check(&rule, "GET", "/api/v4/other"));
565 }
566
567 #[test]
568 fn test_endpoint_rule_double_wildcard_middle() {
569 let rule = EndpointRule {
570 method: "*".to_string(),
571 path: "/api/**/notes".to_string(),
572 };
573 assert!(check(&rule, "GET", "/api/notes"));
574 assert!(check(&rule, "POST", "/api/projects/123/notes"));
575 assert!(check(&rule, "GET", "/api/a/b/c/notes"));
576 assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
577 }
578
579 #[test]
580 fn test_endpoint_rule_strips_query_string() {
581 let rule = EndpointRule {
582 method: "GET".to_string(),
583 path: "/api/data".to_string(),
584 };
585 assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
586 }
587
588 #[test]
589 fn test_endpoint_rule_trailing_slash_normalized() {
590 let rule = EndpointRule {
591 method: "GET".to_string(),
592 path: "/api/data".to_string(),
593 };
594 assert!(check(&rule, "GET", "/api/data/"));
595 assert!(check(&rule, "GET", "/api/data"));
596 }
597
598 #[test]
599 fn test_endpoint_rule_double_slash_normalized() {
600 let rule = EndpointRule {
601 method: "GET".to_string(),
602 path: "/api/data".to_string(),
603 };
604 assert!(check(&rule, "GET", "/api//data"));
605 }
606
607 #[test]
608 fn test_endpoint_rule_root_path() {
609 let rule = EndpointRule {
610 method: "GET".to_string(),
611 path: "/".to_string(),
612 };
613 assert!(check(&rule, "GET", "/"));
614 assert!(!check(&rule, "GET", "/anything"));
615 }
616
617 #[test]
618 fn test_compiled_endpoint_rules_hot_path() {
619 let rules = vec![
620 EndpointRule {
621 method: "GET".to_string(),
622 path: "/repos/*/issues".to_string(),
623 },
624 EndpointRule {
625 method: "POST".to_string(),
626 path: "/repos/*/issues/*/comments".to_string(),
627 },
628 ];
629 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
630 assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
631 assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
632 assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
633 assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
634 }
635
636 #[test]
637 fn test_compiled_endpoint_rules_empty_allows_all() {
638 let compiled = CompiledEndpointRules::compile(&[]).unwrap();
639 assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
640 }
641
642 #[test]
643 fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
644 let rules = vec![EndpointRule {
645 method: "GET".to_string(),
646 path: "/api/[invalid".to_string(),
647 }];
648 assert!(CompiledEndpointRules::compile(&rules).is_err());
649 }
650
651 #[test]
652 fn test_endpoint_allowed_multiple_rules() {
653 let rules = vec![
654 EndpointRule {
655 method: "GET".to_string(),
656 path: "/repos/*/issues".to_string(),
657 },
658 EndpointRule {
659 method: "POST".to_string(),
660 path: "/repos/*/issues/*/comments".to_string(),
661 },
662 ];
663 assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
664 assert!(endpoint_allowed(
665 &rules,
666 "POST",
667 "/repos/myrepo/issues/42/comments"
668 ));
669 assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
670 assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
671 }
672
673 #[test]
674 fn test_endpoint_rule_serde_default() {
675 let json = r#"{
676 "prefix": "test",
677 "upstream": "https://example.com"
678 }"#;
679 let route: RouteConfig = serde_json::from_str(json).unwrap();
680 assert!(route.endpoint_rules.is_empty());
681 assert!(route.tls_ca.is_none());
682 }
683
684 #[test]
685 fn test_tls_ca_serde_roundtrip() {
686 let json = r#"{
687 "prefix": "k8s",
688 "upstream": "https://kubernetes.local:6443",
689 "tls_ca": "/run/secrets/k8s-ca.crt"
690 }"#;
691 let route: RouteConfig = serde_json::from_str(json).unwrap();
692 assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
693
694 let serialized = serde_json::to_string(&route).unwrap();
695 let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
696 assert_eq!(
697 deserialized.tls_ca.as_deref(),
698 Some("/run/secrets/k8s-ca.crt")
699 );
700 }
701
702 #[test]
703 fn test_endpoint_rule_percent_encoded_path_decoded() {
704 let rule = EndpointRule {
707 method: "GET".to_string(),
708 path: "/api/v4/projects/*/issues".to_string(),
709 };
710 assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
711 assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
712 }
713
714 #[test]
715 fn test_endpoint_rule_percent_encoded_full_segment() {
716 let rule = EndpointRule {
717 method: "POST".to_string(),
718 path: "/api/data".to_string(),
719 };
720 assert!(check(&rule, "POST", "/api/%64%61%74%61"));
722 }
723
724 #[test]
725 fn test_compiled_endpoint_rules_percent_encoded() {
726 let rules = vec![EndpointRule {
727 method: "GET".to_string(),
728 path: "/repos/*/issues".to_string(),
729 }];
730 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
731 assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
733 assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
734 }
735
736 #[test]
737 fn test_endpoint_rule_percent_encoded_invalid_utf8() {
738 let rule = EndpointRule {
742 method: "GET".to_string(),
743 path: "/api/projects".to_string(),
744 };
745 assert!(!check(&rule, "GET", "/api/%FFprojects"));
747 }
748
749 #[test]
750 fn test_endpoint_rule_serde_roundtrip() {
751 let rule = EndpointRule {
752 method: "GET".to_string(),
753 path: "/api/*/data".to_string(),
754 };
755 let json = serde_json::to_string(&rule).unwrap();
756 let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
757 assert_eq!(deserialized.method, "GET");
758 assert_eq!(deserialized.path, "/api/*/data");
759 }
760
761 #[test]
766 fn test_oauth2_config_deserialization() {
767 let json = r#"{
768 "token_url": "https://auth.example.com/oauth/token",
769 "client_id": "my-client",
770 "client_secret": "env://CLIENT_SECRET",
771 "scope": "read write"
772 }"#;
773 let config: OAuth2Config = serde_json::from_str(json).unwrap();
774 assert_eq!(config.token_url, "https://auth.example.com/oauth/token");
775 assert_eq!(config.client_id, "my-client");
776 assert_eq!(config.client_secret, "env://CLIENT_SECRET");
777 assert_eq!(config.scope, "read write");
778 }
779
780 #[test]
781 fn test_oauth2_config_default_scope() {
782 let json = r#"{
783 "token_url": "https://auth.example.com/oauth/token",
784 "client_id": "my-client",
785 "client_secret": "env://SECRET"
786 }"#;
787 let config: OAuth2Config = serde_json::from_str(json).unwrap();
788 assert_eq!(config.scope, "");
789 }
790
791 #[test]
792 fn test_route_config_with_oauth2() {
793 let json = r#"{
794 "prefix": "/my-api",
795 "upstream": "https://api.example.com",
796 "oauth2": {
797 "token_url": "https://auth.example.com/oauth/token",
798 "client_id": "agent-1",
799 "client_secret": "env://CLIENT_SECRET",
800 "scope": "api.read"
801 }
802 }"#;
803 let route: RouteConfig = serde_json::from_str(json).unwrap();
804 assert!(route.oauth2.is_some());
805 assert!(route.credential_key.is_none());
806 let oauth2 = route.oauth2.unwrap();
807 assert_eq!(oauth2.token_url, "https://auth.example.com/oauth/token");
808 }
809
810 #[test]
811 fn test_route_config_without_oauth2() {
812 let json = r#"{
813 "prefix": "/openai",
814 "upstream": "https://api.openai.com",
815 "credential_key": "openai"
816 }"#;
817 let route: RouteConfig = serde_json::from_str(json).unwrap();
818 assert!(route.oauth2.is_none());
819 assert!(route.credential_key.is_some());
820 }
821}