1use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9
10#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum InjectMode {
14 #[default]
16 Header,
17 UrlPath,
19 QueryParam,
21 BasicAuth,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProxyConfig {
28 #[serde(default = "default_bind_addr")]
30 pub bind_addr: IpAddr,
31
32 #[serde(default)]
34 pub bind_port: u16,
35
36 #[serde(default)]
39 pub allowed_hosts: Vec<String>,
40
41 #[serde(default)]
43 pub routes: Vec<RouteConfig>,
44
45 #[serde(default)]
48 pub external_proxy: Option<ExternalProxyConfig>,
49
50 #[serde(default)]
52 pub max_connections: usize,
53}
54
55impl Default for ProxyConfig {
56 fn default() -> Self {
57 Self {
58 bind_addr: default_bind_addr(),
59 bind_port: 0,
60 allowed_hosts: Vec::new(),
61 routes: Vec::new(),
62 external_proxy: None,
63 max_connections: 256,
64 }
65 }
66}
67
68fn default_bind_addr() -> IpAddr {
69 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RouteConfig {
75 pub prefix: String,
77
78 pub upstream: String,
80
81 pub credential_key: Option<String>,
84
85 #[serde(default)]
87 pub inject_mode: InjectMode,
88
89 #[serde(default = "default_inject_header")]
93 pub inject_header: String,
94
95 #[serde(default = "default_credential_format")]
99 pub credential_format: String,
100
101 #[serde(default)]
106 pub path_pattern: Option<String>,
107
108 #[serde(default)]
112 pub path_replacement: Option<String>,
113
114 #[serde(default)]
118 pub query_param_name: Option<String>,
119
120 #[serde(default)]
127 pub env_var: Option<String>,
128
129 #[serde(default)]
135 pub endpoint_rules: Vec<EndpointRule>,
136
137 #[serde(default)]
144 pub tls_ca: Option<String>,
145}
146
147#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
154pub struct EndpointRule {
155 pub method: String,
157 pub path: String,
160}
161
162pub struct CompiledEndpointRules {
168 rules: Vec<CompiledRule>,
169}
170
171struct CompiledRule {
172 method: String,
173 matcher: globset::GlobMatcher,
174}
175
176impl CompiledEndpointRules {
177 pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
180 let mut compiled = Vec::with_capacity(rules.len());
181 for rule in rules {
182 let glob = Glob::new(&rule.path)
183 .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
184 compiled.push(CompiledRule {
185 method: rule.method.clone(),
186 matcher: glob.compile_matcher(),
187 });
188 }
189 Ok(Self { rules: compiled })
190 }
191
192 #[must_use]
195 pub fn is_allowed(&self, method: &str, path: &str) -> bool {
196 if self.rules.is_empty() {
197 return true;
198 }
199 let normalized = normalize_path(path);
200 self.rules.iter().any(|r| {
201 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
202 && r.matcher.is_match(&normalized)
203 })
204 }
205}
206
207impl std::fmt::Debug for CompiledEndpointRules {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 f.debug_struct("CompiledEndpointRules")
210 .field("count", &self.rules.len())
211 .finish()
212 }
213}
214
215#[cfg(test)]
221fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
222 if rules.is_empty() {
223 return true;
224 }
225 let normalized = normalize_path(path);
226 rules.iter().any(|r| {
227 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
228 && Glob::new(&r.path)
229 .ok()
230 .map(|g| g.compile_matcher())
231 .is_some_and(|m| m.is_match(&normalized))
232 })
233}
234
235fn normalize_path(path: &str) -> String {
241 let path = path.split('?').next().unwrap_or(path);
243
244 let binary = urlencoding::decode_binary(path.as_bytes());
248 let decoded = String::from_utf8_lossy(&binary);
249
250 let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
253 if segments.is_empty() {
254 "/".to_string()
255 } else {
256 format!("/{}", segments.join("/"))
257 }
258}
259
260fn default_inject_header() -> String {
261 "Authorization".to_string()
262}
263
264fn default_credential_format() -> String {
265 "Bearer {}".to_string()
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ExternalProxyConfig {
271 pub address: String,
273
274 pub auth: Option<ExternalProxyAuth>,
276
277 #[serde(default)]
281 pub bypass_hosts: Vec<String>,
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct ExternalProxyAuth {
287 pub keyring_account: String,
289
290 #[serde(default = "default_auth_scheme")]
292 pub scheme: String,
293}
294
295fn default_auth_scheme() -> String {
296 "basic".to_string()
297}
298
299#[cfg(test)]
300#[allow(clippy::unwrap_used)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_default_config() {
306 let config = ProxyConfig::default();
307 assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
308 assert_eq!(config.bind_port, 0);
309 assert!(config.allowed_hosts.is_empty());
310 assert!(config.routes.is_empty());
311 assert!(config.external_proxy.is_none());
312 }
313
314 #[test]
315 fn test_config_serialization() {
316 let config = ProxyConfig {
317 allowed_hosts: vec!["api.openai.com".to_string()],
318 ..Default::default()
319 };
320 let json = serde_json::to_string(&config).unwrap();
321 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
322 assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
323 }
324
325 #[test]
326 fn test_external_proxy_config_with_bypass_hosts() {
327 let config = ProxyConfig {
328 external_proxy: Some(ExternalProxyConfig {
329 address: "squid.corp:3128".to_string(),
330 auth: None,
331 bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
332 }),
333 ..Default::default()
334 };
335 let json = serde_json::to_string(&config).unwrap();
336 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
337 let ext = deserialized.external_proxy.unwrap();
338 assert_eq!(ext.address, "squid.corp:3128");
339 assert_eq!(ext.bypass_hosts.len(), 2);
340 assert_eq!(ext.bypass_hosts[0], "internal.corp");
341 assert_eq!(ext.bypass_hosts[1], "*.private.net");
342 }
343
344 #[test]
345 fn test_external_proxy_config_bypass_hosts_default_empty() {
346 let json = r#"{"address": "proxy:3128", "auth": null}"#;
347 let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
348 assert!(ext.bypass_hosts.is_empty());
349 }
350
351 #[test]
356 fn test_endpoint_allowed_empty_rules_allows_all() {
357 assert!(endpoint_allowed(&[], "GET", "/anything"));
358 assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
359 }
360
361 fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
363 endpoint_allowed(std::slice::from_ref(rule), method, path)
364 }
365
366 #[test]
367 fn test_endpoint_rule_exact_path() {
368 let rule = EndpointRule {
369 method: "GET".to_string(),
370 path: "/v1/chat/completions".to_string(),
371 };
372 assert!(check(&rule, "GET", "/v1/chat/completions"));
373 assert!(!check(&rule, "GET", "/v1/chat"));
374 assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
375 }
376
377 #[test]
378 fn test_endpoint_rule_method_case_insensitive() {
379 let rule = EndpointRule {
380 method: "get".to_string(),
381 path: "/api".to_string(),
382 };
383 assert!(check(&rule, "GET", "/api"));
384 assert!(check(&rule, "Get", "/api"));
385 }
386
387 #[test]
388 fn test_endpoint_rule_method_wildcard() {
389 let rule = EndpointRule {
390 method: "*".to_string(),
391 path: "/api/resource".to_string(),
392 };
393 assert!(check(&rule, "GET", "/api/resource"));
394 assert!(check(&rule, "DELETE", "/api/resource"));
395 assert!(check(&rule, "POST", "/api/resource"));
396 }
397
398 #[test]
399 fn test_endpoint_rule_method_mismatch() {
400 let rule = EndpointRule {
401 method: "GET".to_string(),
402 path: "/api/resource".to_string(),
403 };
404 assert!(!check(&rule, "POST", "/api/resource"));
405 assert!(!check(&rule, "DELETE", "/api/resource"));
406 }
407
408 #[test]
409 fn test_endpoint_rule_single_wildcard() {
410 let rule = EndpointRule {
411 method: "GET".to_string(),
412 path: "/api/v4/projects/*/merge_requests".to_string(),
413 };
414 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
415 assert!(check(
416 &rule,
417 "GET",
418 "/api/v4/projects/my-proj/merge_requests"
419 ));
420 assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
421 }
422
423 #[test]
424 fn test_endpoint_rule_double_wildcard() {
425 let rule = EndpointRule {
426 method: "GET".to_string(),
427 path: "/api/v4/projects/**".to_string(),
428 };
429 assert!(check(&rule, "GET", "/api/v4/projects/123"));
430 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
431 assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
432 assert!(!check(&rule, "GET", "/api/v4/other"));
433 }
434
435 #[test]
436 fn test_endpoint_rule_double_wildcard_middle() {
437 let rule = EndpointRule {
438 method: "*".to_string(),
439 path: "/api/**/notes".to_string(),
440 };
441 assert!(check(&rule, "GET", "/api/notes"));
442 assert!(check(&rule, "POST", "/api/projects/123/notes"));
443 assert!(check(&rule, "GET", "/api/a/b/c/notes"));
444 assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
445 }
446
447 #[test]
448 fn test_endpoint_rule_strips_query_string() {
449 let rule = EndpointRule {
450 method: "GET".to_string(),
451 path: "/api/data".to_string(),
452 };
453 assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
454 }
455
456 #[test]
457 fn test_endpoint_rule_trailing_slash_normalized() {
458 let rule = EndpointRule {
459 method: "GET".to_string(),
460 path: "/api/data".to_string(),
461 };
462 assert!(check(&rule, "GET", "/api/data/"));
463 assert!(check(&rule, "GET", "/api/data"));
464 }
465
466 #[test]
467 fn test_endpoint_rule_double_slash_normalized() {
468 let rule = EndpointRule {
469 method: "GET".to_string(),
470 path: "/api/data".to_string(),
471 };
472 assert!(check(&rule, "GET", "/api//data"));
473 }
474
475 #[test]
476 fn test_endpoint_rule_root_path() {
477 let rule = EndpointRule {
478 method: "GET".to_string(),
479 path: "/".to_string(),
480 };
481 assert!(check(&rule, "GET", "/"));
482 assert!(!check(&rule, "GET", "/anything"));
483 }
484
485 #[test]
486 fn test_compiled_endpoint_rules_hot_path() {
487 let rules = vec![
488 EndpointRule {
489 method: "GET".to_string(),
490 path: "/repos/*/issues".to_string(),
491 },
492 EndpointRule {
493 method: "POST".to_string(),
494 path: "/repos/*/issues/*/comments".to_string(),
495 },
496 ];
497 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
498 assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
499 assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
500 assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
501 assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
502 }
503
504 #[test]
505 fn test_compiled_endpoint_rules_empty_allows_all() {
506 let compiled = CompiledEndpointRules::compile(&[]).unwrap();
507 assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
508 }
509
510 #[test]
511 fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
512 let rules = vec![EndpointRule {
513 method: "GET".to_string(),
514 path: "/api/[invalid".to_string(),
515 }];
516 assert!(CompiledEndpointRules::compile(&rules).is_err());
517 }
518
519 #[test]
520 fn test_endpoint_allowed_multiple_rules() {
521 let rules = vec![
522 EndpointRule {
523 method: "GET".to_string(),
524 path: "/repos/*/issues".to_string(),
525 },
526 EndpointRule {
527 method: "POST".to_string(),
528 path: "/repos/*/issues/*/comments".to_string(),
529 },
530 ];
531 assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
532 assert!(endpoint_allowed(
533 &rules,
534 "POST",
535 "/repos/myrepo/issues/42/comments"
536 ));
537 assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
538 assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
539 }
540
541 #[test]
542 fn test_endpoint_rule_serde_default() {
543 let json = r#"{
544 "prefix": "/test",
545 "upstream": "https://example.com"
546 }"#;
547 let route: RouteConfig = serde_json::from_str(json).unwrap();
548 assert!(route.endpoint_rules.is_empty());
549 assert!(route.tls_ca.is_none());
550 }
551
552 #[test]
553 fn test_tls_ca_serde_roundtrip() {
554 let json = r#"{
555 "prefix": "/k8s",
556 "upstream": "https://kubernetes.local:6443",
557 "tls_ca": "/run/secrets/k8s-ca.crt"
558 }"#;
559 let route: RouteConfig = serde_json::from_str(json).unwrap();
560 assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
561
562 let serialized = serde_json::to_string(&route).unwrap();
563 let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
564 assert_eq!(
565 deserialized.tls_ca.as_deref(),
566 Some("/run/secrets/k8s-ca.crt")
567 );
568 }
569
570 #[test]
571 fn test_endpoint_rule_percent_encoded_path_decoded() {
572 let rule = EndpointRule {
575 method: "GET".to_string(),
576 path: "/api/v4/projects/*/issues".to_string(),
577 };
578 assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
579 assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
580 }
581
582 #[test]
583 fn test_endpoint_rule_percent_encoded_full_segment() {
584 let rule = EndpointRule {
585 method: "POST".to_string(),
586 path: "/api/data".to_string(),
587 };
588 assert!(check(&rule, "POST", "/api/%64%61%74%61"));
590 }
591
592 #[test]
593 fn test_compiled_endpoint_rules_percent_encoded() {
594 let rules = vec![EndpointRule {
595 method: "GET".to_string(),
596 path: "/repos/*/issues".to_string(),
597 }];
598 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
599 assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
601 assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
602 }
603
604 #[test]
605 fn test_endpoint_rule_percent_encoded_invalid_utf8() {
606 let rule = EndpointRule {
610 method: "GET".to_string(),
611 path: "/api/projects".to_string(),
612 };
613 assert!(!check(&rule, "GET", "/api/%FFprojects"));
615 }
616
617 #[test]
618 fn test_endpoint_rule_serde_roundtrip() {
619 let rule = EndpointRule {
620 method: "GET".to_string(),
621 path: "/api/*/data".to_string(),
622 };
623 let json = serde_json::to_string(&rule).unwrap();
624 let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
625 assert_eq!(deserialized.method, "GET");
626 assert_eq!(deserialized.path, "/api/*/data");
627 }
628}