use std::collections::HashMap;
use std::sync::Arc;
use crate::tools::ToolDefinition;
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum Permission {
Allow,
Ask,
Block,
}
impl Default for Permission {
fn default() -> Self {
Self::Allow
}
}
#[derive(Clone, Debug)]
pub struct ToolPermission {
pub tool_name: String,
pub permission: Permission,
pub reason: Option<String>,
}
impl ToolPermission {
pub fn new(tool_name: impl Into<String>, permission: Permission) -> Self {
Self {
tool_name: tool_name.into(),
permission,
reason: None,
}
}
#[must_use]
pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
self.reason = Some(reason.into());
self
}
}
#[derive(Clone, Debug)]
pub struct PermissionConfig {
permissions: HashMap<String, ToolPermission>,
default: Permission,
}
impl PermissionConfig {
#[must_use]
pub fn new() -> Self {
Self {
permissions: HashMap::new(),
default: Permission::default(),
}
}
#[must_use]
pub const fn with_default(mut self, perm: Permission) -> Self {
self.default = perm;
self
}
#[must_use]
pub fn allow(mut self, tool_name: impl Into<String>) -> Self {
let name = tool_name.into();
self.permissions
.insert(name.clone(), ToolPermission::new(name, Permission::Allow));
self
}
#[must_use]
pub fn ask(mut self, tool_name: impl Into<String>) -> Self {
let name = tool_name.into();
self.permissions
.insert(name.clone(), ToolPermission::new(name, Permission::Ask));
self
}
#[must_use]
pub fn ask_with_reason(
mut self,
tool_name: impl Into<String>,
reason: impl Into<String>,
) -> Self {
let name = tool_name.into();
self.permissions.insert(
name.clone(),
ToolPermission::new(name, Permission::Ask).with_reason(reason.into()),
);
self
}
#[must_use]
pub fn block(mut self, tool_name: impl Into<String>) -> Self {
let name = tool_name.into();
self.permissions
.insert(name.clone(), ToolPermission::new(name, Permission::Block));
self
}
#[must_use]
pub fn block_with_reason(
mut self,
tool_name: impl Into<String>,
reason: impl Into<String>,
) -> Self {
let name = tool_name.into();
self.permissions.insert(
name.clone(),
ToolPermission::new(name, Permission::Block).with_reason(reason.into()),
);
self
}
#[must_use]
pub fn get(&self, tool_name: &str) -> &Permission {
self.permissions
.get(tool_name)
.map_or(&self.default, |tp| &tp.permission)
}
#[must_use]
pub fn get_reason(&self, tool_name: &str) -> Option<&str> {
self.permissions
.get(tool_name)
.and_then(|tp| tp.reason.as_deref())
}
#[must_use]
pub fn is_allowed(&self, tool_name: &str) -> bool {
self.get(tool_name) == &Permission::Allow
}
#[must_use]
pub fn is_blocked(&self, tool_name: &str) -> bool {
self.get(tool_name) == &Permission::Block
}
#[must_use]
pub fn requires_approval(&self, tool_name: &str) -> bool {
self.get(tool_name) == &Permission::Ask
}
}
impl Default for PermissionConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct PermissionCheck {
pub tool_name: String,
pub permission: Permission,
pub reason: Option<String>,
}
impl PermissionCheck {
pub fn new(
tool_name: impl Into<String>,
permission: Permission,
reason: Option<String>,
) -> Self {
Self {
tool_name: tool_name.into(),
permission,
reason,
}
}
#[must_use]
pub fn is_allowed(&self) -> bool {
self.permission == Permission::Allow
}
#[must_use]
pub fn is_blocked(&self) -> bool {
self.permission == Permission::Block
}
#[must_use]
pub fn requires_approval(&self) -> bool {
self.permission == Permission::Ask
}
}
#[derive(Clone, Debug)]
pub struct PermissionGuard {
config: Arc<PermissionConfig>,
}
impl PermissionGuard {
#[must_use]
pub fn new(config: PermissionConfig) -> Self {
Self {
config: Arc::new(config),
}
}
#[must_use]
pub fn check(&self, tool_name: &str) -> PermissionCheck {
let permission = self.config.get(tool_name).clone();
let reason = self.config.get_reason(tool_name).map(String::from);
PermissionCheck::new(tool_name, permission, reason)
}
#[must_use]
pub fn check_all(&self, tools: &[ToolDefinition]) -> Vec<PermissionCheck> {
tools.iter().map(|tool| self.check(&tool.name)).collect()
}
}
#[derive(Debug, thiserror::Error)]
pub enum PermissionError {
#[error("tool '{tool}' is blocked: {reason}")]
Blocked { tool: String, reason: String },
#[error("tool '{tool}' requires approval: {reason}")]
RequiresApproval { tool: String, reason: String },
}
impl PermissionError {
#[must_use]
pub fn blocked(tool: impl Into<String>, reason: impl Into<String>) -> Self {
Self::Blocked {
tool: tool.into(),
reason: reason.into(),
}
}
#[must_use]
pub fn requires_approval(tool: impl Into<String>, reason: impl Into<String>) -> Self {
Self::RequiresApproval {
tool: tool.into(),
reason: reason.into(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_permission_default_is_allow() {
let perm = Permission::default();
assert_eq!(perm, Permission::Allow);
}
#[test]
fn test_permission_config_new() {
let config = PermissionConfig::new();
assert_eq!(config.get("any_tool"), &Permission::Allow);
assert!(config.is_allowed("any_tool"));
assert!(!config.is_blocked("any_tool"));
assert!(!config.requires_approval("any_tool"));
}
#[test]
fn test_permission_config_builder() {
let config = PermissionConfig::new()
.allow("search")
.ask("file_delete")
.block("system_shutdown");
assert!(config.is_allowed("search"));
assert!(config.requires_approval("file_delete"));
assert!(config.is_blocked("system_shutdown"));
}
#[test]
fn test_permission_config_get_specific() {
let config = PermissionConfig::new().block("dangerous_tool");
assert_eq!(config.get("dangerous_tool"), &Permission::Block);
assert!(config.is_blocked("dangerous_tool"));
}
#[test]
fn test_permission_config_get_default() {
let config = PermissionConfig::new();
assert_eq!(config.get("unconfigured_tool"), &Permission::Allow);
let config_with_default = PermissionConfig::new().with_default(Permission::Ask);
assert_eq!(
config_with_default.get("unconfigured_tool"),
&Permission::Ask
);
}
#[test]
fn test_permission_config_is_allowed() {
let config = PermissionConfig::new()
.allow("allowed_tool")
.ask("ask_tool")
.block("blocked_tool");
assert!(config.is_allowed("allowed_tool"));
assert!(config.is_allowed("unconfigured_tool"));
assert!(!config.is_allowed("ask_tool"));
assert!(!config.is_allowed("blocked_tool"));
}
#[test]
fn test_permission_config_is_blocked() {
let config = PermissionConfig::new()
.allow("allowed_tool")
.block("blocked_tool");
assert!(config.is_blocked("blocked_tool"));
assert!(!config.is_blocked("allowed_tool"));
assert!(!config.is_blocked("unconfigured_tool"));
}
#[test]
fn test_permission_config_requires_approval() {
let config = PermissionConfig::new()
.allow("allowed_tool")
.ask("ask_tool")
.block("blocked_tool");
assert!(config.requires_approval("ask_tool"));
assert!(!config.requires_approval("allowed_tool"));
assert!(!config.requires_approval("blocked_tool"));
assert!(!config.requires_approval("unconfigured_tool"));
}
#[test]
fn test_permission_guard_check() {
let config = PermissionConfig::new()
.allow("safe_tool")
.ask_with_reason("risky_tool", "Potential data loss")
.block_with_reason("dangerous_tool", "System instability risk");
let guard = PermissionGuard::new(config);
let safe_check = guard.check("safe_tool");
assert!(safe_check.is_allowed());
assert_eq!(safe_check.tool_name, "safe_tool");
assert!(safe_check.reason.is_none());
let risky_check = guard.check("risky_tool");
assert!(risky_check.requires_approval());
assert_eq!(risky_check.reason.as_deref(), Some("Potential data loss"));
let dangerous_check = guard.check("dangerous_tool");
assert!(dangerous_check.is_blocked());
assert_eq!(
dangerous_check.reason.as_deref(),
Some("System instability risk")
);
}
#[test]
fn test_permission_guard_check_all() {
let config = PermissionConfig::new()
.allow("search")
.ask("delete")
.block("shutdown");
let guard = PermissionGuard::new(config);
let tools = vec![
ToolDefinition::new("search", "Search tool", json!({"type": "object"})),
ToolDefinition::new("delete", "Delete tool", json!({"type": "object"})),
ToolDefinition::new("shutdown", "Shutdown tool", json!({"type": "object"})),
ToolDefinition::new("unknown", "Unknown tool", json!({"type": "object"})),
];
let results = guard.check_all(&tools);
assert_eq!(results.len(), 4);
assert!(results[0].is_allowed());
assert!(results[1].requires_approval());
assert!(results[2].is_blocked());
assert!(results[3].is_allowed()); }
#[test]
fn test_permission_error_display() {
let blocked_err = PermissionError::blocked("system_shutdown", "Not allowed");
assert!(blocked_err.to_string().contains("system_shutdown"));
assert!(blocked_err.to_string().contains("blocked"));
assert!(blocked_err.to_string().contains("Not allowed"));
let approval_err =
PermissionError::requires_approval("file_delete", "Irreversible operation");
assert!(approval_err.to_string().contains("file_delete"));
assert!(approval_err.to_string().contains("requires approval"));
assert!(approval_err.to_string().contains("Irreversible operation"));
}
#[test]
fn test_permission_config_with_reason() {
let config = PermissionConfig::new()
.ask_with_reason("risky_tool", "Potential data loss")
.block_with_reason("dangerous_tool", "System instability risk");
assert_eq!(config.get_reason("risky_tool"), Some("Potential data loss"));
assert_eq!(
config.get_reason("dangerous_tool"),
Some("System instability risk")
);
assert_eq!(config.get_reason("unconfigured_tool"), None);
}
#[test]
fn test_permission_serde_roundtrip() {
let perm = Permission::Ask;
let serialized = serde_json::to_string(&perm).expect("serialize failed");
let deserialized: Permission =
serde_json::from_str(&serialized).expect("deserialize failed");
assert_eq!(perm, deserialized);
}
#[test]
fn test_permission_all_variants_serde() {
let variants = vec![
(Permission::Allow, "Allow"),
(Permission::Ask, "Ask"),
(Permission::Block, "Block"),
];
for (perm, expected_name) in variants {
let serialized = serde_json::to_string(&perm).expect("serialize failed");
assert!(serialized.contains(expected_name));
}
}
}