use regex::Regex;
use serde::Deserialize;
use std::collections::HashMap;
#[derive(Debug, Deserialize)]
pub struct Config {
#[serde(default)]
pub transport: TransportConfig,
#[serde(default)]
pub audit: Option<AuditConfig>,
#[serde(default)]
pub audits: Vec<AuditConfig>,
#[serde(default)]
pub agents: HashMap<String, AgentPolicy>,
pub default_policy: Option<AgentPolicy>,
#[serde(default)]
pub rules: Rules,
#[serde(default)]
pub upstreams: HashMap<String, String>,
pub auth: Option<AuthConfig>,
pub admin_token: Option<String>,
pub telemetry: Option<TelemetryConfig>,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum TransportConfig {
Http {
#[serde(default = "default_addr")]
addr: String,
#[serde(default = "default_upstream_url")]
upstream: String,
#[serde(default = "default_session_ttl")]
session_ttl_secs: u64,
tls: Option<TlsConfig>,
#[serde(default)]
circuit_breaker: CircuitBreakerConfig,
},
Stdio {
server: Vec<String>,
#[serde(default)]
verify: Option<BinaryVerifyConfig>,
},
}
#[derive(Debug, Deserialize, Clone)]
pub struct BinaryVerifyConfig {
pub sha256: Option<String>,
pub cosign_bundle: Option<String>,
pub cosign_identity: Option<String>,
pub cosign_issuer: Option<String>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct TlsConfig {
pub cert: String,
pub key: String,
}
impl Default for TransportConfig {
fn default() -> Self {
TransportConfig::Http {
addr: default_addr(),
upstream: default_upstream_url(),
session_ttl_secs: default_session_ttl(),
tls: None,
circuit_breaker: CircuitBreakerConfig::default(),
}
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct CircuitBreakerConfig {
#[serde(default = "default_cb_threshold")]
pub threshold: usize,
#[serde(default = "default_cb_recovery_secs")]
pub recovery_secs: u64,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
threshold: default_cb_threshold(),
recovery_secs: default_cb_recovery_secs(),
}
}
}
fn default_cb_threshold() -> usize {
5
}
fn default_cb_recovery_secs() -> u64 {
30
}
fn default_addr() -> String {
"0.0.0.0:4000".to_string()
}
fn default_upstream_url() -> String {
"http://localhost:3000/mcp".to_string()
}
fn default_session_ttl() -> u64 {
3600
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum AuditConfig {
Stdout,
Sqlite {
#[serde(default = "default_db_path")]
path: String,
max_entries: Option<usize>,
max_age_days: Option<u64>,
},
Webhook {
url: String,
token: Option<String>,
#[serde(default)]
cloudevents: bool,
#[serde(default = "default_ce_source")]
source: String,
},
}
fn default_ce_source() -> String {
"/arbit".to_string()
}
fn default_db_path() -> String {
"gateway-audit.db".to_string()
}
#[derive(Debug, Deserialize, Clone)]
pub struct AgentPolicy {
pub allowed_tools: Option<Vec<String>>,
#[serde(default)]
pub denied_tools: Vec<String>,
#[serde(default = "default_rate_limit")]
pub rate_limit: usize,
#[serde(default)]
pub tool_rate_limits: HashMap<String, usize>,
pub upstream: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub timeout_secs: Option<u64>,
#[serde(default)]
pub approval_required: Vec<String>,
#[serde(default = "default_hitl_timeout")]
pub hitl_timeout_secs: u64,
#[serde(default)]
pub shadow_tools: Vec<String>,
}
fn default_rate_limit() -> usize {
60
}
fn default_hitl_timeout() -> u64 {
60
}
pub(crate) fn tool_matches(pattern: &str, tool: &str) -> bool {
if !pattern.contains('*') {
return pattern == tool;
}
fn r#match(p: &[u8], t: &[u8]) -> bool {
match p.first() {
None => t.is_empty(),
Some(b'*') => (0..=t.len()).any(|i| r#match(&p[1..], &t[i..])),
Some(&c) => !t.is_empty() && t[0] == c && r#match(&p[1..], &t[1..]),
}
}
r#match(pattern.as_bytes(), tool.as_bytes())
}
#[derive(Debug, Deserialize, Clone)]
#[serde(untagged)]
pub enum AuthConfig {
Single(JwtConfig),
Multi(Vec<JwtConfig>),
}
impl AuthConfig {
pub fn into_configs(self) -> anyhow::Result<Vec<JwtConfig>> {
match self {
AuthConfig::Single(c) => Ok(vec![c.with_provider_defaults()?]),
AuthConfig::Multi(cs) => cs.into_iter().map(|c| c.with_provider_defaults()).collect(),
}
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct JwtConfig {
pub secret: Option<String>,
pub jwks_url: Option<String>,
pub issuer: Option<String>,
pub audience: Option<String>,
#[serde(default = "default_agent_claim")]
pub agent_claim: String,
#[serde(default)]
pub oidc_discovery: bool,
pub provider: Option<String>,
}
impl Default for JwtConfig {
fn default() -> Self {
Self {
secret: None,
jwks_url: None,
issuer: None,
audience: None,
agent_claim: default_agent_claim(),
oidc_discovery: false,
provider: None,
}
}
}
impl JwtConfig {
pub fn with_provider_defaults(mut self) -> anyhow::Result<Self> {
match self.provider.as_deref() {
Some("google") => {
self.issuer
.get_or_insert_with(|| "https://accounts.google.com".to_string());
self.oidc_discovery = true;
}
Some("github-actions") => {
self.issuer.get_or_insert_with(|| {
"https://token.actions.githubusercontent.com".to_string()
});
self.oidc_discovery = true;
}
Some("auth0") => {
if self.issuer.is_none() {
return Err(anyhow::anyhow!(
"provider 'auth0' requires 'issuer' to be set"
));
}
self.oidc_discovery = true;
}
Some("okta") => {
if self.issuer.is_none() {
return Err(anyhow::anyhow!(
"provider 'okta' requires 'issuer' to be set"
));
}
self.oidc_discovery = true;
}
Some(p) => {
return Err(anyhow::anyhow!(
"unknown auth provider '{p}'. Supported: google, github-actions, auth0, okta"
));
}
None => {}
}
Ok(self)
}
}
fn default_agent_claim() -> String {
"sub".to_string()
}
#[derive(Debug, Deserialize, Clone)]
pub struct TelemetryConfig {
pub otlp_endpoint: String,
#[serde(default = "default_service_name")]
pub service_name: String,
}
fn default_service_name() -> String {
"arbit".to_string()
}
#[derive(Debug, Deserialize, Clone, Copy, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum FilterMode {
#[default]
Block,
Redact,
}
#[derive(Debug, Deserialize, Default)]
pub struct Rules {
#[serde(default)]
pub block_patterns: Vec<String>,
pub ip_rate_limit: Option<usize>,
#[serde(default)]
pub block_prompt_injection: bool,
#[serde(default)]
pub filter_mode: FilterMode,
}
#[cfg(test)]
pub(crate) fn make_agent(
allowed: Option<Vec<&str>>,
denied: Vec<&str>,
rate_limit: usize,
) -> AgentPolicy {
AgentPolicy {
allowed_tools: allowed.map(|v| v.into_iter().map(String::from).collect()),
denied_tools: denied.into_iter().map(String::from).collect(),
rate_limit,
tool_rate_limits: std::collections::HashMap::new(),
upstream: None,
api_key: None,
timeout_secs: None,
approval_required: vec![],
hitl_timeout_secs: 60,
shadow_tools: vec![],
}
}
impl Config {
pub fn from_file(path: &str) -> anyhow::Result<Self> {
let s = std::fs::read_to_string(path)
.map_err(|e| anyhow::anyhow!("could not read '{}': {}", path, e))?;
let config: Self =
serde_yaml::from_str(&s).map_err(|e| anyhow::anyhow!("invalid config: {}", e))?;
config.validate()?;
Ok(config)
}
fn validate(&self) -> anyhow::Result<()> {
for pattern in &self.rules.block_patterns {
Regex::new(pattern)
.map_err(|e| anyhow::anyhow!("invalid block_pattern '{}': {}", pattern, e))?;
}
let tool_name_re = Regex::new(r"^[a-zA-Z0-9_/.\-*]+$").unwrap();
let all_policies = self
.agents
.iter()
.map(|(k, v)| (k.as_str(), v))
.chain(self.default_policy.as_ref().map(|p| ("default_policy", p)));
for (agent, policy) in all_policies {
for tool in policy
.allowed_tools
.iter()
.flatten()
.chain(&policy.denied_tools)
{
if !tool_name_re.is_match(tool) {
return Err(anyhow::anyhow!(
"agent '{}': invalid tool name '{}'",
agent,
tool
));
}
}
}
for (agent, policy) in &self.agents {
if let Some(upstream_name) = &policy.upstream
&& !self.upstreams.contains_key(upstream_name)
{
return Err(anyhow::anyhow!(
"agent '{}' references unknown upstream '{}'",
agent,
upstream_name
));
}
}
if let TransportConfig::Http { tls: Some(tls), .. } = &self.transport {
if !std::path::Path::new(&tls.cert).exists() {
return Err(anyhow::anyhow!("TLS cert file not found: {}", tls.cert));
}
if !std::path::Path::new(&tls.key).exists() {
return Err(anyhow::anyhow!("TLS key file not found: {}", tls.key));
}
}
if let TransportConfig::Http {
circuit_breaker: cb,
..
} = &self.transport
&& cb.threshold == 0
{
return Err(anyhow::anyhow!("circuit_breaker.threshold must be > 0"));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn base() -> Config {
Config {
transport: TransportConfig::default(),
audit: None,
audits: vec![],
agents: HashMap::new(),
default_policy: None,
rules: Rules::default(),
upstreams: HashMap::new(),
auth: None,
admin_token: None,
telemetry: None,
}
}
#[test]
fn empty_config_passes_validate() {
assert!(base().validate().is_ok());
}
#[test]
fn invalid_regex_is_rejected() {
let mut cfg = base();
cfg.rules.block_patterns = vec!["[unclosed".to_string()];
assert!(cfg.validate().is_err());
}
#[test]
fn valid_block_patterns_pass() {
let mut cfg = base();
cfg.rules.block_patterns = vec!["private_key".to_string(), r"\bsecret\b".to_string()];
assert!(cfg.validate().is_ok());
}
#[test]
fn tool_name_with_spaces_is_rejected() {
let mut cfg = base();
cfg.agents.insert(
"a".to_string(),
make_agent(Some(vec!["bad name"]), vec![], 60),
);
assert!(cfg.validate().is_err());
}
#[test]
fn tool_name_with_exclamation_is_rejected() {
let mut cfg = base();
cfg.agents
.insert("a".to_string(), make_agent(None, vec!["bad!tool"], 60));
assert!(cfg.validate().is_err());
}
#[test]
fn valid_tool_names_pass() {
let mut cfg = base();
cfg.agents.insert(
"a".to_string(),
make_agent(
Some(vec!["read_file", "list-dir", "tools/v2.echo"]),
vec!["delete_file"],
60,
),
);
assert!(cfg.validate().is_ok());
}
#[test]
fn unknown_upstream_reference_fails() {
let mut cfg = base();
let mut policy = make_agent(None, vec![], 60);
policy.upstream = Some("ghost".to_string());
cfg.agents.insert("a".to_string(), policy);
assert!(cfg.validate().is_err());
}
#[test]
fn known_upstream_reference_passes() {
let mut cfg = base();
cfg.upstreams
.insert("mcp".to_string(), "http://localhost:3000/mcp".to_string());
let mut policy = make_agent(None, vec![], 60);
policy.upstream = Some("mcp".to_string());
cfg.agents.insert("a".to_string(), policy);
assert!(cfg.validate().is_ok());
}
#[test]
fn zero_circuit_breaker_threshold_fails() {
let mut cfg = base();
cfg.transport = TransportConfig::Http {
addr: "0.0.0.0:4000".to_string(),
upstream: "http://localhost:3000/mcp".to_string(),
session_ttl_secs: 3600,
tls: None,
circuit_breaker: CircuitBreakerConfig {
threshold: 0,
recovery_secs: 30,
},
};
assert!(cfg.validate().is_err());
}
#[test]
fn google_preset_sets_issuer_and_discovery() {
let cfg = JwtConfig {
provider: Some("google".to_string()),
..JwtConfig::default()
}
.with_provider_defaults()
.unwrap();
assert_eq!(cfg.issuer.as_deref(), Some("https://accounts.google.com"));
assert!(cfg.oidc_discovery);
}
#[test]
fn github_actions_preset_sets_issuer() {
let cfg = JwtConfig {
provider: Some("github-actions".to_string()),
..JwtConfig::default()
}
.with_provider_defaults()
.unwrap();
assert_eq!(
cfg.issuer.as_deref(),
Some("https://token.actions.githubusercontent.com")
);
assert!(cfg.oidc_discovery);
}
#[test]
fn auth0_without_issuer_fails() {
let cfg = JwtConfig {
provider: Some("auth0".to_string()),
..JwtConfig::default()
};
assert!(cfg.with_provider_defaults().is_err());
}
#[test]
fn auth0_with_issuer_enables_discovery() {
let cfg = JwtConfig {
provider: Some("auth0".to_string()),
issuer: Some("https://myapp.auth0.com".to_string()),
..JwtConfig::default()
}
.with_provider_defaults()
.unwrap();
assert!(cfg.oidc_discovery);
}
#[test]
fn unknown_provider_fails() {
let cfg = JwtConfig {
provider: Some("magic".to_string()),
..JwtConfig::default()
};
assert!(cfg.with_provider_defaults().is_err());
}
#[test]
fn no_provider_is_unchanged() {
let cfg = JwtConfig {
secret: Some("s".to_string()),
..JwtConfig::default()
}
.with_provider_defaults()
.unwrap();
assert_eq!(cfg.secret.as_deref(), Some("s"));
assert!(!cfg.oidc_discovery);
}
#[test]
fn exact_match() {
assert!(tool_matches("read_file", "read_file"));
assert!(!tool_matches("read_file", "write_file"));
}
#[test]
fn suffix_wildcard() {
assert!(tool_matches("read_*", "read_file"));
assert!(tool_matches("read_*", "read_dir"));
assert!(tool_matches("read_*", "read_"));
assert!(!tool_matches("read_*", "write_file"));
}
#[test]
fn prefix_wildcard() {
assert!(tool_matches("*_file", "read_file"));
assert!(tool_matches("*_file", "write_file"));
assert!(!tool_matches("*_file", "read_dir"));
}
#[test]
fn star_matches_all() {
assert!(tool_matches("*", "read_file"));
assert!(tool_matches("*", "anything"));
assert!(tool_matches("*", ""));
}
#[test]
fn middle_wildcard() {
assert!(tool_matches("read_*_v2", "read_file_v2"));
assert!(!tool_matches("read_*_v2", "read_file_v3"));
}
#[test]
fn wildcard_in_denied_tools_validation() {
let mut cfg = base();
cfg.agents.insert(
"a".to_string(),
make_agent(Some(vec!["read_*", "list_*"]), vec!["delete_*"], 60),
);
assert!(cfg.validate().is_ok());
}
}