use super::grant::{hash_params, Grant};
use super::store::{GrantStore, GrantStoreError, MemoryGrantStore};
use serde_json::Value;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ToolAuthorizationPolicy {
#[default]
AutoDeny,
Interactive,
}
pub struct ToolCallAuthorizer {
store: Box<dyn GrantStore>,
policy: ToolAuthorizationPolicy,
}
impl ToolCallAuthorizer {
pub fn new() -> Self {
Self {
store: Box::new(MemoryGrantStore::new()),
policy: ToolAuthorizationPolicy::default(),
}
}
pub fn interactive() -> Self {
Self::new().with_authorization_policy(ToolAuthorizationPolicy::Interactive)
}
pub fn with_store(store: impl GrantStore + 'static) -> Self {
Self {
store: Box::new(store),
policy: ToolAuthorizationPolicy::default(),
}
}
pub fn with_boxed_store(store: Box<dyn GrantStore>) -> Self {
Self {
store,
policy: ToolAuthorizationPolicy::default(),
}
}
pub fn with_authorization_policy(mut self, policy: ToolAuthorizationPolicy) -> Self {
self.policy = policy;
self
}
pub fn policy(&self) -> ToolAuthorizationPolicy {
self.policy
}
pub async fn grant_tool(&self, tool: &str) -> Result<(), GrantStoreError> {
self.store.save(Grant::tool(tool)).await
}
pub async fn grant_params(&self, tool: &str, params: &Value) -> Result<(), GrantStoreError> {
let hash = hash_params(params);
self.store.save(Grant::exact(tool, hash)).await
}
pub async fn grant_params_hash(
&self,
tool: &str,
params_hash: &str,
) -> Result<(), GrantStoreError> {
self.store.save(Grant::exact(tool, params_hash)).await
}
pub async fn check(&self, tool_name: &str, params: &Value) -> Authorization {
let params_hash = hash_params(params);
match self.store.load(tool_name).await {
Ok(grants) => {
for grant in grants {
if grant.matches(¶ms_hash) {
return Authorization::Granted { grant };
}
}
}
Err(e) => {
eprintln!("Warning: Failed to load grants for {}: {}", tool_name, e);
}
}
match self.policy {
ToolAuthorizationPolicy::AutoDeny => Authorization::Denied {
reason: format!("No grant configured for tool '{}'", tool_name),
},
ToolAuthorizationPolicy::Interactive => Authorization::PendingApproval { params_hash },
}
}
pub async fn revoke(
&self,
tool: &str,
params_hash: Option<&str>,
) -> Result<bool, GrantStoreError> {
self.store.delete(tool, params_hash).await
}
pub async fn grants(&self) -> Result<Vec<Grant>, GrantStoreError> {
self.store.load_all().await
}
pub async fn clear(&self) -> Result<(), GrantStoreError> {
self.store.clear().await
}
}
impl Default for ToolCallAuthorizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum Authorization {
Granted {
grant: Grant,
},
Denied {
reason: String,
},
PendingApproval {
params_hash: String,
},
}
impl Authorization {
pub fn is_authorized(&self) -> bool {
matches!(self, Authorization::Granted { .. })
}
pub fn is_denied(&self) -> bool {
matches!(self, Authorization::Denied { .. })
}
pub fn is_pending(&self) -> bool {
matches!(self, Authorization::PendingApproval { .. })
}
}
#[derive(Debug, Clone)]
pub enum AuthorizationResponse {
Once,
Trust {
grant: Grant,
},
Deny {
reason: Option<String>,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_policy_is_auto_deny() {
let auth = ToolCallAuthorizer::new();
assert_eq!(auth.policy(), ToolAuthorizationPolicy::AutoDeny);
}
#[test]
fn test_interactive_constructor_sets_interactive_policy() {
let auth = ToolCallAuthorizer::interactive();
assert_eq!(auth.policy(), ToolAuthorizationPolicy::Interactive);
}
#[test]
fn test_with_authorization_policy() {
let auth = ToolCallAuthorizer::new()
.with_authorization_policy(ToolAuthorizationPolicy::Interactive);
assert_eq!(auth.policy(), ToolAuthorizationPolicy::Interactive);
}
#[tokio::test]
async fn test_auto_deny_policy_returns_denied() {
let auth = ToolCallAuthorizer::new();
let params = serde_json::json!({"key": "value"});
let result = auth.check("test", ¶ms).await;
assert!(result.is_denied());
assert!(!result.is_authorized());
assert!(!result.is_pending());
}
#[tokio::test]
async fn test_interactive_policy_returns_pending() {
let auth = ToolCallAuthorizer::interactive();
let params = serde_json::json!({"key": "value"});
let result = auth.check("test", ¶ms).await;
assert!(result.is_pending());
assert!(!result.is_authorized());
assert!(!result.is_denied());
}
#[tokio::test]
async fn test_grant_overrides_policy() {
let auth = ToolCallAuthorizer::new();
auth.grant_tool("test").await.unwrap();
let result = auth.check("test", &serde_json::json!({})).await;
assert!(result.is_authorized());
}
#[tokio::test]
async fn test_authorizer_tool_wide_grant() {
let auth = ToolCallAuthorizer::new();
auth.grant_tool("test").await.unwrap();
let result = auth.check("test", &serde_json::json!({"a": 1})).await;
assert!(result.is_authorized());
let result = auth.check("test", &serde_json::json!({"b": 2})).await;
assert!(result.is_authorized());
}
#[tokio::test]
async fn test_authorizer_params_grant() {
let auth = ToolCallAuthorizer::new();
let params = serde_json::json!({"key": "value"});
auth.grant_params("test", ¶ms).await.unwrap();
let result = auth.check("test", ¶ms).await;
assert!(result.is_authorized());
let other = serde_json::json!({"key": "other"});
let result = auth.check("test", &other).await;
assert!(result.is_denied());
}
#[tokio::test]
async fn test_authorizer_wrong_tool() {
let auth = ToolCallAuthorizer::new();
auth.grant_tool("tool_a").await.unwrap();
let result = auth.check("tool_b", &serde_json::json!({})).await;
assert!(result.is_denied());
}
#[tokio::test]
async fn test_authorizer_revoke() {
let auth = ToolCallAuthorizer::new();
auth.grant_tool("test").await.unwrap();
assert!(auth
.check("test", &serde_json::json!({}))
.await
.is_authorized());
auth.revoke("test", None).await.unwrap();
assert!(auth.check("test", &serde_json::json!({})).await.is_denied());
}
#[tokio::test]
async fn test_authorizer_grants() {
let auth = ToolCallAuthorizer::new();
auth.grant_tool("a").await.unwrap();
auth.grant_tool("b").await.unwrap();
let grants = auth.grants().await.unwrap();
assert_eq!(grants.len(), 2);
}
#[tokio::test]
async fn test_authorizer_clear() {
let auth = ToolCallAuthorizer::new();
auth.grant_tool("test").await.unwrap();
auth.clear().await.unwrap();
assert!(auth.grants().await.unwrap().is_empty());
}
#[test]
fn test_authorization_methods() {
let granted = Authorization::Granted {
grant: Grant::tool("test"),
};
assert!(granted.is_authorized());
assert!(!granted.is_denied());
assert!(!granted.is_pending());
let denied = Authorization::Denied {
reason: "test".to_string(),
};
assert!(!denied.is_authorized());
assert!(denied.is_denied());
assert!(!denied.is_pending());
let pending = Authorization::PendingApproval {
params_hash: "abc".to_string(),
};
assert!(!pending.is_authorized());
assert!(!pending.is_denied());
assert!(pending.is_pending());
}
}