use std::sync::Arc;
use async_trait::async_trait;
use crate::{
agent::{
context_engineering::ModelRequest,
middleware::{Middleware, MiddlewareContext, MiddlewareError},
AgentState,
},
tools::{Tool, ToolContext},
};
pub struct DynamicToolsMiddleware {
tool_filter: Arc<dyn Fn(&ModelRequest) -> Vec<Arc<dyn Tool>> + Send + Sync>,
}
impl DynamicToolsMiddleware {
pub fn new<F>(filter: F) -> Self
where
F: Fn(&ModelRequest) -> Vec<Arc<dyn Tool>> + Send + Sync + 'static,
{
Self {
tool_filter: Arc::new(filter),
}
}
pub fn from_state<F>(filter: F) -> Self
where
F: Fn(&AgentState) -> Vec<String> + Send + Sync + 'static,
{
Self::new(move |request: &ModelRequest| {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let state = handle.block_on(request.state());
let allowed_tool_names = filter(&state);
request
.tools
.iter()
.filter(|tool| allowed_tool_names.contains(&tool.name()))
.cloned()
.collect()
} else {
request.tools.clone()
}
})
}
pub fn from_permissions<F>(filter: F) -> Self
where
F: Fn(&dyn ToolContext) -> Vec<String> + Send + Sync + 'static,
{
Self::new(move |request: &ModelRequest| {
if let Some(runtime) = request.runtime() {
let allowed_tool_names = filter(runtime.context());
request
.tools
.iter()
.filter(|tool| allowed_tool_names.contains(&tool.name()))
.cloned()
.collect()
} else {
request.tools.clone()
}
})
}
pub fn allow_prefixes(prefixes: Vec<String>) -> Self {
Self::new(move |request: &ModelRequest| {
request
.tools
.iter()
.filter(|tool| {
prefixes
.iter()
.any(|prefix| tool.name().starts_with(prefix))
})
.cloned()
.collect()
})
}
pub fn exclude_tools(excluded: Vec<String>) -> Self {
Self::new(move |request: &ModelRequest| {
request
.tools
.iter()
.filter(|tool| !excluded.contains(&tool.name()))
.cloned()
.collect()
})
}
}
#[async_trait]
impl Middleware for DynamicToolsMiddleware {
async fn before_model_call(
&self,
request: &ModelRequest,
_context: &mut MiddlewareContext,
) -> Result<Option<ModelRequest>, MiddlewareError> {
let filtered_tools = (self.tool_filter)(request);
if filtered_tools.len() != request.tools.len() {
Ok(Some(request.with_messages_and_tools(
request.messages.clone(),
filtered_tools,
)))
} else {
Ok(None)
}
}
}
impl Clone for DynamicToolsMiddleware {
fn clone(&self) -> Self {
Self {
tool_filter: Arc::clone(&self.tool_filter),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schemas::Message;
#[tokio::test]
async fn test_dynamic_tools_exclude() {
let middleware = DynamicToolsMiddleware::exclude_tools(vec![
"delete_tool".to_string(),
"admin_tool".to_string(),
]);
let state = Arc::new(tokio::sync::Mutex::new(AgentState::new()));
let messages = vec![Message::new_human_message("Hello")];
let request = ModelRequest::new(messages, vec![], state);
let mut middleware_context = MiddlewareContext::new();
let result = middleware
.before_model_call(&request, &mut middleware_context)
.await;
assert!(result.is_ok());
}
}