1use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9
10#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum InjectMode {
14 #[default]
16 Header,
17 UrlPath,
19 QueryParam,
21 BasicAuth,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProxyConfig {
28 #[serde(default = "default_bind_addr")]
30 pub bind_addr: IpAddr,
31
32 #[serde(default)]
34 pub bind_port: u16,
35
36 #[serde(default)]
39 pub allowed_hosts: Vec<String>,
40
41 #[serde(default)]
43 pub routes: Vec<RouteConfig>,
44
45 #[serde(default)]
48 pub external_proxy: Option<ExternalProxyConfig>,
49
50 #[serde(default)]
52 pub max_connections: usize,
53}
54
55impl Default for ProxyConfig {
56 fn default() -> Self {
57 Self {
58 bind_addr: default_bind_addr(),
59 bind_port: 0,
60 allowed_hosts: Vec::new(),
61 routes: Vec::new(),
62 external_proxy: None,
63 max_connections: 256,
64 }
65 }
66}
67
68fn default_bind_addr() -> IpAddr {
69 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RouteConfig {
75 pub prefix: String,
78
79 pub upstream: String,
81
82 pub credential_key: Option<String>,
85
86 #[serde(default)]
88 pub inject_mode: InjectMode,
89
90 #[serde(default = "default_inject_header")]
94 pub inject_header: String,
95
96 #[serde(default = "default_credential_format")]
100 pub credential_format: String,
101
102 #[serde(default)]
107 pub path_pattern: Option<String>,
108
109 #[serde(default)]
113 pub path_replacement: Option<String>,
114
115 #[serde(default)]
119 pub query_param_name: Option<String>,
120
121 #[serde(default)]
127 pub proxy: Option<ProxyInjectConfig>,
128
129 #[serde(default)]
136 pub env_var: Option<String>,
137
138 #[serde(default)]
144 pub endpoint_rules: Vec<EndpointRule>,
145
146 #[serde(default)]
153 pub tls_ca: Option<String>,
154
155 #[serde(default)]
162 pub tls_client_cert: Option<String>,
163
164 #[serde(default)]
169 pub tls_client_key: Option<String>,
170}
171
172#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
178#[serde(deny_unknown_fields)]
179pub struct ProxyInjectConfig {
180 #[serde(default)]
182 pub inject_mode: Option<InjectMode>,
183
184 #[serde(default)]
186 pub inject_header: Option<String>,
187
188 #[serde(default)]
190 pub credential_format: Option<String>,
191
192 #[serde(default)]
194 pub path_pattern: Option<String>,
195
196 #[serde(default)]
198 pub path_replacement: Option<String>,
199
200 #[serde(default)]
202 pub query_param_name: Option<String>,
203}
204
205#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
212pub struct EndpointRule {
213 pub method: String,
215 pub path: String,
218}
219
220pub struct CompiledEndpointRules {
226 rules: Vec<CompiledRule>,
227}
228
229struct CompiledRule {
230 method: String,
231 matcher: globset::GlobMatcher,
232}
233
234impl CompiledEndpointRules {
235 pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
238 let mut compiled = Vec::with_capacity(rules.len());
239 for rule in rules {
240 let glob = Glob::new(&rule.path)
241 .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
242 compiled.push(CompiledRule {
243 method: rule.method.clone(),
244 matcher: glob.compile_matcher(),
245 });
246 }
247 Ok(Self { rules: compiled })
248 }
249
250 #[must_use]
253 pub fn is_allowed(&self, method: &str, path: &str) -> bool {
254 if self.rules.is_empty() {
255 return true;
256 }
257 let normalized = normalize_path(path);
258 self.rules.iter().any(|r| {
259 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
260 && r.matcher.is_match(&normalized)
261 })
262 }
263}
264
265impl std::fmt::Debug for CompiledEndpointRules {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 f.debug_struct("CompiledEndpointRules")
268 .field("count", &self.rules.len())
269 .finish()
270 }
271}
272
273#[cfg(test)]
279fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
280 if rules.is_empty() {
281 return true;
282 }
283 let normalized = normalize_path(path);
284 rules.iter().any(|r| {
285 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
286 && Glob::new(&r.path)
287 .ok()
288 .map(|g| g.compile_matcher())
289 .is_some_and(|m| m.is_match(&normalized))
290 })
291}
292
293fn normalize_path(path: &str) -> String {
299 let path = path.split('?').next().unwrap_or(path);
301
302 let binary = urlencoding::decode_binary(path.as_bytes());
306 let decoded = String::from_utf8_lossy(&binary);
307
308 let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
311 if segments.is_empty() {
312 "/".to_string()
313 } else {
314 format!("/{}", segments.join("/"))
315 }
316}
317
318fn default_inject_header() -> String {
319 "Authorization".to_string()
320}
321
322fn default_credential_format() -> String {
323 "Bearer {}".to_string()
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct ExternalProxyConfig {
329 pub address: String,
331
332 pub auth: Option<ExternalProxyAuth>,
334
335 #[serde(default)]
339 pub bypass_hosts: Vec<String>,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct ExternalProxyAuth {
345 pub keyring_account: String,
347
348 #[serde(default = "default_auth_scheme")]
350 pub scheme: String,
351}
352
353fn default_auth_scheme() -> String {
354 "basic".to_string()
355}
356
357#[cfg(test)]
358#[allow(clippy::unwrap_used)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_default_config() {
364 let config = ProxyConfig::default();
365 assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
366 assert_eq!(config.bind_port, 0);
367 assert!(config.allowed_hosts.is_empty());
368 assert!(config.routes.is_empty());
369 assert!(config.external_proxy.is_none());
370 }
371
372 #[test]
373 fn test_config_serialization() {
374 let config = ProxyConfig {
375 allowed_hosts: vec!["api.openai.com".to_string()],
376 ..Default::default()
377 };
378 let json = serde_json::to_string(&config).unwrap();
379 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
380 assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
381 }
382
383 #[test]
384 fn test_external_proxy_config_with_bypass_hosts() {
385 let config = ProxyConfig {
386 external_proxy: Some(ExternalProxyConfig {
387 address: "squid.corp:3128".to_string(),
388 auth: None,
389 bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
390 }),
391 ..Default::default()
392 };
393 let json = serde_json::to_string(&config).unwrap();
394 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
395 let ext = deserialized.external_proxy.unwrap();
396 assert_eq!(ext.address, "squid.corp:3128");
397 assert_eq!(ext.bypass_hosts.len(), 2);
398 assert_eq!(ext.bypass_hosts[0], "internal.corp");
399 assert_eq!(ext.bypass_hosts[1], "*.private.net");
400 }
401
402 #[test]
403 fn test_external_proxy_config_bypass_hosts_default_empty() {
404 let json = r#"{"address": "proxy:3128", "auth": null}"#;
405 let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
406 assert!(ext.bypass_hosts.is_empty());
407 }
408
409 #[test]
414 fn test_endpoint_allowed_empty_rules_allows_all() {
415 assert!(endpoint_allowed(&[], "GET", "/anything"));
416 assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
417 }
418
419 fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
421 endpoint_allowed(std::slice::from_ref(rule), method, path)
422 }
423
424 #[test]
425 fn test_endpoint_rule_exact_path() {
426 let rule = EndpointRule {
427 method: "GET".to_string(),
428 path: "/v1/chat/completions".to_string(),
429 };
430 assert!(check(&rule, "GET", "/v1/chat/completions"));
431 assert!(!check(&rule, "GET", "/v1/chat"));
432 assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
433 }
434
435 #[test]
436 fn test_endpoint_rule_method_case_insensitive() {
437 let rule = EndpointRule {
438 method: "get".to_string(),
439 path: "/api".to_string(),
440 };
441 assert!(check(&rule, "GET", "/api"));
442 assert!(check(&rule, "Get", "/api"));
443 }
444
445 #[test]
446 fn test_endpoint_rule_method_wildcard() {
447 let rule = EndpointRule {
448 method: "*".to_string(),
449 path: "/api/resource".to_string(),
450 };
451 assert!(check(&rule, "GET", "/api/resource"));
452 assert!(check(&rule, "DELETE", "/api/resource"));
453 assert!(check(&rule, "POST", "/api/resource"));
454 }
455
456 #[test]
457 fn test_endpoint_rule_method_mismatch() {
458 let rule = EndpointRule {
459 method: "GET".to_string(),
460 path: "/api/resource".to_string(),
461 };
462 assert!(!check(&rule, "POST", "/api/resource"));
463 assert!(!check(&rule, "DELETE", "/api/resource"));
464 }
465
466 #[test]
467 fn test_endpoint_rule_single_wildcard() {
468 let rule = EndpointRule {
469 method: "GET".to_string(),
470 path: "/api/v4/projects/*/merge_requests".to_string(),
471 };
472 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
473 assert!(check(
474 &rule,
475 "GET",
476 "/api/v4/projects/my-proj/merge_requests"
477 ));
478 assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
479 }
480
481 #[test]
482 fn test_endpoint_rule_double_wildcard() {
483 let rule = EndpointRule {
484 method: "GET".to_string(),
485 path: "/api/v4/projects/**".to_string(),
486 };
487 assert!(check(&rule, "GET", "/api/v4/projects/123"));
488 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
489 assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
490 assert!(!check(&rule, "GET", "/api/v4/other"));
491 }
492
493 #[test]
494 fn test_endpoint_rule_double_wildcard_middle() {
495 let rule = EndpointRule {
496 method: "*".to_string(),
497 path: "/api/**/notes".to_string(),
498 };
499 assert!(check(&rule, "GET", "/api/notes"));
500 assert!(check(&rule, "POST", "/api/projects/123/notes"));
501 assert!(check(&rule, "GET", "/api/a/b/c/notes"));
502 assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
503 }
504
505 #[test]
506 fn test_endpoint_rule_strips_query_string() {
507 let rule = EndpointRule {
508 method: "GET".to_string(),
509 path: "/api/data".to_string(),
510 };
511 assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
512 }
513
514 #[test]
515 fn test_endpoint_rule_trailing_slash_normalized() {
516 let rule = EndpointRule {
517 method: "GET".to_string(),
518 path: "/api/data".to_string(),
519 };
520 assert!(check(&rule, "GET", "/api/data/"));
521 assert!(check(&rule, "GET", "/api/data"));
522 }
523
524 #[test]
525 fn test_endpoint_rule_double_slash_normalized() {
526 let rule = EndpointRule {
527 method: "GET".to_string(),
528 path: "/api/data".to_string(),
529 };
530 assert!(check(&rule, "GET", "/api//data"));
531 }
532
533 #[test]
534 fn test_endpoint_rule_root_path() {
535 let rule = EndpointRule {
536 method: "GET".to_string(),
537 path: "/".to_string(),
538 };
539 assert!(check(&rule, "GET", "/"));
540 assert!(!check(&rule, "GET", "/anything"));
541 }
542
543 #[test]
544 fn test_compiled_endpoint_rules_hot_path() {
545 let rules = vec![
546 EndpointRule {
547 method: "GET".to_string(),
548 path: "/repos/*/issues".to_string(),
549 },
550 EndpointRule {
551 method: "POST".to_string(),
552 path: "/repos/*/issues/*/comments".to_string(),
553 },
554 ];
555 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
556 assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
557 assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
558 assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
559 assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
560 }
561
562 #[test]
563 fn test_compiled_endpoint_rules_empty_allows_all() {
564 let compiled = CompiledEndpointRules::compile(&[]).unwrap();
565 assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
566 }
567
568 #[test]
569 fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
570 let rules = vec![EndpointRule {
571 method: "GET".to_string(),
572 path: "/api/[invalid".to_string(),
573 }];
574 assert!(CompiledEndpointRules::compile(&rules).is_err());
575 }
576
577 #[test]
578 fn test_endpoint_allowed_multiple_rules() {
579 let rules = vec![
580 EndpointRule {
581 method: "GET".to_string(),
582 path: "/repos/*/issues".to_string(),
583 },
584 EndpointRule {
585 method: "POST".to_string(),
586 path: "/repos/*/issues/*/comments".to_string(),
587 },
588 ];
589 assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
590 assert!(endpoint_allowed(
591 &rules,
592 "POST",
593 "/repos/myrepo/issues/42/comments"
594 ));
595 assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
596 assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
597 }
598
599 #[test]
600 fn test_endpoint_rule_serde_default() {
601 let json = r#"{
602 "prefix": "test",
603 "upstream": "https://example.com"
604 }"#;
605 let route: RouteConfig = serde_json::from_str(json).unwrap();
606 assert!(route.endpoint_rules.is_empty());
607 assert!(route.tls_ca.is_none());
608 }
609
610 #[test]
611 fn test_tls_ca_serde_roundtrip() {
612 let json = r#"{
613 "prefix": "k8s",
614 "upstream": "https://kubernetes.local:6443",
615 "tls_ca": "/run/secrets/k8s-ca.crt"
616 }"#;
617 let route: RouteConfig = serde_json::from_str(json).unwrap();
618 assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
619
620 let serialized = serde_json::to_string(&route).unwrap();
621 let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
622 assert_eq!(
623 deserialized.tls_ca.as_deref(),
624 Some("/run/secrets/k8s-ca.crt")
625 );
626 }
627
628 #[test]
629 fn test_endpoint_rule_percent_encoded_path_decoded() {
630 let rule = EndpointRule {
633 method: "GET".to_string(),
634 path: "/api/v4/projects/*/issues".to_string(),
635 };
636 assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
637 assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
638 }
639
640 #[test]
641 fn test_endpoint_rule_percent_encoded_full_segment() {
642 let rule = EndpointRule {
643 method: "POST".to_string(),
644 path: "/api/data".to_string(),
645 };
646 assert!(check(&rule, "POST", "/api/%64%61%74%61"));
648 }
649
650 #[test]
651 fn test_compiled_endpoint_rules_percent_encoded() {
652 let rules = vec![EndpointRule {
653 method: "GET".to_string(),
654 path: "/repos/*/issues".to_string(),
655 }];
656 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
657 assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
659 assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
660 }
661
662 #[test]
663 fn test_endpoint_rule_percent_encoded_invalid_utf8() {
664 let rule = EndpointRule {
668 method: "GET".to_string(),
669 path: "/api/projects".to_string(),
670 };
671 assert!(!check(&rule, "GET", "/api/%FFprojects"));
673 }
674
675 #[test]
676 fn test_endpoint_rule_serde_roundtrip() {
677 let rule = EndpointRule {
678 method: "GET".to_string(),
679 path: "/api/*/data".to_string(),
680 };
681 let json = serde_json::to_string(&rule).unwrap();
682 let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
683 assert_eq!(deserialized.method, "GET");
684 assert_eq!(deserialized.path, "/api/*/data");
685 }
686}