1use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9use std::path::PathBuf;
10use zeroize::Zeroizing;
11
12#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum InjectMode {
16 #[default]
18 Header,
19 UrlPath,
21 QueryParam,
23 BasicAuth,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ProxyConfig {
30 #[serde(default = "default_bind_addr")]
32 pub bind_addr: IpAddr,
33
34 #[serde(default)]
36 pub bind_port: u16,
37
38 #[serde(default)]
41 pub allowed_hosts: Vec<String>,
42
43 #[serde(default)]
45 pub routes: Vec<RouteConfig>,
46
47 #[serde(default)]
50 pub external_proxy: Option<ExternalProxyConfig>,
51
52 #[serde(default)]
56 pub direct_connect_ports: Vec<u16>,
57
58 #[serde(default)]
60 pub max_connections: usize,
61
62 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub intercept_ca_dir: Option<PathBuf>,
81
82 #[serde(default, skip)]
90 pub intercept_parent_ca_pems: Option<Vec<u8>>,
91
92 #[serde(default, skip)]
98 pub preloaded_ca: Option<PreloadedCa>,
99
100 #[serde(default, skip)]
104 pub ca_validity: Option<std::time::Duration>,
105}
106
107#[derive(Clone)]
125pub struct PreloadedCa {
126 pub key_der: Zeroizing<Vec<u8>>,
128 pub cert_pem: String,
130}
131
132impl std::fmt::Debug for PreloadedCa {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 f.debug_struct("PreloadedCa")
135 .field("key_der", &"[REDACTED]")
136 .field("cert_pem_len", &self.cert_pem.len())
137 .finish()
138 }
139}
140
141impl Default for ProxyConfig {
142 fn default() -> Self {
143 Self {
144 bind_addr: default_bind_addr(),
145 bind_port: 0,
146 allowed_hosts: Vec::new(),
147 routes: Vec::new(),
148 external_proxy: None,
149 direct_connect_ports: Vec::new(),
150 max_connections: 256,
151 intercept_ca_dir: None,
152 intercept_parent_ca_pems: None,
153 preloaded_ca: None,
154 ca_validity: None,
155 }
156 }
157}
158
159fn default_bind_addr() -> IpAddr {
160 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct RouteConfig {
166 pub prefix: String,
169
170 pub upstream: String,
172
173 pub credential_key: Option<String>,
176
177 #[serde(default)]
179 pub inject_mode: InjectMode,
180
181 #[serde(default = "default_inject_header")]
185 pub inject_header: String,
186
187 #[serde(default)]
193 pub credential_format: Option<String>,
194
195 #[serde(default)]
200 pub path_pattern: Option<String>,
201
202 #[serde(default)]
206 pub path_replacement: Option<String>,
207
208 #[serde(default)]
212 pub query_param_name: Option<String>,
213
214 #[serde(default)]
220 pub proxy: Option<ProxyInjectConfig>,
221
222 #[serde(default)]
229 pub env_var: Option<String>,
230
231 #[serde(default)]
237 pub endpoint_rules: Vec<EndpointRule>,
238
239 #[serde(default)]
246 pub tls_ca: Option<String>,
247
248 #[serde(default)]
255 pub tls_client_cert: Option<String>,
256
257 #[serde(default)]
262 pub tls_client_key: Option<String>,
263
264 #[serde(default)]
269 pub oauth2: Option<OAuth2Config>,
270}
271
272#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
278#[serde(deny_unknown_fields)]
279pub struct ProxyInjectConfig {
280 #[serde(default)]
282 pub inject_mode: Option<InjectMode>,
283
284 #[serde(default)]
286 pub inject_header: Option<String>,
287
288 #[serde(default)]
290 pub credential_format: Option<String>,
291
292 #[serde(default)]
294 pub path_pattern: Option<String>,
295
296 #[serde(default)]
298 pub path_replacement: Option<String>,
299
300 #[serde(default)]
302 pub query_param_name: Option<String>,
303}
304
305#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
312pub struct EndpointRule {
313 pub method: String,
315 pub path: String,
318}
319
320pub struct CompiledEndpointRules {
326 rules: Vec<CompiledRule>,
327}
328
329struct CompiledRule {
330 method: String,
331 matcher: globset::GlobMatcher,
332}
333
334impl CompiledEndpointRules {
335 pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
338 let mut compiled = Vec::with_capacity(rules.len());
339 for rule in rules {
340 let glob = Glob::new(&rule.path)
341 .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
342 compiled.push(CompiledRule {
343 method: rule.method.clone(),
344 matcher: glob.compile_matcher(),
345 });
346 }
347 Ok(Self { rules: compiled })
348 }
349
350 #[must_use]
352 pub fn is_empty(&self) -> bool {
353 self.rules.is_empty()
354 }
355
356 #[must_use]
358 pub fn is_allowed(&self, method: &str, path: &str) -> bool {
359 if self.rules.is_empty() {
360 return true;
361 }
362 let normalized = normalize_path(path);
363 self.rules.iter().any(|r| {
364 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
365 && r.matcher.is_match(&normalized)
366 })
367 }
368}
369
370impl std::fmt::Debug for CompiledEndpointRules {
371 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372 f.debug_struct("CompiledEndpointRules")
373 .field("count", &self.rules.len())
374 .finish()
375 }
376}
377
378#[cfg(test)]
384fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
385 if rules.is_empty() {
386 return true;
387 }
388 let normalized = normalize_path(path);
389 rules.iter().any(|r| {
390 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
391 && Glob::new(&r.path)
392 .ok()
393 .map(|g| g.compile_matcher())
394 .is_some_and(|m| m.is_match(&normalized))
395 })
396}
397
398fn normalize_path(path: &str) -> String {
404 let path = path.split('?').next().unwrap_or(path);
406
407 let binary = urlencoding::decode_binary(path.as_bytes());
411 let decoded = String::from_utf8_lossy(&binary);
412
413 let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
416 if segments.is_empty() {
417 "/".to_string()
418 } else {
419 format!("/{}", segments.join("/"))
420 }
421}
422
423fn default_inject_header() -> String {
424 "Authorization".to_string()
425}
426
427#[must_use]
431pub fn resolved_credential_format(inject_header: &str, credential_format: Option<&str>) -> String {
432 match credential_format {
433 Some(fmt) => fmt.to_string(),
434 None => {
435 if inject_header.eq_ignore_ascii_case("Authorization") {
436 "Bearer {}".to_string()
437 } else {
438 "{}".to_string()
439 }
440 }
441 }
442}
443
444#[derive(Debug, Clone, Serialize, Deserialize)]
446pub struct ExternalProxyConfig {
447 pub address: String,
449
450 pub auth: Option<ExternalProxyAuth>,
452
453 #[serde(default)]
457 pub bypass_hosts: Vec<String>,
458}
459
460#[derive(Debug, Clone, Serialize, Deserialize)]
462pub struct ExternalProxyAuth {
463 pub keyring_account: String,
465
466 #[serde(default = "default_auth_scheme")]
468 pub scheme: String,
469}
470
471fn default_auth_scheme() -> String {
472 "basic".to_string()
473}
474
475#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
485pub struct OAuth2Config {
486 pub token_url: String,
488 pub client_id: String,
490 pub client_secret: String,
492 #[serde(default)]
494 pub scope: String,
495}
496
497#[cfg(test)]
498#[allow(clippy::unwrap_used)]
499mod tests {
500 use super::*;
501
502 #[test]
503 fn test_default_config() {
504 let config = ProxyConfig::default();
505 assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
506 assert_eq!(config.bind_port, 0);
507 assert!(config.allowed_hosts.is_empty());
508 assert!(config.routes.is_empty());
509 assert!(config.external_proxy.is_none());
510 }
511
512 #[test]
513 fn test_config_serialization() {
514 let config = ProxyConfig {
515 allowed_hosts: vec!["api.openai.com".to_string()],
516 ..Default::default()
517 };
518 let json = serde_json::to_string(&config).unwrap();
519 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
520 assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
521 }
522
523 #[test]
524 fn test_external_proxy_config_with_bypass_hosts() {
525 let config = ProxyConfig {
526 external_proxy: Some(ExternalProxyConfig {
527 address: "squid.corp:3128".to_string(),
528 auth: None,
529 bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
530 }),
531 ..Default::default()
532 };
533 let json = serde_json::to_string(&config).unwrap();
534 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
535 let ext = deserialized.external_proxy.unwrap();
536 assert_eq!(ext.address, "squid.corp:3128");
537 assert_eq!(ext.bypass_hosts.len(), 2);
538 assert_eq!(ext.bypass_hosts[0], "internal.corp");
539 assert_eq!(ext.bypass_hosts[1], "*.private.net");
540 }
541
542 #[test]
543 fn test_external_proxy_config_bypass_hosts_default_empty() {
544 let json = r#"{"address": "proxy:3128", "auth": null}"#;
545 let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
546 assert!(ext.bypass_hosts.is_empty());
547 }
548
549 #[test]
554 fn test_endpoint_allowed_empty_rules_allows_all() {
555 assert!(endpoint_allowed(&[], "GET", "/anything"));
556 assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
557 }
558
559 fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
561 endpoint_allowed(std::slice::from_ref(rule), method, path)
562 }
563
564 #[test]
565 fn test_endpoint_rule_exact_path() {
566 let rule = EndpointRule {
567 method: "GET".to_string(),
568 path: "/v1/chat/completions".to_string(),
569 };
570 assert!(check(&rule, "GET", "/v1/chat/completions"));
571 assert!(!check(&rule, "GET", "/v1/chat"));
572 assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
573 }
574
575 #[test]
576 fn test_endpoint_rule_method_case_insensitive() {
577 let rule = EndpointRule {
578 method: "get".to_string(),
579 path: "/api".to_string(),
580 };
581 assert!(check(&rule, "GET", "/api"));
582 assert!(check(&rule, "Get", "/api"));
583 }
584
585 #[test]
586 fn test_endpoint_rule_method_wildcard() {
587 let rule = EndpointRule {
588 method: "*".to_string(),
589 path: "/api/resource".to_string(),
590 };
591 assert!(check(&rule, "GET", "/api/resource"));
592 assert!(check(&rule, "DELETE", "/api/resource"));
593 assert!(check(&rule, "POST", "/api/resource"));
594 }
595
596 #[test]
597 fn test_endpoint_rule_method_mismatch() {
598 let rule = EndpointRule {
599 method: "GET".to_string(),
600 path: "/api/resource".to_string(),
601 };
602 assert!(!check(&rule, "POST", "/api/resource"));
603 assert!(!check(&rule, "DELETE", "/api/resource"));
604 }
605
606 #[test]
607 fn test_endpoint_rule_single_wildcard() {
608 let rule = EndpointRule {
609 method: "GET".to_string(),
610 path: "/api/v4/projects/*/merge_requests".to_string(),
611 };
612 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
613 assert!(check(
614 &rule,
615 "GET",
616 "/api/v4/projects/my-proj/merge_requests"
617 ));
618 assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
619 }
620
621 #[test]
622 fn test_endpoint_rule_double_wildcard() {
623 let rule = EndpointRule {
624 method: "GET".to_string(),
625 path: "/api/v4/projects/**".to_string(),
626 };
627 assert!(check(&rule, "GET", "/api/v4/projects/123"));
628 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
629 assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
630 assert!(!check(&rule, "GET", "/api/v4/other"));
631 }
632
633 #[test]
634 fn test_endpoint_rule_double_wildcard_middle() {
635 let rule = EndpointRule {
636 method: "*".to_string(),
637 path: "/api/**/notes".to_string(),
638 };
639 assert!(check(&rule, "GET", "/api/notes"));
640 assert!(check(&rule, "POST", "/api/projects/123/notes"));
641 assert!(check(&rule, "GET", "/api/a/b/c/notes"));
642 assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
643 }
644
645 #[test]
646 fn test_endpoint_rule_strips_query_string() {
647 let rule = EndpointRule {
648 method: "GET".to_string(),
649 path: "/api/data".to_string(),
650 };
651 assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
652 }
653
654 #[test]
655 fn test_endpoint_rule_trailing_slash_normalized() {
656 let rule = EndpointRule {
657 method: "GET".to_string(),
658 path: "/api/data".to_string(),
659 };
660 assert!(check(&rule, "GET", "/api/data/"));
661 assert!(check(&rule, "GET", "/api/data"));
662 }
663
664 #[test]
665 fn test_endpoint_rule_double_slash_normalized() {
666 let rule = EndpointRule {
667 method: "GET".to_string(),
668 path: "/api/data".to_string(),
669 };
670 assert!(check(&rule, "GET", "/api//data"));
671 }
672
673 #[test]
674 fn test_endpoint_rule_root_path() {
675 let rule = EndpointRule {
676 method: "GET".to_string(),
677 path: "/".to_string(),
678 };
679 assert!(check(&rule, "GET", "/"));
680 assert!(!check(&rule, "GET", "/anything"));
681 }
682
683 #[test]
684 fn test_compiled_endpoint_rules_hot_path() {
685 let rules = vec![
686 EndpointRule {
687 method: "GET".to_string(),
688 path: "/repos/*/issues".to_string(),
689 },
690 EndpointRule {
691 method: "POST".to_string(),
692 path: "/repos/*/issues/*/comments".to_string(),
693 },
694 ];
695 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
696 assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
697 assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
698 assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
699 assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
700 }
701
702 #[test]
703 fn test_compiled_endpoint_rules_empty_allows_all() {
704 let compiled = CompiledEndpointRules::compile(&[]).unwrap();
705 assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
706 }
707
708 #[test]
709 fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
710 let rules = vec![EndpointRule {
711 method: "GET".to_string(),
712 path: "/api/[invalid".to_string(),
713 }];
714 assert!(CompiledEndpointRules::compile(&rules).is_err());
715 }
716
717 #[test]
718 fn test_endpoint_allowed_multiple_rules() {
719 let rules = vec![
720 EndpointRule {
721 method: "GET".to_string(),
722 path: "/repos/*/issues".to_string(),
723 },
724 EndpointRule {
725 method: "POST".to_string(),
726 path: "/repos/*/issues/*/comments".to_string(),
727 },
728 ];
729 assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
730 assert!(endpoint_allowed(
731 &rules,
732 "POST",
733 "/repos/myrepo/issues/42/comments"
734 ));
735 assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
736 assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
737 }
738
739 #[test]
740 fn test_endpoint_rule_serde_default() {
741 let json = r#"{
742 "prefix": "test",
743 "upstream": "https://example.com"
744 }"#;
745 let route: RouteConfig = serde_json::from_str(json).unwrap();
746 assert!(route.endpoint_rules.is_empty());
747 assert!(route.tls_ca.is_none());
748 }
749
750 #[test]
751 fn test_tls_ca_serde_roundtrip() {
752 let json = r#"{
753 "prefix": "k8s",
754 "upstream": "https://kubernetes.local:6443",
755 "tls_ca": "/run/secrets/k8s-ca.crt"
756 }"#;
757 let route: RouteConfig = serde_json::from_str(json).unwrap();
758 assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
759
760 let serialized = serde_json::to_string(&route).unwrap();
761 let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
762 assert_eq!(
763 deserialized.tls_ca.as_deref(),
764 Some("/run/secrets/k8s-ca.crt")
765 );
766 }
767
768 #[test]
769 fn test_endpoint_rule_percent_encoded_path_decoded() {
770 let rule = EndpointRule {
773 method: "GET".to_string(),
774 path: "/api/v4/projects/*/issues".to_string(),
775 };
776 assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
777 assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
778 }
779
780 #[test]
781 fn test_endpoint_rule_percent_encoded_full_segment() {
782 let rule = EndpointRule {
783 method: "POST".to_string(),
784 path: "/api/data".to_string(),
785 };
786 assert!(check(&rule, "POST", "/api/%64%61%74%61"));
788 }
789
790 #[test]
791 fn test_compiled_endpoint_rules_percent_encoded() {
792 let rules = vec![EndpointRule {
793 method: "GET".to_string(),
794 path: "/repos/*/issues".to_string(),
795 }];
796 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
797 assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
799 assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
800 }
801
802 #[test]
803 fn test_endpoint_rule_percent_encoded_invalid_utf8() {
804 let rule = EndpointRule {
808 method: "GET".to_string(),
809 path: "/api/projects".to_string(),
810 };
811 assert!(!check(&rule, "GET", "/api/%FFprojects"));
813 }
814
815 #[test]
816 fn test_endpoint_rule_serde_roundtrip() {
817 let rule = EndpointRule {
818 method: "GET".to_string(),
819 path: "/api/*/data".to_string(),
820 };
821 let json = serde_json::to_string(&rule).unwrap();
822 let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
823 assert_eq!(deserialized.method, "GET");
824 assert_eq!(deserialized.path, "/api/*/data");
825 }
826
827 #[test]
832 fn test_oauth2_config_deserialization() {
833 let json = r#"{
834 "token_url": "https://auth.example.com/oauth/token",
835 "client_id": "my-client",
836 "client_secret": "env://CLIENT_SECRET",
837 "scope": "read write"
838 }"#;
839 let config: OAuth2Config = serde_json::from_str(json).unwrap();
840 assert_eq!(config.token_url, "https://auth.example.com/oauth/token");
841 assert_eq!(config.client_id, "my-client");
842 assert_eq!(config.client_secret, "env://CLIENT_SECRET");
843 assert_eq!(config.scope, "read write");
844 }
845
846 #[test]
847 fn test_oauth2_config_default_scope() {
848 let json = r#"{
849 "token_url": "https://auth.example.com/oauth/token",
850 "client_id": "my-client",
851 "client_secret": "env://SECRET"
852 }"#;
853 let config: OAuth2Config = serde_json::from_str(json).unwrap();
854 assert_eq!(config.scope, "");
855 }
856
857 #[test]
858 fn test_route_config_with_oauth2() {
859 let json = r#"{
860 "prefix": "/my-api",
861 "upstream": "https://api.example.com",
862 "oauth2": {
863 "token_url": "https://auth.example.com/oauth/token",
864 "client_id": "agent-1",
865 "client_secret": "env://CLIENT_SECRET",
866 "scope": "api.read"
867 }
868 }"#;
869 let route: RouteConfig = serde_json::from_str(json).unwrap();
870 assert!(route.oauth2.is_some());
871 assert!(route.credential_key.is_none());
872 let oauth2 = route.oauth2.unwrap();
873 assert_eq!(oauth2.token_url, "https://auth.example.com/oauth/token");
874 }
875
876 #[test]
877 fn test_route_config_without_oauth2() {
878 let json = r#"{
879 "prefix": "/openai",
880 "upstream": "https://api.openai.com",
881 "credential_key": "openai"
882 }"#;
883 let route: RouteConfig = serde_json::from_str(json).unwrap();
884 assert!(route.oauth2.is_none());
885 assert!(route.credential_key.is_some());
886 }
887
888 #[test]
889 fn test_route_config_credential_format_omitted_is_none() {
890 let json = r#"{
891 "prefix": "anthropic",
892 "upstream": "https://api.anthropic.com",
893 "credential_key": "env://ANTHROPIC_API_KEY",
894 "inject_header": "x-api-key"
895 }"#;
896 let route: RouteConfig = serde_json::from_str(json).unwrap();
897 assert!(route.credential_format.is_none());
898 assert_eq!(
899 resolved_credential_format(&route.inject_header, route.credential_format.as_deref()),
900 "{}"
901 );
902 }
903
904 #[test]
905 fn test_route_config_explicit_bearer_on_custom_header_preserved() {
906 let json = r#"{
907 "prefix": "litellm",
908 "upstream": "https://litellm",
909 "credential_key": "env://LITELLM_TOKEN",
910 "inject_header": "x-litellm-api-key",
911 "credential_format": "Bearer {}"
912 }"#;
913 let route: RouteConfig = serde_json::from_str(json).unwrap();
914 assert_eq!(route.credential_format.as_deref(), Some("Bearer {}"));
915 assert_eq!(
916 resolved_credential_format(&route.inject_header, route.credential_format.as_deref()),
917 "Bearer {}"
918 );
919 }
920
921 #[test]
922 fn test_resolved_credential_format_authorization_case_insensitive() {
923 for header in ["authorization", "AUTHORIZATION", "Authorization"] {
924 assert_eq!(
925 resolved_credential_format(header, None),
926 "Bearer {}",
927 "omitted format: Authorization header name is matched case-insensitively for Bearer default"
928 );
929 }
930 }
931}