use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use dashmap::DashMap;
use rustc_hash::FxHasher;
use serde_json::Value;
use std::sync::Arc;
use crate::error::{NikaError, Result};
use crate::event::{EventKind, EventLog};
use crate::mcp::retry::{retry_mcp_call, McpRetryConfig};
use crate::mcp::rmcp_adapter::RmcpClientAdapter;
use crate::mcp::types::{ContentBlock, McpConfig, ResourceContent, ToolCallResult, ToolDefinition};
use crate::mcp::validation::{ErrorEnhancer, McpValidator, ValidationConfig, ValidationErrorKind};
#[derive(Debug, Clone)]
pub struct McpPingResult {
pub server: String,
pub latency: Duration,
pub tool_count: usize,
pub was_connected: bool,
}
#[derive(Debug, Clone)]
pub enum McpPingError {
StartFailed { server: String, details: String },
Timeout { server: String, timeout: Duration },
ConnectionRefused { server: String },
ServerError { server: String, details: String },
}
impl std::fmt::Display for McpPingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
McpPingError::StartFailed { server, details } => {
write!(f, "MCP server '{}' failed to start: {}", server, details)
}
McpPingError::Timeout { server, timeout } => {
write!(f, "MCP server '{}' timed out after {:?}", server, timeout)
}
McpPingError::ConnectionRefused { server } => {
write!(f, "MCP server '{}' connection refused", server)
}
McpPingError::ServerError { server, details } => {
write!(f, "MCP server '{}' error: {}", server, details)
}
}
}
}
impl McpPingError {
pub fn suggestion(&self) -> &'static str {
match self {
McpPingError::StartFailed { .. } => {
"Check the MCP server command is correct and the executable exists"
}
McpPingError::Timeout { .. } => {
"The MCP server may be slow to start. Try increasing the timeout"
}
McpPingError::ConnectionRefused { .. } => {
"Ensure the MCP server is running and accessible"
}
McpPingError::ServerError { .. } => "Check the MCP server logs for more details",
}
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub ttl: Duration,
pub max_entries: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
ttl: Duration::from_secs(300), max_entries: 1000,
}
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
result: Arc<ToolCallResult>,
created_at: Instant,
}
impl CacheEntry {
fn new(result: Arc<ToolCallResult>) -> Self {
Self {
result,
created_at: Instant::now(),
}
}
fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
}
#[derive(Debug)]
struct ResponseCache {
config: CacheConfig,
entries: DashMap<String, CacheEntry, rustc_hash::FxBuildHasher>,
hits: AtomicU64,
misses: AtomicU64,
}
impl ResponseCache {
fn new(config: CacheConfig) -> Self {
Self {
config,
entries: DashMap::default(),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
fn cache_key(tool: &str, params: &Value) -> String {
let mut hasher = FxHasher::default();
let canonical = Self::canonicalize_value(params);
let params_str = match serde_json::to_string(&canonical) {
Ok(s) => s,
Err(e) => {
tracing::warn!(
tool = tool,
error = %e,
"JSON serialization failed for cache key, using Debug format"
);
format!("{:?}", params)
}
};
params_str.hash(&mut hasher);
format!("{}:{:016x}", tool, hasher.finish())
}
const MAX_CANONICALIZE_DEPTH: usize = 128;
fn canonicalize_value(value: &Value) -> Value {
Self::canonicalize_value_inner(value, 0)
}
fn canonicalize_value_inner(value: &Value, depth: usize) -> Value {
if depth >= Self::MAX_CANONICALIZE_DEPTH {
return value.clone();
}
match value {
Value::Object(map) => {
let mut sorted: serde_json::Map<String, Value> = serde_json::Map::new();
let mut keys: Vec<&String> = map.keys().collect();
keys.sort();
for key in keys {
sorted.insert(
key.clone(),
Self::canonicalize_value_inner(&map[key], depth + 1),
);
}
Value::Object(sorted)
}
Value::Array(arr) => Value::Array(
arr.iter()
.map(|v| Self::canonicalize_value_inner(v, depth + 1))
.collect(),
),
other => other.clone(),
}
}
fn get(&self, tool: &str, params: &Value) -> Option<Arc<ToolCallResult>> {
let key = Self::cache_key(tool, params);
if let Some(entry) = self.entries.get(&key) {
if entry.is_expired(self.config.ttl) {
drop(entry);
self.entries.remove(&key);
self.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
self.hits.fetch_add(1, Ordering::Relaxed);
return Some(Arc::clone(&entry.result));
}
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
fn put(&self, tool: &str, params: &Value, result: ToolCallResult) {
if result.is_error {
return;
}
let key = Self::cache_key(tool, params);
if self.entries.len() >= self.config.max_entries {
self.evict_oldest();
}
self.entries.insert(key, CacheEntry::new(Arc::new(result)));
}
fn evict_oldest(&self) {
let to_remove = (self.config.max_entries / 10).max(1);
let mut entries: Vec<(String, Instant)> = self
.entries
.iter()
.map(|e| (e.key().clone(), e.created_at))
.collect();
if entries.len() <= to_remove {
for (key, _) in &entries {
self.entries.remove(key);
}
return;
}
entries.select_nth_unstable_by_key(to_remove - 1, |(_, created)| *created);
for (key, _) in entries.iter().take(to_remove) {
self.entries.remove(key);
}
}
fn clear(&self) {
self.entries.clear();
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
}
fn stats(&self) -> ResponseCacheStats {
ResponseCacheStats {
entries: self.entries.len(),
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ResponseCacheStats {
pub entries: usize,
pub hits: u64,
pub misses: u64,
}
impl ResponseCacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
pub struct McpClient {
name: String,
connected: AtomicBool,
is_mock: bool,
adapter: Option<RmcpClientAdapter>,
validator: Option<McpValidator>,
cache: Option<ResponseCache>,
last_cache_hit: AtomicBool,
}
impl std::fmt::Debug for McpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpClient")
.field("name", &self.name)
.field("connected", &self.connected)
.field("is_mock", &self.is_mock)
.field("has_adapter", &self.adapter.is_some())
.field("has_validator", &self.validator.is_some())
.field("has_cache", &self.cache.is_some())
.field("last_cache_hit", &self.last_cache_hit)
.finish()
}
}
impl McpClient {
pub fn new(config: McpConfig) -> Result<Self> {
if config.name.is_empty() {
return Err(NikaError::ValidationError {
reason: "MCP server name cannot be empty".to_string(),
});
}
if config.command.is_empty() {
return Err(NikaError::ValidationError {
reason: "MCP server command cannot be empty".to_string(),
});
}
let name = config.name.clone();
let adapter = RmcpClientAdapter::new(config);
Ok(Self {
name,
connected: AtomicBool::new(false),
is_mock: false,
adapter: Some(adapter),
validator: None,
cache: None,
last_cache_hit: AtomicBool::new(false),
})
}
pub fn with_validation(mut self, config: ValidationConfig) -> Self {
self.validator = Some(McpValidator::new(config));
self
}
pub fn with_cache(mut self, config: CacheConfig) -> Self {
self.cache = Some(ResponseCache::new(config));
self
}
pub fn cache_stats(&self) -> Option<ResponseCacheStats> {
self.cache.as_ref().map(|c| c.stats())
}
pub fn was_last_call_cached(&self) -> bool {
self.last_cache_hit.load(Ordering::SeqCst)
}
pub fn mock(name: &str) -> Self {
Self {
name: name.to_string(),
connected: AtomicBool::new(true), is_mock: true,
adapter: None,
validator: None,
cache: None,
last_cache_hit: AtomicBool::new(false),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn is_connected(&self) -> bool {
if self.is_mock {
return self.connected.load(Ordering::SeqCst);
}
self.adapter
.as_ref()
.map(|a| a.is_connected_sync())
.unwrap_or(false)
}
pub async fn is_connected_async(&self) -> bool {
if self.is_mock {
return self.connected.load(Ordering::SeqCst);
}
if let Some(adapter) = &self.adapter {
adapter.is_connected().await
} else {
false
}
}
pub async fn ping(&self) -> std::result::Result<McpPingResult, McpPingError> {
let start = Instant::now();
let was_connected = self.is_connected_async().await;
if self.is_mock {
return Ok(McpPingResult {
server: self.name.clone(),
latency: start.elapsed(),
tool_count: self.mock_list_tools().len(),
was_connected: true,
});
}
if !was_connected {
if let Err(e) = self.connect().await {
let error_msg = e.to_string().to_lowercase();
if error_msg.contains("refused") || error_msg.contains("connection") {
return Err(McpPingError::ConnectionRefused {
server: self.name.clone(),
});
}
return Err(McpPingError::StartFailed {
server: self.name.clone(),
details: e.to_string(),
});
}
}
match tokio::time::timeout(Duration::from_secs(10), self.list_tools()).await {
Ok(Ok(tools)) => Ok(McpPingResult {
server: self.name.clone(),
latency: start.elapsed(),
tool_count: tools.len(),
was_connected,
}),
Ok(Err(e)) => Err(McpPingError::ServerError {
server: self.name.clone(),
details: e.to_string(),
}),
Err(_) => Err(McpPingError::Timeout {
server: self.name.clone(),
timeout: Duration::from_secs(10),
}),
}
}
pub fn is_configured(&self) -> bool {
self.is_mock || self.adapter.is_some()
}
pub async fn connect(&self) -> Result<()> {
if self.is_mock {
self.connected.store(true, Ordering::SeqCst);
if let Some(ref validator) = self.validator {
let tools = self.mock_list_tools();
validator
.cache()
.populate(&self.name, &tools)
.map_err(|e| NikaError::McpSchemaError {
tool: "*".to_string(),
reason: format!("Failed to cache mock tool schemas: {}", e),
})?;
}
return Ok(());
}
let adapter = self
.adapter
.as_ref()
.ok_or_else(|| NikaError::McpNotConnected {
name: self.name.clone(),
})?;
adapter.connect().await?;
self.connected.store(true, Ordering::SeqCst);
if let Some(ref validator) = self.validator {
let tools = adapter.list_tools().await?;
validator
.cache()
.populate(&self.name, &tools)
.map_err(|e| NikaError::McpSchemaError {
tool: "*".to_string(),
reason: format!("Failed to cache tool schemas: {}", e),
})?;
tracing::debug!(
mcp_server = %self.name,
tools_cached = tools.len(),
"Cached tool schemas for validation"
);
}
Ok(())
}
pub async fn disconnect(&self) -> Result<()> {
if self.is_mock {
self.connected.store(false, Ordering::SeqCst);
return Ok(());
}
if let Some(adapter) = &self.adapter {
adapter.disconnect().await?;
}
if let Some(ref cache) = self.cache {
cache.clear();
}
if let Some(ref validator) = self.validator {
validator.cache().clear();
}
self.connected.store(false, Ordering::SeqCst);
Ok(())
}
pub async fn reconnect(&self) -> Result<()> {
if self.is_mock {
self.connected.store(true, Ordering::SeqCst);
return Ok(());
}
let adapter = self
.adapter
.as_ref()
.ok_or_else(|| NikaError::McpNotConnected {
name: self.name.clone(),
})?;
adapter.reconnect().await?;
self.connected.store(true, Ordering::SeqCst);
Ok(())
}
pub fn is_connection_error(error: &NikaError) -> bool {
let error_str = error.to_string().to_lowercase();
error_str.contains("broken pipe")
|| error_str.contains("connection reset")
|| error_str.contains("connection refused")
|| error_str.contains("eof")
|| error_str.contains("stdin not available")
|| error_str.contains("stdout not available")
}
fn enhance_error(&self, tool_name: &str, error: NikaError) -> NikaError {
if let Some(ref validator) = self.validator {
if validator.config().enhance_errors {
let enhancer = ErrorEnhancer::new(validator.cache());
return enhancer.enhance(&self.name, tool_name, error);
}
}
error
}
pub async fn call_tool(&self, name: &str, params: Value) -> Result<ToolCallResult> {
if let Some(ref validator) = self.validator {
if validator.config().pre_validate {
let result = validator.validate(&self.name, name, ¶ms);
if !result.is_valid {
let missing: Vec<String> = result
.errors
.iter()
.filter_map(|e| {
if let ValidationErrorKind::MissingRequired { field } = &e.kind {
Some(field.clone())
} else {
None
}
})
.collect();
let suggestions: Vec<String> = result
.errors
.iter()
.filter_map(|e| {
if let ValidationErrorKind::UnknownField { suggestions, .. } = &e.kind {
Some(suggestions.clone())
} else {
None
}
})
.flatten()
.collect();
let details = result
.errors
.iter()
.map(|e| e.message.clone())
.collect::<Vec<_>>()
.join("; ");
return Err(NikaError::McpValidationFailed {
tool: name.to_string(),
details,
missing,
suggestions,
});
}
}
}
if let Some(ref cache) = self.cache {
if let Some(cached_result) = cache.get(name, ¶ms) {
self.last_cache_hit.store(true, Ordering::SeqCst);
tracing::debug!(
mcp_server = %self.name,
tool = %name,
"Cache hit for MCP tool call"
);
return Ok((*cached_result).clone());
}
}
self.last_cache_hit.store(false, Ordering::SeqCst);
if self.is_mock {
if !self.connected.load(Ordering::SeqCst) {
return Err(NikaError::McpNotConnected {
name: self.name.clone(),
});
}
let result = self.mock_tool_call(name, ¶ms);
if let Some(ref cache) = self.cache {
cache.put(name, ¶ms, result.clone());
}
return Ok(result);
}
let adapter = self
.adapter
.as_ref()
.ok_or_else(|| NikaError::McpNotConnected {
name: self.name.clone(),
})?;
let result = retry_mcp_call(McpRetryConfig::default(), || {
let params = params.clone();
async move {
match adapter.call_tool(name, params).await {
Ok(result) => Ok(result),
Err(e) => {
let enhanced = self.enhance_error(name, e);
if Self::is_connection_error(&enhanced) {
tracing::warn!(
mcp_server = %self.name,
tool = %name,
error = %enhanced,
"Connection error, attempting reconnect"
);
if let Err(reconnect_err) = adapter.reconnect().await {
tracing::error!(
mcp_server = %self.name,
error = %reconnect_err,
"Failed to reconnect"
);
}
}
Err(enhanced)
}
}
}
})
.await?;
if let Some(ref cache) = self.cache {
cache.put(name, ¶ms, result.clone());
tracing::debug!(
mcp_server = %self.name,
tool = %name,
"Cached MCP tool response"
);
}
Ok(result)
}
pub async fn call_tool_with_retry_events(
&self,
name: &str,
params: Value,
task_id: &Arc<str>,
event_log: &EventLog,
) -> Result<ToolCallResult> {
if let Some(ref validator) = self.validator {
if validator.config().pre_validate {
let result = validator.validate(&self.name, name, ¶ms);
if !result.is_valid {
let missing: Vec<String> = result
.errors
.iter()
.filter_map(|e| {
if let ValidationErrorKind::MissingRequired { field } = &e.kind {
Some(field.clone())
} else {
None
}
})
.collect();
let suggestions: Vec<String> = result
.errors
.iter()
.filter_map(|e| {
if let ValidationErrorKind::UnknownField { suggestions, .. } = &e.kind {
Some(suggestions.clone())
} else {
None
}
})
.flatten()
.collect();
let details = result
.errors
.iter()
.map(|e| e.message.clone())
.collect::<Vec<_>>()
.join("; ");
return Err(NikaError::McpValidationFailed {
tool: name.to_string(),
details,
missing,
suggestions,
});
}
}
}
if let Some(ref cache) = self.cache {
if let Some(cached_result) = cache.get(name, ¶ms) {
self.last_cache_hit.store(true, Ordering::SeqCst);
tracing::debug!(
mcp_server = %self.name,
tool = %name,
"Cache hit for MCP tool call"
);
return Ok((*cached_result).clone());
}
}
self.last_cache_hit.store(false, Ordering::SeqCst);
if self.is_mock {
if !self.connected.load(Ordering::SeqCst) {
return Err(NikaError::McpNotConnected {
name: self.name.clone(),
});
}
let result = self.mock_tool_call(name, ¶ms);
if let Some(ref cache) = self.cache {
cache.put(name, ¶ms, result.clone());
}
return Ok(result);
}
let adapter = self
.adapter
.as_ref()
.ok_or_else(|| NikaError::McpNotConnected {
name: self.name.clone(),
})?;
let config = McpRetryConfig::default();
let max_attempts = config.max_retries + 1; let attempt_counter = std::sync::atomic::AtomicU32::new(0);
let result = retry_mcp_call(config, || {
let params = params.clone();
async {
let attempt = attempt_counter.fetch_add(1, Ordering::SeqCst);
match adapter.call_tool(name, params).await {
Ok(result) => Ok(result),
Err(e) => {
let enhanced = self.enhance_error(name, e);
if Self::is_connection_error(&enhanced) {
event_log.emit(EventKind::McpRetry {
task_id: Arc::clone(task_id),
server_name: self.name.clone(),
operation: name.to_string(),
attempt: attempt + 1,
max_attempts: max_attempts as u32,
error: enhanced.to_string(),
});
tracing::warn!(
mcp_server = %self.name,
tool = %name,
attempt = attempt + 1,
error = %enhanced,
"Connection error, attempting reconnect (McpRetry event emitted)"
);
if let Err(reconnect_err) = adapter.reconnect().await {
tracing::error!(
mcp_server = %self.name,
error = %reconnect_err,
"Failed to reconnect"
);
}
}
Err(enhanced)
}
}
}
})
.await?;
if let Some(ref cache) = self.cache {
cache.put(name, ¶ms, result.clone());
tracing::debug!(
mcp_server = %self.name,
tool = %name,
"Cached MCP tool response"
);
}
Ok(result)
}
pub async fn read_resource(&self, uri: &str) -> Result<ResourceContent> {
if self.is_mock {
if !self.connected.load(Ordering::SeqCst) {
return Err(NikaError::McpNotConnected {
name: self.name.clone(),
});
}
return Ok(self.mock_read_resource(uri));
}
let adapter = self
.adapter
.as_ref()
.ok_or_else(|| NikaError::McpNotConnected {
name: self.name.clone(),
})?;
retry_mcp_call(McpRetryConfig::default(), || {
async move {
match adapter.read_resource(uri).await {
Ok(result) => Ok(result),
Err(e) => {
if Self::is_connection_error(&e) {
tracing::warn!(
mcp_server = %self.name,
uri = %uri,
error = %e,
"Connection error, attempting reconnect"
);
if let Err(reconnect_err) = adapter.reconnect().await {
tracing::error!(
mcp_server = %self.name,
error = %reconnect_err,
"Failed to reconnect"
);
}
}
Err(e)
}
}
}
})
.await
}
pub async fn list_tools(&self) -> Result<Vec<ToolDefinition>> {
if self.is_mock {
if !self.connected.load(Ordering::SeqCst) {
return Err(NikaError::McpNotConnected {
name: self.name.clone(),
});
}
return Ok(self.mock_list_tools());
}
let adapter = self
.adapter
.as_ref()
.ok_or_else(|| NikaError::McpNotConnected {
name: self.name.clone(),
})?;
adapter.list_tools().await
}
fn mock_tool_call(&self, name: &str, params: &Value) -> ToolCallResult {
match name {
"novanet_describe" => {
let response = serde_json::json!({
"nodes": 61,
"arcs": 182,
"labels": ["Entity", "EntityNative", "Page", "Block"],
"relationships": ["HAS_NATIVE", "CONTAINS", "FLOWS_TO"]
});
ToolCallResult::success(vec![ContentBlock::text(response.to_string())])
}
"novanet_context" => {
let entity = params
.get("focus_key")
.or_else(|| params.get("entity"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let locale = params
.get("locale")
.and_then(|v| v.as_str())
.unwrap_or("en-US");
let response = serde_json::json!({
"entity": entity,
"locale": locale,
"context": {
"title": format!("{} - Generated Title", entity),
"description": format!("Auto-generated content for {} in {}", entity, locale),
"keywords": ["generated", "mock", entity]
}
});
ToolCallResult::success(vec![ContentBlock::text(response.to_string())])
}
_ => {
let response = serde_json::json!({
"tool": name,
"status": "success",
"message": "Mock tool call completed"
});
ToolCallResult::success(vec![ContentBlock::text(response.to_string())])
}
}
}
fn mock_read_resource(&self, uri: &str) -> ResourceContent {
let text = if uri.starts_with("neo4j://entity/") {
let entity = uri.strip_prefix("neo4j://entity/").unwrap_or("unknown");
serde_json::json!({
"id": entity,
"type": "Entity",
"properties": {
"name": entity,
"created": "2024-01-01T00:00:00Z"
}
})
.to_string()
} else if uri.starts_with("file://") {
"Mock file content".to_string()
} else {
serde_json::json!({
"uri": uri,
"content": "Mock resource content"
})
.to_string()
};
ResourceContent::new(uri)
.with_mime_type("application/json")
.with_text(text)
}
pub fn get_tool_definitions(&self) -> Vec<ToolDefinition> {
if self.is_mock {
self.mock_list_tools()
} else if let Some(ref adapter) = self.adapter {
adapter.get_cached_tools()
} else {
Vec::new()
}
}
pub fn is_tool_cache_fresh(&self, ttl: std::time::Duration) -> bool {
if self.is_mock {
true
} else if let Some(ref adapter) = self.adapter {
adapter.is_tool_cache_fresh(ttl)
} else {
false
}
}
pub fn invalidate_tool_cache(&self) {
if !self.is_mock {
if let Some(ref adapter) = self.adapter {
adapter.invalidate_tool_cache();
}
}
}
fn mock_list_tools(&self) -> Vec<ToolDefinition> {
vec![
ToolDefinition::new("novanet_describe")
.with_description("Bootstrap understanding of the graph"),
ToolDefinition::new("novanet_search")
.with_description("Find nodes via 5 modes: fulltext, property, hybrid, walk, triggers"),
ToolDefinition::new("novanet_context")
.with_description("Unified context assembly for LLM content generation")
.with_input_schema(serde_json::json!({
"type": "object",
"properties": {
"mode": {"type": "string", "description": "Context mode (page, block, knowledge, assemble)"},
"focus_key": {"type": "string", "description": "Focus node key"},
"locale": {"type": "string", "description": "Target locale (e.g., fr-FR)"}
},
"required": ["mode", "locale"]
})),
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_multiple_sequential_calls() {
let client = McpClient::mock("test");
for i in 0..10 {
let result = client
.call_tool("test_tool", serde_json::json!({"iteration": i}))
.await;
assert!(
result.is_ok(),
"Call {} should succeed: {:?}",
i,
result.err()
);
}
}
#[tokio::test]
async fn test_concurrent_calls() {
let client = std::sync::Arc::new(McpClient::mock("test"));
let handles: Vec<_> = (0..20)
.map(|i| {
let client = std::sync::Arc::clone(&client);
tokio::spawn(async move {
client
.call_tool("test_tool", serde_json::json!({"iteration": i}))
.await
})
})
.collect();
for (i, handle) in handles.into_iter().enumerate() {
let result = handle.await.expect("Task should not panic");
assert!(result.is_ok(), "Concurrent call {} should succeed", i);
}
}
#[test]
fn test_client_name_accessor() {
let config = McpConfig::new("test-server", "echo");
let client = McpClient::new(config).unwrap();
assert_eq!(client.name(), "test-server");
}
#[test]
fn test_mock_client_is_pre_connected() {
let client = McpClient::mock("test");
assert!(client.is_connected());
assert!(client.is_mock);
}
#[test]
fn test_real_client_starts_disconnected() {
let config = McpConfig::new("test", "echo");
let client = McpClient::new(config).unwrap();
assert!(!client.is_connected());
assert!(!client.is_mock);
}
#[tokio::test]
async fn test_mock_tool_call_returns_success() {
let client = McpClient::mock("test");
let result = client
.call_tool("unknown_tool", serde_json::json!({}))
.await;
assert!(result.is_ok());
assert!(!result.unwrap().is_error);
}
#[tokio::test]
async fn test_mock_read_resource_entity() {
let client = McpClient::mock("test");
let result = client.read_resource("neo4j://entity/qr-code").await;
assert!(result.is_ok());
let resource = result.unwrap();
assert_eq!(resource.uri, "neo4j://entity/qr-code");
assert!(resource.text.is_some());
}
#[tokio::test]
async fn test_mock_read_resource_file() {
let client = McpClient::mock("test");
let result = client.read_resource("file:///tmp/test.txt").await;
assert!(result.is_ok());
let resource = result.unwrap();
assert_eq!(resource.uri, "file:///tmp/test.txt");
}
#[test]
fn test_mock_client_drop_is_noop() {
let client = McpClient::mock("test");
assert!(client.is_mock);
drop(client);
}
#[test]
fn test_real_client_drop_without_process() {
let config = McpConfig::new("test", "echo");
let client = McpClient::new(config).unwrap();
assert!(!client.is_mock);
drop(client);
}
#[test]
fn test_with_validation_enables_validator() {
let config = McpConfig::new("test", "echo");
let client = McpClient::new(config)
.unwrap()
.with_validation(ValidationConfig::default());
assert!(client.validator.is_some());
}
#[tokio::test]
async fn test_mock_connect_populates_schema_cache_when_validation_enabled() {
let client = McpClient::mock("novanet").with_validation(ValidationConfig::default());
client.connect().await.unwrap();
let validator = client.validator.as_ref().unwrap();
let stats = validator.cache().stats();
assert!(stats.tool_count > 0, "Should have cached tools");
}
#[tokio::test]
async fn test_call_tool_validates_missing_required_field() {
let client = McpClient::mock("novanet").with_validation(ValidationConfig::default());
client.connect().await.unwrap();
let result = client
.call_tool(
"novanet_context",
serde_json::json!({
"focus_key": "qr-code"
}),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, NikaError::McpValidationFailed { .. }));
if let NikaError::McpValidationFailed {
missing, details, ..
} = err
{
assert!(missing.contains(&"mode".to_string()));
assert!(details.contains("mode"));
}
}
#[tokio::test]
async fn test_call_tool_passes_validation_with_valid_params() {
let client = McpClient::mock("novanet").with_validation(ValidationConfig::default());
client.connect().await.unwrap();
let result = client
.call_tool(
"novanet_context",
serde_json::json!({
"mode": "page",
"focus_key": "qr-code",
"locale": "fr-FR"
}),
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_call_tool_skips_validation_when_disabled() {
let config = ValidationConfig {
pre_validate: false, ..Default::default()
};
let client = McpClient::mock("novanet").with_validation(config);
client.connect().await.unwrap();
let result = client
.call_tool(
"novanet_context",
serde_json::json!({
"focus_key": "qr-code"
}),
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_call_tool_without_validation_works() {
let client = McpClient::mock("novanet");
let result = client
.call_tool(
"novanet_context",
serde_json::json!({
}),
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_validation_for_unknown_tool_passes() {
let client = McpClient::mock("novanet").with_validation(ValidationConfig::default());
client.connect().await.unwrap();
let result = client
.call_tool(
"unknown_tool",
serde_json::json!({
"anything": "goes"
}),
)
.await;
assert!(result.is_ok());
}
#[test]
fn test_with_cache_enables_caching() {
let config = McpConfig::new("test", "echo");
let client = McpClient::new(config)
.unwrap()
.with_cache(CacheConfig::default());
assert!(client.cache.is_some());
}
#[test]
fn test_cache_stats_returns_none_when_disabled() {
let client = McpClient::mock("test");
assert!(client.cache_stats().is_none());
}
#[test]
fn test_cache_stats_returns_some_when_enabled() {
let client = McpClient::mock("test").with_cache(CacheConfig::default());
let stats = client.cache_stats();
assert!(stats.is_some());
let stats = stats.unwrap();
assert_eq!(stats.entries, 0);
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
}
#[tokio::test]
async fn test_cache_hit_returns_cached_result() {
let client = McpClient::mock("test").with_cache(CacheConfig::default());
let params = serde_json::json!({"entity": "qr-code"});
let result1 = client.call_tool("novanet_context", params.clone()).await;
assert!(result1.is_ok());
let stats = client.cache_stats().unwrap();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 0);
assert_eq!(stats.entries, 1);
let result2 = client.call_tool("novanet_context", params.clone()).await;
assert!(result2.is_ok());
let stats = client.cache_stats().unwrap();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 1);
let r1 = result1.unwrap();
let r2 = result2.unwrap();
assert_eq!(r1.content.len(), r2.content.len());
}
#[tokio::test]
async fn test_cache_different_params_miss() {
let client = McpClient::mock("test").with_cache(CacheConfig::default());
let params_a = serde_json::json!({"focus_key": "qr-code"});
client.call_tool("novanet_context", params_a).await.unwrap();
let params_b = serde_json::json!({"focus_key": "barcode"});
client.call_tool("novanet_context", params_b).await.unwrap();
let stats = client.cache_stats().unwrap();
assert_eq!(stats.misses, 2);
assert_eq!(stats.hits, 0);
assert_eq!(stats.entries, 2);
}
#[tokio::test]
async fn test_cache_different_tools_miss() {
let client = McpClient::mock("test").with_cache(CacheConfig::default());
let params = serde_json::json!({});
client
.call_tool("novanet_describe", params.clone())
.await
.unwrap();
client
.call_tool("novanet_search", params.clone())
.await
.unwrap();
let stats = client.cache_stats().unwrap();
assert_eq!(stats.misses, 2);
assert_eq!(stats.hits, 0);
}
#[tokio::test]
async fn test_cache_ttl_expiration() {
use std::time::Duration;
let client = McpClient::mock("test").with_cache(CacheConfig {
ttl: Duration::from_millis(50),
max_entries: 100,
});
let params = serde_json::json!({"test": true});
client.call_tool("test_tool", params.clone()).await.unwrap();
assert_eq!(client.cache_stats().unwrap().entries, 1);
tokio::time::sleep(Duration::from_millis(60)).await;
client.call_tool("test_tool", params.clone()).await.unwrap();
let stats = client.cache_stats().unwrap();
assert_eq!(stats.misses, 2); assert_eq!(stats.hits, 0);
}
#[test]
fn test_cache_hit_rate_calculation() {
let stats = super::ResponseCacheStats {
entries: 10,
hits: 80,
misses: 20,
};
assert!((stats.hit_rate() - 0.8).abs() < 0.001);
}
#[test]
fn test_cache_hit_rate_zero_total() {
let stats = super::ResponseCacheStats {
entries: 0,
hits: 0,
misses: 0,
};
assert_eq!(stats.hit_rate(), 0.0);
}
#[test]
fn test_cache_key_deterministic() {
let params = serde_json::json!({"entity": "qr-code", "locale": "fr-FR"});
let key1 = super::ResponseCache::cache_key("tool", ¶ms);
let key2 = super::ResponseCache::cache_key("tool", ¶ms);
assert_eq!(key1, key2);
}
#[test]
fn test_cache_key_different_for_different_params() {
let params1 = serde_json::json!({"entity": "qr-code"});
let params2 = serde_json::json!({"entity": "barcode"});
let key1 = super::ResponseCache::cache_key("tool", ¶ms1);
let key2 = super::ResponseCache::cache_key("tool", ¶ms2);
assert_ne!(key1, key2);
}
#[test]
fn test_cache_key_different_for_different_tools() {
let params = serde_json::json!({"test": true});
let key1 = super::ResponseCache::cache_key("tool_a", ¶ms);
let key2 = super::ResponseCache::cache_key("tool_b", ¶ms);
assert_ne!(key1, key2);
}
#[tokio::test]
async fn test_ping_mock_client_succeeds() {
let client = McpClient::mock("test");
let result = client.ping().await;
assert!(result.is_ok());
let ping = result.unwrap();
assert_eq!(ping.server, "test");
assert!(ping.was_connected);
assert!(ping.tool_count > 0);
assert!(ping.latency.as_millis() < 100);
}
#[test]
fn test_mcp_ping_error_types() {
let start_failed = super::McpPingError::StartFailed {
server: "novanet".to_string(),
details: "command not found".to_string(),
};
assert!(start_failed.to_string().contains("failed to start"));
assert!(!start_failed.suggestion().is_empty());
let timeout = super::McpPingError::Timeout {
server: "slow-server".to_string(),
timeout: std::time::Duration::from_secs(10),
};
assert!(timeout.to_string().contains("timed out"));
let refused = super::McpPingError::ConnectionRefused {
server: "offline".to_string(),
};
assert!(refused.to_string().contains("refused"));
let server_err = super::McpPingError::ServerError {
server: "broken".to_string(),
details: "internal error".to_string(),
};
assert!(server_err.to_string().contains("error"));
}
#[tokio::test]
async fn test_ping_result_has_valid_fields() {
let client = McpClient::mock("novanet");
let result = client.ping().await.unwrap();
assert_eq!(result.server, "novanet");
assert!(result.tool_count >= 3); assert!(result.was_connected); }
#[test]
fn test_is_configured_returns_true_for_mock() {
let client = McpClient::mock("test");
assert!(client.is_configured());
}
#[test]
fn test_is_configured_returns_true_for_real_client() {
let config = McpConfig::new("test", "echo");
let client = McpClient::new(config).unwrap();
assert!(client.is_configured());
}
#[tokio::test]
async fn test_call_tool_with_retry_events_mock_success() {
use crate::event::EventLog;
let client = McpClient::mock("novanet");
let event_log = EventLog::new();
let task_id: Arc<str> = Arc::from("test_retry_events");
let result = client
.call_tool_with_retry_events(
"novanet_context",
serde_json::json!({"focus_key": "qr-code"}),
&task_id,
&event_log,
)
.await;
assert!(
result.is_ok(),
"Mock call should succeed: {:?}",
result.err()
);
let events = event_log.filter_task("test_retry_events");
let retry_events: Vec<_> = events
.iter()
.filter(|e| matches!(e.kind, EventKind::McpRetry { .. }))
.collect();
assert!(
retry_events.is_empty(),
"No retry events for successful calls"
);
}
#[tokio::test]
async fn test_call_tool_with_retry_events_uses_cache() {
use crate::event::EventLog;
use std::time::Duration;
let client = McpClient::mock("novanet").with_cache(CacheConfig {
ttl: Duration::from_secs(60),
max_entries: 100,
});
let event_log = EventLog::new();
let task_id: Arc<str> = Arc::from("test_cache_hit");
let params = serde_json::json!({"focus_key": "qr-code"});
let _result1 = client
.call_tool_with_retry_events("novanet_context", params.clone(), &task_id, &event_log)
.await
.unwrap();
assert!(!client.was_last_call_cached());
let _result2 = client
.call_tool_with_retry_events("novanet_context", params.clone(), &task_id, &event_log)
.await
.unwrap();
assert!(client.was_last_call_cached());
}
#[tokio::test]
async fn test_call_tool_with_retry_events_not_connected_fails() {
use crate::event::EventLog;
let config = McpConfig::new("test", "nonexistent_command");
let client = McpClient::new(config).unwrap();
let event_log = EventLog::new();
let task_id: Arc<str> = Arc::from("test_not_connected");
let result = client
.call_tool_with_retry_events("some_tool", serde_json::json!({}), &task_id, &event_log)
.await;
assert!(result.is_err());
match result.unwrap_err() {
NikaError::McpNotConnected { .. } => {} err => panic!("Expected McpNotConnected, got: {err:?}"),
}
}
#[tokio::test]
async fn test_disconnect_clears_response_cache() {
let client = McpClient::mock("test_cache");
assert!(client.is_connected());
client.disconnect().await.unwrap();
assert!(!client.is_connected());
}
#[tokio::test]
async fn test_disconnect_clears_response_cache_with_entries() {
let cache_config = CacheConfig {
ttl: std::time::Duration::from_secs(300),
max_entries: 100,
};
let client = McpClient::mock("test_cache_entries").with_cache(cache_config);
let _ = client
.call_tool("novanet_describe", serde_json::json!({}))
.await;
let stats = client.cache_stats();
assert!(stats.is_some());
client.disconnect().await.unwrap();
assert!(!client.is_connected());
}
#[tokio::test]
async fn test_disconnect_invalidates_tool_cache_via_adapter() {
let config = McpConfig::new("test_adapter_cache", "echo");
let client = McpClient::new(config).unwrap();
client.disconnect().await.unwrap();
assert!(!client.is_connected());
}
#[test]
fn wave2_cache_key_canonical_json_ordering() {
use serde_json::json;
let mut map_a = serde_json::Map::new();
map_a.insert("alpha".to_string(), json!("first"));
map_a.insert("beta".to_string(), json!("second"));
map_a.insert("gamma".to_string(), json!("third"));
let mut map_b = serde_json::Map::new();
map_b.insert("gamma".to_string(), json!("third"));
map_b.insert("alpha".to_string(), json!("first"));
map_b.insert("beta".to_string(), json!("second"));
let value_a = Value::Object(map_a);
let value_b = Value::Object(map_b);
let json_a = serde_json::to_string(&value_a).unwrap();
let json_b = serde_json::to_string(&value_b).unwrap();
let key_a = ResponseCache::cache_key("test_tool", &value_a);
let key_b = ResponseCache::cache_key("test_tool", &value_b);
assert_eq!(
key_a, key_b,
"Canonical cache keys should match regardless of key insertion order. \
json_a='{}', json_b='{}'",
json_a, json_b
);
}
#[test]
fn wave2_evict_oldest_collects_all_entries() {
use std::time::Duration;
let cache = ResponseCache::new(CacheConfig {
ttl: Duration::from_secs(300),
max_entries: 5,
});
for i in 0..6 {
let params = serde_json::json!({"i": i});
cache.put(
&format!("tool_{}", i),
¶ms,
ToolCallResult::success(vec![ContentBlock::text(format!("result_{}", i))]),
);
}
let stats = cache.stats();
assert!(stats.entries <= 6, "Cache should have at most 6 entries");
let to_remove = 5 / 10; let actual_remove = to_remove.max(1); assert_eq!(
actual_remove, 1,
"Eviction removes max(max_entries/10, 1) entries. \
BUG: This requires iterating ALL entries + sorting to find the oldest one. \
An LRU cache would do this in O(1)."
);
}
}