use serde::{Deserialize, Serialize};
use tokitai_core::ToolDefinition;
#[cfg(feature = "mcp")]
use async_trait::async_trait;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpTool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_schema: Option<serde_json::Value>,
}
pub fn to_mcp_tools(tools: &[ToolDefinition]) -> Vec<McpTool> {
tools
.iter()
.filter_map(|t| match serde_json::from_str(&t.input_schema) {
Ok(schema) => Some(McpTool {
name: t.name.to_string(),
description: t.description.to_string(),
input_schema: schema,
output_schema: None,
}),
Err(e) => {
log::warn!("failed to parse schema for tool '{}': {}", t.name, e);
None
}
})
.collect()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolCall {
pub name: String,
#[serde(default)]
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolResponse {
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
impl McpToolResponse {
pub fn success(result: serde_json::Value) -> Self {
Self {
success: true,
result: Some(result),
error: None,
}
}
pub fn error(message: impl Into<String>) -> Self {
Self {
success: false,
result: None,
error: Some(message.into()),
}
}
}
#[cfg(feature = "mcp")]
#[async_trait]
pub trait McpServer: Sized + Send + Sync {
async fn list_tools(&self) -> Vec<McpTool>;
async fn call_tool(&self, name: &str, arguments: &serde_json::Value) -> McpToolResponse;
async fn tool_count(&self) -> usize {
self.list_tools().await.len()
}
}
#[cfg(feature = "mcp")]
pub struct McpServerWrapper<T> {
inner: T,
}
#[cfg(feature = "mcp")]
impl<T> McpServerWrapper<T>
where
T: tokitai_core::ToolProvider + tokitai_core::ToolCaller + Clone + Send + Sync + 'static,
{
pub fn new(inner: T) -> Self {
Self { inner }
}
pub fn inner(&self) -> &T {
&self.inner
}
pub fn to_mcp_tools(&self) -> Vec<McpTool> {
to_mcp_tools(T::tool_definitions())
}
}
#[cfg(feature = "mcp")]
#[async_trait]
impl<T> McpServer for McpServerWrapper<T>
where
T: tokitai_core::ToolProvider + tokitai_core::ToolCaller + Clone + Send + Sync + 'static,
{
async fn list_tools(&self) -> Vec<McpTool> {
self.to_mcp_tools()
}
async fn call_tool(&self, name: &str, arguments: &serde_json::Value) -> McpToolResponse {
match self.inner.call_tool(name, arguments) {
Ok(result) => McpToolResponse::success(result),
Err(e) => McpToolResponse::error(format!("{}", e)),
}
}
}
#[macro_export]
macro_rules! impl_mcp_server {
($type:ty) => {
impl $type {
#[cfg(feature = "mcp")]
pub fn new_mcp_server() -> $crate::mcp::McpServerWrapper<Self> {
$crate::mcp::McpServerWrapper::new(Self::default())
}
#[cfg(feature = "mcp")]
pub fn mcp_tool_definitions() -> Vec<$crate::mcp::McpTool> {
$crate::mcp::to_mcp_tools(<Self as $crate::ToolProvider>::tool_definitions())
}
}
};
}
#[cfg(feature = "http-server")]
#[derive(Debug, Clone)]
pub struct McpHttpConfig {
pub host: String,
pub port: u16,
pub cors_enabled: bool,
}
#[cfg(feature = "http-server")]
impl Default for McpHttpConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 8080,
cors_enabled: true,
}
}
}
#[cfg(feature = "http-server")]
pub struct McpHttpServer<T> {
#[allow(dead_code)]
inner: T,
#[allow(dead_code)]
config: McpHttpConfig,
}
#[cfg(feature = "http-server")]
impl<T> McpHttpServer<T>
where
T: tokitai_core::ToolProvider + Clone + Send + Sync + 'static,
{
pub fn new(inner: T) -> Self {
Self {
inner,
config: McpHttpConfig::default(),
}
}
pub fn with_config(inner: T, config: McpHttpConfig) -> Self {
Self { inner, config }
}
pub async fn run(&self, addr: &str) -> Result<(), Box<dyn std::error::Error>> {
use axum::{extract::State, routing::get, Json, Router};
use std::sync::Arc;
let addr_parts: Vec<&str> = addr.split(':').collect();
let host = addr_parts.first().unwrap_or(&"127.0.0.1");
let port: u16 = addr_parts
.get(1)
.and_then(|s| s.parse().ok())
.unwrap_or(8080);
let app_state = Arc::new(AppState::new(to_mcp_tools(T::tool_definitions())));
async fn list_tools_handler(State(state): State<Arc<AppState>>) -> Json<Vec<McpTool>> {
Json(state.tools.clone())
}
async fn call_tool_handler(Json(_request): Json<McpToolCall>) -> Json<McpToolResponse> {
Json(McpToolResponse::error("Tool execution requires concrete implementation. Use examples/mcp_http_server.rs for a complete example."))
}
async fn health_handler() -> &'static str {
"OK"
}
let app = Router::new()
.route("/tools", get(list_tools_handler))
.route("/call", axum::routing::post(call_tool_handler))
.route("/health", get(health_handler))
.with_state(app_state);
let listener = tokio::net::TcpListener::bind(format!("{}:{}", host, port)).await?;
println!("MCP Server listening on http://{}:{}", host, port);
println!(" - GET /tools - List available tools");
println!(" - POST /call - Call a tool");
println!(" - GET /health - Health check");
axum::serve(listener, app).await?;
Ok(())
}
}
#[cfg(feature = "http-server")]
pub struct AppState {
pub tools: Vec<McpTool>,
}
#[cfg(feature = "http-server")]
impl AppState {
pub fn new(tools: Vec<McpTool>) -> Self {
Self { tools }
}
}
#[cfg(all(feature = "http-server", feature = "runtime"))]
#[derive(Debug, Clone, Serialize)]
pub struct McpSseMessage {
pub event: String,
pub data: serde_json::Value,
}
#[cfg(all(feature = "http-server", feature = "runtime"))]
impl McpSseMessage {
pub fn tool_list(tools: Vec<McpTool>) -> Self {
Self {
event: "tools/list".to_string(),
data: serde_json::to_value(tools).unwrap_or_default(),
}
}
pub fn tool_result(result: McpToolResponse) -> Self {
Self {
event: "tool/result".to_string(),
data: serde_json::to_value(result).unwrap_or_default(),
}
}
}