1use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9use std::path::PathBuf;
10use zeroize::Zeroizing;
11
12#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum InjectMode {
16 #[default]
18 Header,
19 UrlPath,
21 QueryParam,
23 BasicAuth,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ProxyConfig {
30 #[serde(default = "default_bind_addr")]
32 pub bind_addr: IpAddr,
33
34 #[serde(default)]
36 pub bind_port: u16,
37
38 #[serde(default)]
42 pub allowed_hosts: Vec<String>,
43
44 #[serde(default)]
47 pub strict_filter: bool,
48
49 #[serde(default)]
51 pub routes: Vec<RouteConfig>,
52
53 #[serde(default)]
56 pub external_proxy: Option<ExternalProxyConfig>,
57
58 #[serde(default)]
62 pub direct_connect_ports: Vec<u16>,
63
64 #[serde(default)]
66 pub max_connections: usize,
67
68 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub intercept_ca_dir: Option<PathBuf>,
87
88 #[serde(default, skip)]
96 pub intercept_parent_ca_pems: Option<Vec<u8>>,
97
98 #[serde(default, skip)]
104 pub preloaded_ca: Option<PreloadedCa>,
105
106 #[serde(default, skip)]
110 pub ca_validity: Option<std::time::Duration>,
111}
112
113#[derive(Clone)]
131pub struct PreloadedCa {
132 pub key_der: Zeroizing<Vec<u8>>,
134 pub cert_pem: String,
136}
137
138impl std::fmt::Debug for PreloadedCa {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 f.debug_struct("PreloadedCa")
141 .field("key_der", &"[REDACTED]")
142 .field("cert_pem_len", &self.cert_pem.len())
143 .finish()
144 }
145}
146
147impl Default for ProxyConfig {
148 fn default() -> Self {
149 Self {
150 bind_addr: default_bind_addr(),
151 bind_port: 0,
152 allowed_hosts: Vec::new(),
153 strict_filter: false,
154 routes: Vec::new(),
155 external_proxy: None,
156 direct_connect_ports: Vec::new(),
157 max_connections: 256,
158 intercept_ca_dir: None,
159 intercept_parent_ca_pems: None,
160 preloaded_ca: None,
161 ca_validity: None,
162 }
163 }
164}
165
166fn default_bind_addr() -> IpAddr {
167 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct RouteConfig {
173 pub prefix: String,
176
177 pub upstream: String,
179
180 pub credential_key: Option<String>,
183
184 #[serde(default)]
186 pub inject_mode: InjectMode,
187
188 #[serde(default = "default_inject_header")]
192 pub inject_header: String,
193
194 #[serde(default)]
200 pub credential_format: Option<String>,
201
202 #[serde(default)]
207 pub path_pattern: Option<String>,
208
209 #[serde(default)]
213 pub path_replacement: Option<String>,
214
215 #[serde(default)]
219 pub query_param_name: Option<String>,
220
221 #[serde(default)]
227 pub proxy: Option<ProxyInjectConfig>,
228
229 #[serde(default)]
236 pub env_var: Option<String>,
237
238 #[serde(default)]
244 pub endpoint_rules: Vec<EndpointRule>,
245
246 #[serde(default)]
253 pub tls_ca: Option<String>,
254
255 #[serde(default)]
262 pub tls_client_cert: Option<String>,
263
264 #[serde(default)]
269 pub tls_client_key: Option<String>,
270
271 #[serde(default)]
276 pub oauth2: Option<OAuth2Config>,
277}
278
279#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
285#[serde(deny_unknown_fields)]
286pub struct ProxyInjectConfig {
287 #[serde(default)]
289 pub inject_mode: Option<InjectMode>,
290
291 #[serde(default)]
293 pub inject_header: Option<String>,
294
295 #[serde(default)]
297 pub credential_format: Option<String>,
298
299 #[serde(default)]
301 pub path_pattern: Option<String>,
302
303 #[serde(default)]
305 pub path_replacement: Option<String>,
306
307 #[serde(default)]
309 pub query_param_name: Option<String>,
310}
311
312#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
319pub struct EndpointRule {
320 pub method: String,
322 pub path: String,
325}
326
327pub struct CompiledEndpointRules {
333 rules: Vec<CompiledRule>,
334}
335
336struct CompiledRule {
337 method: String,
338 matcher: globset::GlobMatcher,
339}
340
341impl CompiledEndpointRules {
342 pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
345 let mut compiled = Vec::with_capacity(rules.len());
346 for rule in rules {
347 let glob = Glob::new(&rule.path)
348 .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
349 compiled.push(CompiledRule {
350 method: rule.method.clone(),
351 matcher: glob.compile_matcher(),
352 });
353 }
354 Ok(Self { rules: compiled })
355 }
356
357 #[must_use]
359 pub fn is_empty(&self) -> bool {
360 self.rules.is_empty()
361 }
362
363 #[must_use]
365 pub fn is_allowed(&self, method: &str, path: &str) -> bool {
366 if self.rules.is_empty() {
367 return true;
368 }
369 let normalized = normalize_path(path);
370 self.rules.iter().any(|r| {
371 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
372 && r.matcher.is_match(&normalized)
373 })
374 }
375}
376
377impl std::fmt::Debug for CompiledEndpointRules {
378 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379 f.debug_struct("CompiledEndpointRules")
380 .field("count", &self.rules.len())
381 .finish()
382 }
383}
384
385#[cfg(test)]
391fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
392 if rules.is_empty() {
393 return true;
394 }
395 let normalized = normalize_path(path);
396 rules.iter().any(|r| {
397 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
398 && Glob::new(&r.path)
399 .ok()
400 .map(|g| g.compile_matcher())
401 .is_some_and(|m| m.is_match(&normalized))
402 })
403}
404
405fn normalize_path(path: &str) -> String {
411 let path = path.split('?').next().unwrap_or(path);
413
414 let binary = urlencoding::decode_binary(path.as_bytes());
418 let decoded = String::from_utf8_lossy(&binary);
419
420 let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
423 if segments.is_empty() {
424 "/".to_string()
425 } else {
426 format!("/{}", segments.join("/"))
427 }
428}
429
430fn default_inject_header() -> String {
431 "Authorization".to_string()
432}
433
434#[must_use]
438pub fn resolved_credential_format(inject_header: &str, credential_format: Option<&str>) -> String {
439 match credential_format {
440 Some(fmt) => fmt.to_string(),
441 None => {
442 if inject_header.eq_ignore_ascii_case("Authorization") {
443 "Bearer {}".to_string()
444 } else {
445 "{}".to_string()
446 }
447 }
448 }
449}
450
451#[derive(Debug, Clone, Serialize, Deserialize)]
453pub struct ExternalProxyConfig {
454 pub address: String,
456
457 pub auth: Option<ExternalProxyAuth>,
459
460 #[serde(default)]
464 pub bypass_hosts: Vec<String>,
465}
466
467#[derive(Debug, Clone, Serialize, Deserialize)]
469pub struct ExternalProxyAuth {
470 pub keyring_account: String,
472
473 #[serde(default = "default_auth_scheme")]
475 pub scheme: String,
476}
477
478fn default_auth_scheme() -> String {
479 "basic".to_string()
480}
481
482#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
492pub struct OAuth2Config {
493 pub token_url: String,
495 pub client_id: String,
497 pub client_secret: String,
499 #[serde(default)]
501 pub scope: String,
502}
503
504#[cfg(test)]
505#[allow(clippy::unwrap_used)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_default_config() {
511 let config = ProxyConfig::default();
512 assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
513 assert_eq!(config.bind_port, 0);
514 assert!(config.allowed_hosts.is_empty());
515 assert!(config.routes.is_empty());
516 assert!(config.external_proxy.is_none());
517 }
518
519 #[test]
520 fn test_config_serialization() {
521 let config = ProxyConfig {
522 allowed_hosts: vec!["api.openai.com".to_string()],
523 ..Default::default()
524 };
525 let json = serde_json::to_string(&config).unwrap();
526 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
527 assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
528 }
529
530 #[test]
531 fn test_external_proxy_config_with_bypass_hosts() {
532 let config = ProxyConfig {
533 external_proxy: Some(ExternalProxyConfig {
534 address: "squid.corp:3128".to_string(),
535 auth: None,
536 bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
537 }),
538 ..Default::default()
539 };
540 let json = serde_json::to_string(&config).unwrap();
541 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
542 let ext = deserialized.external_proxy.unwrap();
543 assert_eq!(ext.address, "squid.corp:3128");
544 assert_eq!(ext.bypass_hosts.len(), 2);
545 assert_eq!(ext.bypass_hosts[0], "internal.corp");
546 assert_eq!(ext.bypass_hosts[1], "*.private.net");
547 }
548
549 #[test]
550 fn test_external_proxy_config_bypass_hosts_default_empty() {
551 let json = r#"{"address": "proxy:3128", "auth": null}"#;
552 let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
553 assert!(ext.bypass_hosts.is_empty());
554 }
555
556 #[test]
561 fn test_endpoint_allowed_empty_rules_allows_all() {
562 assert!(endpoint_allowed(&[], "GET", "/anything"));
563 assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
564 }
565
566 fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
568 endpoint_allowed(std::slice::from_ref(rule), method, path)
569 }
570
571 #[test]
572 fn test_endpoint_rule_exact_path() {
573 let rule = EndpointRule {
574 method: "GET".to_string(),
575 path: "/v1/chat/completions".to_string(),
576 };
577 assert!(check(&rule, "GET", "/v1/chat/completions"));
578 assert!(!check(&rule, "GET", "/v1/chat"));
579 assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
580 }
581
582 #[test]
583 fn test_endpoint_rule_method_case_insensitive() {
584 let rule = EndpointRule {
585 method: "get".to_string(),
586 path: "/api".to_string(),
587 };
588 assert!(check(&rule, "GET", "/api"));
589 assert!(check(&rule, "Get", "/api"));
590 }
591
592 #[test]
593 fn test_endpoint_rule_method_wildcard() {
594 let rule = EndpointRule {
595 method: "*".to_string(),
596 path: "/api/resource".to_string(),
597 };
598 assert!(check(&rule, "GET", "/api/resource"));
599 assert!(check(&rule, "DELETE", "/api/resource"));
600 assert!(check(&rule, "POST", "/api/resource"));
601 }
602
603 #[test]
604 fn test_endpoint_rule_method_mismatch() {
605 let rule = EndpointRule {
606 method: "GET".to_string(),
607 path: "/api/resource".to_string(),
608 };
609 assert!(!check(&rule, "POST", "/api/resource"));
610 assert!(!check(&rule, "DELETE", "/api/resource"));
611 }
612
613 #[test]
614 fn test_endpoint_rule_single_wildcard() {
615 let rule = EndpointRule {
616 method: "GET".to_string(),
617 path: "/api/v4/projects/*/merge_requests".to_string(),
618 };
619 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
620 assert!(check(
621 &rule,
622 "GET",
623 "/api/v4/projects/my-proj/merge_requests"
624 ));
625 assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
626 }
627
628 #[test]
629 fn test_endpoint_rule_double_wildcard() {
630 let rule = EndpointRule {
631 method: "GET".to_string(),
632 path: "/api/v4/projects/**".to_string(),
633 };
634 assert!(check(&rule, "GET", "/api/v4/projects/123"));
635 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
636 assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
637 assert!(!check(&rule, "GET", "/api/v4/other"));
638 }
639
640 #[test]
641 fn test_endpoint_rule_double_wildcard_middle() {
642 let rule = EndpointRule {
643 method: "*".to_string(),
644 path: "/api/**/notes".to_string(),
645 };
646 assert!(check(&rule, "GET", "/api/notes"));
647 assert!(check(&rule, "POST", "/api/projects/123/notes"));
648 assert!(check(&rule, "GET", "/api/a/b/c/notes"));
649 assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
650 }
651
652 #[test]
653 fn test_endpoint_rule_strips_query_string() {
654 let rule = EndpointRule {
655 method: "GET".to_string(),
656 path: "/api/data".to_string(),
657 };
658 assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
659 }
660
661 #[test]
662 fn test_endpoint_rule_trailing_slash_normalized() {
663 let rule = EndpointRule {
664 method: "GET".to_string(),
665 path: "/api/data".to_string(),
666 };
667 assert!(check(&rule, "GET", "/api/data/"));
668 assert!(check(&rule, "GET", "/api/data"));
669 }
670
671 #[test]
672 fn test_endpoint_rule_double_slash_normalized() {
673 let rule = EndpointRule {
674 method: "GET".to_string(),
675 path: "/api/data".to_string(),
676 };
677 assert!(check(&rule, "GET", "/api//data"));
678 }
679
680 #[test]
681 fn test_endpoint_rule_root_path() {
682 let rule = EndpointRule {
683 method: "GET".to_string(),
684 path: "/".to_string(),
685 };
686 assert!(check(&rule, "GET", "/"));
687 assert!(!check(&rule, "GET", "/anything"));
688 }
689
690 #[test]
691 fn test_compiled_endpoint_rules_hot_path() {
692 let rules = vec![
693 EndpointRule {
694 method: "GET".to_string(),
695 path: "/repos/*/issues".to_string(),
696 },
697 EndpointRule {
698 method: "POST".to_string(),
699 path: "/repos/*/issues/*/comments".to_string(),
700 },
701 ];
702 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
703 assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
704 assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
705 assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
706 assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
707 }
708
709 #[test]
710 fn test_compiled_endpoint_rules_empty_allows_all() {
711 let compiled = CompiledEndpointRules::compile(&[]).unwrap();
712 assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
713 }
714
715 #[test]
716 fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
717 let rules = vec![EndpointRule {
718 method: "GET".to_string(),
719 path: "/api/[invalid".to_string(),
720 }];
721 assert!(CompiledEndpointRules::compile(&rules).is_err());
722 }
723
724 #[test]
725 fn test_endpoint_allowed_multiple_rules() {
726 let rules = vec![
727 EndpointRule {
728 method: "GET".to_string(),
729 path: "/repos/*/issues".to_string(),
730 },
731 EndpointRule {
732 method: "POST".to_string(),
733 path: "/repos/*/issues/*/comments".to_string(),
734 },
735 ];
736 assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
737 assert!(endpoint_allowed(
738 &rules,
739 "POST",
740 "/repos/myrepo/issues/42/comments"
741 ));
742 assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
743 assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
744 }
745
746 #[test]
747 fn test_endpoint_rule_serde_default() {
748 let json = r#"{
749 "prefix": "test",
750 "upstream": "https://example.com"
751 }"#;
752 let route: RouteConfig = serde_json::from_str(json).unwrap();
753 assert!(route.endpoint_rules.is_empty());
754 assert!(route.tls_ca.is_none());
755 }
756
757 #[test]
758 fn test_tls_ca_serde_roundtrip() {
759 let json = r#"{
760 "prefix": "k8s",
761 "upstream": "https://kubernetes.local:6443",
762 "tls_ca": "/run/secrets/k8s-ca.crt"
763 }"#;
764 let route: RouteConfig = serde_json::from_str(json).unwrap();
765 assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
766
767 let serialized = serde_json::to_string(&route).unwrap();
768 let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
769 assert_eq!(
770 deserialized.tls_ca.as_deref(),
771 Some("/run/secrets/k8s-ca.crt")
772 );
773 }
774
775 #[test]
776 fn test_endpoint_rule_percent_encoded_path_decoded() {
777 let rule = EndpointRule {
780 method: "GET".to_string(),
781 path: "/api/v4/projects/*/issues".to_string(),
782 };
783 assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
784 assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
785 }
786
787 #[test]
788 fn test_endpoint_rule_percent_encoded_full_segment() {
789 let rule = EndpointRule {
790 method: "POST".to_string(),
791 path: "/api/data".to_string(),
792 };
793 assert!(check(&rule, "POST", "/api/%64%61%74%61"));
795 }
796
797 #[test]
798 fn test_compiled_endpoint_rules_percent_encoded() {
799 let rules = vec![EndpointRule {
800 method: "GET".to_string(),
801 path: "/repos/*/issues".to_string(),
802 }];
803 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
804 assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
806 assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
807 }
808
809 #[test]
810 fn test_endpoint_rule_percent_encoded_invalid_utf8() {
811 let rule = EndpointRule {
815 method: "GET".to_string(),
816 path: "/api/projects".to_string(),
817 };
818 assert!(!check(&rule, "GET", "/api/%FFprojects"));
820 }
821
822 #[test]
823 fn test_endpoint_rule_serde_roundtrip() {
824 let rule = EndpointRule {
825 method: "GET".to_string(),
826 path: "/api/*/data".to_string(),
827 };
828 let json = serde_json::to_string(&rule).unwrap();
829 let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
830 assert_eq!(deserialized.method, "GET");
831 assert_eq!(deserialized.path, "/api/*/data");
832 }
833
834 #[test]
839 fn test_oauth2_config_deserialization() {
840 let json = r#"{
841 "token_url": "https://auth.example.com/oauth/token",
842 "client_id": "my-client",
843 "client_secret": "env://CLIENT_SECRET",
844 "scope": "read write"
845 }"#;
846 let config: OAuth2Config = serde_json::from_str(json).unwrap();
847 assert_eq!(config.token_url, "https://auth.example.com/oauth/token");
848 assert_eq!(config.client_id, "my-client");
849 assert_eq!(config.client_secret, "env://CLIENT_SECRET");
850 assert_eq!(config.scope, "read write");
851 }
852
853 #[test]
854 fn test_oauth2_config_default_scope() {
855 let json = r#"{
856 "token_url": "https://auth.example.com/oauth/token",
857 "client_id": "my-client",
858 "client_secret": "env://SECRET"
859 }"#;
860 let config: OAuth2Config = serde_json::from_str(json).unwrap();
861 assert_eq!(config.scope, "");
862 }
863
864 #[test]
865 fn test_route_config_with_oauth2() {
866 let json = r#"{
867 "prefix": "/my-api",
868 "upstream": "https://api.example.com",
869 "oauth2": {
870 "token_url": "https://auth.example.com/oauth/token",
871 "client_id": "agent-1",
872 "client_secret": "env://CLIENT_SECRET",
873 "scope": "api.read"
874 }
875 }"#;
876 let route: RouteConfig = serde_json::from_str(json).unwrap();
877 assert!(route.oauth2.is_some());
878 assert!(route.credential_key.is_none());
879 let oauth2 = route.oauth2.unwrap();
880 assert_eq!(oauth2.token_url, "https://auth.example.com/oauth/token");
881 }
882
883 #[test]
884 fn test_route_config_without_oauth2() {
885 let json = r#"{
886 "prefix": "/openai",
887 "upstream": "https://api.openai.com",
888 "credential_key": "openai"
889 }"#;
890 let route: RouteConfig = serde_json::from_str(json).unwrap();
891 assert!(route.oauth2.is_none());
892 assert!(route.credential_key.is_some());
893 }
894
895 #[test]
896 fn test_route_config_credential_format_omitted_is_none() {
897 let json = r#"{
898 "prefix": "anthropic",
899 "upstream": "https://api.anthropic.com",
900 "credential_key": "env://ANTHROPIC_API_KEY",
901 "inject_header": "x-api-key"
902 }"#;
903 let route: RouteConfig = serde_json::from_str(json).unwrap();
904 assert!(route.credential_format.is_none());
905 assert_eq!(
906 resolved_credential_format(&route.inject_header, route.credential_format.as_deref()),
907 "{}"
908 );
909 }
910
911 #[test]
912 fn test_route_config_explicit_bearer_on_custom_header_preserved() {
913 let json = r#"{
914 "prefix": "litellm",
915 "upstream": "https://litellm",
916 "credential_key": "env://LITELLM_TOKEN",
917 "inject_header": "x-litellm-api-key",
918 "credential_format": "Bearer {}"
919 }"#;
920 let route: RouteConfig = serde_json::from_str(json).unwrap();
921 assert_eq!(route.credential_format.as_deref(), Some("Bearer {}"));
922 assert_eq!(
923 resolved_credential_format(&route.inject_header, route.credential_format.as_deref()),
924 "Bearer {}"
925 );
926 }
927
928 #[test]
929 fn test_resolved_credential_format_authorization_case_insensitive() {
930 for header in ["authorization", "AUTHORIZATION", "Authorization"] {
931 assert_eq!(
932 resolved_credential_format(header, None),
933 "Bearer {}",
934 "omitted format: Authorization header name is matched case-insensitively for Bearer default"
935 );
936 }
937 }
938}