1use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9
10#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum InjectMode {
14 #[default]
16 Header,
17 UrlPath,
19 QueryParam,
21 BasicAuth,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProxyConfig {
28 #[serde(default = "default_bind_addr")]
30 pub bind_addr: IpAddr,
31
32 #[serde(default)]
34 pub bind_port: u16,
35
36 #[serde(default)]
39 pub allowed_hosts: Vec<String>,
40
41 #[serde(default)]
43 pub routes: Vec<RouteConfig>,
44
45 #[serde(default)]
48 pub external_proxy: Option<ExternalProxyConfig>,
49
50 #[serde(default)]
52 pub max_connections: usize,
53}
54
55impl Default for ProxyConfig {
56 fn default() -> Self {
57 Self {
58 bind_addr: default_bind_addr(),
59 bind_port: 0,
60 allowed_hosts: Vec::new(),
61 routes: Vec::new(),
62 external_proxy: None,
63 max_connections: 256,
64 }
65 }
66}
67
68fn default_bind_addr() -> IpAddr {
69 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RouteConfig {
75 pub prefix: String,
78
79 pub upstream: String,
81
82 pub credential_key: Option<String>,
85
86 #[serde(default)]
88 pub inject_mode: InjectMode,
89
90 #[serde(default = "default_inject_header")]
94 pub inject_header: String,
95
96 #[serde(default = "default_credential_format")]
100 pub credential_format: String,
101
102 #[serde(default)]
107 pub path_pattern: Option<String>,
108
109 #[serde(default)]
113 pub path_replacement: Option<String>,
114
115 #[serde(default)]
119 pub query_param_name: Option<String>,
120
121 #[serde(default)]
128 pub env_var: Option<String>,
129
130 #[serde(default)]
136 pub endpoint_rules: Vec<EndpointRule>,
137
138 #[serde(default)]
145 pub tls_ca: Option<String>,
146}
147
148#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
155pub struct EndpointRule {
156 pub method: String,
158 pub path: String,
161}
162
163pub struct CompiledEndpointRules {
169 rules: Vec<CompiledRule>,
170}
171
172struct CompiledRule {
173 method: String,
174 matcher: globset::GlobMatcher,
175}
176
177impl CompiledEndpointRules {
178 pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
181 let mut compiled = Vec::with_capacity(rules.len());
182 for rule in rules {
183 let glob = Glob::new(&rule.path)
184 .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
185 compiled.push(CompiledRule {
186 method: rule.method.clone(),
187 matcher: glob.compile_matcher(),
188 });
189 }
190 Ok(Self { rules: compiled })
191 }
192
193 #[must_use]
196 pub fn is_allowed(&self, method: &str, path: &str) -> bool {
197 if self.rules.is_empty() {
198 return true;
199 }
200 let normalized = normalize_path(path);
201 self.rules.iter().any(|r| {
202 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
203 && r.matcher.is_match(&normalized)
204 })
205 }
206}
207
208impl std::fmt::Debug for CompiledEndpointRules {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 f.debug_struct("CompiledEndpointRules")
211 .field("count", &self.rules.len())
212 .finish()
213 }
214}
215
216#[cfg(test)]
222fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
223 if rules.is_empty() {
224 return true;
225 }
226 let normalized = normalize_path(path);
227 rules.iter().any(|r| {
228 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
229 && Glob::new(&r.path)
230 .ok()
231 .map(|g| g.compile_matcher())
232 .is_some_and(|m| m.is_match(&normalized))
233 })
234}
235
236fn normalize_path(path: &str) -> String {
242 let path = path.split('?').next().unwrap_or(path);
244
245 let binary = urlencoding::decode_binary(path.as_bytes());
249 let decoded = String::from_utf8_lossy(&binary);
250
251 let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
254 if segments.is_empty() {
255 "/".to_string()
256 } else {
257 format!("/{}", segments.join("/"))
258 }
259}
260
261fn default_inject_header() -> String {
262 "Authorization".to_string()
263}
264
265fn default_credential_format() -> String {
266 "Bearer {}".to_string()
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct ExternalProxyConfig {
272 pub address: String,
274
275 pub auth: Option<ExternalProxyAuth>,
277
278 #[serde(default)]
282 pub bypass_hosts: Vec<String>,
283}
284
285#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct ExternalProxyAuth {
288 pub keyring_account: String,
290
291 #[serde(default = "default_auth_scheme")]
293 pub scheme: String,
294}
295
296fn default_auth_scheme() -> String {
297 "basic".to_string()
298}
299
300#[cfg(test)]
301#[allow(clippy::unwrap_used)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_default_config() {
307 let config = ProxyConfig::default();
308 assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
309 assert_eq!(config.bind_port, 0);
310 assert!(config.allowed_hosts.is_empty());
311 assert!(config.routes.is_empty());
312 assert!(config.external_proxy.is_none());
313 }
314
315 #[test]
316 fn test_config_serialization() {
317 let config = ProxyConfig {
318 allowed_hosts: vec!["api.openai.com".to_string()],
319 ..Default::default()
320 };
321 let json = serde_json::to_string(&config).unwrap();
322 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
323 assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
324 }
325
326 #[test]
327 fn test_external_proxy_config_with_bypass_hosts() {
328 let config = ProxyConfig {
329 external_proxy: Some(ExternalProxyConfig {
330 address: "squid.corp:3128".to_string(),
331 auth: None,
332 bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
333 }),
334 ..Default::default()
335 };
336 let json = serde_json::to_string(&config).unwrap();
337 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
338 let ext = deserialized.external_proxy.unwrap();
339 assert_eq!(ext.address, "squid.corp:3128");
340 assert_eq!(ext.bypass_hosts.len(), 2);
341 assert_eq!(ext.bypass_hosts[0], "internal.corp");
342 assert_eq!(ext.bypass_hosts[1], "*.private.net");
343 }
344
345 #[test]
346 fn test_external_proxy_config_bypass_hosts_default_empty() {
347 let json = r#"{"address": "proxy:3128", "auth": null}"#;
348 let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
349 assert!(ext.bypass_hosts.is_empty());
350 }
351
352 #[test]
357 fn test_endpoint_allowed_empty_rules_allows_all() {
358 assert!(endpoint_allowed(&[], "GET", "/anything"));
359 assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
360 }
361
362 fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
364 endpoint_allowed(std::slice::from_ref(rule), method, path)
365 }
366
367 #[test]
368 fn test_endpoint_rule_exact_path() {
369 let rule = EndpointRule {
370 method: "GET".to_string(),
371 path: "/v1/chat/completions".to_string(),
372 };
373 assert!(check(&rule, "GET", "/v1/chat/completions"));
374 assert!(!check(&rule, "GET", "/v1/chat"));
375 assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
376 }
377
378 #[test]
379 fn test_endpoint_rule_method_case_insensitive() {
380 let rule = EndpointRule {
381 method: "get".to_string(),
382 path: "/api".to_string(),
383 };
384 assert!(check(&rule, "GET", "/api"));
385 assert!(check(&rule, "Get", "/api"));
386 }
387
388 #[test]
389 fn test_endpoint_rule_method_wildcard() {
390 let rule = EndpointRule {
391 method: "*".to_string(),
392 path: "/api/resource".to_string(),
393 };
394 assert!(check(&rule, "GET", "/api/resource"));
395 assert!(check(&rule, "DELETE", "/api/resource"));
396 assert!(check(&rule, "POST", "/api/resource"));
397 }
398
399 #[test]
400 fn test_endpoint_rule_method_mismatch() {
401 let rule = EndpointRule {
402 method: "GET".to_string(),
403 path: "/api/resource".to_string(),
404 };
405 assert!(!check(&rule, "POST", "/api/resource"));
406 assert!(!check(&rule, "DELETE", "/api/resource"));
407 }
408
409 #[test]
410 fn test_endpoint_rule_single_wildcard() {
411 let rule = EndpointRule {
412 method: "GET".to_string(),
413 path: "/api/v4/projects/*/merge_requests".to_string(),
414 };
415 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
416 assert!(check(
417 &rule,
418 "GET",
419 "/api/v4/projects/my-proj/merge_requests"
420 ));
421 assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
422 }
423
424 #[test]
425 fn test_endpoint_rule_double_wildcard() {
426 let rule = EndpointRule {
427 method: "GET".to_string(),
428 path: "/api/v4/projects/**".to_string(),
429 };
430 assert!(check(&rule, "GET", "/api/v4/projects/123"));
431 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
432 assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
433 assert!(!check(&rule, "GET", "/api/v4/other"));
434 }
435
436 #[test]
437 fn test_endpoint_rule_double_wildcard_middle() {
438 let rule = EndpointRule {
439 method: "*".to_string(),
440 path: "/api/**/notes".to_string(),
441 };
442 assert!(check(&rule, "GET", "/api/notes"));
443 assert!(check(&rule, "POST", "/api/projects/123/notes"));
444 assert!(check(&rule, "GET", "/api/a/b/c/notes"));
445 assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
446 }
447
448 #[test]
449 fn test_endpoint_rule_strips_query_string() {
450 let rule = EndpointRule {
451 method: "GET".to_string(),
452 path: "/api/data".to_string(),
453 };
454 assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
455 }
456
457 #[test]
458 fn test_endpoint_rule_trailing_slash_normalized() {
459 let rule = EndpointRule {
460 method: "GET".to_string(),
461 path: "/api/data".to_string(),
462 };
463 assert!(check(&rule, "GET", "/api/data/"));
464 assert!(check(&rule, "GET", "/api/data"));
465 }
466
467 #[test]
468 fn test_endpoint_rule_double_slash_normalized() {
469 let rule = EndpointRule {
470 method: "GET".to_string(),
471 path: "/api/data".to_string(),
472 };
473 assert!(check(&rule, "GET", "/api//data"));
474 }
475
476 #[test]
477 fn test_endpoint_rule_root_path() {
478 let rule = EndpointRule {
479 method: "GET".to_string(),
480 path: "/".to_string(),
481 };
482 assert!(check(&rule, "GET", "/"));
483 assert!(!check(&rule, "GET", "/anything"));
484 }
485
486 #[test]
487 fn test_compiled_endpoint_rules_hot_path() {
488 let rules = vec![
489 EndpointRule {
490 method: "GET".to_string(),
491 path: "/repos/*/issues".to_string(),
492 },
493 EndpointRule {
494 method: "POST".to_string(),
495 path: "/repos/*/issues/*/comments".to_string(),
496 },
497 ];
498 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
499 assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
500 assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
501 assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
502 assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
503 }
504
505 #[test]
506 fn test_compiled_endpoint_rules_empty_allows_all() {
507 let compiled = CompiledEndpointRules::compile(&[]).unwrap();
508 assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
509 }
510
511 #[test]
512 fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
513 let rules = vec![EndpointRule {
514 method: "GET".to_string(),
515 path: "/api/[invalid".to_string(),
516 }];
517 assert!(CompiledEndpointRules::compile(&rules).is_err());
518 }
519
520 #[test]
521 fn test_endpoint_allowed_multiple_rules() {
522 let rules = vec![
523 EndpointRule {
524 method: "GET".to_string(),
525 path: "/repos/*/issues".to_string(),
526 },
527 EndpointRule {
528 method: "POST".to_string(),
529 path: "/repos/*/issues/*/comments".to_string(),
530 },
531 ];
532 assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
533 assert!(endpoint_allowed(
534 &rules,
535 "POST",
536 "/repos/myrepo/issues/42/comments"
537 ));
538 assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
539 assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
540 }
541
542 #[test]
543 fn test_endpoint_rule_serde_default() {
544 let json = r#"{
545 "prefix": "test",
546 "upstream": "https://example.com"
547 }"#;
548 let route: RouteConfig = serde_json::from_str(json).unwrap();
549 assert!(route.endpoint_rules.is_empty());
550 assert!(route.tls_ca.is_none());
551 }
552
553 #[test]
554 fn test_tls_ca_serde_roundtrip() {
555 let json = r#"{
556 "prefix": "k8s",
557 "upstream": "https://kubernetes.local:6443",
558 "tls_ca": "/run/secrets/k8s-ca.crt"
559 }"#;
560 let route: RouteConfig = serde_json::from_str(json).unwrap();
561 assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
562
563 let serialized = serde_json::to_string(&route).unwrap();
564 let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
565 assert_eq!(
566 deserialized.tls_ca.as_deref(),
567 Some("/run/secrets/k8s-ca.crt")
568 );
569 }
570
571 #[test]
572 fn test_endpoint_rule_percent_encoded_path_decoded() {
573 let rule = EndpointRule {
576 method: "GET".to_string(),
577 path: "/api/v4/projects/*/issues".to_string(),
578 };
579 assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
580 assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
581 }
582
583 #[test]
584 fn test_endpoint_rule_percent_encoded_full_segment() {
585 let rule = EndpointRule {
586 method: "POST".to_string(),
587 path: "/api/data".to_string(),
588 };
589 assert!(check(&rule, "POST", "/api/%64%61%74%61"));
591 }
592
593 #[test]
594 fn test_compiled_endpoint_rules_percent_encoded() {
595 let rules = vec![EndpointRule {
596 method: "GET".to_string(),
597 path: "/repos/*/issues".to_string(),
598 }];
599 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
600 assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
602 assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
603 }
604
605 #[test]
606 fn test_endpoint_rule_percent_encoded_invalid_utf8() {
607 let rule = EndpointRule {
611 method: "GET".to_string(),
612 path: "/api/projects".to_string(),
613 };
614 assert!(!check(&rule, "GET", "/api/%FFprojects"));
616 }
617
618 #[test]
619 fn test_endpoint_rule_serde_roundtrip() {
620 let rule = EndpointRule {
621 method: "GET".to_string(),
622 path: "/api/*/data".to_string(),
623 };
624 let json = serde_json::to_string(&rule).unwrap();
625 let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
626 assert_eq!(deserialized.method, "GET");
627 assert_eq!(deserialized.path, "/api/*/data");
628 }
629}