1use globset::Glob;
7use serde::{Deserialize, Serialize};
8use std::net::IpAddr;
9
10#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum InjectMode {
14 #[default]
16 Header,
17 UrlPath,
19 QueryParam,
21 BasicAuth,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProxyConfig {
28 #[serde(default = "default_bind_addr")]
30 pub bind_addr: IpAddr,
31
32 #[serde(default)]
34 pub bind_port: u16,
35
36 #[serde(default)]
39 pub allowed_hosts: Vec<String>,
40
41 #[serde(default)]
43 pub routes: Vec<RouteConfig>,
44
45 #[serde(default)]
48 pub external_proxy: Option<ExternalProxyConfig>,
49
50 #[serde(default)]
52 pub max_connections: usize,
53}
54
55impl Default for ProxyConfig {
56 fn default() -> Self {
57 Self {
58 bind_addr: default_bind_addr(),
59 bind_port: 0,
60 allowed_hosts: Vec::new(),
61 routes: Vec::new(),
62 external_proxy: None,
63 max_connections: 256,
64 }
65 }
66}
67
68fn default_bind_addr() -> IpAddr {
69 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RouteConfig {
75 pub prefix: String,
77
78 pub upstream: String,
80
81 pub credential_key: Option<String>,
84
85 #[serde(default)]
87 pub inject_mode: InjectMode,
88
89 #[serde(default = "default_inject_header")]
93 pub inject_header: String,
94
95 #[serde(default = "default_credential_format")]
99 pub credential_format: String,
100
101 #[serde(default)]
106 pub path_pattern: Option<String>,
107
108 #[serde(default)]
112 pub path_replacement: Option<String>,
113
114 #[serde(default)]
118 pub query_param_name: Option<String>,
119
120 #[serde(default)]
127 pub env_var: Option<String>,
128
129 #[serde(default)]
135 pub endpoint_rules: Vec<EndpointRule>,
136}
137
138#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
145pub struct EndpointRule {
146 pub method: String,
148 pub path: String,
151}
152
153pub struct CompiledEndpointRules {
159 rules: Vec<CompiledRule>,
160}
161
162struct CompiledRule {
163 method: String,
164 matcher: globset::GlobMatcher,
165}
166
167impl CompiledEndpointRules {
168 pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
171 let mut compiled = Vec::with_capacity(rules.len());
172 for rule in rules {
173 let glob = Glob::new(&rule.path)
174 .map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
175 compiled.push(CompiledRule {
176 method: rule.method.clone(),
177 matcher: glob.compile_matcher(),
178 });
179 }
180 Ok(Self { rules: compiled })
181 }
182
183 #[must_use]
186 pub fn is_allowed(&self, method: &str, path: &str) -> bool {
187 if self.rules.is_empty() {
188 return true;
189 }
190 let normalized = normalize_path(path);
191 self.rules.iter().any(|r| {
192 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
193 && r.matcher.is_match(&normalized)
194 })
195 }
196}
197
198impl std::fmt::Debug for CompiledEndpointRules {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 f.debug_struct("CompiledEndpointRules")
201 .field("count", &self.rules.len())
202 .finish()
203 }
204}
205
206#[cfg(test)]
212fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
213 if rules.is_empty() {
214 return true;
215 }
216 let normalized = normalize_path(path);
217 rules.iter().any(|r| {
218 (r.method == "*" || r.method.eq_ignore_ascii_case(method))
219 && Glob::new(&r.path)
220 .ok()
221 .map(|g| g.compile_matcher())
222 .is_some_and(|m| m.is_match(&normalized))
223 })
224}
225
226fn normalize_path(path: &str) -> String {
232 let path = path.split('?').next().unwrap_or(path);
234
235 let binary = urlencoding::decode_binary(path.as_bytes());
239 let decoded = String::from_utf8_lossy(&binary);
240
241 let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
244 if segments.is_empty() {
245 "/".to_string()
246 } else {
247 format!("/{}", segments.join("/"))
248 }
249}
250
251fn default_inject_header() -> String {
252 "Authorization".to_string()
253}
254
255fn default_credential_format() -> String {
256 "Bearer {}".to_string()
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct ExternalProxyConfig {
262 pub address: String,
264
265 pub auth: Option<ExternalProxyAuth>,
267
268 #[serde(default)]
272 pub bypass_hosts: Vec<String>,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct ExternalProxyAuth {
278 pub keyring_account: String,
280
281 #[serde(default = "default_auth_scheme")]
283 pub scheme: String,
284}
285
286fn default_auth_scheme() -> String {
287 "basic".to_string()
288}
289
290#[cfg(test)]
291#[allow(clippy::unwrap_used)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_default_config() {
297 let config = ProxyConfig::default();
298 assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
299 assert_eq!(config.bind_port, 0);
300 assert!(config.allowed_hosts.is_empty());
301 assert!(config.routes.is_empty());
302 assert!(config.external_proxy.is_none());
303 }
304
305 #[test]
306 fn test_config_serialization() {
307 let config = ProxyConfig {
308 allowed_hosts: vec!["api.openai.com".to_string()],
309 ..Default::default()
310 };
311 let json = serde_json::to_string(&config).unwrap();
312 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
313 assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
314 }
315
316 #[test]
317 fn test_external_proxy_config_with_bypass_hosts() {
318 let config = ProxyConfig {
319 external_proxy: Some(ExternalProxyConfig {
320 address: "squid.corp:3128".to_string(),
321 auth: None,
322 bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
323 }),
324 ..Default::default()
325 };
326 let json = serde_json::to_string(&config).unwrap();
327 let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
328 let ext = deserialized.external_proxy.unwrap();
329 assert_eq!(ext.address, "squid.corp:3128");
330 assert_eq!(ext.bypass_hosts.len(), 2);
331 assert_eq!(ext.bypass_hosts[0], "internal.corp");
332 assert_eq!(ext.bypass_hosts[1], "*.private.net");
333 }
334
335 #[test]
336 fn test_external_proxy_config_bypass_hosts_default_empty() {
337 let json = r#"{"address": "proxy:3128", "auth": null}"#;
338 let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
339 assert!(ext.bypass_hosts.is_empty());
340 }
341
342 #[test]
347 fn test_endpoint_allowed_empty_rules_allows_all() {
348 assert!(endpoint_allowed(&[], "GET", "/anything"));
349 assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
350 }
351
352 fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
354 endpoint_allowed(std::slice::from_ref(rule), method, path)
355 }
356
357 #[test]
358 fn test_endpoint_rule_exact_path() {
359 let rule = EndpointRule {
360 method: "GET".to_string(),
361 path: "/v1/chat/completions".to_string(),
362 };
363 assert!(check(&rule, "GET", "/v1/chat/completions"));
364 assert!(!check(&rule, "GET", "/v1/chat"));
365 assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
366 }
367
368 #[test]
369 fn test_endpoint_rule_method_case_insensitive() {
370 let rule = EndpointRule {
371 method: "get".to_string(),
372 path: "/api".to_string(),
373 };
374 assert!(check(&rule, "GET", "/api"));
375 assert!(check(&rule, "Get", "/api"));
376 }
377
378 #[test]
379 fn test_endpoint_rule_method_wildcard() {
380 let rule = EndpointRule {
381 method: "*".to_string(),
382 path: "/api/resource".to_string(),
383 };
384 assert!(check(&rule, "GET", "/api/resource"));
385 assert!(check(&rule, "DELETE", "/api/resource"));
386 assert!(check(&rule, "POST", "/api/resource"));
387 }
388
389 #[test]
390 fn test_endpoint_rule_method_mismatch() {
391 let rule = EndpointRule {
392 method: "GET".to_string(),
393 path: "/api/resource".to_string(),
394 };
395 assert!(!check(&rule, "POST", "/api/resource"));
396 assert!(!check(&rule, "DELETE", "/api/resource"));
397 }
398
399 #[test]
400 fn test_endpoint_rule_single_wildcard() {
401 let rule = EndpointRule {
402 method: "GET".to_string(),
403 path: "/api/v4/projects/*/merge_requests".to_string(),
404 };
405 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
406 assert!(check(
407 &rule,
408 "GET",
409 "/api/v4/projects/my-proj/merge_requests"
410 ));
411 assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
412 }
413
414 #[test]
415 fn test_endpoint_rule_double_wildcard() {
416 let rule = EndpointRule {
417 method: "GET".to_string(),
418 path: "/api/v4/projects/**".to_string(),
419 };
420 assert!(check(&rule, "GET", "/api/v4/projects/123"));
421 assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
422 assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
423 assert!(!check(&rule, "GET", "/api/v4/other"));
424 }
425
426 #[test]
427 fn test_endpoint_rule_double_wildcard_middle() {
428 let rule = EndpointRule {
429 method: "*".to_string(),
430 path: "/api/**/notes".to_string(),
431 };
432 assert!(check(&rule, "GET", "/api/notes"));
433 assert!(check(&rule, "POST", "/api/projects/123/notes"));
434 assert!(check(&rule, "GET", "/api/a/b/c/notes"));
435 assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
436 }
437
438 #[test]
439 fn test_endpoint_rule_strips_query_string() {
440 let rule = EndpointRule {
441 method: "GET".to_string(),
442 path: "/api/data".to_string(),
443 };
444 assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
445 }
446
447 #[test]
448 fn test_endpoint_rule_trailing_slash_normalized() {
449 let rule = EndpointRule {
450 method: "GET".to_string(),
451 path: "/api/data".to_string(),
452 };
453 assert!(check(&rule, "GET", "/api/data/"));
454 assert!(check(&rule, "GET", "/api/data"));
455 }
456
457 #[test]
458 fn test_endpoint_rule_double_slash_normalized() {
459 let rule = EndpointRule {
460 method: "GET".to_string(),
461 path: "/api/data".to_string(),
462 };
463 assert!(check(&rule, "GET", "/api//data"));
464 }
465
466 #[test]
467 fn test_endpoint_rule_root_path() {
468 let rule = EndpointRule {
469 method: "GET".to_string(),
470 path: "/".to_string(),
471 };
472 assert!(check(&rule, "GET", "/"));
473 assert!(!check(&rule, "GET", "/anything"));
474 }
475
476 #[test]
477 fn test_compiled_endpoint_rules_hot_path() {
478 let rules = vec![
479 EndpointRule {
480 method: "GET".to_string(),
481 path: "/repos/*/issues".to_string(),
482 },
483 EndpointRule {
484 method: "POST".to_string(),
485 path: "/repos/*/issues/*/comments".to_string(),
486 },
487 ];
488 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
489 assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
490 assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
491 assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
492 assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
493 }
494
495 #[test]
496 fn test_compiled_endpoint_rules_empty_allows_all() {
497 let compiled = CompiledEndpointRules::compile(&[]).unwrap();
498 assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
499 }
500
501 #[test]
502 fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
503 let rules = vec![EndpointRule {
504 method: "GET".to_string(),
505 path: "/api/[invalid".to_string(),
506 }];
507 assert!(CompiledEndpointRules::compile(&rules).is_err());
508 }
509
510 #[test]
511 fn test_endpoint_allowed_multiple_rules() {
512 let rules = vec![
513 EndpointRule {
514 method: "GET".to_string(),
515 path: "/repos/*/issues".to_string(),
516 },
517 EndpointRule {
518 method: "POST".to_string(),
519 path: "/repos/*/issues/*/comments".to_string(),
520 },
521 ];
522 assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
523 assert!(endpoint_allowed(
524 &rules,
525 "POST",
526 "/repos/myrepo/issues/42/comments"
527 ));
528 assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
529 assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
530 }
531
532 #[test]
533 fn test_endpoint_rule_serde_default() {
534 let json = r#"{
535 "prefix": "/test",
536 "upstream": "https://example.com"
537 }"#;
538 let route: RouteConfig = serde_json::from_str(json).unwrap();
539 assert!(route.endpoint_rules.is_empty());
540 }
541
542 #[test]
543 fn test_endpoint_rule_percent_encoded_path_decoded() {
544 let rule = EndpointRule {
547 method: "GET".to_string(),
548 path: "/api/v4/projects/*/issues".to_string(),
549 };
550 assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
551 assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
552 }
553
554 #[test]
555 fn test_endpoint_rule_percent_encoded_full_segment() {
556 let rule = EndpointRule {
557 method: "POST".to_string(),
558 path: "/api/data".to_string(),
559 };
560 assert!(check(&rule, "POST", "/api/%64%61%74%61"));
562 }
563
564 #[test]
565 fn test_compiled_endpoint_rules_percent_encoded() {
566 let rules = vec![EndpointRule {
567 method: "GET".to_string(),
568 path: "/repos/*/issues".to_string(),
569 }];
570 let compiled = CompiledEndpointRules::compile(&rules).unwrap();
571 assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
573 assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
574 }
575
576 #[test]
577 fn test_endpoint_rule_percent_encoded_invalid_utf8() {
578 let rule = EndpointRule {
582 method: "GET".to_string(),
583 path: "/api/projects".to_string(),
584 };
585 assert!(!check(&rule, "GET", "/api/%FFprojects"));
587 }
588
589 #[test]
590 fn test_endpoint_rule_serde_roundtrip() {
591 let rule = EndpointRule {
592 method: "GET".to_string(),
593 path: "/api/*/data".to_string(),
594 };
595 let json = serde_json::to_string(&rule).unwrap();
596 let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
597 assert_eq!(deserialized.method, "GET");
598 assert_eq!(deserialized.path, "/api/*/data");
599 }
600}