use crate::cache::{CacheConfig, SearchCache};
use crate::tools::{self, fetch, search};
use crate::types::{
DaedraError, DaedraResult, PageContent, SearchArgs, SearchResponse, VisitPageArgs,
search_args_schema, visit_page_args_schema,
};
use crate::{SERVER_NAME, VERSION};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::RwLock;
use tracing::{debug, error, info, instrument};
pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TransportType {
#[default]
Stdio,
Sse {
port: u16,
host: [u8; 4],
},
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub cache: CacheConfig,
pub verbose: bool,
pub max_concurrent_tools: usize,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
cache: CacheConfig::default(),
verbose: false,
max_concurrent_tools: 10,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: Option<Value>,
pub method: String,
#[serde(default)]
pub params: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
impl JsonRpcResponse {
pub fn success(id: Option<Value>, result: Value) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: Some(result),
error: None,
}
}
pub fn error(id: Option<Value>, code: i32, message: String) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(JsonRpcError {
code,
message,
data: None,
}),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpTool {
pub name: String,
pub description: Option<String>,
#[serde(rename = "inputSchema")]
pub input_schema: Value,
}
#[derive(Clone)]
pub struct DaedraHandler {
cache: SearchCache,
search_provider: Arc<tools::SearchProvider>,
fetch_client: Arc<fetch::FetchClient>,
initialized: Arc<RwLock<bool>>,
}
impl DaedraHandler {
pub fn new(config: ServerConfig) -> DaedraResult<Self> {
Ok(Self {
cache: SearchCache::new(config.cache),
search_provider: Arc::new(tools::SearchProvider::auto()),
fetch_client: Arc::new(fetch::FetchClient::new()?),
initialized: Arc::new(RwLock::new(false)),
})
}
pub fn get_server_info(&self) -> Value {
json!({
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": SERVER_NAME,
"version": VERSION
}
})
}
pub fn list_tools(&self) -> Vec<McpTool> {
vec![
McpTool {
name: "web_search".to_string(),
description: Some(
"Search the web using 9 backends (Wikipedia, StackOverflow, GitHub, Wiby, Bing, Serper, Tavily, DDG Instant, DDG). Returns aggregated results from multiple sources."
.to_string(),
),
input_schema: search_args_schema(),
},
McpTool {
name: "visit_page".to_string(),
description: Some(
"Visit a webpage and extract its content as Markdown. Useful for reading articles, documentation, or any web page."
.to_string(),
),
input_schema: visit_page_args_schema(),
},
]
}
#[instrument(skip(self))]
pub async fn execute_search(&self, args: SearchArgs) -> DaedraResult<SearchResponse> {
let options = args.options.clone().unwrap_or_default();
if let Some(cached) = self
.cache
.get_search(
&args.query,
&options.region,
&options.safe_search.to_string(),
)
.await
{
info!(query = %args.query, "Returning cached search results");
return Ok(cached);
}
let mut response = self.search_provider.search(&args).await?;
let enrich_count = 3.min(response.data.len());
if enrich_count > 0 {
let fetch_client = self.fetch_client.clone();
let futures: Vec<_> = response.data[..enrich_count].iter()
.filter(|r| r.description.len() < 100) .map(|r| {
let url = r.url.clone();
let client = fetch_client.clone();
async move {
let args = crate::types::VisitPageArgs {
url: url.clone(),
selector: None,
include_images: false,
};
match tokio::time::timeout(
std::time::Duration::from_secs(5),
client.fetch(&args),
).await {
Ok(Ok(page)) => {
let snippet: String = page.content.chars().take(300).collect();
Some((url, snippet))
}
_ => None,
}
}
})
.collect();
let enrichments = futures::future::join_all(futures).await;
for enrichment in enrichments.into_iter().flatten() {
if let Some(result) = response.data.iter_mut().find(|r| r.url == enrichment.0) {
if result.description.len() < 100 {
result.description = enrichment.1;
}
}
}
}
self.cache
.set_search(
&args.query,
&options.region,
&options.safe_search.to_string(),
response.clone(),
)
.await;
Ok(response)
}
#[instrument(skip(self))]
pub async fn execute_fetch(&self, args: VisitPageArgs) -> DaedraResult<PageContent> {
if let Some(cached) = self
.cache
.get_page(&args.url, args.selector.as_deref())
.await
{
info!(url = %args.url, "Returning cached page content");
return Ok(cached);
}
let content = self.fetch_client.fetch(&args).await?;
self.cache
.set_page(&args.url, args.selector.as_deref(), content.clone())
.await;
Ok(content)
}
pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
debug!(method = %request.method, "Handling request");
match request.method.as_str() {
"initialize" => {
let mut initialized = self.initialized.write().await;
*initialized = true;
JsonRpcResponse::success(request.id, self.get_server_info())
},
"initialized" | "notifications/initialized" => {
JsonRpcResponse::success(request.id, json!({}))
},
"tools/list" => {
let tools = self.list_tools();
JsonRpcResponse::success(request.id, json!({ "tools": tools }))
},
"tools/call" => {
let params = match request.params {
Some(p) => p,
None => {
return JsonRpcResponse::error(
request.id,
-32602,
"Missing parameters".to_string(),
);
},
};
let tool_name = params
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default();
let arguments = params.get("arguments").cloned().unwrap_or(json!({}));
self.call_tool(request.id, tool_name, arguments).await
},
"ping" => JsonRpcResponse::success(request.id, json!({})),
_ => JsonRpcResponse::error(
request.id,
-32601,
format!("Method not found: {}", request.method),
),
}
}
async fn call_tool(&self, id: Option<Value>, name: &str, arguments: Value) -> JsonRpcResponse {
info!(tool = %name, "Executing tool");
match name {
"web_search" | "search_duckduckgo" => {
let args: SearchArgs = match serde_json::from_value(arguments) {
Ok(a) => a,
Err(e) => {
return JsonRpcResponse::error(
id,
-32602,
format!("Invalid search arguments: {}", e),
);
},
};
match self.execute_search(args).await {
Ok(response) => {
let text = serde_json::to_string_pretty(&response).unwrap_or_default();
JsonRpcResponse::success(
id,
json!({
"content": [{ "type": "text", "text": text }],
"isError": false
}),
)
},
Err(e) => {
error!(error = %e, "Search failed");
JsonRpcResponse::success(
id,
json!({
"content": [{ "type": "text", "text": format!("Search failed: {}", e) }],
"isError": true
}),
)
},
}
},
"visit_page" => {
let args: VisitPageArgs = match serde_json::from_value(arguments) {
Ok(a) => a,
Err(e) => {
return JsonRpcResponse::error(
id,
-32602,
format!("Invalid fetch arguments: {}", e),
);
},
};
if !fetch::is_valid_url(&args.url) {
return JsonRpcResponse::success(
id,
json!({
"content": [{ "type": "text", "text": "Invalid URL: must be HTTP or HTTPS" }],
"isError": true
}),
);
}
match self.execute_fetch(args).await {
Ok(content) => {
let output = format!(
"# {}\n\n**URL:** {}\n**Fetched:** {}\n**Words:** {}\n\n---\n\n{}",
content.title,
content.url,
content.timestamp,
content.word_count,
content.content
);
JsonRpcResponse::success(
id,
json!({
"content": [{ "type": "text", "text": output }],
"isError": false
}),
)
},
Err(e) => {
error!(error = %e, "Fetch failed");
JsonRpcResponse::success(
id,
json!({
"content": [{ "type": "text", "text": format!("Failed to fetch page: {}", e) }],
"isError": true
}),
)
},
}
},
_ => JsonRpcResponse::error(id, -32601, format!("Unknown tool: {}", name)),
}
}
pub fn cache(&self) -> &SearchCache {
&self.cache
}
}
pub struct DaedraServer {
handler: DaedraHandler,
#[allow(dead_code)]
config: ServerConfig,
}
impl DaedraServer {
pub fn new(config: ServerConfig) -> DaedraResult<Self> {
let handler = DaedraHandler::new(config.clone())?;
Ok(Self { handler, config })
}
pub fn with_defaults() -> DaedraResult<Self> {
Self::new(ServerConfig::default())
}
#[instrument(skip(self))]
pub async fn run(self, transport: TransportType) -> DaedraResult<()> {
info!(
server = SERVER_NAME,
version = VERSION,
"Starting Daedra MCP server"
);
match transport {
TransportType::Stdio => self.run_stdio().await,
TransportType::Sse { port, host } => self.run_sse(host, port).await,
}
}
async fn run_stdio(self) -> DaedraResult<()> {
info!("Starting STDIO transport");
let stdin = tokio::io::stdin();
let mut stdout = tokio::io::stdout();
let reader = BufReader::new(stdin);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
if line.trim().is_empty() {
continue;
}
debug!(request = %line, "Received request");
let request: JsonRpcRequest = match serde_json::from_str(&line) {
Ok(r) => r,
Err(e) => {
let error_response =
JsonRpcResponse::error(None, -32700, format!("Parse error: {}", e));
let response_str = serde_json::to_string(&error_response).unwrap();
stdout.write_all(response_str.as_bytes()).await?;
stdout.write_all(b"\n").await?;
stdout.flush().await?;
continue;
},
};
let is_notification = request.id.is_none();
let response = self.handler.handle_request(request).await;
if !is_notification {
let response_str = serde_json::to_string(&response).unwrap();
debug!(response = %response_str, "Sending response");
stdout.write_all(response_str.as_bytes()).await?;
stdout.write_all(b"\n").await?;
stdout.flush().await?;
}
}
info!("STDIO server stopped");
Ok(())
}
async fn run_sse(self, host: [u8; 4], port: u16) -> DaedraResult<()> {
use axum::{
Json, Router,
extract::State,
response::sse::{Event, Sse},
routing::{get, post},
};
use futures::stream::{self, Stream};
use std::convert::Infallible;
use tower_http::cors::CorsLayer;
info!(host = ?host, port = port, "Starting SSE transport");
let handler = Arc::new(self.handler);
async fn health() -> &'static str {
"OK"
}
async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let stream = stream::once(async { Ok(Event::default().data("connected")) });
Sse::new(stream)
}
async fn rpc_handler(
State(handler): State<Arc<DaedraHandler>>,
Json(request): Json<JsonRpcRequest>,
) -> Json<JsonRpcResponse> {
let response = handler.handle_request(request).await;
Json(response)
}
let app = Router::new()
.route("/health", get(health))
.route("/sse", get(sse_handler))
.route("/rpc", post(rpc_handler))
.layer(CorsLayer::permissive())
.with_state(handler);
let addr = std::net::SocketAddr::from((host, port));
let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
DaedraError::ServerError(format!(
"Failed to bind to {}:{}: {}",
host.iter()
.map(|b| b.to_string())
.collect::<Vec<_>>()
.join("."),
port,
e
))
})?;
info!(
"SSE server listening on http://{}:{}",
host.iter()
.map(|b| b.to_string())
.collect::<Vec<_>>()
.join("."),
port
);
axum::serve(listener, app)
.await
.map_err(|e| DaedraError::ServerError(format!("Server error: {}", e)))?;
Ok(())
}
pub fn cache_stats(&self) -> crate::cache::CacheStats {
self.handler.cache.stats()
}
pub async fn clear_cache(&self) {
self.handler.cache.clear().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_server_config_default() {
let config = ServerConfig::default();
assert!(!config.verbose);
assert_eq!(config.max_concurrent_tools, 10);
}
#[test]
fn test_transport_type_default() {
assert_eq!(TransportType::default(), TransportType::Stdio);
}
#[tokio::test]
async fn test_handler_creation() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config);
assert!(handler.is_ok());
}
#[test]
fn test_list_tools() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let tools = handler.list_tools();
assert_eq!(tools.len(), 2);
assert!(tools.iter().any(|t| t.name == "web_search"));
assert!(tools.iter().any(|t| t.name == "visit_page"));
}
#[test]
fn test_json_rpc_response_success() {
let response = JsonRpcResponse::success(Some(json!(1)), json!({"status": "ok"}));
assert_eq!(response.jsonrpc, "2.0");
assert!(response.result.is_some());
assert!(response.error.is_none());
}
#[test]
fn test_json_rpc_response_error() {
let response =
JsonRpcResponse::error(Some(json!(1)), -32600, "Invalid request".to_string());
assert_eq!(response.jsonrpc, "2.0");
assert!(response.result.is_none());
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32600);
}
#[tokio::test]
async fn test_handle_ping() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "ping".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(response.result.is_some());
assert!(response.error.is_none());
}
#[tokio::test]
async fn test_handle_initialize() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "initialize".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(response.result.is_some());
let result = response.result.unwrap();
assert_eq!(result["protocolVersion"], MCP_PROTOCOL_VERSION);
assert_eq!(result["serverInfo"]["name"], SERVER_NAME);
}
#[tokio::test]
async fn test_handle_tools_list() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "tools/list".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(response.result.is_some());
let result = response.result.unwrap();
let tools = result["tools"].as_array().unwrap();
assert_eq!(tools.len(), 2);
}
#[tokio::test]
async fn test_handle_unknown_method() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "unknown/method".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32601);
}
#[tokio::test]
async fn test_handle_notifications_initialized() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "notifications/initialized".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(
response.error.is_none(),
"notifications/initialized should not return error"
);
assert!(response.result.is_some());
}
#[tokio::test]
async fn test_handle_initialized_without_prefix() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "initialized".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(
response.error.is_none(),
"initialized should not return error"
);
assert!(response.result.is_some());
}
}