1use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9use std::path::PathBuf;
10
11#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum InjectMode {
15 #[default]
17 Header,
18 UrlPath,
20 QueryParam,
22 BasicAuth,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ProxyConfig {
29 #[serde(default = "default_bind_addr")]
31 pub bind_addr: IpAddr,
32
33 #[serde(default)]
35 pub bind_port: u16,
36
37 #[serde(default)]
40 pub allowed_hosts: Vec<String>,
41
42 #[serde(default)]
44 pub routes: Vec<RouteConfig>,
45
46 #[serde(default)]
49 pub external_proxy: Option<ExternalProxyConfig>,
50
51 #[serde(default)]
55 pub direct_connect_ports: Vec<u16>,
56
57 #[serde(default)]
59 pub max_connections: usize,
60
61 #[serde(default, skip_serializing_if = "Option::is_none")]
79 pub intercept_ca_dir: Option<PathBuf>,
80
81 #[serde(default, skip)]
89 pub intercept_parent_ca_pems: Option<Vec<u8>>,
90}
91
92impl Default for ProxyConfig {
93 fn default() -> Self {
94 Self {
95 bind_addr: default_bind_addr(),
96 bind_port: 0,
97 allowed_hosts: Vec::new(),
98 routes: Vec::new(),
99 external_proxy: None,
100 direct_connect_ports: Vec::new(),
101 max_connections: 256,
102 intercept_ca_dir: None,
103 intercept_parent_ca_pems: None,
104 }
105 }
106}
107
108fn default_bind_addr() -> IpAddr {
109 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct RouteConfig {
115 pub prefix: String,
118
119 pub upstream: String,
121
122 pub credential_key: Option<String>,
125
126 #[serde(default)]
128 pub inject_mode: InjectMode,
129
130 #[serde(default = "default_inject_header")]
134 pub inject_header: String,
135
136 #[serde(default)]
142 pub credential_format: Option<String>,
143
144 #[serde(default)]
149 pub path_pattern: Option<String>,
150
151 #[serde(default)]
155 pub path_replacement: Option<String>,
156
157 #[serde(default)]
161 pub query_param_name: Option<String>,
162
163 #[serde(default)]
169 pub proxy: Option<ProxyInjectConfig>,
170
171 #[serde(default)]
178 pub env_var: Option<String>,
179
180 #[serde(default)]
186 pub endpoint_rules: Vec<EndpointRule>,
187
188 #[serde(default)]
195 pub tls_ca: Option<String>,
196
197 #[serde(default)]
204 pub tls_client_cert: Option<String>,
205
206 #[serde(default)]
211 pub tls_client_key: Option<String>,
212
213 #[serde(default)]
218 pub oauth2: Option<OAuth2Config>,
219}
220
221#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
227#[serde(deny_unknown_fields)]
228pub struct ProxyInjectConfig {
229 #[serde(default)]
231 pub inject_mode: Option<InjectMode>,
232
233 #[serde(default)]
235 pub inject_header: Option<String>,
236
237 #[serde(default)]
239 pub credential_format: Option<String>,
240
241 #[serde(default)]
243 pub path_pattern: Option<String>,
244
245 #[serde(default)]
247 pub path_replacement: Option<String>,
248
249 #[serde(default)]
251 pub query_param_name: Option<String>,
252}
253
254#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
261pub struct EndpointRule {
262 pub method: String,
264 pub path: String,
267}
268
269pub struct CompiledEndpointRules {
275 rules: Vec<CompiledRule>,
276}
277
278struct CompiledRule {
279 method: String,
280 matcher: globset::GlobMatcher,
281}
282
283impl CompiledEndpointRules {
284 pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
287 let mut compiled = Vec::with_capacity(rules.len());
288 for rule in rules {
289 let glob = Glob::new(&rule.path)
290 .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
291 compiled.push(CompiledRule {
292 method: rule.method.clone(),
293 matcher: glob.compile_matcher(),
294 });
295 }
296 Ok(Self { rules: compiled })
297 }
298
299 #[must_use]
301 pub fn is_empty(&self) -> bool {
302 self.rules.is_empty()
303 }
304
305 #[must_use]
307 pub fn is_allowed(&self, method: &str, path: &str) -> bool {
308 if self.rules.is_empty() {
309 return true;
310 }
311 let normalized = normalize_path(path);
312 self.rules.iter().any(|r| {
313 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
314 && r.matcher.is_match(&normalized)
315 })
316 }
317}
318
319impl std::fmt::Debug for CompiledEndpointRules {
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 f.debug_struct("CompiledEndpointRules")
322 .field("count", &self.rules.len())
323 .finish()
324 }
325}
326
327#[cfg(test)]
333fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
334 if rules.is_empty() {
335 return true;
336 }
337 let normalized = normalize_path(path);
338 rules.iter().any(|r| {
339 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
340 && Glob::new(&r.path)
341 .ok()
342 .map(|g| g.compile_matcher())
343 .is_some_and(|m| m.is_match(&normalized))
344 })
345}
346
347fn normalize_path(path: &str) -> String {
353 let path = path.split('?').next().unwrap_or(path);
355
356 let binary = urlencoding::decode_binary(path.as_bytes());
360 let decoded = String::from_utf8_lossy(&binary);
361
362 let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
365 if segments.is_empty() {
366 "/".to_string()
367 } else {
368 format!("/{}", segments.join("/"))
369 }
370}
371
372fn default_inject_header() -> String {
373 "Authorization".to_string()
374}
375
376#[must_use]
380pub fn resolved_credential_format(inject_header: &str, credential_format: Option<&str>) -> String {
381 match credential_format {
382 Some(fmt) => fmt.to_string(),
383 None => {
384 if inject_header.eq_ignore_ascii_case("Authorization") {
385 "Bearer {}".to_string()
386 } else {
387 "{}".to_string()
388 }
389 }
390 }
391}
392
393#[derive(Debug, Clone, Serialize, Deserialize)]
395pub struct ExternalProxyConfig {
396 pub address: String,
398
399 pub auth: Option<ExternalProxyAuth>,
401
402 #[serde(default)]
406 pub bypass_hosts: Vec<String>,
407}
408
409#[derive(Debug, Clone, Serialize, Deserialize)]
411pub struct ExternalProxyAuth {
412 pub keyring_account: String,
414
415 #[serde(default = "default_auth_scheme")]
417 pub scheme: String,
418}
419
420fn default_auth_scheme() -> String {
421 "basic".to_string()
422}
423
424#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
434pub struct OAuth2Config {
435 pub token_url: String,
437 pub client_id: String,
439 pub client_secret: String,
441 #[serde(default)]
443 pub scope: String,
444}
445
446#[cfg(test)]
447#[allow(clippy::unwrap_used)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_default_config() {
453 let config = ProxyConfig::default();
454 assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
455 assert_eq!(config.bind_port, 0);
456 assert!(config.allowed_hosts.is_empty());
457 assert!(config.routes.is_empty());
458 assert!(config.external_proxy.is_none());
459 }
460
461 #[test]
462 fn test_config_serialization() {
463 let config = ProxyConfig {
464 allowed_hosts: vec!["api.openai.com".to_string()],
465 ..Default::default()
466 };
467 let json = serde_json::to_string(&config).unwrap();
468 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
469 assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
470 }
471
472 #[test]
473 fn test_external_proxy_config_with_bypass_hosts() {
474 let config = ProxyConfig {
475 external_proxy: Some(ExternalProxyConfig {
476 address: "squid.corp:3128".to_string(),
477 auth: None,
478 bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
479 }),
480 ..Default::default()
481 };
482 let json = serde_json::to_string(&config).unwrap();
483 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
484 let ext = deserialized.external_proxy.unwrap();
485 assert_eq!(ext.address, "squid.corp:3128");
486 assert_eq!(ext.bypass_hosts.len(), 2);
487 assert_eq!(ext.bypass_hosts[0], "internal.corp");
488 assert_eq!(ext.bypass_hosts[1], "*.private.net");
489 }
490
491 #[test]
492 fn test_external_proxy_config_bypass_hosts_default_empty() {
493 let json = r#"{"address": "proxy:3128", "auth": null}"#;
494 let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
495 assert!(ext.bypass_hosts.is_empty());
496 }
497
498 #[test]
503 fn test_endpoint_allowed_empty_rules_allows_all() {
504 assert!(endpoint_allowed(&[], "GET", "/anything"));
505 assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
506 }
507
508 fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
510 endpoint_allowed(std::slice::from_ref(rule), method, path)
511 }
512
513 #[test]
514 fn test_endpoint_rule_exact_path() {
515 let rule = EndpointRule {
516 method: "GET".to_string(),
517 path: "/v1/chat/completions".to_string(),
518 };
519 assert!(check(&rule, "GET", "/v1/chat/completions"));
520 assert!(!check(&rule, "GET", "/v1/chat"));
521 assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
522 }
523
524 #[test]
525 fn test_endpoint_rule_method_case_insensitive() {
526 let rule = EndpointRule {
527 method: "get".to_string(),
528 path: "/api".to_string(),
529 };
530 assert!(check(&rule, "GET", "/api"));
531 assert!(check(&rule, "Get", "/api"));
532 }
533
534 #[test]
535 fn test_endpoint_rule_method_wildcard() {
536 let rule = EndpointRule {
537 method: "*".to_string(),
538 path: "/api/resource".to_string(),
539 };
540 assert!(check(&rule, "GET", "/api/resource"));
541 assert!(check(&rule, "DELETE", "/api/resource"));
542 assert!(check(&rule, "POST", "/api/resource"));
543 }
544
545 #[test]
546 fn test_endpoint_rule_method_mismatch() {
547 let rule = EndpointRule {
548 method: "GET".to_string(),
549 path: "/api/resource".to_string(),
550 };
551 assert!(!check(&rule, "POST", "/api/resource"));
552 assert!(!check(&rule, "DELETE", "/api/resource"));
553 }
554
555 #[test]
556 fn test_endpoint_rule_single_wildcard() {
557 let rule = EndpointRule {
558 method: "GET".to_string(),
559 path: "/api/v4/projects/*/merge_requests".to_string(),
560 };
561 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
562 assert!(check(
563 &rule,
564 "GET",
565 "/api/v4/projects/my-proj/merge_requests"
566 ));
567 assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
568 }
569
570 #[test]
571 fn test_endpoint_rule_double_wildcard() {
572 let rule = EndpointRule {
573 method: "GET".to_string(),
574 path: "/api/v4/projects/**".to_string(),
575 };
576 assert!(check(&rule, "GET", "/api/v4/projects/123"));
577 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
578 assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
579 assert!(!check(&rule, "GET", "/api/v4/other"));
580 }
581
582 #[test]
583 fn test_endpoint_rule_double_wildcard_middle() {
584 let rule = EndpointRule {
585 method: "*".to_string(),
586 path: "/api/**/notes".to_string(),
587 };
588 assert!(check(&rule, "GET", "/api/notes"));
589 assert!(check(&rule, "POST", "/api/projects/123/notes"));
590 assert!(check(&rule, "GET", "/api/a/b/c/notes"));
591 assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
592 }
593
594 #[test]
595 fn test_endpoint_rule_strips_query_string() {
596 let rule = EndpointRule {
597 method: "GET".to_string(),
598 path: "/api/data".to_string(),
599 };
600 assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
601 }
602
603 #[test]
604 fn test_endpoint_rule_trailing_slash_normalized() {
605 let rule = EndpointRule {
606 method: "GET".to_string(),
607 path: "/api/data".to_string(),
608 };
609 assert!(check(&rule, "GET", "/api/data/"));
610 assert!(check(&rule, "GET", "/api/data"));
611 }
612
613 #[test]
614 fn test_endpoint_rule_double_slash_normalized() {
615 let rule = EndpointRule {
616 method: "GET".to_string(),
617 path: "/api/data".to_string(),
618 };
619 assert!(check(&rule, "GET", "/api//data"));
620 }
621
622 #[test]
623 fn test_endpoint_rule_root_path() {
624 let rule = EndpointRule {
625 method: "GET".to_string(),
626 path: "/".to_string(),
627 };
628 assert!(check(&rule, "GET", "/"));
629 assert!(!check(&rule, "GET", "/anything"));
630 }
631
632 #[test]
633 fn test_compiled_endpoint_rules_hot_path() {
634 let rules = vec![
635 EndpointRule {
636 method: "GET".to_string(),
637 path: "/repos/*/issues".to_string(),
638 },
639 EndpointRule {
640 method: "POST".to_string(),
641 path: "/repos/*/issues/*/comments".to_string(),
642 },
643 ];
644 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
645 assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
646 assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
647 assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
648 assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
649 }
650
651 #[test]
652 fn test_compiled_endpoint_rules_empty_allows_all() {
653 let compiled = CompiledEndpointRules::compile(&[]).unwrap();
654 assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
655 }
656
657 #[test]
658 fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
659 let rules = vec![EndpointRule {
660 method: "GET".to_string(),
661 path: "/api/[invalid".to_string(),
662 }];
663 assert!(CompiledEndpointRules::compile(&rules).is_err());
664 }
665
666 #[test]
667 fn test_endpoint_allowed_multiple_rules() {
668 let rules = vec![
669 EndpointRule {
670 method: "GET".to_string(),
671 path: "/repos/*/issues".to_string(),
672 },
673 EndpointRule {
674 method: "POST".to_string(),
675 path: "/repos/*/issues/*/comments".to_string(),
676 },
677 ];
678 assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
679 assert!(endpoint_allowed(
680 &rules,
681 "POST",
682 "/repos/myrepo/issues/42/comments"
683 ));
684 assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
685 assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
686 }
687
688 #[test]
689 fn test_endpoint_rule_serde_default() {
690 let json = r#"{
691 "prefix": "test",
692 "upstream": "https://example.com"
693 }"#;
694 let route: RouteConfig = serde_json::from_str(json).unwrap();
695 assert!(route.endpoint_rules.is_empty());
696 assert!(route.tls_ca.is_none());
697 }
698
699 #[test]
700 fn test_tls_ca_serde_roundtrip() {
701 let json = r#"{
702 "prefix": "k8s",
703 "upstream": "https://kubernetes.local:6443",
704 "tls_ca": "/run/secrets/k8s-ca.crt"
705 }"#;
706 let route: RouteConfig = serde_json::from_str(json).unwrap();
707 assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
708
709 let serialized = serde_json::to_string(&route).unwrap();
710 let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
711 assert_eq!(
712 deserialized.tls_ca.as_deref(),
713 Some("/run/secrets/k8s-ca.crt")
714 );
715 }
716
717 #[test]
718 fn test_endpoint_rule_percent_encoded_path_decoded() {
719 let rule = EndpointRule {
722 method: "GET".to_string(),
723 path: "/api/v4/projects/*/issues".to_string(),
724 };
725 assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
726 assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
727 }
728
729 #[test]
730 fn test_endpoint_rule_percent_encoded_full_segment() {
731 let rule = EndpointRule {
732 method: "POST".to_string(),
733 path: "/api/data".to_string(),
734 };
735 assert!(check(&rule, "POST", "/api/%64%61%74%61"));
737 }
738
739 #[test]
740 fn test_compiled_endpoint_rules_percent_encoded() {
741 let rules = vec![EndpointRule {
742 method: "GET".to_string(),
743 path: "/repos/*/issues".to_string(),
744 }];
745 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
746 assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
748 assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
749 }
750
751 #[test]
752 fn test_endpoint_rule_percent_encoded_invalid_utf8() {
753 let rule = EndpointRule {
757 method: "GET".to_string(),
758 path: "/api/projects".to_string(),
759 };
760 assert!(!check(&rule, "GET", "/api/%FFprojects"));
762 }
763
764 #[test]
765 fn test_endpoint_rule_serde_roundtrip() {
766 let rule = EndpointRule {
767 method: "GET".to_string(),
768 path: "/api/*/data".to_string(),
769 };
770 let json = serde_json::to_string(&rule).unwrap();
771 let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
772 assert_eq!(deserialized.method, "GET");
773 assert_eq!(deserialized.path, "/api/*/data");
774 }
775
776 #[test]
781 fn test_oauth2_config_deserialization() {
782 let json = r#"{
783 "token_url": "https://auth.example.com/oauth/token",
784 "client_id": "my-client",
785 "client_secret": "env://CLIENT_SECRET",
786 "scope": "read write"
787 }"#;
788 let config: OAuth2Config = serde_json::from_str(json).unwrap();
789 assert_eq!(config.token_url, "https://auth.example.com/oauth/token");
790 assert_eq!(config.client_id, "my-client");
791 assert_eq!(config.client_secret, "env://CLIENT_SECRET");
792 assert_eq!(config.scope, "read write");
793 }
794
795 #[test]
796 fn test_oauth2_config_default_scope() {
797 let json = r#"{
798 "token_url": "https://auth.example.com/oauth/token",
799 "client_id": "my-client",
800 "client_secret": "env://SECRET"
801 }"#;
802 let config: OAuth2Config = serde_json::from_str(json).unwrap();
803 assert_eq!(config.scope, "");
804 }
805
806 #[test]
807 fn test_route_config_with_oauth2() {
808 let json = r#"{
809 "prefix": "/my-api",
810 "upstream": "https://api.example.com",
811 "oauth2": {
812 "token_url": "https://auth.example.com/oauth/token",
813 "client_id": "agent-1",
814 "client_secret": "env://CLIENT_SECRET",
815 "scope": "api.read"
816 }
817 }"#;
818 let route: RouteConfig = serde_json::from_str(json).unwrap();
819 assert!(route.oauth2.is_some());
820 assert!(route.credential_key.is_none());
821 let oauth2 = route.oauth2.unwrap();
822 assert_eq!(oauth2.token_url, "https://auth.example.com/oauth/token");
823 }
824
825 #[test]
826 fn test_route_config_without_oauth2() {
827 let json = r#"{
828 "prefix": "/openai",
829 "upstream": "https://api.openai.com",
830 "credential_key": "openai"
831 }"#;
832 let route: RouteConfig = serde_json::from_str(json).unwrap();
833 assert!(route.oauth2.is_none());
834 assert!(route.credential_key.is_some());
835 }
836
837 #[test]
838 fn test_route_config_credential_format_omitted_is_none() {
839 let json = r#"{
840 "prefix": "anthropic",
841 "upstream": "https://api.anthropic.com",
842 "credential_key": "env://ANTHROPIC_API_KEY",
843 "inject_header": "x-api-key"
844 }"#;
845 let route: RouteConfig = serde_json::from_str(json).unwrap();
846 assert!(route.credential_format.is_none());
847 assert_eq!(
848 resolved_credential_format(&route.inject_header, route.credential_format.as_deref()),
849 "{}"
850 );
851 }
852
853 #[test]
854 fn test_route_config_explicit_bearer_on_custom_header_preserved() {
855 let json = r#"{
856 "prefix": "litellm",
857 "upstream": "https://litellm",
858 "credential_key": "env://LITELLM_TOKEN",
859 "inject_header": "x-litellm-api-key",
860 "credential_format": "Bearer {}"
861 }"#;
862 let route: RouteConfig = serde_json::from_str(json).unwrap();
863 assert_eq!(route.credential_format.as_deref(), Some("Bearer {}"));
864 assert_eq!(
865 resolved_credential_format(&route.inject_header, route.credential_format.as_deref()),
866 "Bearer {}"
867 );
868 }
869
870 #[test]
871 fn test_resolved_credential_format_authorization_case_insensitive() {
872 for header in ["authorization", "AUTHORIZATION", "Authorization"] {
873 assert_eq!(
874 resolved_credential_format(header, None),
875 "Bearer {}",
876 "omitted format: Authorization header name is matched case-insensitively for Bearer default"
877 );
878 }
879 }
880}