1use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9use std::path::PathBuf;
10
11#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum InjectMode {
15 #[default]
17 Header,
18 UrlPath,
20 QueryParam,
22 BasicAuth,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ProxyConfig {
29 #[serde(default = "default_bind_addr")]
31 pub bind_addr: IpAddr,
32
33 #[serde(default)]
35 pub bind_port: u16,
36
37 #[serde(default)]
40 pub allowed_hosts: Vec<String>,
41
42 #[serde(default)]
44 pub routes: Vec<RouteConfig>,
45
46 #[serde(default)]
49 pub external_proxy: Option<ExternalProxyConfig>,
50
51 #[serde(default)]
55 pub direct_connect_ports: Vec<u16>,
56
57 #[serde(default)]
59 pub max_connections: usize,
60
61 #[serde(default, skip_serializing_if = "Option::is_none")]
79 pub intercept_ca_dir: Option<PathBuf>,
80
81 #[serde(default, skip)]
89 pub intercept_parent_ca_pems: Option<Vec<u8>>,
90}
91
92impl Default for ProxyConfig {
93 fn default() -> Self {
94 Self {
95 bind_addr: default_bind_addr(),
96 bind_port: 0,
97 allowed_hosts: Vec::new(),
98 routes: Vec::new(),
99 external_proxy: None,
100 direct_connect_ports: Vec::new(),
101 max_connections: 256,
102 intercept_ca_dir: None,
103 intercept_parent_ca_pems: None,
104 }
105 }
106}
107
108fn default_bind_addr() -> IpAddr {
109 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct RouteConfig {
115 pub prefix: String,
118
119 pub upstream: String,
121
122 pub credential_key: Option<String>,
125
126 #[serde(default)]
128 pub inject_mode: InjectMode,
129
130 #[serde(default = "default_inject_header")]
134 pub inject_header: String,
135
136 #[serde(default = "default_credential_format")]
140 pub credential_format: String,
141
142 #[serde(default)]
147 pub path_pattern: Option<String>,
148
149 #[serde(default)]
153 pub path_replacement: Option<String>,
154
155 #[serde(default)]
159 pub query_param_name: Option<String>,
160
161 #[serde(default)]
167 pub proxy: Option<ProxyInjectConfig>,
168
169 #[serde(default)]
176 pub env_var: Option<String>,
177
178 #[serde(default)]
184 pub endpoint_rules: Vec<EndpointRule>,
185
186 #[serde(default)]
193 pub tls_ca: Option<String>,
194
195 #[serde(default)]
202 pub tls_client_cert: Option<String>,
203
204 #[serde(default)]
209 pub tls_client_key: Option<String>,
210
211 #[serde(default)]
216 pub oauth2: Option<OAuth2Config>,
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
225#[serde(deny_unknown_fields)]
226pub struct ProxyInjectConfig {
227 #[serde(default)]
229 pub inject_mode: Option<InjectMode>,
230
231 #[serde(default)]
233 pub inject_header: Option<String>,
234
235 #[serde(default)]
237 pub credential_format: Option<String>,
238
239 #[serde(default)]
241 pub path_pattern: Option<String>,
242
243 #[serde(default)]
245 pub path_replacement: Option<String>,
246
247 #[serde(default)]
249 pub query_param_name: Option<String>,
250}
251
252#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
259pub struct EndpointRule {
260 pub method: String,
262 pub path: String,
265}
266
267pub struct CompiledEndpointRules {
273 rules: Vec<CompiledRule>,
274}
275
276struct CompiledRule {
277 method: String,
278 matcher: globset::GlobMatcher,
279}
280
281impl CompiledEndpointRules {
282 pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
285 let mut compiled = Vec::with_capacity(rules.len());
286 for rule in rules {
287 let glob = Glob::new(&rule.path)
288 .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
289 compiled.push(CompiledRule {
290 method: rule.method.clone(),
291 matcher: glob.compile_matcher(),
292 });
293 }
294 Ok(Self { rules: compiled })
295 }
296
297 #[must_use]
300 pub fn is_allowed(&self, method: &str, path: &str) -> bool {
301 if self.rules.is_empty() {
302 return true;
303 }
304 let normalized = normalize_path(path);
305 self.rules.iter().any(|r| {
306 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
307 && r.matcher.is_match(&normalized)
308 })
309 }
310}
311
312impl std::fmt::Debug for CompiledEndpointRules {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 f.debug_struct("CompiledEndpointRules")
315 .field("count", &self.rules.len())
316 .finish()
317 }
318}
319
320#[cfg(test)]
326fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
327 if rules.is_empty() {
328 return true;
329 }
330 let normalized = normalize_path(path);
331 rules.iter().any(|r| {
332 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
333 && Glob::new(&r.path)
334 .ok()
335 .map(|g| g.compile_matcher())
336 .is_some_and(|m| m.is_match(&normalized))
337 })
338}
339
340fn normalize_path(path: &str) -> String {
346 let path = path.split('?').next().unwrap_or(path);
348
349 let binary = urlencoding::decode_binary(path.as_bytes());
353 let decoded = String::from_utf8_lossy(&binary);
354
355 let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
358 if segments.is_empty() {
359 "/".to_string()
360 } else {
361 format!("/{}", segments.join("/"))
362 }
363}
364
365fn default_inject_header() -> String {
366 "Authorization".to_string()
367}
368
369fn default_credential_format() -> String {
370 "Bearer {}".to_string()
371}
372
373#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct ExternalProxyConfig {
376 pub address: String,
378
379 pub auth: Option<ExternalProxyAuth>,
381
382 #[serde(default)]
386 pub bypass_hosts: Vec<String>,
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct ExternalProxyAuth {
392 pub keyring_account: String,
394
395 #[serde(default = "default_auth_scheme")]
397 pub scheme: String,
398}
399
400fn default_auth_scheme() -> String {
401 "basic".to_string()
402}
403
404#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
414pub struct OAuth2Config {
415 pub token_url: String,
417 pub client_id: String,
419 pub client_secret: String,
421 #[serde(default)]
423 pub scope: String,
424}
425
426#[cfg(test)]
427#[allow(clippy::unwrap_used)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_default_config() {
433 let config = ProxyConfig::default();
434 assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
435 assert_eq!(config.bind_port, 0);
436 assert!(config.allowed_hosts.is_empty());
437 assert!(config.routes.is_empty());
438 assert!(config.external_proxy.is_none());
439 }
440
441 #[test]
442 fn test_config_serialization() {
443 let config = ProxyConfig {
444 allowed_hosts: vec!["api.openai.com".to_string()],
445 ..Default::default()
446 };
447 let json = serde_json::to_string(&config).unwrap();
448 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
449 assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
450 }
451
452 #[test]
453 fn test_external_proxy_config_with_bypass_hosts() {
454 let config = ProxyConfig {
455 external_proxy: Some(ExternalProxyConfig {
456 address: "squid.corp:3128".to_string(),
457 auth: None,
458 bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
459 }),
460 ..Default::default()
461 };
462 let json = serde_json::to_string(&config).unwrap();
463 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
464 let ext = deserialized.external_proxy.unwrap();
465 assert_eq!(ext.address, "squid.corp:3128");
466 assert_eq!(ext.bypass_hosts.len(), 2);
467 assert_eq!(ext.bypass_hosts[0], "internal.corp");
468 assert_eq!(ext.bypass_hosts[1], "*.private.net");
469 }
470
471 #[test]
472 fn test_external_proxy_config_bypass_hosts_default_empty() {
473 let json = r#"{"address": "proxy:3128", "auth": null}"#;
474 let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
475 assert!(ext.bypass_hosts.is_empty());
476 }
477
478 #[test]
483 fn test_endpoint_allowed_empty_rules_allows_all() {
484 assert!(endpoint_allowed(&[], "GET", "/anything"));
485 assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
486 }
487
488 fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
490 endpoint_allowed(std::slice::from_ref(rule), method, path)
491 }
492
493 #[test]
494 fn test_endpoint_rule_exact_path() {
495 let rule = EndpointRule {
496 method: "GET".to_string(),
497 path: "/v1/chat/completions".to_string(),
498 };
499 assert!(check(&rule, "GET", "/v1/chat/completions"));
500 assert!(!check(&rule, "GET", "/v1/chat"));
501 assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
502 }
503
504 #[test]
505 fn test_endpoint_rule_method_case_insensitive() {
506 let rule = EndpointRule {
507 method: "get".to_string(),
508 path: "/api".to_string(),
509 };
510 assert!(check(&rule, "GET", "/api"));
511 assert!(check(&rule, "Get", "/api"));
512 }
513
514 #[test]
515 fn test_endpoint_rule_method_wildcard() {
516 let rule = EndpointRule {
517 method: "*".to_string(),
518 path: "/api/resource".to_string(),
519 };
520 assert!(check(&rule, "GET", "/api/resource"));
521 assert!(check(&rule, "DELETE", "/api/resource"));
522 assert!(check(&rule, "POST", "/api/resource"));
523 }
524
525 #[test]
526 fn test_endpoint_rule_method_mismatch() {
527 let rule = EndpointRule {
528 method: "GET".to_string(),
529 path: "/api/resource".to_string(),
530 };
531 assert!(!check(&rule, "POST", "/api/resource"));
532 assert!(!check(&rule, "DELETE", "/api/resource"));
533 }
534
535 #[test]
536 fn test_endpoint_rule_single_wildcard() {
537 let rule = EndpointRule {
538 method: "GET".to_string(),
539 path: "/api/v4/projects/*/merge_requests".to_string(),
540 };
541 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
542 assert!(check(
543 &rule,
544 "GET",
545 "/api/v4/projects/my-proj/merge_requests"
546 ));
547 assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
548 }
549
550 #[test]
551 fn test_endpoint_rule_double_wildcard() {
552 let rule = EndpointRule {
553 method: "GET".to_string(),
554 path: "/api/v4/projects/**".to_string(),
555 };
556 assert!(check(&rule, "GET", "/api/v4/projects/123"));
557 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
558 assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
559 assert!(!check(&rule, "GET", "/api/v4/other"));
560 }
561
562 #[test]
563 fn test_endpoint_rule_double_wildcard_middle() {
564 let rule = EndpointRule {
565 method: "*".to_string(),
566 path: "/api/**/notes".to_string(),
567 };
568 assert!(check(&rule, "GET", "/api/notes"));
569 assert!(check(&rule, "POST", "/api/projects/123/notes"));
570 assert!(check(&rule, "GET", "/api/a/b/c/notes"));
571 assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
572 }
573
574 #[test]
575 fn test_endpoint_rule_strips_query_string() {
576 let rule = EndpointRule {
577 method: "GET".to_string(),
578 path: "/api/data".to_string(),
579 };
580 assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
581 }
582
583 #[test]
584 fn test_endpoint_rule_trailing_slash_normalized() {
585 let rule = EndpointRule {
586 method: "GET".to_string(),
587 path: "/api/data".to_string(),
588 };
589 assert!(check(&rule, "GET", "/api/data/"));
590 assert!(check(&rule, "GET", "/api/data"));
591 }
592
593 #[test]
594 fn test_endpoint_rule_double_slash_normalized() {
595 let rule = EndpointRule {
596 method: "GET".to_string(),
597 path: "/api/data".to_string(),
598 };
599 assert!(check(&rule, "GET", "/api//data"));
600 }
601
602 #[test]
603 fn test_endpoint_rule_root_path() {
604 let rule = EndpointRule {
605 method: "GET".to_string(),
606 path: "/".to_string(),
607 };
608 assert!(check(&rule, "GET", "/"));
609 assert!(!check(&rule, "GET", "/anything"));
610 }
611
612 #[test]
613 fn test_compiled_endpoint_rules_hot_path() {
614 let rules = vec![
615 EndpointRule {
616 method: "GET".to_string(),
617 path: "/repos/*/issues".to_string(),
618 },
619 EndpointRule {
620 method: "POST".to_string(),
621 path: "/repos/*/issues/*/comments".to_string(),
622 },
623 ];
624 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
625 assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
626 assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
627 assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
628 assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
629 }
630
631 #[test]
632 fn test_compiled_endpoint_rules_empty_allows_all() {
633 let compiled = CompiledEndpointRules::compile(&[]).unwrap();
634 assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
635 }
636
637 #[test]
638 fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
639 let rules = vec![EndpointRule {
640 method: "GET".to_string(),
641 path: "/api/[invalid".to_string(),
642 }];
643 assert!(CompiledEndpointRules::compile(&rules).is_err());
644 }
645
646 #[test]
647 fn test_endpoint_allowed_multiple_rules() {
648 let rules = vec![
649 EndpointRule {
650 method: "GET".to_string(),
651 path: "/repos/*/issues".to_string(),
652 },
653 EndpointRule {
654 method: "POST".to_string(),
655 path: "/repos/*/issues/*/comments".to_string(),
656 },
657 ];
658 assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
659 assert!(endpoint_allowed(
660 &rules,
661 "POST",
662 "/repos/myrepo/issues/42/comments"
663 ));
664 assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
665 assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
666 }
667
668 #[test]
669 fn test_endpoint_rule_serde_default() {
670 let json = r#"{
671 "prefix": "test",
672 "upstream": "https://example.com"
673 }"#;
674 let route: RouteConfig = serde_json::from_str(json).unwrap();
675 assert!(route.endpoint_rules.is_empty());
676 assert!(route.tls_ca.is_none());
677 }
678
679 #[test]
680 fn test_tls_ca_serde_roundtrip() {
681 let json = r#"{
682 "prefix": "k8s",
683 "upstream": "https://kubernetes.local:6443",
684 "tls_ca": "/run/secrets/k8s-ca.crt"
685 }"#;
686 let route: RouteConfig = serde_json::from_str(json).unwrap();
687 assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
688
689 let serialized = serde_json::to_string(&route).unwrap();
690 let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
691 assert_eq!(
692 deserialized.tls_ca.as_deref(),
693 Some("/run/secrets/k8s-ca.crt")
694 );
695 }
696
697 #[test]
698 fn test_endpoint_rule_percent_encoded_path_decoded() {
699 let rule = EndpointRule {
702 method: "GET".to_string(),
703 path: "/api/v4/projects/*/issues".to_string(),
704 };
705 assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
706 assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
707 }
708
709 #[test]
710 fn test_endpoint_rule_percent_encoded_full_segment() {
711 let rule = EndpointRule {
712 method: "POST".to_string(),
713 path: "/api/data".to_string(),
714 };
715 assert!(check(&rule, "POST", "/api/%64%61%74%61"));
717 }
718
719 #[test]
720 fn test_compiled_endpoint_rules_percent_encoded() {
721 let rules = vec![EndpointRule {
722 method: "GET".to_string(),
723 path: "/repos/*/issues".to_string(),
724 }];
725 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
726 assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
728 assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
729 }
730
731 #[test]
732 fn test_endpoint_rule_percent_encoded_invalid_utf8() {
733 let rule = EndpointRule {
737 method: "GET".to_string(),
738 path: "/api/projects".to_string(),
739 };
740 assert!(!check(&rule, "GET", "/api/%FFprojects"));
742 }
743
744 #[test]
745 fn test_endpoint_rule_serde_roundtrip() {
746 let rule = EndpointRule {
747 method: "GET".to_string(),
748 path: "/api/*/data".to_string(),
749 };
750 let json = serde_json::to_string(&rule).unwrap();
751 let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
752 assert_eq!(deserialized.method, "GET");
753 assert_eq!(deserialized.path, "/api/*/data");
754 }
755
756 #[test]
761 fn test_oauth2_config_deserialization() {
762 let json = r#"{
763 "token_url": "https://auth.example.com/oauth/token",
764 "client_id": "my-client",
765 "client_secret": "env://CLIENT_SECRET",
766 "scope": "read write"
767 }"#;
768 let config: OAuth2Config = serde_json::from_str(json).unwrap();
769 assert_eq!(config.token_url, "https://auth.example.com/oauth/token");
770 assert_eq!(config.client_id, "my-client");
771 assert_eq!(config.client_secret, "env://CLIENT_SECRET");
772 assert_eq!(config.scope, "read write");
773 }
774
775 #[test]
776 fn test_oauth2_config_default_scope() {
777 let json = r#"{
778 "token_url": "https://auth.example.com/oauth/token",
779 "client_id": "my-client",
780 "client_secret": "env://SECRET"
781 }"#;
782 let config: OAuth2Config = serde_json::from_str(json).unwrap();
783 assert_eq!(config.scope, "");
784 }
785
786 #[test]
787 fn test_route_config_with_oauth2() {
788 let json = r#"{
789 "prefix": "/my-api",
790 "upstream": "https://api.example.com",
791 "oauth2": {
792 "token_url": "https://auth.example.com/oauth/token",
793 "client_id": "agent-1",
794 "client_secret": "env://CLIENT_SECRET",
795 "scope": "api.read"
796 }
797 }"#;
798 let route: RouteConfig = serde_json::from_str(json).unwrap();
799 assert!(route.oauth2.is_some());
800 assert!(route.credential_key.is_none());
801 let oauth2 = route.oauth2.unwrap();
802 assert_eq!(oauth2.token_url, "https://auth.example.com/oauth/token");
803 }
804
805 #[test]
806 fn test_route_config_without_oauth2() {
807 let json = r#"{
808 "prefix": "/openai",
809 "upstream": "https://api.openai.com",
810 "credential_key": "openai"
811 }"#;
812 let route: RouteConfig = serde_json::from_str(json).unwrap();
813 assert!(route.oauth2.is_none());
814 assert!(route.credential_key.is_some());
815 }
816}