use std::collections::HashMap;
use std::sync::Arc;
use rmcp::{
ErrorData,
model::{CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams},
service::RequestContext,
};
use crate::{
plugin::{ArcPlugin, ElicitPlugin, prefixed_name, strip_prefix},
rmcp::RoleServer,
};
#[derive(Clone, Default)]
pub struct PluginRegistry {
plugins: Vec<(String, ArcPlugin)>,
dispatch: HashMap<String, usize>,
}
impl PluginRegistry {
pub fn new() -> Self {
Self::default()
}
#[tracing::instrument(skip(self, plugin), fields(prefix))]
pub fn register(mut self, prefix: impl Into<String>, plugin: impl ElicitPlugin) -> Self {
let prefix = prefix.into();
let plugin: ArcPlugin = Arc::new(plugin);
let idx = self.plugins.len();
for tool in plugin.list_tools() {
let full = format!("{prefix}__{}", tool.name);
assert!(
!self.dispatch.contains_key(&full),
"tool name collision: `{full}` already registered"
);
self.dispatch.insert(full, idx);
}
tracing::debug!(prefix = %prefix, tool_count = plugin.list_tools().len(), "Registered plugin");
self.plugins.push((prefix, plugin));
self
}
#[tracing::instrument(skip(self, plugin))]
pub fn register_flat(mut self, plugin: impl ElicitPlugin) -> Self {
let plugin: ArcPlugin = Arc::new(plugin);
let idx = self.plugins.len();
for tool in plugin.list_tools() {
let name = tool.name.to_string();
assert!(
!self.dispatch.contains_key(&name),
"tool name collision: `{name}` already registered"
);
self.dispatch.insert(name, idx);
}
tracing::debug!(
tool_count = plugin.list_tools().len(),
"Registered flat plugin"
);
self.plugins.push((String::new(), plugin));
self
}
fn all_tools(&self) -> Vec<rmcp::model::Tool> {
let mut tools = Vec::new();
for (prefix, plugin) in &self.plugins {
for mut tool in plugin.list_tools() {
if !prefix.is_empty() {
tool.name = prefixed_name(prefix, &tool.name);
}
tools.push(tool);
}
}
tools
}
pub fn filter<F>(self, filter: F) -> Toolchain<F>
where
F: Fn(&rmcp::model::Tool) -> bool + Send + Sync + 'static,
{
Toolchain {
registry: self,
filter,
}
}
}
impl rmcp::ServerHandler for PluginRegistry {
#[tracing::instrument(skip(self, _request, _context))]
fn list_tools(
&self,
_request: Option<PaginatedRequestParams>,
_context: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<ListToolsResult, ErrorData>> + Send + '_ {
let tools = self.all_tools();
tracing::debug!(count = tools.len(), "Listing tools");
std::future::ready(Ok(ListToolsResult {
tools,
..Default::default()
}))
}
#[tracing::instrument(skip(self, context), fields(tool = %request.name))]
fn call_tool(
&self,
request: CallToolRequestParams,
context: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<CallToolResult, ErrorData>> + Send + '_ {
async move {
let idx = self
.dispatch
.get(request.name.as_ref())
.copied()
.ok_or_else(|| {
ErrorData::invalid_params(format!("tool `{}` not found", request.name), None)
})?;
let (prefix, plugin) = &self.plugins[idx];
let bare_name: String = if prefix.is_empty() {
request.name.as_ref().to_string()
} else {
strip_prefix(prefix, request.name.as_ref())
.ok_or_else(|| ErrorData::invalid_params("prefix mismatch", None))?
.to_string()
};
let mut forwarded = request;
forwarded.name = std::borrow::Cow::Owned(bare_name.clone());
tracing::debug!(bare = %bare_name, "Dispatching to plugin");
plugin.call_tool(forwarded, context).await
}
}
}
pub struct Toolchain<F = fn(&rmcp::model::Tool) -> bool>
where
F: Fn(&rmcp::model::Tool) -> bool + Send + Sync + 'static,
{
registry: PluginRegistry,
filter: F,
}
impl<F> Toolchain<F>
where
F: Fn(&rmcp::model::Tool) -> bool + Send + Sync + 'static,
{
pub fn new(registry: PluginRegistry, filter: F) -> Self {
Self { registry, filter }
}
}
impl<F> rmcp::ServerHandler for Toolchain<F>
where
F: Fn(&rmcp::model::Tool) -> bool + Send + Sync + 'static,
{
#[tracing::instrument(skip(self, _request, _context))]
fn list_tools(
&self,
_request: Option<PaginatedRequestParams>,
_context: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<ListToolsResult, ErrorData>> + Send + '_ {
let tools: Vec<_> = self
.registry
.all_tools()
.into_iter()
.filter(|t| (self.filter)(t))
.collect();
tracing::debug!(visible = tools.len(), "Listing toolchain tools");
std::future::ready(Ok(ListToolsResult {
tools,
..Default::default()
}))
}
#[tracing::instrument(skip(self, context), fields(tool = %request.name))]
fn call_tool(
&self,
request: CallToolRequestParams,
context: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<CallToolResult, ErrorData>> + Send + '_ {
async move {
let visible = self
.registry
.all_tools()
.into_iter()
.any(|t| t.name == request.name && (self.filter)(&t));
if !visible {
return Err(ErrorData::invalid_params(
format!("tool `{}` not in toolchain", request.name),
None,
));
}
self.registry.call_tool(request, context).await
}
}
}