pub mod schema;
pub mod tool;
pub mod tools;
pub mod transport;
pub use {
crate::protocol::mcp::PROTOCOL_VERSION,
tool::{
Tool,
ToolError,
ToolRegistry,
},
tools::execute_command::CommandDescriptor,
};
use {
crate::{
connect::{
ipc::IpcHandle,
lsp::{
ClientKind,
DaemonConnection,
LspClient,
errors::LspClientError,
},
},
daemon::DaemonConfig,
protocol::{
jsonrpc::{
self,
Error,
ErrorCode,
Id,
Message,
Response,
},
lsp::{
ClientCapabilities,
ClientInfo,
GeneralClientCapabilities,
InitializeParams,
PositionEncodingKind,
},
mcp::{
CallToolParams,
CallToolResult,
Implementation,
InitializeResult,
ListToolsResult,
ServerCapabilities,
ToolsCapability,
},
},
},
std::{
collections::HashMap,
io,
},
};
#[derive(Debug)]
pub enum ConnectError {
Io(io::Error),
Initialize(LspClientError),
PositionEncoding {
want: PositionEncodingKind,
got: Option<PositionEncodingKind>,
},
}
impl From<io::Error> for ConnectError {
fn from(value: io::Error) -> Self {
Self::Io(value)
}
}
impl std::fmt::Display for ConnectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
| Self::Io(e) => write!(f, "failed to connect to laburnum daemon: {e}"),
| Self::Initialize(e) => write!(f, "LSP initialize failed: {e}"),
| Self::PositionEncoding { want, got } => write!(
f,
"daemon did not negotiate required position encoding {want:?} \
(got {got:?})"
),
}
}
}
impl std::error::Error for ConnectError {}
pub struct McpServer {
client: LspClient,
registry: ToolRegistry,
server_info: Implementation,
_ipc_handle: Option<IpcHandle>,
}
impl McpServer {
pub async fn connect(
daemon_config: DaemonConfig,
server_info: Implementation,
) -> Result<McpServerBuilder, ConnectError> {
let mut metadata = HashMap::new();
metadata.insert("client_name".to_string(), server_info.name.clone());
metadata.insert("client_version".to_string(), server_info.version.clone());
let connection = DaemonConnection::connect_as(
daemon_config,
&server_info.version,
ClientKind::Mcp,
metadata,
)
.await?;
let init_params = InitializeParams {
capabilities: ClientCapabilities {
general: Some(GeneralClientCapabilities {
position_encodings: Some(vec![PositionEncodingKind::UTF8]),
..Default::default()
}),
..Default::default()
},
client_info: Some(ClientInfo {
name: server_info.name.clone(),
version: Some(server_info.version.clone()),
}),
..Default::default()
};
let init_result = connection
.client
.start(init_params)
.await
.map_err(ConnectError::Initialize)?;
let got = init_result.capabilities.position_encoding.clone();
if got.as_ref() != Some(&PositionEncodingKind::UTF8) {
return Err(ConnectError::PositionEncoding {
want: PositionEncodingKind::UTF8,
got,
});
}
Ok(McpServerBuilder {
client: connection.client,
registry: ToolRegistry::new(),
server_info,
ipc_handle: Some(connection.handle),
})
}
#[cfg(test)]
#[doc(hidden)]
pub(crate) fn from_client_for_test(client: LspClient) -> McpServerBuilder {
McpServerBuilder {
client,
registry: ToolRegistry::new(),
server_info: Implementation {
name: "laburnum-mcp-test".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
},
ipc_handle: None,
}
}
pub async fn run_stdio(self) -> io::Result<()> {
let mut stdin =
smol::io::BufReader::new(smol::Unblock::new(std::io::stdin()));
let mut stdout = smol::Unblock::new(std::io::stdout());
self.run(&mut stdin, &mut stdout).await
}
pub async fn run<R, W>(self, reader: &mut R, writer: &mut W) -> io::Result<()>
where
R: smol::io::AsyncBufReadExt + Unpin,
W: smol::io::AsyncWriteExt + Unpin,
{
while let Some(msg) = transport::read_message(reader).await? {
match msg {
| Message::Request(req) => {
let (method, id, params) = req.into_parts();
let response = self.dispatch_request(&method, id, params).await;
transport::write_message(writer, &Message::Response(response)).await?;
},
| Message::Notification(_) => {
},
| Message::Response(_) => {
},
}
}
Ok(())
}
#[doc(hidden)] pub(crate) async fn dispatch_request(
&self,
method: &str,
id: Id,
params: Option<serde_json::Value>,
) -> Response {
match method {
| "initialize" => self.handle_initialize(id),
| "tools/list" => self.handle_list_tools(id),
| "tools/call" => self.handle_call_tool(id, params).await,
| other => Response::from_error(
id,
Error {
code: ErrorCode::MethodNotFound,
message: format!("unknown method: {other}").into(),
data: None,
},
),
}
}
fn handle_initialize(&self, id: Id) -> Response {
let result = InitializeResult {
protocol_version: PROTOCOL_VERSION.to_string(),
capabilities: ServerCapabilities {
tools: Some(ToolsCapability { list_changed: false }),
},
server_info: self.server_info.clone(),
};
Response::from_ok(id, json_or_internal_error(&result))
}
fn handle_list_tools(&self, id: Id) -> Response {
let result = ListToolsResult {
tools: self.registry.descriptors(),
next_cursor: None,
};
Response::from_ok(id, json_or_internal_error(&result))
}
async fn handle_call_tool(
&self,
id: Id,
params: Option<serde_json::Value>,
) -> Response {
let params: CallToolParams = match params {
| Some(v) => match serde_json::from_value(v) {
| Ok(p) => p,
| Err(e) => return Response::from_error(id, invalid_params(e.to_string())),
},
| None => {
return Response::from_error(
id,
invalid_params("tools/call requires a params object"),
);
},
};
let Some(tool) = self.registry.get(¶ms.name) else {
return Response::from_error(
id,
Error {
code: ErrorCode::InvalidParams,
message: format!("unknown tool: {}", params.name).into(),
data: None,
},
);
};
let result: CallToolResult = tool.call(&self.client, params.arguments).await;
Response::from_ok(id, json_or_internal_error(&result))
}
}
fn invalid_params(message: impl Into<String>) -> Error {
Error {
code: ErrorCode::InvalidParams,
message: message.into().into(),
data: None,
}
}
fn json_or_internal_error<T: serde::Serialize>(value: &T) -> serde_json::Value {
serde_json::to_value(value)
.unwrap_or_else(|e| serde_json::json!({ "error": e.to_string() }))
}
pub struct McpServerBuilder {
client: LspClient,
registry: ToolRegistry,
server_info: Implementation,
ipc_handle: Option<IpcHandle>,
}
impl McpServerBuilder {
#[must_use]
pub fn with_default_tools(mut self) -> Self {
tools::register_defaults(&mut self.registry);
self
}
#[must_use]
pub fn tool<T: Tool>(mut self) -> Self {
self.registry.register::<T>();
self
}
#[must_use]
pub fn commands(
mut self,
commands: impl IntoIterator<Item = CommandDescriptor>,
) -> Self {
tools::execute_command::register_commands(&mut self.registry, commands);
self
}
pub fn build(self) -> McpServer {
McpServer {
client: self.client,
registry: self.registry,
server_info: self.server_info,
_ipc_handle: self.ipc_handle,
}
}
}
#[allow(unused_imports)]
use jsonrpc as _jsonrpc;
#[cfg(test)]
mod tests {
use {
super::*,
crate::connect::ipc::Connection,
serde::{
Deserialize,
Serialize,
},
};
fn server_with_echo() -> McpServer {
let (server_conn, _client_conn) = Connection::memory();
McpServer::from_client_for_test(LspClient::new(server_conn))
.tool::<EchoTool>()
.build()
}
#[derive(Deserialize, Serialize)]
struct EchoInput {
message: String,
}
enum EchoTool {}
impl Tool for EchoTool {
type Input = EchoInput;
type Output = EchoInput;
const NAME: &'static str = "echo";
const DESCRIPTION: &'static str = "Echoes its arguments back as text.";
fn input_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": { "message": { "type": "string" } },
"required": ["message"]
})
}
async fn call(
_client: &LspClient,
input: Self::Input,
) -> Result<Self::Output, ToolError> {
Ok(input)
}
}
#[test]
fn initialize_declares_tools_capability() {
smol::block_on(async {
let server = server_with_echo();
let resp = server.dispatch_request("initialize", Id::Number(1), None).await;
let body = serde_json::to_value(&resp).unwrap();
assert_eq!(body["id"], 1);
let result = &body["result"];
assert_eq!(result["protocolVersion"], PROTOCOL_VERSION);
assert!(result["capabilities"]["tools"].is_object());
assert!(result["serverInfo"]["name"].is_string());
});
}
#[test]
fn tools_list_returns_registered_tools_sorted() {
smol::block_on(async {
let server = server_with_echo();
let resp = server.dispatch_request("tools/list", Id::Number(2), None).await;
let body = serde_json::to_value(&resp).unwrap();
let tools = body["result"]["tools"].as_array().unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0]["name"], "echo");
assert!(tools[0]["description"].is_string());
assert!(tools[0]["inputSchema"].is_object());
});
}
#[test]
fn tools_call_dispatches_to_tool() {
smol::block_on(async {
let server = server_with_echo();
let params = serde_json::json!({
"name": "echo",
"arguments": { "message": "hello" }
});
let resp = server
.dispatch_request("tools/call", Id::Number(3), Some(params))
.await;
let body = serde_json::to_value(&resp).unwrap();
let content = &body["result"]["content"][0];
assert_eq!(content["type"], "text");
assert!(content["text"].as_str().unwrap().contains("\"hello\""));
assert_eq!(body["result"]["isError"], false);
});
}
#[test]
fn tools_call_unknown_tool_returns_invalid_params() {
smol::block_on(async {
let server = server_with_echo();
let params = serde_json::json!({ "name": "missing", "arguments": {} });
let resp = server
.dispatch_request("tools/call", Id::Number(4), Some(params))
.await;
let body = serde_json::to_value(&resp).unwrap();
assert_eq!(body["error"]["code"], -32602);
assert!(body["error"]["message"].as_str().unwrap().contains("missing"));
});
}
#[test]
fn tools_call_invalid_args_returns_is_error() {
smol::block_on(async {
let (server_conn, _client_conn) = Connection::memory();
let server = McpServer::from_client_for_test(LspClient::new(server_conn))
.with_default_tools()
.build();
let params = serde_json::json!({
"name": "find_symbol",
"arguments": {}
});
let resp = server
.dispatch_request("tools/call", Id::Number(6), Some(params))
.await;
let body = serde_json::to_value(&resp).unwrap();
assert!(body["error"].is_null());
assert_eq!(body["result"]["isError"], true);
assert!(
body["result"]["content"][0]["text"]
.as_str()
.unwrap()
.contains("invalid arguments")
);
});
}
#[test]
fn unknown_method_returns_method_not_found() {
smol::block_on(async {
let server = server_with_echo();
let resp = server
.dispatch_request("bogus/method", Id::Number(5), None)
.await;
let body = serde_json::to_value(&resp).unwrap();
assert_eq!(body["error"]["code"], -32601);
});
}
}