use std::process::Stdio;
use parking_lot::Mutex;
use rmcp::model::{CallToolRequestParams, ListToolsResult};
use rmcp::service::{RoleClient, RunningService};
use rmcp::transport::TokioChildProcess;
use rmcp::ServiceExt;
use serde_json::Value;
use tokio::process::Command;
use tokio::sync::Mutex as AsyncMutex;
use tokio::time::timeout;
use crate::error::{NikaError, Result};
use crate::mcp::retry::{retry_mcp_call, McpRetryConfig};
use crate::mcp::types::{
ContentBlock, McpConfig, McpErrorCode, ResourceContent, ToolCallResult, ToolDefinition,
};
use crate::util::{CONNECT_TIMEOUT, MCP_CALL_TIMEOUT};
fn extract_error_code(error: &str) -> Option<McpErrorCode> {
let patterns = [r"code:\s*(-?\d+)", r"\((-\d+)\)", r"error\s+(-\d+)"];
for pattern in patterns {
if let Ok(re) = regex::Regex::new(pattern) {
if let Some(caps) = re.captures(error) {
if let Some(m) = caps.get(1) {
if let Ok(code) = m.as_str().parse::<i32>() {
if (-32799..=-32000).contains(&code) || (-32700..=-32600).contains(&code) {
return Some(McpErrorCode::from_code(code));
}
}
}
}
}
}
None
}
type RmcpService = RunningService<RoleClient, ()>;
pub(crate) struct RmcpClientAdapter {
name: String,
config: McpConfig,
service: AsyncMutex<Option<RmcpService>>,
server_version: Mutex<Option<String>>,
cached_tools: Mutex<Vec<ToolDefinition>>,
}
impl std::fmt::Debug for RmcpClientAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RmcpClientAdapter")
.field("name", &self.name)
.field("config", &self.config)
.field("connected", &self.is_connected_sync())
.finish()
}
}
impl RmcpClientAdapter {
pub fn new(config: McpConfig) -> Self {
Self {
name: config.name.clone(),
config,
service: AsyncMutex::new(None),
server_version: Mutex::new(None),
cached_tools: Mutex::new(Vec::new()),
}
}
#[allow(dead_code)] pub fn name(&self) -> &str {
&self.name
}
pub fn is_connected_sync(&self) -> bool {
self.service
.try_lock()
.map(|guard| guard.is_some())
.unwrap_or(false)
}
pub async fn is_connected(&self) -> bool {
self.service.lock().await.is_some()
}
pub async fn connect(&self) -> Result<()> {
let mut guard = self.service.lock().await;
if guard.is_some() {
return Ok(()); }
let mut cmd = Command::new(&self.config.command);
cmd.args(&self.config.args);
cmd.stderr(Stdio::null());
cmd.env("RUST_LOG", "off");
for (key, value) in &self.config.env {
cmd.env(key, value);
}
let transport = TokioChildProcess::new(cmd).map_err(|e| NikaError::McpStartError {
name: self.name.clone(),
reason: format!("Failed to create transport: {}", e),
})?;
let service = timeout(CONNECT_TIMEOUT, ().serve(transport))
.await
.map_err(|_| NikaError::McpTimeout {
name: self.name.clone(),
operation: "connect".to_string(),
timeout_secs: CONNECT_TIMEOUT.as_secs(),
})?
.map_err(|e| NikaError::McpStartError {
name: self.name.clone(),
reason: format!("Failed to connect: {}", e),
})?;
if let Some(info) = service.peer_info() {
*self.server_version.lock() = Some(info.protocol_version.to_string());
}
*guard = Some(service);
Ok(())
}
pub async fn disconnect(&self) -> Result<()> {
let mut guard = self.service.lock().await;
if let Some(service) = guard.take() {
let _ = service.cancel().await;
}
*self.server_version.lock() = None;
Ok(())
}
pub async fn reconnect(&self) -> Result<()> {
tracing::info!(
mcp_server = %self.name,
"Attempting MCP server reconnection"
);
self.disconnect().await?;
self.connect().await
}
pub async fn call_tool(&self, name: &str, params: Value) -> Result<ToolCallResult> {
let peer = {
let guard = self.service.lock().await;
let service = guard.as_ref().ok_or_else(|| NikaError::McpNotConnected {
name: self.name.clone(),
})?;
use std::ops::Deref;
service.deref().clone()
};
let arguments = params.as_object().cloned();
let request = CallToolRequestParams {
meta: None,
name: name.to_string().into(),
arguments,
task: None,
};
let result = timeout(MCP_CALL_TIMEOUT, peer.call_tool(request))
.await
.map_err(|_| NikaError::McpTimeout {
name: self.name.clone(),
operation: "call_tool".to_string(),
timeout_secs: MCP_CALL_TIMEOUT.as_secs(),
})?
.map_err(|e| {
let error_str = e.to_string();
NikaError::McpToolError {
tool: name.to_string(),
reason: error_str.clone(),
error_code: extract_error_code(&error_str),
}
})?;
let content: Vec<ContentBlock> = result
.content
.iter()
.filter_map(|c| {
c.as_text().map(|t| ContentBlock::text(t.text.clone()))
})
.collect();
Ok(ToolCallResult {
content,
is_error: result.is_error.unwrap_or(false),
})
}
#[allow(dead_code)] pub async fn call_tool_with_retry(
&self,
name: &str,
params: Value,
retry_config: Option<McpRetryConfig>,
) -> Result<ToolCallResult> {
let config = retry_config.unwrap_or_default();
let name_owned = name.to_string();
let server_name = self.name.clone();
tracing::debug!(
mcp_server = %server_name,
tool = %name_owned,
max_retries = config.max_retries,
"Calling MCP tool with retry"
);
retry_mcp_call(config, || {
let params = params.clone();
let name = name_owned.clone();
async move { self.call_tool(&name, params).await }
})
.await
}
pub async fn read_resource(&self, uri: &str) -> Result<ResourceContent> {
let peer = {
let guard = self.service.lock().await;
let service = guard.as_ref().ok_or_else(|| NikaError::McpNotConnected {
name: self.name.clone(),
})?;
use std::ops::Deref;
service.deref().clone()
};
let request = rmcp::model::ReadResourceRequestParams {
meta: None,
uri: uri.into(),
};
let result = timeout(MCP_CALL_TIMEOUT, peer.read_resource(request))
.await
.map_err(|_| NikaError::McpTimeout {
name: self.name.clone(),
operation: "read_resource".to_string(),
timeout_secs: MCP_CALL_TIMEOUT.as_secs(),
})?
.map_err(|e| {
let error_str = e.to_string().to_lowercase();
if error_str.contains("not found") {
NikaError::McpResourceNotFound {
uri: uri.to_string(),
}
} else {
let error_str = e.to_string();
NikaError::McpToolError {
tool: "resources/read".to_string(),
reason: error_str.clone(),
error_code: extract_error_code(&error_str),
}
}
})?;
let resource = result
.contents
.first()
.ok_or_else(|| NikaError::McpResourceNotFound {
uri: uri.to_string(),
})?;
let text = serde_json::to_string(resource).map_err(|e| NikaError::McpToolError {
tool: "resources/read".to_string(),
reason: format!("Failed to serialize resource: {}", e),
error_code: None, })?;
let content = ResourceContent::new(uri)
.with_text(&text)
.with_mime_type("application/json");
Ok(content)
}
pub async fn list_tools(&self) -> Result<Vec<ToolDefinition>> {
let peer = {
let guard = self.service.lock().await;
let service = guard.as_ref().ok_or_else(|| NikaError::McpNotConnected {
name: self.name.clone(),
})?;
use std::ops::Deref;
service.deref().clone()
};
let result: ListToolsResult =
timeout(MCP_CALL_TIMEOUT, peer.list_tools(Default::default()))
.await
.map_err(|_| NikaError::McpTimeout {
name: self.name.clone(),
operation: "list_tools".to_string(),
timeout_secs: MCP_CALL_TIMEOUT.as_secs(),
})?
.map_err(|e| {
let error_str = e.to_string();
NikaError::McpToolError {
tool: "tools/list".to_string(),
reason: error_str.clone(),
error_code: extract_error_code(&error_str),
}
})?;
let tools: Vec<ToolDefinition> = result
.tools
.into_iter()
.map(|t| {
let mut tool = ToolDefinition::new(t.name.as_ref());
if let Some(desc) = &t.description {
tool = tool.with_description(desc.as_ref());
}
let mut schema_map: serde_json::Map<String, serde_json::Value> =
(*t.input_schema).clone();
if !schema_map.contains_key("type") {
schema_map.insert("type".to_string(), serde_json::json!("object"));
}
tool = tool.with_input_schema(serde_json::Value::Object(schema_map));
tool
})
.collect();
*self.cached_tools.lock() = tools.clone();
Ok(tools)
}
pub fn get_cached_tools(&self) -> Vec<ToolDefinition> {
self.cached_tools.lock().clone()
}
}
impl Drop for RmcpClientAdapter {
fn drop(&mut self) {
tracing::debug!(
mcp_server = %self.name,
"RmcpClientAdapter dropped"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adapter_new() {
let config = McpConfig::new("test-server", "echo");
let adapter = RmcpClientAdapter::new(config);
assert_eq!(adapter.name(), "test-server");
}
#[test]
fn test_adapter_new_with_args_and_env() {
let config = McpConfig::new("novanet", "cargo")
.with_arg("run")
.with_env("NEO4J_URI", "bolt://localhost:7687");
let adapter = RmcpClientAdapter::new(config);
assert_eq!(adapter.name(), "novanet");
}
#[test]
fn test_adapter_debug_not_connected() {
let config = McpConfig::new("test-server", "echo");
let adapter = RmcpClientAdapter::new(config);
let debug_str = format!("{:?}", adapter);
assert!(debug_str.contains("RmcpClientAdapter"));
assert!(debug_str.contains("test-server"));
assert!(debug_str.contains("connected"));
}
#[tokio::test]
async fn test_adapter_not_connected_by_default() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
assert!(!adapter.is_connected().await);
}
#[test]
fn test_adapter_not_connected_sync() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
assert!(!adapter.is_connected_sync());
}
#[tokio::test]
async fn test_disconnect_when_not_connected_is_ok() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
let result = adapter.disconnect().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_disconnect_clears_server_version() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
*adapter.server_version.lock() = Some("1.0".to_string());
assert_eq!(
adapter.server_version.lock().as_ref().map(|s| s.as_str()),
Some("1.0")
);
adapter.disconnect().await.ok();
assert_eq!(adapter.server_version.lock().as_ref(), None);
}
#[tokio::test]
async fn test_call_tool_when_not_connected_returns_error() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
let result = adapter.call_tool("test_tool", serde_json::json!({})).await;
assert!(result.is_err());
match result.unwrap_err() {
NikaError::McpNotConnected { name } => assert_eq!(name, "test"),
e => panic!("Expected McpNotConnected, got: {:?}", e),
}
}
#[tokio::test]
async fn test_call_tool_with_object_params() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
let params = serde_json::json!({
"entity": "qr-code",
"locale": "fr-FR"
});
let result = adapter.call_tool("novanet_generate", params).await;
assert!(result.is_err());
match result.unwrap_err() {
NikaError::McpNotConnected { name } => assert_eq!(name, "test"),
e => panic!("Expected McpNotConnected, got: {:?}", e),
}
}
#[tokio::test]
async fn test_call_tool_with_non_object_params() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
let params = serde_json::json!("not-an-object");
let result = adapter.call_tool("test_tool", params).await;
assert!(result.is_err());
match result.unwrap_err() {
NikaError::McpNotConnected { name } => assert_eq!(name, "test"),
e => panic!("Expected McpNotConnected, got: {:?}", e),
}
}
#[tokio::test]
async fn test_read_resource_when_not_connected_returns_error() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
let result = adapter.read_resource("neo4j://entity/test").await;
assert!(result.is_err());
match result.unwrap_err() {
NikaError::McpNotConnected { name } => assert_eq!(name, "test"),
e => panic!("Expected McpNotConnected, got: {:?}", e),
}
}
#[tokio::test]
async fn test_read_resource_with_various_uris() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
let uris = vec![
"neo4j://entity/qr-code",
"neo4j://page/landing",
"neo4j://block/hero",
"file:///path/to/file",
];
for uri in uris {
let result = adapter.read_resource(uri).await;
assert!(result.is_err());
}
}
#[tokio::test]
async fn test_list_tools_when_not_connected_returns_error() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
let result = adapter.list_tools().await;
assert!(result.is_err());
match result.unwrap_err() {
NikaError::McpNotConnected { name } => assert_eq!(name, "test"),
e => panic!("Expected McpNotConnected, got: {:?}", e),
}
}
#[test]
fn test_get_cached_tools_returns_empty_initially() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
let cached = adapter.get_cached_tools();
assert!(cached.is_empty());
}
#[test]
fn test_get_cached_tools_with_populated_cache() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
let tools = vec![
ToolDefinition::new("novanet_generate"),
ToolDefinition::new("novanet_traverse"),
];
*adapter.cached_tools.lock() = tools.clone();
let cached = adapter.get_cached_tools();
assert_eq!(cached.len(), 2);
assert_eq!(cached[0].name, "novanet_generate");
assert_eq!(cached[1].name, "novanet_traverse");
}
#[test]
fn test_cached_tools_independence() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
let tool1 = ToolDefinition::new("tool1");
*adapter.cached_tools.lock() = vec![tool1];
let cached1 = adapter.get_cached_tools();
assert_eq!(cached1.len(), 1);
let tool2 = ToolDefinition::new("tool2");
*adapter.cached_tools.lock() = vec![tool2];
let cached2 = adapter.get_cached_tools();
assert_eq!(cached2.len(), 1);
assert_eq!(cached2[0].name, "tool2");
}
#[test]
fn test_adapter_drop_does_not_panic() {
let config = McpConfig::new("test", "echo");
let adapter = RmcpClientAdapter::new(config);
drop(adapter);
}
#[test]
fn test_extract_error_code_with_code_pattern() {
let error = "JSON-RPC error code: -32602 - Invalid params";
let code = extract_error_code(error);
assert_eq!(code, Some(McpErrorCode::InvalidParams));
}
#[test]
fn test_extract_error_code_with_parentheses_pattern() {
let error = "Method not found (-32601)";
let code = extract_error_code(error);
assert_eq!(code, Some(McpErrorCode::MethodNotFound));
}
#[test]
fn test_extract_error_code_with_error_word_pattern() {
let error = "error -32700 in parser";
let code = extract_error_code(error);
assert_eq!(code, Some(McpErrorCode::ParseError));
}
#[test]
fn test_extract_error_code_parse_error() {
let error = "Parse error code: -32700";
let code = extract_error_code(error);
assert_eq!(code, Some(McpErrorCode::ParseError));
}
#[test]
fn test_extract_error_code_invalid_request() {
let error = "Invalid request (-32600)";
let code = extract_error_code(error);
assert_eq!(code, Some(McpErrorCode::InvalidRequest));
}
#[test]
fn test_extract_error_code_internal_error() {
let error = "code: -32603";
let code = extract_error_code(error);
assert_eq!(code, Some(McpErrorCode::InternalError));
}
#[test]
fn test_extract_error_code_server_error() {
let error = "Server error (-32050): Internal failure";
let code = extract_error_code(error);
assert!(matches!(code, Some(McpErrorCode::ServerError(-32050))));
}
#[test]
fn test_extract_error_code_various_server_errors() {
let test_cases = vec![
("error -32000", Some(McpErrorCode::ServerError(-32000))),
("code: -32050", Some(McpErrorCode::ServerError(-32050))),
("(-32099)", Some(McpErrorCode::ServerError(-32099))),
];
for (error, expected) in test_cases {
let code = extract_error_code(error);
assert_eq!(code, expected, "Failed for error: {}", error);
}
}
#[test]
fn test_extract_error_code_no_code_present() {
let error = "Connection refused";
let code = extract_error_code(error);
assert_eq!(code, None);
}
#[test]
fn test_extract_error_code_non_jsonrpc_code() {
let error = "HTTP error code: 404";
let code = extract_error_code(error);
assert_eq!(code, None); }
#[test]
fn test_extract_error_code_negative_but_outside_range() {
let error = "error code: -100";
let code = extract_error_code(error);
assert_eq!(code, None);
}
#[test]
fn test_extract_error_code_positive_numbers_ignored() {
let error = "code: 500 or code: 400";
let code = extract_error_code(error);
assert_eq!(code, None);
}
#[test]
fn test_extract_error_code_multiple_codes_uses_first() {
let error = "code: -32602 and also code: -32601";
let code = extract_error_code(error);
assert_eq!(code, Some(McpErrorCode::InvalidParams));
}
#[test]
fn test_extract_error_code_empty_string() {
let error = "";
let code = extract_error_code(error);
assert_eq!(code, None);
}
#[test]
fn test_extract_error_code_with_text_around_code() {
let error = "Request failed: code: -32602 - invalid parameters supplied";
let code = extract_error_code(error);
assert_eq!(code, Some(McpErrorCode::InvalidParams));
}
#[test]
fn test_adapter_preserves_config_command() {
let config = McpConfig::new("myserver", "python").with_args(["script.py", "--flag"]);
let adapter = RmcpClientAdapter::new(config);
assert_eq!(adapter.name(), "myserver");
}
#[test]
fn test_adapter_with_complex_config() {
let config = McpConfig::new("complex-server", "node")
.with_arg("--require")
.with_arg("dotenv/config")
.with_arg("index.js")
.with_env("LOG_LEVEL", "debug")
.with_env("PORT", "3000");
let adapter = RmcpClientAdapter::new(config);
assert_eq!(adapter.name(), "complex-server");
}
#[test]
fn test_adapter_name_accessor() {
let config = McpConfig::new("my-test-server", "echo");
let adapter = RmcpClientAdapter::new(config);
assert_eq!(adapter.name(), "my-test-server");
}
}