1use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9
10#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum InjectMode {
14 #[default]
16 Header,
17 UrlPath,
19 QueryParam,
21 BasicAuth,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProxyConfig {
28 #[serde(default = "default_bind_addr")]
30 pub bind_addr: IpAddr,
31
32 #[serde(default)]
34 pub bind_port: u16,
35
36 #[serde(default)]
39 pub allowed_hosts: Vec<String>,
40
41 #[serde(default)]
43 pub routes: Vec<RouteConfig>,
44
45 #[serde(default)]
48 pub external_proxy: Option<ExternalProxyConfig>,
49
50 #[serde(default)]
54 pub direct_connect_ports: Vec<u16>,
55
56 #[serde(default)]
58 pub max_connections: usize,
59}
60
61impl Default for ProxyConfig {
62 fn default() -> Self {
63 Self {
64 bind_addr: default_bind_addr(),
65 bind_port: 0,
66 allowed_hosts: Vec::new(),
67 routes: Vec::new(),
68 external_proxy: None,
69 direct_connect_ports: Vec::new(),
70 max_connections: 256,
71 }
72 }
73}
74
75fn default_bind_addr() -> IpAddr {
76 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct RouteConfig {
82 pub prefix: String,
85
86 pub upstream: String,
88
89 pub credential_key: Option<String>,
92
93 #[serde(default)]
95 pub inject_mode: InjectMode,
96
97 #[serde(default = "default_inject_header")]
101 pub inject_header: String,
102
103 #[serde(default = "default_credential_format")]
107 pub credential_format: String,
108
109 #[serde(default)]
114 pub path_pattern: Option<String>,
115
116 #[serde(default)]
120 pub path_replacement: Option<String>,
121
122 #[serde(default)]
126 pub query_param_name: Option<String>,
127
128 #[serde(default)]
134 pub proxy: Option<ProxyInjectConfig>,
135
136 #[serde(default)]
143 pub env_var: Option<String>,
144
145 #[serde(default)]
151 pub endpoint_rules: Vec<EndpointRule>,
152
153 #[serde(default)]
160 pub tls_ca: Option<String>,
161
162 #[serde(default)]
169 pub tls_client_cert: Option<String>,
170
171 #[serde(default)]
176 pub tls_client_key: Option<String>,
177
178 #[serde(default)]
183 pub oauth2: Option<OAuth2Config>,
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
192#[serde(deny_unknown_fields)]
193pub struct ProxyInjectConfig {
194 #[serde(default)]
196 pub inject_mode: Option<InjectMode>,
197
198 #[serde(default)]
200 pub inject_header: Option<String>,
201
202 #[serde(default)]
204 pub credential_format: Option<String>,
205
206 #[serde(default)]
208 pub path_pattern: Option<String>,
209
210 #[serde(default)]
212 pub path_replacement: Option<String>,
213
214 #[serde(default)]
216 pub query_param_name: Option<String>,
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
226pub struct EndpointRule {
227 pub method: String,
229 pub path: String,
232}
233
234pub struct CompiledEndpointRules {
240 rules: Vec<CompiledRule>,
241}
242
243struct CompiledRule {
244 method: String,
245 matcher: globset::GlobMatcher,
246}
247
248impl CompiledEndpointRules {
249 pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
252 let mut compiled = Vec::with_capacity(rules.len());
253 for rule in rules {
254 let glob = Glob::new(&rule.path)
255 .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
256 compiled.push(CompiledRule {
257 method: rule.method.clone(),
258 matcher: glob.compile_matcher(),
259 });
260 }
261 Ok(Self { rules: compiled })
262 }
263
264 #[must_use]
267 pub fn is_allowed(&self, method: &str, path: &str) -> bool {
268 if self.rules.is_empty() {
269 return true;
270 }
271 let normalized = normalize_path(path);
272 self.rules.iter().any(|r| {
273 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
274 && r.matcher.is_match(&normalized)
275 })
276 }
277}
278
279impl std::fmt::Debug for CompiledEndpointRules {
280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 f.debug_struct("CompiledEndpointRules")
282 .field("count", &self.rules.len())
283 .finish()
284 }
285}
286
287#[cfg(test)]
293fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
294 if rules.is_empty() {
295 return true;
296 }
297 let normalized = normalize_path(path);
298 rules.iter().any(|r| {
299 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
300 && Glob::new(&r.path)
301 .ok()
302 .map(|g| g.compile_matcher())
303 .is_some_and(|m| m.is_match(&normalized))
304 })
305}
306
307fn normalize_path(path: &str) -> String {
313 let path = path.split('?').next().unwrap_or(path);
315
316 let binary = urlencoding::decode_binary(path.as_bytes());
320 let decoded = String::from_utf8_lossy(&binary);
321
322 let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
325 if segments.is_empty() {
326 "/".to_string()
327 } else {
328 format!("/{}", segments.join("/"))
329 }
330}
331
332fn default_inject_header() -> String {
333 "Authorization".to_string()
334}
335
336fn default_credential_format() -> String {
337 "Bearer {}".to_string()
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct ExternalProxyConfig {
343 pub address: String,
345
346 pub auth: Option<ExternalProxyAuth>,
348
349 #[serde(default)]
353 pub bypass_hosts: Vec<String>,
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize)]
358pub struct ExternalProxyAuth {
359 pub keyring_account: String,
361
362 #[serde(default = "default_auth_scheme")]
364 pub scheme: String,
365}
366
367fn default_auth_scheme() -> String {
368 "basic".to_string()
369}
370
371#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
381pub struct OAuth2Config {
382 pub token_url: String,
384 pub client_id: String,
386 pub client_secret: String,
388 #[serde(default)]
390 pub scope: String,
391}
392
393#[cfg(test)]
394#[allow(clippy::unwrap_used)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn test_default_config() {
400 let config = ProxyConfig::default();
401 assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
402 assert_eq!(config.bind_port, 0);
403 assert!(config.allowed_hosts.is_empty());
404 assert!(config.routes.is_empty());
405 assert!(config.external_proxy.is_none());
406 }
407
408 #[test]
409 fn test_config_serialization() {
410 let config = ProxyConfig {
411 allowed_hosts: vec!["api.openai.com".to_string()],
412 ..Default::default()
413 };
414 let json = serde_json::to_string(&config).unwrap();
415 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
416 assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
417 }
418
419 #[test]
420 fn test_external_proxy_config_with_bypass_hosts() {
421 let config = ProxyConfig {
422 external_proxy: Some(ExternalProxyConfig {
423 address: "squid.corp:3128".to_string(),
424 auth: None,
425 bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
426 }),
427 ..Default::default()
428 };
429 let json = serde_json::to_string(&config).unwrap();
430 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
431 let ext = deserialized.external_proxy.unwrap();
432 assert_eq!(ext.address, "squid.corp:3128");
433 assert_eq!(ext.bypass_hosts.len(), 2);
434 assert_eq!(ext.bypass_hosts[0], "internal.corp");
435 assert_eq!(ext.bypass_hosts[1], "*.private.net");
436 }
437
438 #[test]
439 fn test_external_proxy_config_bypass_hosts_default_empty() {
440 let json = r#"{"address": "proxy:3128", "auth": null}"#;
441 let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
442 assert!(ext.bypass_hosts.is_empty());
443 }
444
445 #[test]
450 fn test_endpoint_allowed_empty_rules_allows_all() {
451 assert!(endpoint_allowed(&[], "GET", "/anything"));
452 assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
453 }
454
455 fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
457 endpoint_allowed(std::slice::from_ref(rule), method, path)
458 }
459
460 #[test]
461 fn test_endpoint_rule_exact_path() {
462 let rule = EndpointRule {
463 method: "GET".to_string(),
464 path: "/v1/chat/completions".to_string(),
465 };
466 assert!(check(&rule, "GET", "/v1/chat/completions"));
467 assert!(!check(&rule, "GET", "/v1/chat"));
468 assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
469 }
470
471 #[test]
472 fn test_endpoint_rule_method_case_insensitive() {
473 let rule = EndpointRule {
474 method: "get".to_string(),
475 path: "/api".to_string(),
476 };
477 assert!(check(&rule, "GET", "/api"));
478 assert!(check(&rule, "Get", "/api"));
479 }
480
481 #[test]
482 fn test_endpoint_rule_method_wildcard() {
483 let rule = EndpointRule {
484 method: "*".to_string(),
485 path: "/api/resource".to_string(),
486 };
487 assert!(check(&rule, "GET", "/api/resource"));
488 assert!(check(&rule, "DELETE", "/api/resource"));
489 assert!(check(&rule, "POST", "/api/resource"));
490 }
491
492 #[test]
493 fn test_endpoint_rule_method_mismatch() {
494 let rule = EndpointRule {
495 method: "GET".to_string(),
496 path: "/api/resource".to_string(),
497 };
498 assert!(!check(&rule, "POST", "/api/resource"));
499 assert!(!check(&rule, "DELETE", "/api/resource"));
500 }
501
502 #[test]
503 fn test_endpoint_rule_single_wildcard() {
504 let rule = EndpointRule {
505 method: "GET".to_string(),
506 path: "/api/v4/projects/*/merge_requests".to_string(),
507 };
508 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
509 assert!(check(
510 &rule,
511 "GET",
512 "/api/v4/projects/my-proj/merge_requests"
513 ));
514 assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
515 }
516
517 #[test]
518 fn test_endpoint_rule_double_wildcard() {
519 let rule = EndpointRule {
520 method: "GET".to_string(),
521 path: "/api/v4/projects/**".to_string(),
522 };
523 assert!(check(&rule, "GET", "/api/v4/projects/123"));
524 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
525 assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
526 assert!(!check(&rule, "GET", "/api/v4/other"));
527 }
528
529 #[test]
530 fn test_endpoint_rule_double_wildcard_middle() {
531 let rule = EndpointRule {
532 method: "*".to_string(),
533 path: "/api/**/notes".to_string(),
534 };
535 assert!(check(&rule, "GET", "/api/notes"));
536 assert!(check(&rule, "POST", "/api/projects/123/notes"));
537 assert!(check(&rule, "GET", "/api/a/b/c/notes"));
538 assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
539 }
540
541 #[test]
542 fn test_endpoint_rule_strips_query_string() {
543 let rule = EndpointRule {
544 method: "GET".to_string(),
545 path: "/api/data".to_string(),
546 };
547 assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
548 }
549
550 #[test]
551 fn test_endpoint_rule_trailing_slash_normalized() {
552 let rule = EndpointRule {
553 method: "GET".to_string(),
554 path: "/api/data".to_string(),
555 };
556 assert!(check(&rule, "GET", "/api/data/"));
557 assert!(check(&rule, "GET", "/api/data"));
558 }
559
560 #[test]
561 fn test_endpoint_rule_double_slash_normalized() {
562 let rule = EndpointRule {
563 method: "GET".to_string(),
564 path: "/api/data".to_string(),
565 };
566 assert!(check(&rule, "GET", "/api//data"));
567 }
568
569 #[test]
570 fn test_endpoint_rule_root_path() {
571 let rule = EndpointRule {
572 method: "GET".to_string(),
573 path: "/".to_string(),
574 };
575 assert!(check(&rule, "GET", "/"));
576 assert!(!check(&rule, "GET", "/anything"));
577 }
578
579 #[test]
580 fn test_compiled_endpoint_rules_hot_path() {
581 let rules = vec![
582 EndpointRule {
583 method: "GET".to_string(),
584 path: "/repos/*/issues".to_string(),
585 },
586 EndpointRule {
587 method: "POST".to_string(),
588 path: "/repos/*/issues/*/comments".to_string(),
589 },
590 ];
591 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
592 assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
593 assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
594 assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
595 assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
596 }
597
598 #[test]
599 fn test_compiled_endpoint_rules_empty_allows_all() {
600 let compiled = CompiledEndpointRules::compile(&[]).unwrap();
601 assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
602 }
603
604 #[test]
605 fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
606 let rules = vec![EndpointRule {
607 method: "GET".to_string(),
608 path: "/api/[invalid".to_string(),
609 }];
610 assert!(CompiledEndpointRules::compile(&rules).is_err());
611 }
612
613 #[test]
614 fn test_endpoint_allowed_multiple_rules() {
615 let rules = vec![
616 EndpointRule {
617 method: "GET".to_string(),
618 path: "/repos/*/issues".to_string(),
619 },
620 EndpointRule {
621 method: "POST".to_string(),
622 path: "/repos/*/issues/*/comments".to_string(),
623 },
624 ];
625 assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
626 assert!(endpoint_allowed(
627 &rules,
628 "POST",
629 "/repos/myrepo/issues/42/comments"
630 ));
631 assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
632 assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
633 }
634
635 #[test]
636 fn test_endpoint_rule_serde_default() {
637 let json = r#"{
638 "prefix": "test",
639 "upstream": "https://example.com"
640 }"#;
641 let route: RouteConfig = serde_json::from_str(json).unwrap();
642 assert!(route.endpoint_rules.is_empty());
643 assert!(route.tls_ca.is_none());
644 }
645
646 #[test]
647 fn test_tls_ca_serde_roundtrip() {
648 let json = r#"{
649 "prefix": "k8s",
650 "upstream": "https://kubernetes.local:6443",
651 "tls_ca": "/run/secrets/k8s-ca.crt"
652 }"#;
653 let route: RouteConfig = serde_json::from_str(json).unwrap();
654 assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
655
656 let serialized = serde_json::to_string(&route).unwrap();
657 let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
658 assert_eq!(
659 deserialized.tls_ca.as_deref(),
660 Some("/run/secrets/k8s-ca.crt")
661 );
662 }
663
664 #[test]
665 fn test_endpoint_rule_percent_encoded_path_decoded() {
666 let rule = EndpointRule {
669 method: "GET".to_string(),
670 path: "/api/v4/projects/*/issues".to_string(),
671 };
672 assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
673 assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
674 }
675
676 #[test]
677 fn test_endpoint_rule_percent_encoded_full_segment() {
678 let rule = EndpointRule {
679 method: "POST".to_string(),
680 path: "/api/data".to_string(),
681 };
682 assert!(check(&rule, "POST", "/api/%64%61%74%61"));
684 }
685
686 #[test]
687 fn test_compiled_endpoint_rules_percent_encoded() {
688 let rules = vec![EndpointRule {
689 method: "GET".to_string(),
690 path: "/repos/*/issues".to_string(),
691 }];
692 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
693 assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
695 assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
696 }
697
698 #[test]
699 fn test_endpoint_rule_percent_encoded_invalid_utf8() {
700 let rule = EndpointRule {
704 method: "GET".to_string(),
705 path: "/api/projects".to_string(),
706 };
707 assert!(!check(&rule, "GET", "/api/%FFprojects"));
709 }
710
711 #[test]
712 fn test_endpoint_rule_serde_roundtrip() {
713 let rule = EndpointRule {
714 method: "GET".to_string(),
715 path: "/api/*/data".to_string(),
716 };
717 let json = serde_json::to_string(&rule).unwrap();
718 let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
719 assert_eq!(deserialized.method, "GET");
720 assert_eq!(deserialized.path, "/api/*/data");
721 }
722
723 #[test]
728 fn test_oauth2_config_deserialization() {
729 let json = r#"{
730 "token_url": "https://auth.example.com/oauth/token",
731 "client_id": "my-client",
732 "client_secret": "env://CLIENT_SECRET",
733 "scope": "read write"
734 }"#;
735 let config: OAuth2Config = serde_json::from_str(json).unwrap();
736 assert_eq!(config.token_url, "https://auth.example.com/oauth/token");
737 assert_eq!(config.client_id, "my-client");
738 assert_eq!(config.client_secret, "env://CLIENT_SECRET");
739 assert_eq!(config.scope, "read write");
740 }
741
742 #[test]
743 fn test_oauth2_config_default_scope() {
744 let json = r#"{
745 "token_url": "https://auth.example.com/oauth/token",
746 "client_id": "my-client",
747 "client_secret": "env://SECRET"
748 }"#;
749 let config: OAuth2Config = serde_json::from_str(json).unwrap();
750 assert_eq!(config.scope, "");
751 }
752
753 #[test]
754 fn test_route_config_with_oauth2() {
755 let json = r#"{
756 "prefix": "/my-api",
757 "upstream": "https://api.example.com",
758 "oauth2": {
759 "token_url": "https://auth.example.com/oauth/token",
760 "client_id": "agent-1",
761 "client_secret": "env://CLIENT_SECRET",
762 "scope": "api.read"
763 }
764 }"#;
765 let route: RouteConfig = serde_json::from_str(json).unwrap();
766 assert!(route.oauth2.is_some());
767 assert!(route.credential_key.is_none());
768 let oauth2 = route.oauth2.unwrap();
769 assert_eq!(oauth2.token_url, "https://auth.example.com/oauth/token");
770 }
771
772 #[test]
773 fn test_route_config_without_oauth2() {
774 let json = r#"{
775 "prefix": "/openai",
776 "upstream": "https://api.openai.com",
777 "credential_key": "openai"
778 }"#;
779 let route: RouteConfig = serde_json::from_str(json).unwrap();
780 assert!(route.oauth2.is_none());
781 assert!(route.credential_key.is_some());
782 }
783}