use globset::Glob;
use serde::{Deserialize, Serialize};
use std::net::IpAddr;
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum InjectMode {
#[default]
Header,
UrlPath,
QueryParam,
BasicAuth,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyConfig {
#[serde(default = "default_bind_addr")]
pub bind_addr: IpAddr,
#[serde(default)]
pub bind_port: u16,
#[serde(default)]
pub allowed_hosts: Vec<String>,
#[serde(default)]
pub routes: Vec<RouteConfig>,
#[serde(default)]
pub external_proxy: Option<ExternalProxyConfig>,
#[serde(default)]
pub direct_connect_ports: Vec<u16>,
#[serde(default)]
pub max_connections: usize,
}
impl Default for ProxyConfig {
fn default() -> Self {
Self {
bind_addr: default_bind_addr(),
bind_port: 0,
allowed_hosts: Vec::new(),
routes: Vec::new(),
external_proxy: None,
direct_connect_ports: Vec::new(),
max_connections: 256,
}
}
}
fn default_bind_addr() -> IpAddr {
IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouteConfig {
pub prefix: String,
pub upstream: String,
pub credential_key: Option<String>,
#[serde(default)]
pub inject_mode: InjectMode,
#[serde(default = "default_inject_header")]
pub inject_header: String,
#[serde(default = "default_credential_format")]
pub credential_format: String,
#[serde(default)]
pub path_pattern: Option<String>,
#[serde(default)]
pub path_replacement: Option<String>,
#[serde(default)]
pub query_param_name: Option<String>,
#[serde(default)]
pub proxy: Option<ProxyInjectConfig>,
#[serde(default)]
pub env_var: Option<String>,
#[serde(default)]
pub endpoint_rules: Vec<EndpointRule>,
#[serde(default)]
pub tls_ca: Option<String>,
#[serde(default)]
pub tls_client_cert: Option<String>,
#[serde(default)]
pub tls_client_key: Option<String>,
#[serde(default)]
pub oauth2: Option<OAuth2Config>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ProxyInjectConfig {
#[serde(default)]
pub inject_mode: Option<InjectMode>,
#[serde(default)]
pub inject_header: Option<String>,
#[serde(default)]
pub credential_format: Option<String>,
#[serde(default)]
pub path_pattern: Option<String>,
#[serde(default)]
pub path_replacement: Option<String>,
#[serde(default)]
pub query_param_name: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct EndpointRule {
pub method: String,
pub path: String,
}
pub struct CompiledEndpointRules {
rules: Vec<CompiledRule>,
}
struct CompiledRule {
method: String,
matcher: globset::GlobMatcher,
}
impl CompiledEndpointRules {
pub fn compile(rules: &[EndpointRule]) -> Result<Self, String> {
let mut compiled = Vec::with_capacity(rules.len());
for rule in rules {
let glob = Glob::new(&rule.path)
.map_err(|e| format!("invalid endpoint path pattern '{}': {}", rule.path, e))?;
compiled.push(CompiledRule {
method: rule.method.clone(),
matcher: glob.compile_matcher(),
});
}
Ok(Self { rules: compiled })
}
#[must_use]
pub fn is_allowed(&self, method: &str, path: &str) -> bool {
if self.rules.is_empty() {
return true;
}
let normalized = normalize_path(path);
self.rules.iter().any(|r| {
(r.method == "*" || r.method.eq_ignore_ascii_case(method))
&& r.matcher.is_match(&normalized)
})
}
}
impl std::fmt::Debug for CompiledEndpointRules {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompiledEndpointRules")
.field("count", &self.rules.len())
.finish()
}
}
#[cfg(test)]
fn endpoint_allowed(rules: &[EndpointRule], method: &str, path: &str) -> bool {
if rules.is_empty() {
return true;
}
let normalized = normalize_path(path);
rules.iter().any(|r| {
(r.method == "*" || r.method.eq_ignore_ascii_case(method))
&& Glob::new(&r.path)
.ok()
.map(|g| g.compile_matcher())
.is_some_and(|m| m.is_match(&normalized))
})
}
fn normalize_path(path: &str) -> String {
let path = path.split('?').next().unwrap_or(path);
let binary = urlencoding::decode_binary(path.as_bytes());
let decoded = String::from_utf8_lossy(&binary);
let segments: Vec<&str> = decoded.split('/').filter(|s| !s.is_empty()).collect();
if segments.is_empty() {
"/".to_string()
} else {
format!("/{}", segments.join("/"))
}
}
fn default_inject_header() -> String {
"Authorization".to_string()
}
fn default_credential_format() -> String {
"Bearer {}".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExternalProxyConfig {
pub address: String,
pub auth: Option<ExternalProxyAuth>,
#[serde(default)]
pub bypass_hosts: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExternalProxyAuth {
pub keyring_account: String,
#[serde(default = "default_auth_scheme")]
pub scheme: String,
}
fn default_auth_scheme() -> String {
"basic".to_string()
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OAuth2Config {
pub token_url: String,
pub client_id: String,
pub client_secret: String,
#[serde(default)]
pub scope: String,
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = ProxyConfig::default();
assert_eq!(config.bind_addr, IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
assert_eq!(config.bind_port, 0);
assert!(config.allowed_hosts.is_empty());
assert!(config.routes.is_empty());
assert!(config.external_proxy.is_none());
}
#[test]
fn test_config_serialization() {
let config = ProxyConfig {
allowed_hosts: vec!["api.openai.com".to_string()],
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.allowed_hosts, vec!["api.openai.com"]);
}
#[test]
fn test_external_proxy_config_with_bypass_hosts() {
let config = ProxyConfig {
external_proxy: Some(ExternalProxyConfig {
address: "squid.corp:3128".to_string(),
auth: None,
bypass_hosts: vec!["internal.corp".to_string(), "*.private.net".to_string()],
}),
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: ProxyConfig = serde_json::from_str(&json).unwrap();
let ext = deserialized.external_proxy.unwrap();
assert_eq!(ext.address, "squid.corp:3128");
assert_eq!(ext.bypass_hosts.len(), 2);
assert_eq!(ext.bypass_hosts[0], "internal.corp");
assert_eq!(ext.bypass_hosts[1], "*.private.net");
}
#[test]
fn test_external_proxy_config_bypass_hosts_default_empty() {
let json = r#"{"address": "proxy:3128", "auth": null}"#;
let ext: ExternalProxyConfig = serde_json::from_str(json).unwrap();
assert!(ext.bypass_hosts.is_empty());
}
#[test]
fn test_endpoint_allowed_empty_rules_allows_all() {
assert!(endpoint_allowed(&[], "GET", "/anything"));
assert!(endpoint_allowed(&[], "DELETE", "/admin/nuke"));
}
fn check(rule: &EndpointRule, method: &str, path: &str) -> bool {
endpoint_allowed(std::slice::from_ref(rule), method, path)
}
#[test]
fn test_endpoint_rule_exact_path() {
let rule = EndpointRule {
method: "GET".to_string(),
path: "/v1/chat/completions".to_string(),
};
assert!(check(&rule, "GET", "/v1/chat/completions"));
assert!(!check(&rule, "GET", "/v1/chat"));
assert!(!check(&rule, "GET", "/v1/chat/completions/extra"));
}
#[test]
fn test_endpoint_rule_method_case_insensitive() {
let rule = EndpointRule {
method: "get".to_string(),
path: "/api".to_string(),
};
assert!(check(&rule, "GET", "/api"));
assert!(check(&rule, "Get", "/api"));
}
#[test]
fn test_endpoint_rule_method_wildcard() {
let rule = EndpointRule {
method: "*".to_string(),
path: "/api/resource".to_string(),
};
assert!(check(&rule, "GET", "/api/resource"));
assert!(check(&rule, "DELETE", "/api/resource"));
assert!(check(&rule, "POST", "/api/resource"));
}
#[test]
fn test_endpoint_rule_method_mismatch() {
let rule = EndpointRule {
method: "GET".to_string(),
path: "/api/resource".to_string(),
};
assert!(!check(&rule, "POST", "/api/resource"));
assert!(!check(&rule, "DELETE", "/api/resource"));
}
#[test]
fn test_endpoint_rule_single_wildcard() {
let rule = EndpointRule {
method: "GET".to_string(),
path: "/api/v4/projects/*/merge_requests".to_string(),
};
assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
assert!(check(
&rule,
"GET",
"/api/v4/projects/my-proj/merge_requests"
));
assert!(!check(&rule, "GET", "/api/v4/projects/merge_requests"));
}
#[test]
fn test_endpoint_rule_double_wildcard() {
let rule = EndpointRule {
method: "GET".to_string(),
path: "/api/v4/projects/**".to_string(),
};
assert!(check(&rule, "GET", "/api/v4/projects/123"));
assert!(check(&rule, "GET", "/api/v4/projects/123/merge_requests"));
assert!(check(&rule, "GET", "/api/v4/projects/a/b/c/d"));
assert!(!check(&rule, "GET", "/api/v4/other"));
}
#[test]
fn test_endpoint_rule_double_wildcard_middle() {
let rule = EndpointRule {
method: "*".to_string(),
path: "/api/**/notes".to_string(),
};
assert!(check(&rule, "GET", "/api/notes"));
assert!(check(&rule, "POST", "/api/projects/123/notes"));
assert!(check(&rule, "GET", "/api/a/b/c/notes"));
assert!(!check(&rule, "GET", "/api/a/b/c/comments"));
}
#[test]
fn test_endpoint_rule_strips_query_string() {
let rule = EndpointRule {
method: "GET".to_string(),
path: "/api/data".to_string(),
};
assert!(check(&rule, "GET", "/api/data?page=1&limit=10"));
}
#[test]
fn test_endpoint_rule_trailing_slash_normalized() {
let rule = EndpointRule {
method: "GET".to_string(),
path: "/api/data".to_string(),
};
assert!(check(&rule, "GET", "/api/data/"));
assert!(check(&rule, "GET", "/api/data"));
}
#[test]
fn test_endpoint_rule_double_slash_normalized() {
let rule = EndpointRule {
method: "GET".to_string(),
path: "/api/data".to_string(),
};
assert!(check(&rule, "GET", "/api//data"));
}
#[test]
fn test_endpoint_rule_root_path() {
let rule = EndpointRule {
method: "GET".to_string(),
path: "/".to_string(),
};
assert!(check(&rule, "GET", "/"));
assert!(!check(&rule, "GET", "/anything"));
}
#[test]
fn test_compiled_endpoint_rules_hot_path() {
let rules = vec![
EndpointRule {
method: "GET".to_string(),
path: "/repos/*/issues".to_string(),
},
EndpointRule {
method: "POST".to_string(),
path: "/repos/*/issues/*/comments".to_string(),
},
];
let compiled = CompiledEndpointRules::compile(&rules).unwrap();
assert!(compiled.is_allowed("GET", "/repos/myrepo/issues"));
assert!(compiled.is_allowed("POST", "/repos/myrepo/issues/42/comments"));
assert!(!compiled.is_allowed("DELETE", "/repos/myrepo"));
assert!(!compiled.is_allowed("GET", "/repos/myrepo/pulls"));
}
#[test]
fn test_compiled_endpoint_rules_empty_allows_all() {
let compiled = CompiledEndpointRules::compile(&[]).unwrap();
assert!(compiled.is_allowed("DELETE", "/admin/nuke"));
}
#[test]
fn test_compiled_endpoint_rules_invalid_pattern_rejected() {
let rules = vec![EndpointRule {
method: "GET".to_string(),
path: "/api/[invalid".to_string(),
}];
assert!(CompiledEndpointRules::compile(&rules).is_err());
}
#[test]
fn test_endpoint_allowed_multiple_rules() {
let rules = vec![
EndpointRule {
method: "GET".to_string(),
path: "/repos/*/issues".to_string(),
},
EndpointRule {
method: "POST".to_string(),
path: "/repos/*/issues/*/comments".to_string(),
},
];
assert!(endpoint_allowed(&rules, "GET", "/repos/myrepo/issues"));
assert!(endpoint_allowed(
&rules,
"POST",
"/repos/myrepo/issues/42/comments"
));
assert!(!endpoint_allowed(&rules, "DELETE", "/repos/myrepo"));
assert!(!endpoint_allowed(&rules, "GET", "/repos/myrepo/pulls"));
}
#[test]
fn test_endpoint_rule_serde_default() {
let json = r#"{
"prefix": "test",
"upstream": "https://example.com"
}"#;
let route: RouteConfig = serde_json::from_str(json).unwrap();
assert!(route.endpoint_rules.is_empty());
assert!(route.tls_ca.is_none());
}
#[test]
fn test_tls_ca_serde_roundtrip() {
let json = r#"{
"prefix": "k8s",
"upstream": "https://kubernetes.local:6443",
"tls_ca": "/run/secrets/k8s-ca.crt"
}"#;
let route: RouteConfig = serde_json::from_str(json).unwrap();
assert_eq!(route.tls_ca.as_deref(), Some("/run/secrets/k8s-ca.crt"));
let serialized = serde_json::to_string(&route).unwrap();
let deserialized: RouteConfig = serde_json::from_str(&serialized).unwrap();
assert_eq!(
deserialized.tls_ca.as_deref(),
Some("/run/secrets/k8s-ca.crt")
);
}
#[test]
fn test_endpoint_rule_percent_encoded_path_decoded() {
let rule = EndpointRule {
method: "GET".to_string(),
path: "/api/v4/projects/*/issues".to_string(),
};
assert!(check(&rule, "GET", "/api/v4/%70rojects/123/issues"));
assert!(check(&rule, "GET", "/api/v4/pro%6Aects/123/issues"));
}
#[test]
fn test_endpoint_rule_percent_encoded_full_segment() {
let rule = EndpointRule {
method: "POST".to_string(),
path: "/api/data".to_string(),
};
assert!(check(&rule, "POST", "/api/%64%61%74%61"));
}
#[test]
fn test_compiled_endpoint_rules_percent_encoded() {
let rules = vec![EndpointRule {
method: "GET".to_string(),
path: "/repos/*/issues".to_string(),
}];
let compiled = CompiledEndpointRules::compile(&rules).unwrap();
assert!(compiled.is_allowed("GET", "/repos/myrepo/%69ssues"));
assert!(!compiled.is_allowed("GET", "/repos/myrepo/%70ulls"));
}
#[test]
fn test_endpoint_rule_percent_encoded_invalid_utf8() {
let rule = EndpointRule {
method: "GET".to_string(),
path: "/api/projects".to_string(),
};
assert!(!check(&rule, "GET", "/api/%FFprojects"));
}
#[test]
fn test_endpoint_rule_serde_roundtrip() {
let rule = EndpointRule {
method: "GET".to_string(),
path: "/api/*/data".to_string(),
};
let json = serde_json::to_string(&rule).unwrap();
let deserialized: EndpointRule = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.method, "GET");
assert_eq!(deserialized.path, "/api/*/data");
}
#[test]
fn test_oauth2_config_deserialization() {
let json = r#"{
"token_url": "https://auth.example.com/oauth/token",
"client_id": "my-client",
"client_secret": "env://CLIENT_SECRET",
"scope": "read write"
}"#;
let config: OAuth2Config = serde_json::from_str(json).unwrap();
assert_eq!(config.token_url, "https://auth.example.com/oauth/token");
assert_eq!(config.client_id, "my-client");
assert_eq!(config.client_secret, "env://CLIENT_SECRET");
assert_eq!(config.scope, "read write");
}
#[test]
fn test_oauth2_config_default_scope() {
let json = r#"{
"token_url": "https://auth.example.com/oauth/token",
"client_id": "my-client",
"client_secret": "env://SECRET"
}"#;
let config: OAuth2Config = serde_json::from_str(json).unwrap();
assert_eq!(config.scope, "");
}
#[test]
fn test_route_config_with_oauth2() {
let json = r#"{
"prefix": "/my-api",
"upstream": "https://api.example.com",
"oauth2": {
"token_url": "https://auth.example.com/oauth/token",
"client_id": "agent-1",
"client_secret": "env://CLIENT_SECRET",
"scope": "api.read"
}
}"#;
let route: RouteConfig = serde_json::from_str(json).unwrap();
assert!(route.oauth2.is_some());
assert!(route.credential_key.is_none());
let oauth2 = route.oauth2.unwrap();
assert_eq!(oauth2.token_url, "https://auth.example.com/oauth/token");
}
#[test]
fn test_route_config_without_oauth2() {
let json = r#"{
"prefix": "/openai",
"upstream": "https://api.openai.com",
"credential_key": "openai"
}"#;
let route: RouteConfig = serde_json::from_str(json).unwrap();
assert!(route.oauth2.is_none());
assert!(route.credential_key.is_some());
}
}