use crate::TRonError;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::RwLock;
#[non_exhaustive]
pub enum PolicyResult {
Allow,
Deny(String),
UnknownAgent,
UnknownTool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentPolicy {
#[serde(default)]
pub allow: Vec<String>,
#[serde(default)]
pub deny: Vec<String>,
#[serde(default)]
pub rate_limit: Option<RateLimitPolicy>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitPolicy {
pub calls_per_minute: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PolicyConfig {
#[serde(default)]
pub agent: HashMap<String, AgentPolicy>,
}
pub struct PolicyEngine {
config: RwLock<PolicyConfig>,
}
impl Default for PolicyEngine {
fn default() -> Self {
Self::new()
}
}
impl PolicyEngine {
pub fn new() -> Self {
Self {
config: RwLock::new(PolicyConfig::default()),
}
}
#[must_use]
pub fn config(&self) -> PolicyConfig {
self.config
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.clone()
}
pub fn load_toml(&self, toml_str: &str) -> Result<(), TRonError> {
let config: PolicyConfig =
toml::from_str(toml_str).map_err(|e| TRonError::PolicyConfig(e.to_string()))?;
let mut guard = self
.config
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
*guard = config;
tracing::info!("policy reloaded");
Ok(())
}
#[must_use]
pub fn check(&self, agent_id: &str, tool_name: &str) -> PolicyResult {
let config = self
.config
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let policy = match config.agent.get(agent_id) {
Some(p) => p,
None => return PolicyResult::UnknownAgent,
};
for pattern in &policy.deny {
if matches_glob(pattern, tool_name) {
return PolicyResult::Deny(format!(
"tool '{tool_name}' denied by policy for agent '{agent_id}'"
));
}
}
for pattern in &policy.allow {
if matches_glob(pattern, tool_name) {
return PolicyResult::Allow;
}
}
PolicyResult::UnknownTool
}
pub fn grant(&self, agent_id: &str, pattern: &str) {
let mut config = self
.config
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let policy = config
.agent
.entry(agent_id.to_string())
.or_insert_with(|| AgentPolicy {
allow: vec![],
deny: vec![],
rate_limit: None,
});
policy.allow.push(pattern.to_string());
}
pub fn revoke(&self, agent_id: &str, pattern: &str) {
let mut config = self
.config
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let policy = config
.agent
.entry(agent_id.to_string())
.or_insert_with(|| AgentPolicy {
allow: vec![],
deny: vec![],
rate_limit: None,
});
policy.deny.push(pattern.to_string());
}
}
#[inline]
fn matches_glob(pattern: &str, name: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix('*') {
name.starts_with(prefix)
} else {
pattern == name
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn glob_wildcard() {
assert!(matches_glob("*", "anything"));
assert!(matches_glob("tarang_*", "tarang_probe"));
assert!(matches_glob("tarang_*", "tarang_analyze"));
assert!(!matches_glob("tarang_*", "rasa_edit"));
assert!(matches_glob("aegis_quarantine", "aegis_quarantine"));
assert!(!matches_glob("aegis_quarantine", "aegis_scan"));
}
#[test]
fn policy_deny_wins() {
let engine = PolicyEngine::new();
engine.grant("agent-1", "tarang_*");
engine.revoke("agent-1", "tarang_delete");
assert!(matches!(
engine.check("agent-1", "tarang_probe"),
PolicyResult::Allow
));
assert!(matches!(
engine.check("agent-1", "tarang_delete"),
PolicyResult::Deny(_)
));
}
#[test]
fn unknown_agent() {
let engine = PolicyEngine::new();
assert!(matches!(
engine.check("nobody", "any_tool"),
PolicyResult::UnknownAgent
));
}
#[test]
fn load_toml_policy() {
let engine = PolicyEngine::new();
let toml = r#"
[agent."web-agent"]
allow = ["tarang_*", "rasa_*"]
deny = ["aegis_*"]
"#;
engine.load_toml(toml).unwrap();
assert!(matches!(
engine.check("web-agent", "tarang_probe"),
PolicyResult::Allow
));
assert!(matches!(
engine.check("web-agent", "aegis_scan"),
PolicyResult::Deny(_)
));
}
#[test]
fn unknown_tool_for_known_agent() {
let engine = PolicyEngine::new();
engine.grant("agent-1", "tarang_*");
assert!(matches!(
engine.check("agent-1", "rasa_edit"),
PolicyResult::UnknownTool
));
}
#[test]
fn malformed_toml_error() {
let engine = PolicyEngine::new();
let result = engine.load_toml("this is not valid toml {{{}}}");
assert!(result.is_err());
}
#[test]
fn deny_only_policy() {
let engine = PolicyEngine::new();
let toml = r#"
[agent."lockdown"]
deny = ["*"]
"#;
engine.load_toml(toml).unwrap();
assert!(matches!(
engine.check("lockdown", "anything"),
PolicyResult::Deny(_)
));
}
#[test]
fn allow_only_policy() {
let engine = PolicyEngine::new();
let toml = r#"
[agent."open"]
allow = ["*"]
"#;
engine.load_toml(toml).unwrap();
assert!(matches!(
engine.check("open", "anything"),
PolicyResult::Allow
));
}
#[test]
fn reload_policy_replaces_previous() {
let engine = PolicyEngine::new();
engine.grant("agent-1", "tarang_*");
assert!(matches!(
engine.check("agent-1", "tarang_probe"),
PolicyResult::Allow
));
engine.load_toml("").unwrap();
assert!(matches!(
engine.check("agent-1", "tarang_probe"),
PolicyResult::UnknownAgent
));
}
#[test]
fn multiple_agents_in_policy() {
let engine = PolicyEngine::new();
let toml = r#"
[agent."reader"]
allow = ["tarang_*"]
[agent."admin"]
allow = ["*"]
deny = ["ark_remove"]
"#;
engine.load_toml(toml).unwrap();
assert!(matches!(
engine.check("reader", "tarang_probe"),
PolicyResult::Allow
));
assert!(matches!(
engine.check("reader", "aegis_scan"),
PolicyResult::UnknownTool
));
assert!(matches!(
engine.check("admin", "aegis_scan"),
PolicyResult::Allow
));
assert!(matches!(
engine.check("admin", "ark_remove"),
PolicyResult::Deny(_)
));
}
#[test]
fn empty_pattern_no_match() {
assert!(!matches_glob("", "anything"));
assert!(matches_glob("", ""));
}
#[test]
fn glob_star_suffix_only() {
assert!(!matches_glob("*_delete", "tarang_delete"));
}
#[test]
fn rate_limit_parsed_from_toml() {
let engine = PolicyEngine::new();
let toml = r#"
[agent."limited"]
allow = ["*"]
[agent."limited".rate_limit]
calls_per_minute = 10
[agent."unlimited"]
allow = ["*"]
"#;
engine.load_toml(toml).unwrap();
let config = engine.config();
let limited = config.agent.get("limited").unwrap();
assert_eq!(limited.rate_limit.as_ref().unwrap().calls_per_minute, 10);
let unlimited = config.agent.get("unlimited").unwrap();
assert!(unlimited.rate_limit.is_none());
}
#[test]
fn config_snapshot() {
let engine = PolicyEngine::new();
engine.grant("agent-1", "tarang_*");
let config = engine.config();
assert!(config.agent.contains_key("agent-1"));
assert_eq!(config.agent["agent-1"].allow, vec!["tarang_*"]);
}
}