use {
crate::{
connect::lsp::{
LspClient,
errors::LspClientError,
},
protocol::mcp::{
CallToolResult,
ToolDescriptor,
},
},
std::{
collections::HashMap,
sync::Arc,
},
};
#[derive(Debug)]
pub enum ToolError {
Lsp(LspClientError),
Other(String),
}
impl From<LspClientError> for ToolError {
fn from(value: LspClientError) -> Self {
Self::Lsp(value)
}
}
impl std::fmt::Display for ToolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
| Self::Lsp(e) => write!(f, "LSP request failed: {e}"),
| Self::Other(s) => f.write_str(s),
}
}
}
pub trait Tool: Send + Sync + 'static {
type Input: serde::de::DeserializeOwned + Send + Sync + 'static;
type Output: serde::Serialize + Send + Sync + 'static;
const NAME: &'static str;
const DESCRIPTION: &'static str;
fn input_schema() -> serde_json::Value;
fn call(
client: &LspClient,
input: Self::Input,
) -> impl std::future::Future<Output = Result<Self::Output, ToolError>> + Send;
}
pub trait DynTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn input_schema(&self) -> serde_json::Value;
fn call<'a>(
&'a self,
client: &'a LspClient,
arguments: serde_json::Value,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = CallToolResult> + Send + 'a>,
>;
}
struct Adapter<T: Tool>(std::marker::PhantomData<fn() -> T>);
impl<T: Tool> Adapter<T> {
const fn new() -> Self {
Self(std::marker::PhantomData)
}
}
impl<T: Tool> DynTool for Adapter<T> {
fn name(&self) -> &str {
T::NAME
}
fn description(&self) -> &str {
T::DESCRIPTION
}
fn input_schema(&self) -> serde_json::Value {
T::input_schema()
}
fn call<'a>(
&'a self,
client: &'a LspClient,
arguments: serde_json::Value,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = CallToolResult> + Send + 'a>,
> {
Box::pin(async move {
let input: T::Input = match serde_json::from_value(arguments) {
| Ok(v) => v,
| Err(e) => {
return CallToolResult::error_text(format!("invalid arguments: {e}"));
},
};
match T::call(client, input).await {
| Ok(output) => match serde_json::to_value(&output) {
| Ok(json) => CallToolResult::ok_json(&json),
| Err(e) => CallToolResult::error_text(format!(
"serialization failed: {e}"
)),
},
| Err(e) => CallToolResult::error_text(e.to_string()),
}
})
}
}
#[derive(Default)]
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn DynTool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register<T: Tool>(&mut self) {
self.insert(T::NAME.to_string(), Arc::new(Adapter::<T>::new()));
}
pub fn register_dyn(&mut self, tool: Arc<dyn DynTool>) {
let name = tool.name().to_string();
self.insert(name, tool);
}
fn insert(&mut self, name: String, tool: Arc<dyn DynTool>) {
if self.tools.contains_key(&name) {
panic!("duplicate MCP tool registration: {name}");
}
self.tools.insert(name, tool);
}
pub(crate) fn get(&self, name: &str) -> Option<Arc<dyn DynTool>> {
self.tools.get(name).cloned()
}
pub fn descriptors(&self) -> Vec<ToolDescriptor> {
let mut out: Vec<ToolDescriptor> = self
.tools
.values()
.map(|t| ToolDescriptor {
name: t.name().to_string(),
description: t.description().to_string(),
input_schema: t.input_schema(),
})
.collect();
out.sort_by(|a, b| a.name.cmp(&b.name));
out
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}