llmy-agent 0.16.1

All-in-one LLM utilities.
Documentation
//! MCP server support — expose a [`ToolBox`] as an MCP server over
//! stdio or Streamable HTTP.

use std::future::Future;
use std::sync::Arc;

use llmy_types::error::LLMYError;
use rmcp::model::{
    CallToolRequestParams, CallToolResult, Content, ListToolsResult, PaginatedRequestParams,
    ServerInfo,
};
use rmcp::transport::io::stdio;
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
use rmcp::transport::{StreamableHttpServerConfig, StreamableHttpService};
use rmcp::{ErrorData, ServerHandler, serve_server};

use crate::tool::ToolBox;

/// An MCP server backed by a [`ToolBox`].
///
/// Wraps a [`ToolBox`] and [`ServerInfo`], implementing the
/// [`ServerHandler`] trait so it can be served over stdio or HTTP.
#[derive(Clone)]
pub struct McpToolBox {
    toolbox: ToolBox,
    server_info: ServerInfo,
}

impl McpToolBox {
    /// Creates a new MCP server from the given tools and server metadata.
    pub fn new(toolbox: ToolBox, server_info: ServerInfo) -> Self {
        Self {
            toolbox,
            server_info,
        }
    }

    /// Serves the MCP server over stdin/stdout.
    pub async fn serve_stdio(self) -> Result<(), LLMYError> {
        let transport = stdio();
        let server = serve_server(self, transport).await?;
        server.waiting().await?;
        Ok(())
    }

    /// Serves the MCP server over Streamable HTTP on the given address.
    pub async fn serve_http(
        self,
        addr: impl tokio::net::ToSocketAddrs,
        config: StreamableHttpServerConfig,
    ) -> Result<(), LLMYError> {
        let session_manager = Arc::new(LocalSessionManager::default());
        let service = StreamableHttpService::new(move || Ok(self.clone()), session_manager, config);
        let router = axum::Router::new().fallback_service(service);
        let listener = tokio::net::TcpListener::bind(addr).await?;
        axum::serve(listener, router)
            .await
            .map_err(|e| LLMYError::IO(e.into()))?;
        Ok(())
    }

    fn to_mcp_tools(&self) -> Vec<rmcp::model::Tool> {
        self.toolbox.mcp_tools()
    }
}

impl ServerHandler for McpToolBox {
    fn get_info(&self) -> ServerInfo {
        self.server_info.clone()
    }

    fn list_tools(
        &self,
        _request: Option<PaginatedRequestParams>,
        _context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
    ) -> impl Future<Output = Result<ListToolsResult, ErrorData>> + Send + '_ {
        std::future::ready(Ok(ListToolsResult::with_all_items(self.to_mcp_tools())))
    }

    fn call_tool(
        &self,
        request: CallToolRequestParams,
        _context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
    ) -> impl Future<Output = Result<CallToolResult, ErrorData>> + Send + '_ {
        async move {
            let name = request.name.to_string();
            let arguments = request
                .arguments
                .map(serde_json::Value::Object)
                .unwrap_or(serde_json::Value::Object(Default::default()));
            tracing::info!("MCP server call_tool: {} {}", &name, &arguments);

            match self.toolbox.invoke_value(name, arguments).await {
                Some(Ok(result)) => Ok(CallToolResult::success(vec![Content::text(result)])),
                Some(Err(e)) => Ok(CallToolResult::error(vec![Content::text(e.to_string())])),
                None => Err(ErrorData::invalid_params(
                    "tool not found",
                    Some(serde_json::json!({ "name": request.name })),
                )),
            }
        }
    }
}