use super::{PolicyAction, PolicyContext, PolicyDecision, PolicyEvaluator};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
pub enum ToolTrustLevel {
Untrusted = 0,
Low = 1,
#[default]
Medium = 2,
High = 3,
System = 4,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ToolPermissions {
pub network_access: bool,
pub filesystem_access: bool,
pub filesystem_write: bool,
pub env_access: bool,
pub subprocess_access: bool,
pub pii_access: bool,
pub allowed_domains: HashSet<String>,
pub allowed_paths: HashSet<String>,
}
impl ToolPermissions {
pub fn sandboxed() -> Self {
Self {
network_access: false,
filesystem_access: false,
filesystem_write: false,
env_access: false,
subprocess_access: false,
pii_access: false,
allowed_domains: HashSet::new(),
allowed_paths: HashSet::new(),
}
}
pub fn network_only() -> Self {
Self {
network_access: true,
filesystem_access: false,
filesystem_write: false,
env_access: false,
subprocess_access: false,
pii_access: false,
allowed_domains: HashSet::new(),
allowed_paths: HashSet::new(),
}
}
pub fn full() -> Self {
Self {
network_access: true,
filesystem_access: true,
filesystem_write: true,
env_access: true,
subprocess_access: true,
pii_access: true,
allowed_domains: HashSet::new(),
allowed_paths: HashSet::new(),
}
}
pub fn allow_domain(mut self, domain: impl Into<String>) -> Self {
self.allowed_domains.insert(domain.into());
self
}
pub fn allow_path(mut self, path: impl Into<String>) -> Self {
self.allowed_paths.insert(path.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ToolPolicy {
pub default_permissions: ToolPermissions,
pub tool_permissions: std::collections::HashMap<String, ToolPermissions>,
pub tool_trust: std::collections::HashMap<String, ToolTrustLevel>,
pub min_trust_level: ToolTrustLevel,
pub blocked_tools: HashSet<String>,
}
impl Default for ToolPolicy {
fn default() -> Self {
Self {
default_permissions: ToolPermissions::sandboxed(),
tool_permissions: std::collections::HashMap::new(),
tool_trust: std::collections::HashMap::new(),
min_trust_level: ToolTrustLevel::Low,
blocked_tools: HashSet::new(),
}
}
}
impl ToolPolicy {
pub fn new() -> Self {
Self::default()
}
pub fn with_default_permissions(mut self, perms: ToolPermissions) -> Self {
self.default_permissions = perms;
self
}
pub fn set_tool_permissions(mut self, tool: impl Into<String>, perms: ToolPermissions) -> Self {
self.tool_permissions.insert(tool.into(), perms);
self
}
pub fn set_tool_trust(mut self, tool: impl Into<String>, level: ToolTrustLevel) -> Self {
self.tool_trust.insert(tool.into(), level);
self
}
pub fn block_tool(mut self, tool: impl Into<String>) -> Self {
self.blocked_tools.insert(tool.into());
self
}
pub fn get_permissions(&self, tool_name: &str) -> &ToolPermissions {
self.tool_permissions
.get(tool_name)
.unwrap_or(&self.default_permissions)
}
pub fn get_trust_level(&self, tool_name: &str) -> ToolTrustLevel {
self.tool_trust
.get(tool_name)
.copied()
.unwrap_or(ToolTrustLevel::Medium)
}
}
impl PolicyEvaluator for ToolPolicy {
fn evaluate(&self, context: &PolicyContext) -> PolicyDecision {
match &context.action {
PolicyAction::InvokeTool { tool_name } => {
if self.blocked_tools.contains(tool_name) {
return PolicyDecision::Deny {
reason: format!("Tool '{}' is blocked by policy", tool_name),
};
}
let trust = self.get_trust_level(tool_name);
if trust < self.min_trust_level {
return PolicyDecision::Deny {
reason: format!(
"Tool '{}' trust level {:?} is below minimum {:?}",
tool_name, trust, self.min_trust_level
),
};
}
PolicyDecision::Allow
}
_ => PolicyDecision::Allow,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_tool_trust_level_ordering() {
assert!(ToolTrustLevel::Untrusted < ToolTrustLevel::Low);
assert!(ToolTrustLevel::Low < ToolTrustLevel::Medium);
assert!(ToolTrustLevel::Medium < ToolTrustLevel::High);
assert!(ToolTrustLevel::High < ToolTrustLevel::System);
}
#[test]
fn test_tool_trust_level_default() {
assert_eq!(ToolTrustLevel::default(), ToolTrustLevel::Medium);
}
#[test]
fn test_tool_permissions_default() {
let perms = ToolPermissions::default();
assert!(!perms.network_access);
assert!(!perms.filesystem_access);
assert!(!perms.filesystem_write);
assert!(!perms.env_access);
assert!(!perms.subprocess_access);
assert!(!perms.pii_access);
}
#[test]
fn test_tool_permissions_sandboxed() {
let perms = ToolPermissions::sandboxed();
assert!(!perms.network_access);
assert!(!perms.filesystem_access);
assert!(!perms.filesystem_write);
assert!(!perms.env_access);
assert!(!perms.subprocess_access);
assert!(!perms.pii_access);
assert!(perms.allowed_domains.is_empty());
assert!(perms.allowed_paths.is_empty());
}
#[test]
fn test_tool_permissions_network_only() {
let perms = ToolPermissions::network_only();
assert!(perms.network_access);
assert!(!perms.filesystem_access);
assert!(!perms.subprocess_access);
}
#[test]
fn test_tool_permissions_full() {
let perms = ToolPermissions::full();
assert!(perms.network_access);
assert!(perms.filesystem_access);
assert!(perms.filesystem_write);
assert!(perms.env_access);
assert!(perms.subprocess_access);
assert!(perms.pii_access);
}
#[test]
fn test_tool_permissions_allow_domain() {
let perms = ToolPermissions::network_only()
.allow_domain("api.example.com")
.allow_domain("cdn.example.com");
assert!(perms.allowed_domains.contains("api.example.com"));
assert!(perms.allowed_domains.contains("cdn.example.com"));
assert_eq!(perms.allowed_domains.len(), 2);
}
#[test]
fn test_tool_permissions_allow_path() {
let perms = ToolPermissions::sandboxed()
.allow_path("/tmp")
.allow_path("/var/data");
assert!(perms.allowed_paths.contains("/tmp"));
assert!(perms.allowed_paths.contains("/var/data"));
}
#[test]
fn test_tool_policy_default() {
let policy = ToolPolicy::default();
assert_eq!(policy.min_trust_level, ToolTrustLevel::Low);
assert!(policy.blocked_tools.is_empty());
}
#[test]
fn test_tool_policy_with_default_permissions() {
let policy = ToolPolicy::new().with_default_permissions(ToolPermissions::network_only());
assert!(policy.default_permissions.network_access);
}
#[test]
fn test_tool_policy_set_tool_permissions() {
let policy =
ToolPolicy::new().set_tool_permissions("web_search", ToolPermissions::network_only());
let perms = policy.get_permissions("web_search");
assert!(perms.network_access);
let default_perms = policy.get_permissions("other_tool");
assert!(!default_perms.network_access);
}
#[test]
fn test_tool_policy_set_tool_trust() {
let policy = ToolPolicy::new()
.set_tool_trust("trusted_tool", ToolTrustLevel::High)
.set_tool_trust("untrusted_tool", ToolTrustLevel::Untrusted);
assert_eq!(policy.get_trust_level("trusted_tool"), ToolTrustLevel::High);
assert_eq!(
policy.get_trust_level("untrusted_tool"),
ToolTrustLevel::Untrusted
);
assert_eq!(
policy.get_trust_level("unknown_tool"),
ToolTrustLevel::Medium
); }
#[test]
fn test_tool_policy_block_tool() {
let policy = ToolPolicy::new()
.block_tool("dangerous_tool")
.block_tool("another_dangerous");
assert!(policy.blocked_tools.contains("dangerous_tool"));
assert!(policy.blocked_tools.contains("another_dangerous"));
}
#[test]
fn test_tool_policy_evaluate_allowed() {
let policy = ToolPolicy::new();
let context = PolicyContext {
tenant_id: None,
user_id: None,
action: PolicyAction::InvokeTool {
tool_name: "safe_tool".to_string(),
},
metadata: HashMap::new(),
};
let decision = policy.evaluate(&context);
assert!(decision.is_allowed());
}
#[test]
fn test_tool_policy_evaluate_blocked() {
let policy = ToolPolicy::new().block_tool("blocked_tool");
let context = PolicyContext {
tenant_id: None,
user_id: None,
action: PolicyAction::InvokeTool {
tool_name: "blocked_tool".to_string(),
},
metadata: HashMap::new(),
};
let decision = policy.evaluate(&context);
assert!(decision.is_denied());
}
#[test]
fn test_tool_policy_evaluate_trust_level_denied() {
let mut policy = ToolPolicy::new();
policy.min_trust_level = ToolTrustLevel::High;
let policy = policy.set_tool_trust("low_trust", ToolTrustLevel::Low);
let context = PolicyContext {
tenant_id: None,
user_id: None,
action: PolicyAction::InvokeTool {
tool_name: "low_trust".to_string(),
},
metadata: HashMap::new(),
};
let decision = policy.evaluate(&context);
assert!(decision.is_denied());
}
#[test]
fn test_tool_policy_evaluate_non_tool_action_allowed() {
let policy = ToolPolicy::new().block_tool("some_tool");
let context = PolicyContext {
tenant_id: None,
user_id: None,
action: PolicyAction::LlmCall {
model: "gpt-4".to_string(),
},
metadata: HashMap::new(),
};
let decision = policy.evaluate(&context);
assert!(decision.is_allowed());
}
}