use serde_json::Value;
use super::{error_codes, McpServer, Request};
pub(super) fn req(method: &str, params: Value) -> Request {
Request {
jsonrpc: Some("2.0".into()),
id: Some(Value::from(1u64)),
method: method.into(),
params: Some(params),
}
}
#[tokio::test]
async fn rejects_wrong_jsonrpc_version() {
let server = McpServer::new("http://127.0.0.1:1");
let r = Request {
jsonrpc: Some("1.0".into()),
id: Some(Value::from(7u64)),
method: "search_health".into(),
params: None,
};
let resp = server.dispatch(r).await;
let err = resp.error.expect("expected error");
assert_eq!(err.code, error_codes::INVALID_REQUEST);
assert_eq!(resp.id, Some(Value::from(7u64)));
}
#[tokio::test]
async fn unknown_tool_returns_method_not_found() {
let server = McpServer::new("http://127.0.0.1:1");
let resp = server.dispatch(req("not_a_tool", Value::Null)).await;
let err = resp.error.expect("expected error");
assert_eq!(err.code, error_codes::METHOD_NOT_FOUND);
}
#[tokio::test]
async fn missing_params_returns_invalid_params() {
let server = McpServer::new("http://127.0.0.1:1");
let resp = server
.dispatch(req("index_file", serde_json::json!({})))
.await;
let err = resp.error.expect("expected error");
assert_eq!(err.code, error_codes::INVALID_PARAMS);
}
#[tokio::test]
async fn tools_list_returns_all_tools() {
let server = McpServer::new("http://127.0.0.1:1");
let resp = server.dispatch(req("tools/list", Value::Null)).await;
let result = resp.result.expect("expected result");
let tools = result
.get("tools")
.and_then(Value::as_array)
.expect("array");
assert!(
tools.len() >= 6,
"expected at least 6 tools, got {}",
tools.len()
);
let names: Vec<&str> = tools
.iter()
.filter_map(|t| t.get("name").and_then(Value::as_str))
.collect();
for required in [
"search",
"index_file",
"remove_file",
"list_indexes",
"create_index",
"search_health",
] {
assert!(
names.contains(&required),
"missing required tool: {required}"
);
}
}
#[tokio::test]
async fn test_initialize_response() {
let server = McpServer::new("http://127.0.0.1:1");
let r = Request {
jsonrpc: Some("2.0".into()),
id: Some(Value::from(1u64)),
method: "initialize".into(),
params: Some(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": { "name": "test", "version": "0.0.0" }
})),
};
let resp = server.dispatch(r).await;
assert!(resp.error.is_none(), "initialize must not error");
let result = resp.result.expect("expected result");
assert_eq!(result["protocolVersion"], "2024-11-05");
assert!(result["capabilities"].get("tools").is_some());
assert_eq!(result["serverInfo"]["name"], "trusty-search");
assert!(result["serverInfo"]["version"].is_string());
}
#[tokio::test]
async fn test_tools_list_response() {
let server = McpServer::new("http://127.0.0.1:1");
let resp = server.dispatch(req("tools/list", Value::Null)).await;
let result = resp.result.expect("expected result");
let tools = result
.get("tools")
.and_then(Value::as_array)
.expect("array");
let names: Vec<&str> = tools
.iter()
.filter_map(|t| t.get("name").and_then(Value::as_str))
.collect();
for required in [
"search",
"index_file",
"remove_file",
"list_indexes",
"create_index",
"search_health",
] {
assert!(
names.contains(&required),
"tools/list missing '{required}' (got {names:?})"
);
}
for t in tools {
assert!(t.get("name").is_some());
assert!(t.get("inputSchema").is_some());
}
}
#[tokio::test]
async fn test_unknown_method_returns_error() {
let server = McpServer::new("http://127.0.0.1:1");
let resp = server
.dispatch(req("definitely_not_a_method", Value::Null))
.await;
let err = resp.error.expect("expected error");
assert_eq!(err.code, error_codes::METHOD_NOT_FOUND);
}
#[tokio::test]
async fn notification_initialized_is_suppressed() {
let server = McpServer::new("http://127.0.0.1:1");
let r = Request {
jsonrpc: Some("2.0".into()),
id: None, method: "notifications/initialized".into(),
params: None,
};
let resp = server.dispatch(r).await;
assert!(resp.suppress, "notifications must be suppressed");
}
#[tokio::test]
async fn test_tools_list_complete() {
let server = McpServer::new("http://127.0.0.1:1");
let resp = server.dispatch(req("tools/list", Value::Null)).await;
let result = resp.result.expect("expected result");
let tools = result
.get("tools")
.and_then(Value::as_array)
.expect("array");
let names: Vec<&str> = tools
.iter()
.filter_map(|t| t.get("name").and_then(Value::as_str))
.collect();
for required in [
"search",
"index_file",
"remove_file",
"list_indexes",
"create_index",
"search_health",
"delete_index",
"reindex",
"index_status",
"list_chunks",
"chat",
"search_all",
] {
assert!(
names.contains(&required),
"tools/list missing '{required}' (got {names:?})"
);
}
}
#[tokio::test]
async fn search_all_missing_query_returns_invalid_params() {
let server = McpServer::new("http://127.0.0.1:1");
let resp = server
.dispatch(req("search_all", serde_json::json!({})))
.await;
let err = resp.error.expect("expected error");
assert_eq!(err.code, error_codes::INVALID_PARAMS);
}
#[tokio::test]
async fn tools_call_without_name_returns_invalid_params() {
let server = McpServer::new("http://127.0.0.1:1");
let resp = server
.dispatch(req("tools/call", serde_json::json!({})))
.await;
let err = resp.error.expect("expected error");
assert_eq!(err.code, error_codes::INVALID_PARAMS);
}
#[tokio::test]
async fn grep_missing_pattern_returns_invalid_params() {
let server = McpServer::new("http://127.0.0.1:1");
let resp = server.dispatch(req("grep", serde_json::json!({}))).await;
let err = resp.error.expect("expected error");
assert_eq!(err.code, error_codes::INVALID_PARAMS);
}
#[tokio::test]
async fn grep_max_count_alias_forwarded_as_max_results() {
use axum::routing::post;
use axum::{Json, Router};
use std::sync::Arc;
use tokio::sync::Mutex;
let captured: Arc<Mutex<Option<Value>>> = Arc::new(Mutex::new(None));
let captured_clone = Arc::clone(&captured);
async fn grep_handler(
axum::extract::State(captured): axum::extract::State<Arc<Mutex<Option<Value>>>>,
Json(body): Json<Value>,
) -> Json<Value> {
*captured.lock().await = Some(body);
Json(serde_json::json!({ "matches": [], "total": 0, "truncated": false }))
}
let app = Router::new()
.route("/indexes/idx/grep", post(grep_handler))
.with_state(captured_clone);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let _ = axum::serve(listener, app).await;
});
let server = McpServer::new(format!("http://{addr}"));
let resp = server
.dispatch(req(
"grep",
serde_json::json!({
"pattern": "fn foo",
"index_id": "idx",
"max_count": 5_u64,
}),
))
.await;
assert!(resp.error.is_none(), "unexpected error: {:?}", resp.error);
let body = captured.lock().await.clone().expect("no request captured");
assert_eq!(
body.get("max_results").and_then(Value::as_u64),
Some(5),
"max_count must be forwarded as max_results; got body: {body:?}"
);
}
#[tokio::test]
async fn grep_listed_in_tools_with_required_pattern() {
let server = McpServer::new("http://127.0.0.1:1");
let resp = server.dispatch(req("tools/list", Value::Null)).await;
let result = resp.result.expect("expected result");
let tools = result
.get("tools")
.and_then(Value::as_array)
.expect("array");
let grep = tools
.iter()
.find(|t| t.get("name").and_then(Value::as_str) == Some("grep"))
.expect("grep tool missing from tools/list");
let required = grep["inputSchema"]["required"]
.as_array()
.expect("required array");
assert!(
required.iter().any(|v| v.as_str() == Some("pattern")),
"grep schema must require 'pattern'"
);
}
pub(super) async fn spawn_mock_daemon(
status_response: Value,
search_response: Value,
) -> (
String,
std::sync::Arc<tokio::sync::Mutex<Vec<Value>>>,
std::sync::Arc<tokio::sync::Mutex<Vec<String>>>,
) {
use axum::extract::{Path, State};
use axum::routing::{get, post};
use axum::{Json, Router};
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
struct MockState {
status_response: Value,
search_response: Value,
captured_bodies: Arc<Mutex<Vec<Value>>>,
captured_paths: Arc<Mutex<Vec<String>>>,
}
let captured_bodies: Arc<Mutex<Vec<Value>>> = Arc::new(Mutex::new(Vec::new()));
let captured_paths: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let state = MockState {
status_response,
search_response,
captured_bodies: Arc::clone(&captured_bodies),
captured_paths: Arc::clone(&captured_paths),
};
async fn status_handler(Path(id): Path<String>, State(s): State<MockState>) -> Json<Value> {
let mut v = s.status_response.clone();
if v.is_object() {
v["index_id"] = Value::String(id);
}
Json(v)
}
async fn search_handler_mock(
Path(id): Path<String>,
State(s): State<MockState>,
Json(body): Json<Value>,
) -> Json<Value> {
s.captured_paths
.lock()
.await
.push(format!("/indexes/{id}/search"));
s.captured_bodies.lock().await.push(body);
Json(s.search_response.clone())
}
let app = Router::new()
.route("/indexes/{id}/status", get(status_handler))
.route("/indexes/{id}/search", post(search_handler_mock))
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let _ = axum::serve(listener, app).await;
});
let base_url = format!("http://{addr}");
(base_url, captured_bodies, captured_paths)
}
#[test]
fn resolve_index_id_prefers_explicit_then_pinned() {
let pinned = McpServer::new("http://127.0.0.1:1").with_pinned_index("my-project");
assert_eq!(
pinned.resolve_index_id(&serde_json::json!({})),
Some("my-project".to_string())
);
assert_eq!(
pinned.resolve_index_id(&serde_json::json!({ "index_id": "other" })),
Some("other".to_string())
);
assert_eq!(
pinned.resolve_index_id(&serde_json::json!({ "index_id": "" })),
Some("my-project".to_string())
);
let unpinned = McpServer::new("http://127.0.0.1:1");
assert_eq!(unpinned.resolve_index_id(&serde_json::json!({})), None);
assert_eq!(
unpinned.resolve_index_id(&serde_json::json!({ "index_id": "x" })),
Some("x".to_string())
);
}
#[test]
fn blank_pin_is_treated_as_no_pin() {
let s = McpServer::new("http://127.0.0.1:1").with_pinned_index(" ");
assert_eq!(s.resolve_index_id(&serde_json::json!({})), None);
}
#[tokio::test]
async fn search_without_pin_requires_index_id() {
let server = McpServer::new("http://127.0.0.1:1");
let resp = server
.dispatch(req("search", serde_json::json!({ "query": "fn main" })))
.await;
let err = resp.error.expect("expected error");
assert_eq!(err.code, error_codes::INVALID_PARAMS);
}
#[tokio::test]
async fn pinned_search_defaults_index_id_to_pin() {
let (base_url, _bodies, paths) = spawn_mock_daemon(
serde_json::json!({ "search_capabilities": ["vector", "kg"] }),
serde_json::json!({ "results": [], "intent": "Definition", "latency_ms": 1 }),
)
.await;
let server =
McpServer::with_client(base_url, reqwest::Client::new()).with_pinned_index("pinned-proj");
let resp = server
.dispatch(req("search", serde_json::json!({ "query": "fn main" })))
.await;
assert!(
resp.error.is_none(),
"pinned search should succeed: {resp:?}"
);
let seen = paths.lock().await;
assert_eq!(seen.len(), 1, "exactly one daemon search call: {seen:?}");
assert_eq!(seen[0], "/indexes/pinned-proj/search");
}
#[tokio::test]
async fn pinned_grep_scopes_to_pinned_index() {
use axum::extract::{Path, State};
use axum::routing::post;
use axum::{Json, Router};
use std::sync::Arc;
use tokio::sync::Mutex;
let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
async fn per_index_grep(
Path(id): Path<String>,
State(c): State<Arc<Mutex<Vec<String>>>>,
Json(_body): Json<Value>,
) -> Json<Value> {
c.lock().await.push(format!("/indexes/{id}/grep"));
Json(serde_json::json!({ "matches": [] }))
}
async fn global_grep(
State(c): State<Arc<Mutex<Vec<String>>>>,
Json(_body): Json<Value>,
) -> Json<Value> {
c.lock().await.push("/grep".to_string());
Json(serde_json::json!({ "matches": [] }))
}
let app = Router::new()
.route("/indexes/{id}/grep", post(per_index_grep))
.route("/grep", post(global_grep))
.with_state(Arc::clone(&captured));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let _ = axum::serve(listener, app).await;
});
let base_url = format!("http://{addr}");
let server =
McpServer::with_client(base_url, reqwest::Client::new()).with_pinned_index("pinned-proj");
let resp = server
.dispatch(req("grep", serde_json::json!({ "pattern": "fn foo" })))
.await;
assert!(resp.error.is_none(), "pinned grep should succeed: {resp:?}");
let seen = captured.lock().await;
assert_eq!(seen.as_slice(), &["/indexes/pinned-proj/grep".to_string()]);
}