1use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9
10#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum InjectMode {
14 #[default]
16 Header,
17 UrlPath,
19 QueryParam,
21 BasicAuth,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProxyConfig {
28 #[serde(default = "default_bind_addr")]
30 pub bind_addr: IpAddr,
31
32 #[serde(default)]
34 pub bind_port: u16,
35
36 #[serde(default)]
39 pub allowed_hosts: Vec<String>,
40
41 #[serde(default)]
43 pub routes: Vec<RouteConfig>,
44
45 #[serde(default)]
48 pub external_proxy: Option<ExternalProxyConfig>,
49
50 #[serde(default)]
52 pub max_connections: usize,
53}
54
55impl Default for ProxyConfig {
56 fn default() -> Self {
57 Self {
58 bind_addr: default_bind_addr(),
59 bind_port: 0,
60 allowed_hosts: Vec::new(),
61 routes: Vec::new(),
62 external_proxy: None,
63 max_connections: 256,
64 }
65 }
66}
67
68fn default_bind_addr() -> IpAddr {
69 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RouteConfig {
75 pub prefix: String,
78
79 pub upstream: String,
81
82 pub credential_key: Option<String>,
85
86 #[serde(default)]
88 pub inject_mode: InjectMode,
89
90 #[serde(default = "default_inject_header")]
94 pub inject_header: String,
95
96 #[serde(default = "default_credential_format")]
100 pub credential_format: String,
101
102 #[serde(default)]
107 pub path_pattern: Option<String>,
108
109 #[serde(default)]
113 pub path_replacement: Option<String>,
114
115 #[serde(default)]
119 pub query_param_name: Option<String>,
120
121 #[serde(default)]
128 pub env_var: Option<String>,
129
130 #[serde(default)]
136 pub endpoint_rules: Vec<EndpointRule>,
137
138 #[serde(default)]
145 pub tls_ca: Option<String>,
146
147 #[serde(default)]
154 pub tls_client_cert: Option<String>,
155
156 #[serde(default)]
161 pub tls_client_key: Option<String>,
162}
163
164#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct EndpointRule {
172 pub method: String,
174 pub path: String,
177}
178
179pub struct CompiledEndpointRules {
185 rules: Vec<CompiledRule>,
186}
187
188struct CompiledRule {
189 method: String,
190 matcher: globset::GlobMatcher,
191}
192
193impl CompiledEndpointRules {
194 pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
197 let mut compiled = Vec::with_capacity(rules.len());
198 for rule in rules {
199 let glob = Glob::new(&rule.path)
200 .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
201 compiled.push(CompiledRule {
202 method: rule.method.clone(),
203 matcher: glob.compile_matcher(),
204 });
205 }
206 Ok(Self { rules: compiled })
207 }
208
209 #[must_use]
212 pub fn is_allowed(&self, method: &str, path: &str) -> bool {
213 if self.rules.is_empty() {
214 return true;
215 }
216 let normalized = normalize_path(path);
217 self.rules.iter().any(|r| {
218 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
219 && r.matcher.is_match(&normalized)
220 })
221 }
222}
223
224impl std::fmt::Debug for CompiledEndpointRules {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 f.debug_struct("CompiledEndpointRules")
227 .field("count", &self.rules.len())
228 .finish()
229 }
230}
231
232#[cfg(test)]
238fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
239 if rules.is_empty() {
240 return true;
241 }
242 let normalized = normalize_path(path);
243 rules.iter().any(|r| {
244 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
245 && Glob::new(&r.path)
246 .ok()
247 .map(|g| g.compile_matcher())
248 .is_some_and(|m| m.is_match(&normalized))
249 })
250}
251
252fn normalize_path(path: &str) -> String {
258 let path = path.split('?').next().unwrap_or(path);
260
261 let binary = urlencoding::decode_binary(path.as_bytes());
265 let decoded = String::from_utf8_lossy(&binary);
266
267 let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
270 if segments.is_empty() {
271 "/".to_string()
272 } else {
273 format!("/{}", segments.join("/"))
274 }
275}
276
277fn default_inject_header() -> String {
278 "Authorization".to_string()
279}
280
281fn default_credential_format() -> String {
282 "Bearer {}".to_string()
283}
284
285#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct ExternalProxyConfig {
288 pub address: String,
290
291 pub auth: Option<ExternalProxyAuth>,
293
294 #[serde(default)]
298 pub bypass_hosts: Vec<String>,
299}
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct ExternalProxyAuth {
304 pub keyring_account: String,
306
307 #[serde(default = "default_auth_scheme")]
309 pub scheme: String,
310}
311
312fn default_auth_scheme() -> String {
313 "basic".to_string()
314}
315
316#[cfg(test)]
317#[allow(clippy::unwrap_used)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_default_config() {
323 let config = ProxyConfig::default();
324 assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
325 assert_eq!(config.bind_port, 0);
326 assert!(config.allowed_hosts.is_empty());
327 assert!(config.routes.is_empty());
328 assert!(config.external_proxy.is_none());
329 }
330
331 #[test]
332 fn test_config_serialization() {
333 let config = ProxyConfig {
334 allowed_hosts: vec!["api.openai.com".to_string()],
335 ..Default::default()
336 };
337 let json = serde_json::to_string(&config).unwrap();
338 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
339 assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
340 }
341
342 #[test]
343 fn test_external_proxy_config_with_bypass_hosts() {
344 let config = ProxyConfig {
345 external_proxy: Some(ExternalProxyConfig {
346 address: "squid.corp:3128".to_string(),
347 auth: None,
348 bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
349 }),
350 ..Default::default()
351 };
352 let json = serde_json::to_string(&config).unwrap();
353 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
354 let ext = deserialized.external_proxy.unwrap();
355 assert_eq!(ext.address, "squid.corp:3128");
356 assert_eq!(ext.bypass_hosts.len(), 2);
357 assert_eq!(ext.bypass_hosts[0], "internal.corp");
358 assert_eq!(ext.bypass_hosts[1], "*.private.net");
359 }
360
361 #[test]
362 fn test_external_proxy_config_bypass_hosts_default_empty() {
363 let json = r#"{"address": "proxy:3128", "auth": null}"#;
364 let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
365 assert!(ext.bypass_hosts.is_empty());
366 }
367
368 #[test]
373 fn test_endpoint_allowed_empty_rules_allows_all() {
374 assert!(endpoint_allowed(&[], "GET", "/anything"));
375 assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
376 }
377
378 fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
380 endpoint_allowed(std::slice::from_ref(rule), method, path)
381 }
382
383 #[test]
384 fn test_endpoint_rule_exact_path() {
385 let rule = EndpointRule {
386 method: "GET".to_string(),
387 path: "/v1/chat/completions".to_string(),
388 };
389 assert!(check(&rule, "GET", "/v1/chat/completions"));
390 assert!(!check(&rule, "GET", "/v1/chat"));
391 assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
392 }
393
394 #[test]
395 fn test_endpoint_rule_method_case_insensitive() {
396 let rule = EndpointRule {
397 method: "get".to_string(),
398 path: "/api".to_string(),
399 };
400 assert!(check(&rule, "GET", "/api"));
401 assert!(check(&rule, "Get", "/api"));
402 }
403
404 #[test]
405 fn test_endpoint_rule_method_wildcard() {
406 let rule = EndpointRule {
407 method: "*".to_string(),
408 path: "/api/resource".to_string(),
409 };
410 assert!(check(&rule, "GET", "/api/resource"));
411 assert!(check(&rule, "DELETE", "/api/resource"));
412 assert!(check(&rule, "POST", "/api/resource"));
413 }
414
415 #[test]
416 fn test_endpoint_rule_method_mismatch() {
417 let rule = EndpointRule {
418 method: "GET".to_string(),
419 path: "/api/resource".to_string(),
420 };
421 assert!(!check(&rule, "POST", "/api/resource"));
422 assert!(!check(&rule, "DELETE", "/api/resource"));
423 }
424
425 #[test]
426 fn test_endpoint_rule_single_wildcard() {
427 let rule = EndpointRule {
428 method: "GET".to_string(),
429 path: "/api/v4/projects/*/merge_requests".to_string(),
430 };
431 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
432 assert!(check(
433 &rule,
434 "GET",
435 "/api/v4/projects/my-proj/merge_requests"
436 ));
437 assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
438 }
439
440 #[test]
441 fn test_endpoint_rule_double_wildcard() {
442 let rule = EndpointRule {
443 method: "GET".to_string(),
444 path: "/api/v4/projects/**".to_string(),
445 };
446 assert!(check(&rule, "GET", "/api/v4/projects/123"));
447 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
448 assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
449 assert!(!check(&rule, "GET", "/api/v4/other"));
450 }
451
452 #[test]
453 fn test_endpoint_rule_double_wildcard_middle() {
454 let rule = EndpointRule {
455 method: "*".to_string(),
456 path: "/api/**/notes".to_string(),
457 };
458 assert!(check(&rule, "GET", "/api/notes"));
459 assert!(check(&rule, "POST", "/api/projects/123/notes"));
460 assert!(check(&rule, "GET", "/api/a/b/c/notes"));
461 assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
462 }
463
464 #[test]
465 fn test_endpoint_rule_strips_query_string() {
466 let rule = EndpointRule {
467 method: "GET".to_string(),
468 path: "/api/data".to_string(),
469 };
470 assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
471 }
472
473 #[test]
474 fn test_endpoint_rule_trailing_slash_normalized() {
475 let rule = EndpointRule {
476 method: "GET".to_string(),
477 path: "/api/data".to_string(),
478 };
479 assert!(check(&rule, "GET", "/api/data/"));
480 assert!(check(&rule, "GET", "/api/data"));
481 }
482
483 #[test]
484 fn test_endpoint_rule_double_slash_normalized() {
485 let rule = EndpointRule {
486 method: "GET".to_string(),
487 path: "/api/data".to_string(),
488 };
489 assert!(check(&rule, "GET", "/api//data"));
490 }
491
492 #[test]
493 fn test_endpoint_rule_root_path() {
494 let rule = EndpointRule {
495 method: "GET".to_string(),
496 path: "/".to_string(),
497 };
498 assert!(check(&rule, "GET", "/"));
499 assert!(!check(&rule, "GET", "/anything"));
500 }
501
502 #[test]
503 fn test_compiled_endpoint_rules_hot_path() {
504 let rules = vec![
505 EndpointRule {
506 method: "GET".to_string(),
507 path: "/repos/*/issues".to_string(),
508 },
509 EndpointRule {
510 method: "POST".to_string(),
511 path: "/repos/*/issues/*/comments".to_string(),
512 },
513 ];
514 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
515 assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
516 assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
517 assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
518 assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
519 }
520
521 #[test]
522 fn test_compiled_endpoint_rules_empty_allows_all() {
523 let compiled = CompiledEndpointRules::compile(&[]).unwrap();
524 assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
525 }
526
527 #[test]
528 fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
529 let rules = vec![EndpointRule {
530 method: "GET".to_string(),
531 path: "/api/[invalid".to_string(),
532 }];
533 assert!(CompiledEndpointRules::compile(&rules).is_err());
534 }
535
536 #[test]
537 fn test_endpoint_allowed_multiple_rules() {
538 let rules = vec![
539 EndpointRule {
540 method: "GET".to_string(),
541 path: "/repos/*/issues".to_string(),
542 },
543 EndpointRule {
544 method: "POST".to_string(),
545 path: "/repos/*/issues/*/comments".to_string(),
546 },
547 ];
548 assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
549 assert!(endpoint_allowed(
550 &rules,
551 "POST",
552 "/repos/myrepo/issues/42/comments"
553 ));
554 assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
555 assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
556 }
557
558 #[test]
559 fn test_endpoint_rule_serde_default() {
560 let json = r#"{
561 "prefix": "test",
562 "upstream": "https://example.com"
563 }"#;
564 let route: RouteConfig = serde_json::from_str(json).unwrap();
565 assert!(route.endpoint_rules.is_empty());
566 assert!(route.tls_ca.is_none());
567 }
568
569 #[test]
570 fn test_tls_ca_serde_roundtrip() {
571 let json = r#"{
572 "prefix": "k8s",
573 "upstream": "https://kubernetes.local:6443",
574 "tls_ca": "/run/secrets/k8s-ca.crt"
575 }"#;
576 let route: RouteConfig = serde_json::from_str(json).unwrap();
577 assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
578
579 let serialized = serde_json::to_string(&route).unwrap();
580 let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
581 assert_eq!(
582 deserialized.tls_ca.as_deref(),
583 Some("/run/secrets/k8s-ca.crt")
584 );
585 }
586
587 #[test]
588 fn test_endpoint_rule_percent_encoded_path_decoded() {
589 let rule = EndpointRule {
592 method: "GET".to_string(),
593 path: "/api/v4/projects/*/issues".to_string(),
594 };
595 assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
596 assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
597 }
598
599 #[test]
600 fn test_endpoint_rule_percent_encoded_full_segment() {
601 let rule = EndpointRule {
602 method: "POST".to_string(),
603 path: "/api/data".to_string(),
604 };
605 assert!(check(&rule, "POST", "/api/%64%61%74%61"));
607 }
608
609 #[test]
610 fn test_compiled_endpoint_rules_percent_encoded() {
611 let rules = vec![EndpointRule {
612 method: "GET".to_string(),
613 path: "/repos/*/issues".to_string(),
614 }];
615 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
616 assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
618 assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
619 }
620
621 #[test]
622 fn test_endpoint_rule_percent_encoded_invalid_utf8() {
623 let rule = EndpointRule {
627 method: "GET".to_string(),
628 path: "/api/projects".to_string(),
629 };
630 assert!(!check(&rule, "GET", "/api/%FFprojects"));
632 }
633
634 #[test]
635 fn test_endpoint_rule_serde_roundtrip() {
636 let rule = EndpointRule {
637 method: "GET".to_string(),
638 path: "/api/*/data".to_string(),
639 };
640 let json = serde_json::to_string(&rule).unwrap();
641 let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
642 assert_eq!(deserialized.method, "GET");
643 assert_eq!(deserialized.path, "/api/*/data");
644 }
645}