use std::collections::HashSet;
use std::sync::Arc;
use crate::error::{Error, JsonRpcError};
use crate::prompt::Prompt;
use crate::resource::Resource;
use crate::session::SessionState;
use crate::tool::Tool;
pub trait Filterable: Send + Sync {
fn name(&self) -> &str;
}
impl Filterable for Tool {
fn name(&self) -> &str {
&self.name
}
}
impl Filterable for Resource {
fn name(&self) -> &str {
&self.name
}
}
impl Filterable for Prompt {
fn name(&self) -> &str {
&self.name
}
}
#[derive(Clone, Default)]
#[non_exhaustive]
pub enum DenialBehavior {
#[default]
NotFound,
Unauthorized,
Custom(Arc<dyn Fn(&str) -> Error + Send + Sync>),
}
impl std::fmt::Debug for DenialBehavior {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotFound => write!(f, "NotFound"),
Self::Unauthorized => write!(f, "Unauthorized"),
Self::Custom(_) => write!(f, "Custom(...)"),
}
}
}
impl DenialBehavior {
pub fn custom<F>(f: F) -> Self
where
F: Fn(&str) -> Error + Send + Sync + 'static,
{
Self::Custom(Arc::new(f))
}
pub fn to_error(&self, name: &str) -> Error {
match self {
Self::NotFound => Error::JsonRpc(JsonRpcError::method_not_found(name)),
Self::Unauthorized => {
Error::JsonRpc(JsonRpcError::forbidden(format!("Unauthorized: {}", name)))
}
Self::Custom(f) => f(name),
}
}
}
pub struct CapabilityFilter<T: Filterable> {
#[allow(clippy::type_complexity)]
filter: Arc<dyn Fn(&SessionState, &T) -> bool + Send + Sync>,
denial: DenialBehavior,
}
impl<T: Filterable> Clone for CapabilityFilter<T> {
fn clone(&self) -> Self {
Self {
filter: Arc::clone(&self.filter),
denial: self.denial.clone(),
}
}
}
impl<T: Filterable> std::fmt::Debug for CapabilityFilter<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CapabilityFilter")
.field("denial", &self.denial)
.finish_non_exhaustive()
}
}
impl<T: Filterable> CapabilityFilter<T> {
pub fn new<F>(filter: F) -> Self
where
F: Fn(&SessionState, &T) -> bool + Send + Sync + 'static,
{
Self {
filter: Arc::new(filter),
denial: DenialBehavior::default(),
}
}
pub fn denial_behavior(mut self, behavior: DenialBehavior) -> Self {
self.denial = behavior;
self
}
pub fn is_visible(&self, session: &SessionState, capability: &T) -> bool {
(self.filter)(session, capability)
}
pub fn denial_error(&self, name: &str) -> Error {
self.denial.to_error(name)
}
pub fn allow_list(names: &[&str]) -> Self
where
T: 'static,
{
let allowed: HashSet<String> = names.iter().map(|s| (*s).to_string()).collect();
Self::new(move |_session, cap: &T| allowed.contains(cap.name()))
}
pub fn deny_list(names: &[&str]) -> Self
where
T: 'static,
{
let denied: HashSet<String> = names.iter().map(|s| (*s).to_string()).collect();
Self::new(move |_session, cap: &T| !denied.contains(cap.name()))
}
}
impl CapabilityFilter<Tool> {
pub fn write_guard<F>(is_write_allowed: F) -> Self
where
F: Fn(&SessionState) -> bool + Send + Sync + 'static,
{
Self::new(move |session, tool: &Tool| {
let read_only = tool.annotations.as_ref().is_some_and(|a| a.read_only_hint);
read_only || is_write_allowed(session)
})
}
}
pub type ToolFilter = CapabilityFilter<Tool>;
pub type ResourceFilter = CapabilityFilter<Resource>;
pub type PromptFilter = CapabilityFilter<Prompt>;
#[cfg(test)]
mod tests {
use super::*;
use crate::CallToolResult;
use crate::tool::ToolBuilder;
fn make_test_tool(name: &str) -> Tool {
ToolBuilder::new(name)
.description("Test tool")
.handler(|_: serde_json::Value| async { Ok(CallToolResult::text("ok")) })
.build()
}
#[test]
fn test_filter_allows() {
let filter = CapabilityFilter::new(|_, tool: &Tool| tool.name() != "blocked");
let session = SessionState::new();
let allowed = make_test_tool("allowed");
let blocked = make_test_tool("blocked");
assert!(filter.is_visible(&session, &allowed));
assert!(!filter.is_visible(&session, &blocked));
}
#[test]
fn test_denial_behavior_not_found() {
let behavior = DenialBehavior::NotFound;
let error = behavior.to_error("test_tool");
assert!(matches!(error, Error::JsonRpc(_)));
}
#[test]
fn test_denial_behavior_unauthorized() {
let behavior = DenialBehavior::Unauthorized;
let error = behavior.to_error("test_tool");
match error {
Error::JsonRpc(e) => {
assert_eq!(e.code, -32007); assert!(e.message.contains("Unauthorized"));
}
_ => panic!("Expected JsonRpc error"),
}
}
#[test]
fn test_denial_behavior_custom() {
let behavior = DenialBehavior::custom(|name| Error::tool(format!("No access to {}", name)));
let error = behavior.to_error("secret_tool");
match error {
Error::Tool(e) => {
assert!(e.message.contains("No access to secret_tool"));
}
_ => panic!("Expected Tool error"),
}
}
#[test]
fn test_filter_clone() {
let filter = CapabilityFilter::new(|_, _: &Tool| true);
let cloned = filter.clone();
let session = SessionState::new();
let tool = make_test_tool("test");
assert!(cloned.is_visible(&session, &tool));
}
#[test]
fn test_filter_with_denial_behavior() {
let filter = CapabilityFilter::new(|_, _: &Tool| false)
.denial_behavior(DenialBehavior::Unauthorized);
let error = filter.denial_error("test");
match error {
Error::JsonRpc(e) => assert_eq!(e.code, -32007), _ => panic!("Expected JsonRpc error"),
}
}
fn make_read_only_tool(name: &str) -> Tool {
ToolBuilder::new(name)
.description("Read-only tool")
.read_only()
.handler(|_: serde_json::Value| async { Ok(CallToolResult::text("ok")) })
.build()
}
#[test]
fn test_write_guard_allows_read_only_when_writes_blocked() {
let filter = CapabilityFilter::<Tool>::write_guard(|_| false);
let session = SessionState::new();
let tool = make_read_only_tool("reader");
assert!(filter.is_visible(&session, &tool));
}
#[test]
fn test_write_guard_blocks_write_tool_when_writes_blocked() {
let filter = CapabilityFilter::<Tool>::write_guard(|_| false);
let session = SessionState::new();
let tool = make_test_tool("writer");
assert!(!filter.is_visible(&session, &tool));
}
#[test]
fn test_write_guard_allows_write_tool_when_writes_allowed() {
let filter = CapabilityFilter::<Tool>::write_guard(|_| true);
let session = SessionState::new();
let tool = make_test_tool("writer");
assert!(filter.is_visible(&session, &tool));
}
#[test]
fn test_write_guard_with_denial_behavior() {
let filter = CapabilityFilter::<Tool>::write_guard(|_| false)
.denial_behavior(DenialBehavior::Unauthorized);
let session = SessionState::new();
let tool = make_test_tool("writer");
assert!(!filter.is_visible(&session, &tool));
let error = filter.denial_error("writer");
match error {
Error::JsonRpc(e) => assert_eq!(e.code, -32007),
_ => panic!("Expected JsonRpc error"),
}
}
#[test]
fn test_allow_list_shows_listed_tools() {
let filter = CapabilityFilter::<Tool>::allow_list(&["query", "list_tables"]);
let session = SessionState::new();
assert!(filter.is_visible(&session, &make_test_tool("query")));
assert!(filter.is_visible(&session, &make_test_tool("list_tables")));
assert!(!filter.is_visible(&session, &make_test_tool("delete")));
assert!(!filter.is_visible(&session, &make_test_tool("drop_table")));
}
#[test]
fn test_allow_list_empty_blocks_all() {
let filter = CapabilityFilter::<Tool>::allow_list(&[]);
let session = SessionState::new();
assert!(!filter.is_visible(&session, &make_test_tool("anything")));
}
#[test]
fn test_deny_list_hides_listed_tools() {
let filter = CapabilityFilter::<Tool>::deny_list(&["delete", "drop_table"]);
let session = SessionState::new();
assert!(filter.is_visible(&session, &make_test_tool("query")));
assert!(filter.is_visible(&session, &make_test_tool("list_tables")));
assert!(!filter.is_visible(&session, &make_test_tool("delete")));
assert!(!filter.is_visible(&session, &make_test_tool("drop_table")));
}
#[test]
fn test_deny_list_empty_allows_all() {
let filter = CapabilityFilter::<Tool>::deny_list(&[]);
let session = SessionState::new();
assert!(filter.is_visible(&session, &make_test_tool("anything")));
}
#[test]
fn test_allow_list_with_denial_behavior() {
let filter = CapabilityFilter::<Tool>::allow_list(&["query"])
.denial_behavior(DenialBehavior::Unauthorized);
let session = SessionState::new();
assert!(!filter.is_visible(&session, &make_test_tool("delete")));
let error = filter.denial_error("delete");
match error {
Error::JsonRpc(e) => assert_eq!(e.code, -32007),
_ => panic!("Expected JsonRpc error"),
}
}
}