use std::{collections::HashSet, sync::Arc};
use futures::future::BoxFuture;
use rustc_hash::FxHashMap;
use schemars::{JsonSchema, generate::SchemaSettings};
use serde_json::{Map, Value};
use crate::{Error, Role};
use super::{McpConnectionTo, McpTool};
pub type McpToolSchema = Map<String, Value>;
#[derive(Clone, Debug)]
pub enum EnabledTools {
DenyList(HashSet<String>),
AllowList(HashSet<String>),
}
impl Default for EnabledTools {
fn default() -> Self {
EnabledTools::DenyList(HashSet::new())
}
}
impl EnabledTools {
#[must_use]
pub fn is_enabled(&self, name: &str) -> bool {
match self {
EnabledTools::DenyList(deny) => !deny.contains(name),
EnabledTools::AllowList(allow) => allow.contains(name),
}
}
}
#[derive(Clone, Debug)]
pub struct McpToolMetadata {
name: String,
title: Option<String>,
description: String,
input_schema: Arc<McpToolSchema>,
output_schema: Option<Arc<McpToolSchema>>,
}
impl McpToolMetadata {
fn from_tool<R: Role, M: McpTool<R>>(tool: &M) -> Self {
Self {
name: tool.name(),
title: tool.title(),
description: tool.description(),
input_schema: schema_for_type::<M::Input>(),
output_schema: schema_for_output::<M::Output>(),
}
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn title(&self) -> Option<&str> {
self.title.as_deref()
}
#[must_use]
pub fn description(&self) -> &str {
&self.description
}
#[must_use]
pub fn input_schema(&self) -> &Arc<McpToolSchema> {
&self.input_schema
}
#[must_use]
pub fn output_schema(&self) -> Option<&Arc<McpToolSchema>> {
self.output_schema.as_ref()
}
}
pub struct RegisteredMcpTool<Counterpart: Role> {
metadata: McpToolMetadata,
tool: Arc<dyn ErasedMcpTool<Counterpart>>,
}
impl<Counterpart: Role + std::fmt::Debug> std::fmt::Debug for RegisteredMcpTool<Counterpart> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RegisteredMcpTool")
.field("metadata", &self.metadata)
.field("has_structured_output", &self.has_structured_output())
.finish_non_exhaustive()
}
}
impl<Counterpart: Role> RegisteredMcpTool<Counterpart> {
fn new(tool: impl McpTool<Counterpart> + 'static) -> Self {
let metadata = McpToolMetadata::from_tool(&tool);
Self {
metadata,
tool: make_erased_mcp_tool(tool),
}
}
#[must_use]
pub fn metadata(&self) -> &McpToolMetadata {
&self.metadata
}
#[must_use]
pub fn name(&self) -> &str {
self.metadata.name()
}
#[must_use]
pub fn has_structured_output(&self) -> bool {
self.metadata.output_schema().is_some()
}
pub fn call_tool(
&self,
input: Value,
connection: McpConnectionTo<Counterpart>,
) -> BoxFuture<'_, Result<Value, Error>> {
self.tool.call_tool(input, connection)
}
}
#[derive(Debug)]
pub struct McpToolRegistry<Counterpart: Role> {
instructions: Option<String>,
tool_indices: FxHashMap<String, usize>,
tools: Vec<RegisteredMcpTool<Counterpart>>,
enabled_tools: EnabledTools,
}
impl<Counterpart: Role> Default for McpToolRegistry<Counterpart> {
fn default() -> Self {
Self {
instructions: None,
tool_indices: FxHashMap::default(),
tools: Vec::new(),
enabled_tools: EnabledTools::default(),
}
}
}
impl<Counterpart: Role> McpToolRegistry<Counterpart> {
pub fn set_instructions(&mut self, instructions: impl ToString) {
self.instructions = Some(instructions.to_string());
}
#[must_use]
pub fn instructions(&self) -> Option<&str> {
self.instructions.as_deref()
}
pub fn register_tool(&mut self, tool: impl McpTool<Counterpart> + 'static) {
let registered_tool = RegisteredMcpTool::new(tool);
let name = registered_tool.name().to_string();
if let Some(&index) = self.tool_indices.get(&name) {
self.tools[index] = registered_tool;
} else {
self.tool_indices.insert(name, self.tools.len());
self.tools.push(registered_tool);
}
}
pub fn tools(&self) -> impl Iterator<Item = &RegisteredMcpTool<Counterpart>> {
self.tools.iter()
}
pub fn enabled_tools(&self) -> impl Iterator<Item = &RegisteredMcpTool<Counterpart>> {
self.tools
.iter()
.filter(|tool| self.enabled_tools.is_enabled(tool.name()))
}
#[must_use]
pub fn tool(&self, name: &str) -> Option<&RegisteredMcpTool<Counterpart>> {
self.tool_indices
.get(name)
.and_then(|&index| self.tools.get(index))
}
#[must_use]
pub fn enabled_tool(&self, name: &str) -> Option<&RegisteredMcpTool<Counterpart>> {
self.tool(name)
.filter(|tool| self.enabled_tools.is_enabled(tool.name()))
}
#[must_use]
pub fn contains_tool(&self, name: &str) -> bool {
self.tool_indices.contains_key(name)
}
pub fn disable_all_tools(&mut self) {
self.enabled_tools = EnabledTools::AllowList(HashSet::new());
}
pub fn enable_all_tools(&mut self) {
self.enabled_tools = EnabledTools::DenyList(HashSet::new());
}
pub fn disable_tool(&mut self, name: &str) -> Result<(), Error> {
if !self.contains_tool(name) {
return Err(Error::invalid_request().data(format!("unknown tool: {name}")));
}
match &mut self.enabled_tools {
EnabledTools::DenyList(deny) => {
deny.insert(name.to_string());
}
EnabledTools::AllowList(allow) => {
allow.remove(name);
}
}
Ok(())
}
pub fn enable_tool(&mut self, name: &str) -> Result<(), Error> {
if !self.contains_tool(name) {
return Err(Error::invalid_request().data(format!("unknown tool: {name}")));
}
match &mut self.enabled_tools {
EnabledTools::DenyList(deny) => {
deny.remove(name);
}
EnabledTools::AllowList(allow) => {
allow.insert(name.to_string());
}
}
Ok(())
}
}
trait ErasedMcpTool<Counterpart: Role>: Send + Sync {
fn call_tool(
&self,
input: Value,
connection: McpConnectionTo<Counterpart>,
) -> BoxFuture<'_, Result<Value, Error>>;
}
fn make_erased_mcp_tool<R, M>(tool: M) -> Arc<dyn ErasedMcpTool<R>>
where
R: Role,
M: McpTool<R> + 'static,
{
struct ErasedMcpToolImpl<M> {
tool: M,
}
impl<R, M> ErasedMcpTool<R> for ErasedMcpToolImpl<M>
where
R: Role,
M: McpTool<R>,
{
fn call_tool(
&self,
input: Value,
context: McpConnectionTo<R>,
) -> BoxFuture<'_, Result<Value, Error>> {
Box::pin(async move {
let input = serde_json::from_value(input).map_err(crate::util::internal_error)?;
serde_json::to_value(self.tool.call_tool(input, context).await?)
.map_err(crate::util::internal_error)
})
}
}
Arc::new(ErasedMcpToolImpl { tool })
}
fn schema_for_type<T: JsonSchema>() -> Arc<McpToolSchema> {
let settings = SchemaSettings::draft2020_12();
let generator = settings.into_generator();
let schema = generator.into_root_schema_for::<T>();
let object = serde_json::to_value(schema).expect("failed to serialize schema");
let Value::Object(object) = object else {
panic!("Schema serialization produced non-object value: expected JSON object");
};
Arc::new(object)
}
fn schema_for_output<T: JsonSchema>() -> Option<Arc<McpToolSchema>> {
let schema = schema_for_type::<T>();
match schema.get("type") {
Some(Value::String(t)) if t == "object" => Some(schema),
_ => None,
}
}