use crate::cache::{CacheConfig, SearchCache};
use crate::tools::{self, fetch, crawl_site};
use crate::types::{
CrawlArgs, DaedraError, DaedraResult, PageContent, SearchArgs, SearchResponse, SearchResult,
VisitPageArgs, crawl_args_schema, search_args_schema, visit_page_args_schema,
};
use crate::{SERVER_NAME, VERSION};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::{RwLock, Semaphore};
use tracing::{debug, error, info, instrument};
pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TransportType {
#[default]
Stdio,
Sse {
port: u16,
host: [u8; 4],
},
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub cache: CacheConfig,
pub verbose: bool,
pub max_concurrent_tools: usize,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
cache: CacheConfig::default(),
verbose: false,
max_concurrent_tools: 10,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: Option<Value>,
pub method: String,
#[serde(default)]
pub params: Option<Value>,
}
pub fn is_notification(request: &JsonRpcRequest) -> bool {
request.id.is_none() || matches!(&request.id, Some(Value::Null))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
impl JsonRpcResponse {
pub fn success(id: Option<Value>, result: Value) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: Some(result),
error: None,
}
}
pub fn error(id: Option<Value>, code: i32, message: String) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(JsonRpcError {
code,
message,
data: None,
}),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpTool {
pub name: String,
pub description: Option<String>,
#[serde(rename = "inputSchema")]
pub input_schema: Value,
}
#[derive(Clone)]
pub struct DaedraHandler {
cache: SearchCache,
search_provider: Arc<tools::SearchProvider>,
fetch_client: Arc<fetch::FetchClient>,
initialized: Arc<RwLock<bool>>,
}
impl DaedraHandler {
pub fn new(config: ServerConfig) -> DaedraResult<Self> {
Ok(Self {
cache: SearchCache::new(config.cache),
search_provider: Arc::new(tools::SearchProvider::auto()),
fetch_client: Arc::new(fetch::FetchClient::new()?),
initialized: Arc::new(RwLock::new(false)),
})
}
pub fn get_server_info(&self) -> Value {
json!({
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": SERVER_NAME,
"version": VERSION
}
})
}
pub fn list_tools(&self) -> Vec<McpTool> {
vec![
McpTool {
name: "web_search".to_string(),
description: Some(
"Search the web using 9 backends (Wikipedia, StackOverflow, GitHub, Wiby, Bing, Serper, Tavily, DDG Instant, DDG). Returns aggregated results from multiple sources."
.to_string(),
),
input_schema: search_args_schema(),
},
McpTool {
name: "search_duckduckgo".to_string(),
description: Some(
"Alias for web_search (backward compatibility). Search the web using 9 backends (Wikipedia, StackOverflow, GitHub, Wiby, Bing, Serper, Tavily, DDG Instant, DDG). Returns aggregated results from multiple sources."
.to_string(),
),
input_schema: search_args_schema(),
},
McpTool {
name: "visit_page".to_string(),
description: Some(
"Visit a webpage and extract its content as Markdown. Useful for reading articles, documentation, or any web page."
.to_string(),
),
input_schema: visit_page_args_schema(),
},
McpTool {
name: "crawl_site".to_string(),
description: Some(
"Crawl a website starting from a root URL. Discovers pages via sitemap.xml or link following, fetches up to max_pages concurrently, and returns Markdown content for each page."
.to_string(),
),
input_schema: crawl_args_schema(),
},
]
}
#[instrument(skip(self))]
pub async fn execute_search(&self, args: SearchArgs) -> DaedraResult<SearchResponse> {
let options = args.options.clone().unwrap_or_default();
if let Some(cached) = self
.cache
.get_search(
&args.query,
&options.region,
&options.safe_search.to_string(),
)
.await
{
info!(query = %args.query, "Returning cached search results");
return Ok(cached);
}
let mut response = self.search_provider.search(&args).await?;
self.enrich_sparse_results(&mut response.data, 3).await;
self.cache
.set_search(
&args.query,
&options.region,
&options.safe_search.to_string(),
response.clone(),
)
.await;
Ok(response)
}
async fn enrich_sparse_results(&self, results: &mut [SearchResult], count: usize) {
let enrich_count = count.min(results.len());
if enrich_count == 0 {
return;
}
let fetch_client = self.fetch_client.clone();
let enrich_semaphore = Arc::new(Semaphore::new(2));
let futures: Vec<_> = results[..enrich_count]
.iter()
.filter(|r| r.description.len() < 100)
.map(|r| {
let url = r.url.clone();
let client = fetch_client.clone();
let semaphore = enrich_semaphore.clone();
async move {
let _permit = semaphore.acquire_owned().await.unwrap();
let args = VisitPageArgs {
url: url.clone(),
selector: None,
include_images: false,
};
match tokio::time::timeout(
std::time::Duration::from_secs(5),
client.fetch(&args),
)
.await
{
Ok(Ok(page)) => {
let snippet: String = page.content.chars().take(300).collect();
Some((url, snippet))
}
_ => None,
}
}
})
.collect();
let enrichments = futures::future::join_all(futures).await;
for enrichment in enrichments.into_iter().flatten() {
if let Some(result) = results.iter_mut().find(|r| r.url == enrichment.0) {
if result.description.len() < 100 {
result.description = enrichment.1;
}
}
}
}
#[instrument(skip(self))]
pub async fn execute_fetch(&self, args: VisitPageArgs) -> DaedraResult<PageContent> {
if let Some(cached) = self
.cache
.get_page(&args.url, args.selector.as_deref())
.await
{
info!(url = %args.url, "Returning cached page content");
return Ok(cached);
}
let content = self.fetch_client.fetch(&args).await?;
self.cache
.set_page(&args.url, args.selector.as_deref(), content.clone())
.await;
Ok(content)
}
pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
debug!(method = %request.method, "Handling request");
if request.method == "initialize" {
let mut initialized = self.initialized.write().await;
*initialized = true;
}
self.handle_method(&request.method, request.id, request.params)
.await
}
async fn handle_method(
&self,
method: &str,
id: Option<Value>,
params: Option<Value>,
) -> JsonRpcResponse {
match method {
"initialize" => JsonRpcResponse::success(id, self.get_server_info()),
"initialized" | "notifications/initialized" => JsonRpcResponse::success(id, json!({})),
"tools/list" => JsonRpcResponse::success(id, json!({ "tools": self.list_tools() })),
"tools/call" => match parse_tool_call_params(params, id.clone()) {
Ok((name, args)) => self.call_tool(id, &name, args).await,
Err(resp) => resp,
},
"ping" => JsonRpcResponse::success(id, json!({})),
_ => JsonRpcResponse::error(
id,
-32601,
format!("Method not found: {}", method),
),
}
}
async fn handle_web_search(&self, id: Option<Value>, arguments: Value) -> JsonRpcResponse {
let args: SearchArgs = match serde_json::from_value(arguments) {
Ok(a) => a,
Err(e) => {
return JsonRpcResponse::error(
id,
-32602,
format!("Invalid search arguments: {}", e),
);
},
};
match self.execute_search(args).await {
Ok(response) => {
let text = serde_json::to_string_pretty(&response).unwrap_or_default();
tool_success_response(id, text)
}
Err(e) => {
error!(error = %e, "Search failed");
tool_error_response(id, &format!("Search failed: {}", e))
}
}
}
async fn handle_visit_page(&self, id: Option<Value>, arguments: Value) -> JsonRpcResponse {
let args: VisitPageArgs = match serde_json::from_value(arguments) {
Ok(a) => a,
Err(e) => {
return JsonRpcResponse::error(
id,
-32602,
format!("Invalid fetch arguments: {}", e),
);
},
};
if !fetch::is_valid_url(&args.url) {
return tool_error_response(id, "Invalid URL: must be HTTP or HTTPS");
}
match self.execute_fetch(args).await {
Ok(content) => tool_success_response(id, format_page_result(&content)),
Err(e) => {
error!(error = %e, "Fetch failed");
tool_error_response(id, &format!("Failed to fetch page: {}", e))
}
}
}
async fn handle_crawl_site(&self, id: Option<Value>, arguments: Value) -> JsonRpcResponse {
let args: CrawlArgs = match serde_json::from_value(arguments) {
Ok(a) => a,
Err(e) => {
return JsonRpcResponse::error(
id,
-32602,
format!("Invalid crawl arguments: {}", e),
);
},
};
match crawl_site(args).await {
Ok(result) => {
let text = serde_json::to_string_pretty(&result).unwrap_or_default();
tool_success_response(id, text)
}
Err(e) => {
error!(error = %e, "Crawl failed");
tool_error_response(id, &format!("Crawl failed: {}", e))
}
}
}
async fn call_tool(&self, id: Option<Value>, name: &str, arguments: Value) -> JsonRpcResponse {
info!(tool = %name, "Executing tool");
match name {
"web_search" | "search_duckduckgo" => self.handle_web_search(id, arguments).await,
"visit_page" => self.handle_visit_page(id, arguments).await,
"crawl_site" => self.handle_crawl_site(id, arguments).await,
_ => JsonRpcResponse::error(id, -32601, format!("Unknown tool: {}", name)),
}
}
pub fn cache(&self) -> &SearchCache {
&self.cache
}
}
fn parse_tool_call_params(
params: Option<Value>,
id: Option<Value>,
) -> Result<(String, Value), JsonRpcResponse> {
let params = match params {
Some(p) => p,
None => {
return Err(JsonRpcResponse::error(
id,
-32602,
"Missing parameters".to_string(),
));
}
};
let tool_name = params
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let arguments = params.get("arguments").cloned().unwrap_or(json!({}));
Ok((tool_name, arguments))
}
fn format_page_result(content: &PageContent) -> String {
format!(
"# {}
**URL:** {}
**Fetched:** {}
**Words:** {}
---
{}",
content.title, content.url, content.timestamp, content.word_count, content.content
)
}
fn tool_error_response(id: Option<Value>, message: &str) -> JsonRpcResponse {
JsonRpcResponse::success(
id,
json!({
"content": [{ "type": "text", "text": message }],
"isError": true
}),
)
}
fn tool_success_response(id: Option<Value>, text: String) -> JsonRpcResponse {
JsonRpcResponse::success(
id,
json!({
"content": [{ "type": "text", "text": text }],
"isError": false
}),
)
}
async fn process_stdio_line(line: &str, handler: &DaedraHandler) -> Option<JsonRpcResponse> {
if line.trim().is_empty() {
return None;
}
debug!(request = %line, "Received request");
let request: JsonRpcRequest = match serde_json::from_str(line) {
Ok(r) => r,
Err(e) => {
return Some(JsonRpcResponse::error(
None,
-32700,
format!("Parse error: {}", e),
));
}
};
let response = handler.handle_request(request.clone()).await;
if is_notification(&request) {
None
} else {
Some(response)
}
}
async fn write_stdio_response(
response: JsonRpcResponse,
stdout: &mut tokio::io::BufWriter<tokio::io::Stdout>,
) -> std::io::Result<()> {
let response_str = serde_json::to_string(&response).unwrap();
debug!(response = %response_str, "Sending response");
stdout.write_all(response_str.as_bytes()).await?;
stdout.write_all(b"
").await?;
stdout.flush().await
}
pub struct DaedraServer {
handler: DaedraHandler,
#[allow(dead_code)]
config: ServerConfig,
}
impl DaedraServer {
pub fn new(config: ServerConfig) -> DaedraResult<Self> {
let handler = DaedraHandler::new(config.clone())?;
Ok(Self { handler, config })
}
pub fn with_defaults() -> DaedraResult<Self> {
Self::new(ServerConfig::default())
}
#[instrument(skip(self))]
pub async fn run(self, transport: TransportType) -> DaedraResult<()> {
info!(
server = SERVER_NAME,
version = VERSION,
"Starting Daedra MCP server"
);
match transport {
TransportType::Stdio => self.run_stdio().await,
TransportType::Sse { port, host } => self.run_sse(host, port).await,
}
}
async fn run_stdio(self) -> DaedraResult<()> {
info!("Starting STDIO transport");
let stdin = tokio::io::stdin();
let mut stdout = tokio::io::BufWriter::new(tokio::io::stdout());
let reader = BufReader::new(stdin);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
if let Some(response) = process_stdio_line(&line, &self.handler).await {
write_stdio_response(response, &mut stdout).await?;
}
}
info!("STDIO server stopped");
Ok(())
}
async fn run_sse(self, host: [u8; 4], port: u16) -> DaedraResult<()> {
use axum::{
Json, Router,
extract::State,
response::sse::{Event, Sse},
routing::{get, post},
};
use futures::stream::{self, Stream};
use std::convert::Infallible;
use tower_http::cors::CorsLayer;
info!(host = ?host, port = port, "Starting SSE transport");
let handler = Arc::new(self.handler);
async fn health() -> &'static str {
"OK"
}
async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let stream = stream::once(async { Ok(Event::default().data("connected")) });
Sse::new(stream)
}
async fn rpc_handler(
State(handler): State<Arc<DaedraHandler>>,
Json(request): Json<JsonRpcRequest>,
) -> Json<JsonRpcResponse> {
let response = handler.handle_request(request).await;
Json(response)
}
let app = Router::new()
.route("/health", get(health))
.route("/sse", get(sse_handler))
.route("/rpc", post(rpc_handler))
.layer(CorsLayer::permissive())
.with_state(handler);
let addr = std::net::SocketAddr::from((host, port));
let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
DaedraError::ServerError(format!(
"Failed to bind to {}:{}: {}",
host.iter()
.map(|b| b.to_string())
.collect::<Vec<_>>()
.join("."),
port,
e
))
})?;
info!(
"SSE server listening on http://{}:{}",
host.iter()
.map(|b| b.to_string())
.collect::<Vec<_>>()
.join("."),
port
);
axum::serve(listener, app)
.await
.map_err(|e| DaedraError::ServerError(format!("Server error: {}", e)))?;
Ok(())
}
pub fn cache_stats(&self) -> crate::cache::CacheStats {
self.handler.cache.stats()
}
pub async fn clear_cache(&self) {
self.handler.cache.clear().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_server_config_default() {
let config = ServerConfig::default();
assert!(!config.verbose);
assert_eq!(config.max_concurrent_tools, 10);
}
#[test]
fn test_transport_type_default() {
assert_eq!(TransportType::default(), TransportType::Stdio);
}
#[tokio::test]
async fn test_handler_creation() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config);
assert!(handler.is_ok());
}
#[test]
fn test_list_tools() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let tools = handler.list_tools();
assert_eq!(tools.len(), 4);
assert!(tools.iter().any(|t| t.name == "web_search"));
assert!(tools.iter().any(|t| t.name == "search_duckduckgo"));
assert!(tools.iter().any(|t| t.name == "visit_page"));
assert!(tools.iter().any(|t| t.name == "crawl_site"));
}
#[test]
fn test_json_rpc_response_success() {
let response = JsonRpcResponse::success(Some(json!(1)), json!({"status": "ok"}));
assert_eq!(response.jsonrpc, "2.0");
assert!(response.result.is_some());
assert!(response.error.is_none());
}
#[test]
fn test_json_rpc_response_error() {
let response =
JsonRpcResponse::error(Some(json!(1)), -32600, "Invalid request".to_string());
assert_eq!(response.jsonrpc, "2.0");
assert!(response.result.is_none());
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32600);
}
#[tokio::test]
async fn test_handle_ping() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "ping".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(response.result.is_some());
assert!(response.error.is_none());
}
#[tokio::test]
async fn test_handle_initialize() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "initialize".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(response.result.is_some());
let result = response.result.unwrap();
assert_eq!(result["protocolVersion"], MCP_PROTOCOL_VERSION);
assert_eq!(result["serverInfo"]["name"], SERVER_NAME);
}
#[tokio::test]
async fn test_handle_tools_list() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "tools/list".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(response.result.is_some());
let result = response.result.unwrap();
let tools = result["tools"].as_array().unwrap();
assert_eq!(tools.len(), 4);
}
#[tokio::test]
async fn test_handle_unknown_method() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "unknown/method".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32601);
}
#[tokio::test]
async fn test_handle_notifications_initialized() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "notifications/initialized".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(
response.error.is_none(),
"notifications/initialized should not return error"
);
assert!(response.result.is_some());
}
#[tokio::test]
async fn test_handle_initialized_without_prefix() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "initialized".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(
response.error.is_none(),
"initialized should not return error"
);
assert!(response.result.is_some());
}
#[tokio::test]
#[ignore = "network"]
async fn test_handle_call_tool_web_search() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "tools/call".to_string(),
params: Some(json!({"name": "web_search", "arguments": {"query": "test"}})),
};
let response = handler.handle_request(request).await;
assert!(response.error.is_none());
let result = response.result.unwrap();
assert_eq!(result["isError"], false);
assert!(result["content"].as_array().is_some());
}
#[tokio::test]
async fn test_handle_call_tool_unknown() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "tools/call".to_string(),
params: Some(json!({"name": "nonexistent", "arguments": {}})),
};
let response = handler.handle_request(request).await;
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32601);
}
#[tokio::test]
async fn test_handle_call_tool_missing_params() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "tools/call".to_string(),
params: None,
};
let response = handler.handle_request(request).await;
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32602);
}
#[test]
fn test_is_notification_no_id() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: None,
method: "initialized".to_string(),
params: None,
};
assert!(is_notification(&request));
}
#[test]
fn test_is_notification_null_id() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(Value::Null),
method: "initialized".to_string(),
params: None,
};
assert!(is_notification(&request));
}
#[test]
fn test_is_notification_with_id() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "ping".to_string(),
params: None,
};
assert!(!is_notification(&request));
}
#[tokio::test]
async fn test_json_rpc_parse_error() {
let config = ServerConfig::default();
let handler = DaedraHandler::new(config).unwrap();
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "tools/call".to_string(),
params: Some(json!({"name": "web_search", "arguments": {"not_query": true}})),
};
let response = handler.handle_request(request).await;
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32602);
}
#[tokio::test]
async fn test_execute_search_caches_results() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let args = SearchArgs {
query: "cache-test-unique-query-xyz".to_string(),
options: None,
};
let options = args.options.clone().unwrap_or_default();
let cached_response = SearchResponse::new(args.query.clone(), vec![], &options);
handler
.cache()
.set_search(
&args.query,
&options.region,
&options.safe_search.to_string(),
cached_response.clone(),
)
.await;
let result = handler.execute_search(args).await.unwrap();
assert_eq!(result.data.len(), cached_response.data.len());
assert_eq!(result.metadata.query, cached_response.metadata.query);
}
#[tokio::test]
async fn test_handle_method_initialize() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler
.handle_method("initialize", Some(json!(1)), None)
.await;
assert!(response.result.is_some());
assert!(response.error.is_none());
let result = response.result.unwrap();
assert_eq!(result["protocolVersion"], MCP_PROTOCOL_VERSION);
assert_eq!(result["serverInfo"]["name"], SERVER_NAME);
}
#[tokio::test]
async fn test_handle_method_ping() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler.handle_method("ping", Some(json!(1)), None).await;
assert!(response.result.is_some());
assert!(response.error.is_none());
}
#[tokio::test]
async fn test_handle_method_tools_list() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler.handle_method("tools/list", Some(json!(1)), None).await;
assert!(response.result.is_some());
let result = response.result.unwrap();
let tools = result["tools"].as_array().unwrap();
assert_eq!(tools.len(), 4);
}
#[tokio::test]
async fn test_handle_method_unknown() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler.handle_method("foo", Some(json!(1)), None).await;
let err = response.error.unwrap();
assert_eq!(err.code, -32601);
assert!(err.message.contains("foo"));
}
#[tokio::test]
async fn test_handle_method_initialized() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler.handle_method("initialized", Some(json!(1)), None).await;
assert!(response.error.is_none());
assert_eq!(response.result.unwrap(), json!({}));
}
#[tokio::test]
async fn test_handle_method_notifications_initialized() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler
.handle_method("notifications/initialized", Some(json!(1)), None)
.await;
assert!(response.error.is_none());
assert_eq!(response.result.unwrap(), json!({}));
}
#[tokio::test]
async fn test_handle_method_tools_call_missing_params() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler
.handle_method("tools/call", Some(json!(1)), None)
.await;
assert!(response.result.is_none());
let err = response.error.unwrap();
assert_eq!(err.code, -32602);
assert!(err.message.contains("Missing parameters"));
}
#[tokio::test]
async fn test_handle_method_tools_call_unknown_tool() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler
.handle_method(
"tools/call",
Some(json!(1)),
Some(json!({"name": "unknown", "arguments": {}})),
)
.await;
assert!(response.result.is_none());
let err = response.error.unwrap();
assert_eq!(err.code, -32601);
assert!(err.message.contains("unknown"));
}
#[tokio::test]
#[ignore = "network"]
async fn test_handle_method_tools_call_web_search_no_args() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler
.handle_method(
"tools/call",
Some(json!(1)),
Some(json!({"name": "web_search", "arguments": {}})),
)
.await;
assert!(response.error.is_none());
let result = response.result.unwrap();
assert!(result.get("isError").is_some());
}
#[tokio::test]
#[ignore = "network"]
async fn test_execute_search_returns_results() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let args = SearchArgs {
query: "Rust programming language".to_string(),
options: Some(crate::types::SearchOptions {
num_results: 5,
..Default::default()
}),
};
let response = handler.execute_search(args).await.unwrap();
assert!(!response.data.is_empty(), "search should return results");
}
#[tokio::test]
#[ignore = "network"]
async fn test_execute_search_caches_on_second_call() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let args = SearchArgs {
query: "cache-second-call-unique-query-abc".to_string(),
options: None,
};
let first = handler.execute_search(args.clone()).await;
let second = handler.execute_search(args).await;
assert!(first.is_ok(), "first search should succeed: {:?}", first.err());
assert!(second.is_ok(), "second search should succeed: {:?}", second.err());
assert!(!first.unwrap().data.is_empty());
assert!(!second.unwrap().data.is_empty());
}
#[tokio::test]
async fn test_handle_visit_page_malformed_args() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler
.handle_visit_page(Some(json!(1)), json!({"url": 12345}))
.await;
assert!(response.result.is_none());
let err = response.error.unwrap();
assert_eq!(err.code, -32602);
assert!(err.message.contains("Invalid fetch arguments"));
}
#[tokio::test]
async fn test_handle_visit_page_invalid_url() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler
.handle_method(
"tools/call",
Some(json!(1)),
Some(json!({
"name": "visit_page",
"arguments": {"url": "ftp://example.com"}
})),
)
.await;
assert!(response.error.is_none());
let result = response.result.unwrap();
assert_eq!(result["isError"], true);
let text = result["content"][0]["text"].as_str().unwrap();
assert!(text.contains("Invalid URL"));
}
#[tokio::test]
#[ignore = "network"]
async fn test_handle_method_tools_call_web_search() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler
.handle_method(
"tools/call",
Some(json!(1)),
Some(json!({
"name": "web_search",
"arguments": {"query": "test"}
})),
)
.await;
assert!(response.error.is_none());
let result = response.result.unwrap();
assert_eq!(result["isError"], false);
assert!(result["content"].as_array().is_some());
}
#[tokio::test]
async fn test_handle_method_tools_call_visit_page_invalid() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler
.handle_method(
"tools/call",
Some(json!(1)),
Some(json!({
"name": "visit_page",
"arguments": {"url": "not-a-url"}
})),
)
.await;
assert!(response.error.is_none());
let result = response.result.unwrap();
assert_eq!(result["isError"], true);
}
#[test]
fn test_parse_tool_call_params_valid() {
let result = parse_tool_call_params(
Some(json!({"name": "web_search", "arguments": {}})),
Some(json!(1)),
);
assert!(result.is_ok());
let (name, args) = result.unwrap();
assert_eq!(name, "web_search");
assert_eq!(args, json!({}));
}
#[test]
fn test_parse_tool_call_params_missing() {
let result = parse_tool_call_params(None, Some(json!(1)));
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.error.unwrap().code, -32602);
}
#[test]
fn test_parse_tool_call_params_no_name() {
let result = parse_tool_call_params(Some(json!({})), Some(json!(1)));
assert!(result.is_ok());
let (name, args) = result.unwrap();
assert_eq!(name, "");
assert_eq!(args, json!({}));
}
#[test]
fn test_parse_tool_call_params_with_args() {
let result = parse_tool_call_params(
Some(json!({
"name": "visit_page",
"arguments": {"url": "https://example.com"}
})),
Some(json!(1)),
);
assert!(result.is_ok());
let (name, args) = result.unwrap();
assert_eq!(name, "visit_page");
assert_eq!(args["url"], "https://example.com");
}
#[test]
fn test_tool_error_response_has_is_error() {
let response = tool_error_response(Some(json!(1)), "something went wrong");
let result = response.result.unwrap();
assert_eq!(result["isError"], true);
assert_eq!(
result["content"][0]["text"].as_str().unwrap(),
"something went wrong"
);
}
#[test]
fn test_tool_success_response_no_error() {
let response = tool_success_response(Some(json!(1)), "ok".to_string());
let result = response.result.unwrap();
assert_eq!(result["isError"], false);
assert_eq!(result["content"][0]["text"].as_str().unwrap(), "ok");
}
#[test]
fn test_format_page_result() {
let content = PageContent {
url: "https://example.com".to_string(),
title: "Example".to_string(),
content: "Hello world".to_string(),
timestamp: "2024-01-01T00:00:00Z".to_string(),
word_count: 2,
links: None,
};
let formatted = format_page_result(&content);
assert!(formatted.contains("Example"));
assert!(formatted.contains("https://example.com"));
assert!(formatted.contains("**Words:** 2"));
assert!(formatted.contains("Hello world"));
}
#[tokio::test]
#[ignore = "network"]
async fn test_handle_visit_page_valid_url() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler
.handle_visit_page(
Some(json!(1)),
json!({"url": "https://example.com"}),
)
.await;
assert!(response.error.is_none());
let result = response.result.unwrap();
assert_eq!(result["isError"], false);
let text = result["content"][0]["text"].as_str().unwrap();
assert!(text.contains("Example") || text.contains("example.com"));
}
#[tokio::test]
#[ignore = "network"]
async fn test_handle_visit_page_valid_url_fetch_fails() {
let handler = DaedraHandler::new(ServerConfig::default()).unwrap();
let response = handler
.handle_visit_page(
Some(json!(1)),
json!({"url": "https://127.0.0.1:1/"}),
)
.await;
assert!(response.error.is_none());
let result = response.result.unwrap();
assert_eq!(result["isError"], true);
let text = result["content"][0]["text"].as_str().unwrap();
assert!(text.contains("Failed to fetch"));
}
}