use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::session::SessionManager;
use crate::tool::{McpTool, compute_tool_fingerprint, tool_to_descriptor};
pub struct ToolRegistry {
compiled_tools: HashMap<String, Arc<dyn McpTool>>,
state: RwLock<ToolState>,
session_manager: Arc<SessionManager>,
server_state: Arc<dyn turul_mcp_server_state_storage::ServerStateStorage>,
last_check: RwLock<Option<std::time::Instant>>,
check_ttl: std::time::Duration,
}
struct ToolState {
active: HashSet<String>,
fingerprint: String,
}
impl ToolRegistry {
pub fn new(
compiled_tools: HashMap<String, Arc<dyn McpTool>>,
session_manager: Arc<SessionManager>,
server_state: Arc<dyn turul_mcp_server_state_storage::ServerStateStorage>,
) -> Self {
let active: HashSet<String> = compiled_tools.keys().cloned().collect();
let fingerprint = Self::compute_fingerprint_for(&compiled_tools, &active);
let check_ttl_secs: u64 = std::env::var("TURUL_TOOL_CHECK_TTL_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(10);
Self {
compiled_tools,
state: RwLock::new(ToolState {
active,
fingerprint,
}),
session_manager,
server_state,
last_check: RwLock::new(None),
check_ttl: std::time::Duration::from_secs(check_ttl_secs),
}
}
pub async fn activate_tool(&self, name: &str) -> Result<bool, ToolRegistryError> {
if !self.compiled_tools.contains_key(name) {
return Err(ToolRegistryError::NotCompiled(name.to_string()));
}
let changed = {
let mut state = self.state.write().await;
let inserted = state.active.insert(name.to_string());
if inserted {
state.fingerprint =
Self::compute_fingerprint_for(&self.compiled_tools, &state.active);
}
inserted
};
if changed {
self.broadcast_notification().await?;
info!("Tool '{}' activated", name);
self.persist_entity_change(name, true).await;
} else {
debug!("Tool '{}' already active", name);
}
Ok(changed)
}
pub async fn deactivate_tool(&self, name: &str) -> Result<bool, ToolRegistryError> {
if !self.compiled_tools.contains_key(name) {
return Err(ToolRegistryError::NotCompiled(name.to_string()));
}
let changed = {
let mut state = self.state.write().await;
let removed = state.active.remove(name);
if removed {
state.fingerprint =
Self::compute_fingerprint_for(&self.compiled_tools, &state.active);
}
removed
};
if changed {
self.broadcast_notification().await?;
info!("Tool '{}' deactivated", name);
self.persist_entity_change(name, false).await;
} else {
debug!("Tool '{}' already inactive", name);
}
Ok(changed)
}
pub async fn list_active_tools(&self) -> Vec<turul_mcp_protocol::Tool> {
let state = self.state.read().await;
let mut tools: Vec<turul_mcp_protocol::Tool> = self
.compiled_tools
.iter()
.filter(|(name, _)| state.active.contains(*name))
.map(|(_, tool)| tool_to_descriptor(tool.as_ref()))
.collect();
tools.sort_by(|a, b| a.name.cmp(&b.name));
tools
}
pub async fn get_tool(&self, name: &str) -> Option<Arc<dyn McpTool>> {
let state = self.state.read().await;
if state.active.contains(name) {
self.compiled_tools.get(name).cloned()
} else {
None
}
}
pub async fn fingerprint(&self) -> String {
self.state.read().await.fingerprint.clone()
}
pub fn check_ttl(&self) -> std::time::Duration {
self.check_ttl
}
pub fn compiled_tool_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.compiled_tools.keys().cloned().collect();
names.sort();
names
}
async fn broadcast_notification(&self) -> Result<(), ToolRegistryError> {
let notification = turul_mcp_protocol::JsonRpcNotification::new(
"notifications/tools/list_changed".to_string(),
);
let data = serde_json::to_value(¬ification)
.unwrap_or_else(|e| panic!("JsonRpcNotification serialization must not fail: {}", e));
self.session_manager
.broadcast_event(crate::session::SessionEvent::Custom {
event_type: "notifications/tools/list_changed".to_string(),
data,
})
.await
.map_err(ToolRegistryError::NotificationFailed)
}
pub async fn sync_from_storage(&self) -> Result<SyncResult, ToolRegistryError> {
let storage = &self.server_state;
let local_fp = self.fingerprint().await;
let stored_fp = storage
.get_fingerprint("tools")
.await
.map_err(|e| ToolRegistryError::StorageError(e.to_string()))?;
match stored_fp {
None => {
self.write_state_to_storage().await?;
Ok(SyncResult::InitializedStorage)
}
Some(stored) if stored == local_fp => {
self.load_state_from_storage().await?;
Ok(SyncResult::InSync)
}
Some(stored) => {
warn!(
"Tool fingerprint mismatch: local={}, storage={}. Updating storage.",
local_fp, stored
);
self.write_state_to_storage().await?;
Ok(SyncResult::UpdatedStorage {
old_fingerprint: stored,
})
}
}
}
async fn write_state_to_storage(&self) -> Result<(), ToolRegistryError> {
let storage = &self.server_state;
let state = self.state.read().await;
for name in &state.active {
let entity = turul_mcp_server_state_storage::EntityState {
entity_id: name.clone(),
active: true,
metadata: None,
updated_at: chrono::Utc::now().to_rfc3339(),
};
storage
.set_entity_state("tools", name, entity)
.await
.map_err(|e| ToolRegistryError::StorageError(e.to_string()))?;
}
storage
.set_fingerprint("tools", state.fingerprint.clone())
.await
.map_err(|e| ToolRegistryError::StorageError(e.to_string()))?;
Ok(())
}
async fn load_state_from_storage(&self) -> Result<(), ToolRegistryError> {
let storage = &self.server_state;
let active_ids = storage
.get_active_entities("tools")
.await
.map_err(|e| ToolRegistryError::StorageError(e.to_string()))?;
let mut state = self.state.write().await;
state.active = active_ids.into_iter().collect();
state.fingerprint = Self::compute_fingerprint_for(&self.compiled_tools, &state.active);
Ok(())
}
async fn persist_entity_change(&self, name: &str, active: bool) {
let entity = turul_mcp_server_state_storage::EntityState {
entity_id: name.to_string(),
active,
metadata: None,
updated_at: chrono::Utc::now().to_rfc3339(),
};
if let Err(e) = self
.server_state
.set_entity_state("tools", name, entity)
.await
{
warn!("Failed to persist tool state to storage: {}", e);
}
let fp = self.fingerprint().await;
if let Err(e) = self.server_state.set_fingerprint("tools", fp).await {
warn!("Failed to persist fingerprint to storage: {}", e);
}
}
pub async fn check_for_changes(&self) -> Result<bool, ToolRegistryError> {
{
let last = self.last_check.read().await;
if let Some(instant) = *last {
if instant.elapsed() < self.check_ttl {
return Ok(false);
}
}
}
let stored_fp = self
.server_state
.get_fingerprint("tools")
.await
.map_err(|e| ToolRegistryError::StorageError(e.to_string()))?;
*self.last_check.write().await = Some(std::time::Instant::now());
let local_fp = self.fingerprint().await;
match stored_fp {
Some(fp) if fp != local_fp => {
debug!(
"Dynamic: external tool change detected (stored={}, local={})",
fp, local_fp
);
self.load_state_from_storage().await?;
self.broadcast_notification().await?;
debug!("Dynamic: tool state reloaded and clients notified");
Ok(true)
}
_ => Ok(false),
}
}
pub fn start_polling(
self: &Arc<Self>,
interval: std::time::Duration,
) -> tokio::task::JoinHandle<()> {
let registry = Arc::clone(self);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
match registry.server_state.get_fingerprint("tools").await {
Ok(Some(stored_fp)) => {
let local_fp = registry.fingerprint().await;
if stored_fp != local_fp {
debug!(
"Dynamic: detected tool change from another instance (stored={}, local={})",
stored_fp, local_fp
);
if let Err(e) = registry.load_state_from_storage().await {
warn!("Failed to reload tool state from storage: {}", e);
continue;
}
if let Err(e) = registry.broadcast_notification().await {
warn!("Failed to persist tool change notification: {}", e);
}
debug!("Dynamic: tool state reloaded and clients notified");
}
}
Ok(None) => {
debug!("No fingerprint in storage yet");
}
Err(e) => {
warn!("Failed to check storage fingerprint: {}", e);
}
}
}
})
}
fn compute_fingerprint_for(
compiled: &HashMap<String, Arc<dyn McpTool>>,
active: &HashSet<String>,
) -> String {
let active_tools: HashMap<String, Arc<dyn McpTool>> = compiled
.iter()
.filter(|(name, _)| active.contains(*name))
.map(|(name, tool)| (name.clone(), Arc::clone(tool)))
.collect();
compute_tool_fingerprint(&active_tools)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ToolRegistryError {
#[error(
"Tool '{0}' is not a compiled tool — cannot activate/deactivate tools that were not registered at build time"
)]
NotCompiled(String),
#[error("Storage error: {0}")]
StorageError(String),
#[error("Notification persistence failed: {0}")]
NotificationFailed(String),
}
#[derive(Debug)]
pub enum SyncResult {
InitializedStorage,
InSync,
UpdatedStorage { old_fingerprint: String },
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use serde_json::Value;
use turul_mcp_builders::prelude::*;
use turul_mcp_protocol::McpResult;
use turul_mcp_protocol::tools::{CallToolResult, ToolResult, ToolSchema};
struct TestDynTool {
tool_name: &'static str,
}
impl HasBaseMetadata for TestDynTool {
fn name(&self) -> &str {
self.tool_name
}
}
impl HasDescription for TestDynTool {
fn description(&self) -> Option<&str> {
Some("test tool")
}
}
impl HasInputSchema for TestDynTool {
fn input_schema(&self) -> &ToolSchema {
static SCHEMA: std::sync::OnceLock<ToolSchema> = std::sync::OnceLock::new();
SCHEMA.get_or_init(ToolSchema::object)
}
}
impl HasOutputSchema for TestDynTool {
fn output_schema(&self) -> Option<&ToolSchema> {
None
}
}
impl HasAnnotations for TestDynTool {
fn annotations(&self) -> Option<&turul_mcp_protocol::tools::ToolAnnotations> {
None
}
}
impl HasToolMeta for TestDynTool {
fn tool_meta(&self) -> Option<&HashMap<String, Value>> {
None
}
}
impl HasIcons for TestDynTool {}
impl HasExecution for TestDynTool {}
#[async_trait]
impl McpTool for TestDynTool {
async fn call(
&self,
_args: Value,
_session: Option<crate::session::SessionContext>,
) -> McpResult<CallToolResult> {
Ok(CallToolResult::success(vec![ToolResult::text("ok")]))
}
}
struct SchemaTestTool {
tool_name: &'static str,
schema: turul_mcp_protocol::tools::ToolSchema,
}
impl HasBaseMetadata for SchemaTestTool {
fn name(&self) -> &str {
self.tool_name
}
}
impl HasDescription for SchemaTestTool {
fn description(&self) -> Option<&str> {
Some("schema test tool")
}
}
impl HasInputSchema for SchemaTestTool {
fn input_schema(&self) -> &turul_mcp_protocol::tools::ToolSchema {
&self.schema
}
}
impl HasOutputSchema for SchemaTestTool {
fn output_schema(&self) -> Option<&turul_mcp_protocol::tools::ToolSchema> {
None
}
}
impl HasAnnotations for SchemaTestTool {
fn annotations(&self) -> Option<&turul_mcp_protocol::tools::ToolAnnotations> {
None
}
}
impl HasToolMeta for SchemaTestTool {
fn tool_meta(&self) -> Option<&HashMap<String, Value>> {
None
}
}
impl HasIcons for SchemaTestTool {}
impl HasExecution for SchemaTestTool {}
#[async_trait]
impl McpTool for SchemaTestTool {
async fn call(
&self,
_args: Value,
_session: Option<crate::session::SessionContext>,
) -> McpResult<CallToolResult> {
Ok(CallToolResult::success(vec![ToolResult::text("ok")]))
}
}
fn test_tools() -> HashMap<String, Arc<dyn McpTool>> {
let mut tools: HashMap<String, Arc<dyn McpTool>> = HashMap::new();
tools.insert(
"alpha".to_string(),
Arc::new(TestDynTool { tool_name: "alpha" }),
);
tools.insert(
"beta".to_string(),
Arc::new(TestDynTool { tool_name: "beta" }),
);
tools.insert(
"gamma".to_string(),
Arc::new(TestDynTool { tool_name: "gamma" }),
);
tools
}
fn test_session_manager() -> Arc<SessionManager> {
Arc::new(SessionManager::new(
turul_mcp_protocol::ServerCapabilities::default(),
))
}
fn test_storage() -> Arc<dyn turul_mcp_server_state_storage::ServerStateStorage> {
Arc::new(turul_mcp_server_state_storage::InMemoryServerStateStorage::new())
}
fn test_registry() -> ToolRegistry {
ToolRegistry::new(test_tools(), test_session_manager(), test_storage())
}
#[tokio::test]
async fn test_all_tools_active_by_default() {
let registry = test_registry();
let active = registry.list_active_tools().await;
assert_eq!(active.len(), 3);
}
#[tokio::test]
async fn test_deactivate_tool() {
let registry = test_registry();
let result = registry.deactivate_tool("beta").await.unwrap();
assert!(result, "beta should have been deactivated");
let active = registry.list_active_tools().await;
assert_eq!(active.len(), 2);
assert!(active.iter().all(|t| t.name != "beta"));
}
#[tokio::test]
async fn test_activate_tool() {
let registry = test_registry();
registry.deactivate_tool("beta").await.unwrap();
assert_eq!(registry.list_active_tools().await.len(), 2);
let result = registry.activate_tool("beta").await.unwrap();
assert!(result, "beta should have been newly activated");
assert_eq!(registry.list_active_tools().await.len(), 3);
}
#[tokio::test]
async fn test_activate_already_active() {
let registry = test_registry();
let result = registry.activate_tool("alpha").await.unwrap();
assert!(!result, "alpha was already active");
}
#[tokio::test]
async fn test_deactivate_already_inactive() {
let registry = test_registry();
registry.deactivate_tool("beta").await.unwrap();
let result = registry.deactivate_tool("beta").await.unwrap();
assert!(!result, "beta was already inactive");
}
#[tokio::test]
async fn test_activate_nonexistent_tool_errors() {
let registry = test_registry();
let result = registry.activate_tool("nonexistent").await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ToolRegistryError::NotCompiled(_)
));
}
#[tokio::test]
async fn test_deactivate_nonexistent_tool_errors() {
let registry = test_registry();
let result = registry.deactivate_tool("nonexistent").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_get_tool_active() {
let registry = test_registry();
let tool = registry.get_tool("alpha").await;
assert!(tool.is_some());
}
#[tokio::test]
async fn test_get_tool_inactive() {
let registry = test_registry();
registry.deactivate_tool("alpha").await.unwrap();
let tool = registry.get_tool("alpha").await;
assert!(tool.is_none(), "Inactive tool should return None");
}
#[tokio::test]
async fn test_fingerprint_changes_on_mutation() {
let registry = test_registry();
let fp_before = registry.fingerprint().await;
registry.deactivate_tool("beta").await.unwrap();
let fp_after = registry.fingerprint().await;
assert_ne!(
fp_before, fp_after,
"Fingerprint must change when active set changes"
);
}
#[tokio::test]
async fn test_fingerprint_stable_without_mutation() {
let registry = test_registry();
let fp1 = registry.fingerprint().await;
let fp2 = registry.fingerprint().await;
assert_eq!(fp1, fp2);
}
#[tokio::test]
async fn test_compiled_tool_names() {
let registry = test_registry();
let names = registry.compiled_tool_names();
assert_eq!(names, vec!["alpha", "beta", "gamma"]);
}
#[tokio::test]
async fn test_concurrent_operations() {
let registry = Arc::new(test_registry());
let mut handles = Vec::new();
for i in 0..20 {
let reg = Arc::clone(®istry);
let handle = tokio::spawn(async move {
if i % 3 == 0 {
let _ = reg.deactivate_tool("beta").await;
} else if i % 3 == 1 {
let _ = reg.activate_tool("beta").await;
} else {
let _ = reg.list_active_tools().await;
}
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let active = registry.list_active_tools().await;
assert!(active.len() >= 2 && active.len() <= 3);
}
#[tokio::test]
async fn test_activate_tool_emits_notification_event() {
let session_manager = test_session_manager();
let session_id = session_manager.create_session().await;
let mut receiver = session_manager.subscribe_all_session_events();
let registry = ToolRegistry::new(test_tools(), session_manager.clone(), test_storage());
registry.deactivate_tool("beta").await.unwrap();
let mut found_notification = false;
while let Ok((recv_session_id, event)) =
tokio::time::timeout(std::time::Duration::from_millis(100), receiver.recv())
.await
.unwrap_or(Err(tokio::sync::broadcast::error::RecvError::Closed))
{
if let crate::session::SessionEvent::Custom { event_type, .. } = &event {
if event_type == "notifications/tools/list_changed" {
found_notification = true;
assert_eq!(
recv_session_id, session_id,
"Notification should be sent to the existing session"
);
break;
}
}
}
assert!(
found_notification,
"deactivate_tool() must broadcast notifications/tools/list_changed via SessionManager"
);
}
#[tokio::test]
async fn test_notification_payload_matches_mcp_wire_format() {
let session_manager = test_session_manager();
let _session_id = session_manager.create_session().await;
let mut receiver = session_manager.subscribe_all_session_events();
let registry = ToolRegistry::new(test_tools(), session_manager, test_storage());
registry.deactivate_tool("alpha").await.unwrap();
let (_sid, event) =
tokio::time::timeout(std::time::Duration::from_millis(100), receiver.recv())
.await
.expect("Timeout waiting for event")
.expect("Channel closed");
if let crate::session::SessionEvent::Custom { event_type, data } = event {
assert_eq!(event_type, "notifications/tools/list_changed");
assert_eq!(
data.get("jsonrpc").and_then(|j| j.as_str()),
Some("2.0"),
"Must contain jsonrpc: \"2.0\" per JSON-RPC 2.0 spec"
);
assert_eq!(
data.get("method").and_then(|m| m.as_str()),
Some("notifications/tools/list_changed"),
"Must contain method field per MCP spec"
);
assert!(
data.get("params").is_none() || data.get("params").unwrap().is_null(),
"No params field for list_changed notification"
);
} else {
panic!("Expected SessionEvent::Custom, got {:?}", event);
}
}
#[tokio::test]
async fn test_activate_tool_also_emits_notification() {
let session_manager = test_session_manager();
let _session_id = session_manager.create_session().await;
let registry = ToolRegistry::new(test_tools(), session_manager.clone(), test_storage());
registry.deactivate_tool("beta").await.unwrap();
let mut receiver = session_manager.subscribe_all_session_events();
registry.activate_tool("beta").await.unwrap();
let mut found = false;
while let Ok((_sid, event)) =
tokio::time::timeout(std::time::Duration::from_millis(100), receiver.recv())
.await
.unwrap_or(Err(tokio::sync::broadcast::error::RecvError::Closed))
{
if let crate::session::SessionEvent::Custom { event_type, .. } = &event {
if event_type == "notifications/tools/list_changed" {
found = true;
break;
}
}
}
assert!(found, "activate_tool() must also broadcast notification");
}
#[tokio::test]
async fn test_fingerprint_round_trip() {
let registry = test_registry();
let fp_initial = registry.fingerprint().await;
registry.deactivate_tool("beta").await.unwrap();
let fp_deactivated = registry.fingerprint().await;
assert_ne!(fp_initial, fp_deactivated);
registry.activate_tool("beta").await.unwrap();
let fp_reactivated = registry.fingerprint().await;
assert_eq!(
fp_initial, fp_reactivated,
"Restoring same active set must restore same fingerprint"
);
}
#[tokio::test]
async fn test_deactivate_all_tools() {
let registry = test_registry();
let fp_full = registry.fingerprint().await;
registry.deactivate_tool("alpha").await.unwrap();
registry.deactivate_tool("beta").await.unwrap();
registry.deactivate_tool("gamma").await.unwrap();
let active = registry.list_active_tools().await;
assert!(active.is_empty(), "All tools deactivated → empty list");
let fp_empty = registry.fingerprint().await;
assert_ne!(
fp_full, fp_empty,
"Empty set fingerprint differs from full set"
);
assert_eq!(
fp_empty.len(),
16,
"Empty set still produces valid fingerprint"
);
assert!(registry.get_tool("alpha").await.is_none());
}
#[tokio::test]
async fn test_notification_does_not_prevent_fingerprint_change() {
let registry = test_registry();
let fp_before = registry.fingerprint().await;
registry.deactivate_tool("beta").await.unwrap();
let fp_after = registry.fingerprint().await;
assert_ne!(
fp_before, fp_after,
"After tool mutation, fingerprint MUST change. \
Existing sessions with the old fingerprint MUST be rejected (404). \
The notification is advisory only and does not bypass this."
);
}
#[tokio::test]
async fn test_sync_from_storage_initializes_empty_storage() {
let storage = test_storage();
let registry = ToolRegistry::new(test_tools(), test_session_manager(), storage.clone());
let result = registry.sync_from_storage().await.unwrap();
assert!(matches!(result, SyncResult::InitializedStorage));
let stored_fp = storage.get_fingerprint("tools").await.unwrap();
assert!(stored_fp.is_some());
assert_eq!(stored_fp.unwrap(), registry.fingerprint().await);
}
#[tokio::test]
async fn test_sync_from_storage_in_sync() {
let storage = test_storage();
let registry = ToolRegistry::new(test_tools(), test_session_manager(), storage.clone());
registry.sync_from_storage().await.unwrap();
let registry2 = ToolRegistry::new(test_tools(), test_session_manager(), storage.clone());
let result = registry2.sync_from_storage().await.unwrap();
assert!(matches!(result, SyncResult::InSync));
}
#[tokio::test]
async fn test_independent_registries_same_tools_no_spurious_mismatch() {
use turul_mcp_protocol::schema::JsonSchema;
use turul_mcp_protocol::tools::ToolSchema;
let mut props_a = HashMap::new();
props_a.insert("name".to_string(), JsonSchema::string());
props_a.insert("age".to_string(), JsonSchema::number());
props_a.insert("active".to_string(), JsonSchema::boolean());
let mut tools_a: HashMap<String, Arc<dyn McpTool>> = HashMap::new();
tools_a.insert(
"alpha".to_string(),
Arc::new(TestDynTool { tool_name: "alpha" }),
);
tools_a.insert(
"complex".to_string(),
Arc::new(SchemaTestTool {
tool_name: "complex",
schema: ToolSchema::object()
.with_properties(props_a)
.with_required(vec!["name".to_string()]),
}),
);
let mut props_b = HashMap::new();
props_b.insert("active".to_string(), JsonSchema::boolean());
props_b.insert("name".to_string(), JsonSchema::string());
props_b.insert("age".to_string(), JsonSchema::number());
let mut tools_b: HashMap<String, Arc<dyn McpTool>> = HashMap::new();
tools_b.insert(
"complex".to_string(),
Arc::new(SchemaTestTool {
tool_name: "complex",
schema: ToolSchema::object()
.with_properties(props_b)
.with_required(vec!["name".to_string()]),
}),
);
tools_b.insert(
"alpha".to_string(),
Arc::new(TestDynTool { tool_name: "alpha" }),
);
let storage = test_storage();
let registry_a = ToolRegistry::new(tools_a, test_session_manager(), storage.clone());
let result_a = registry_a.sync_from_storage().await.unwrap();
assert!(matches!(result_a, SyncResult::InitializedStorage));
let registry_b = ToolRegistry::new(tools_b, test_session_manager(), storage.clone());
let result_b = registry_b.sync_from_storage().await.unwrap();
assert!(
matches!(result_b, SyncResult::InSync),
"Identically-configured registries with different HashMap order must sync as InSync, got {:?}",
result_b
);
assert_eq!(
registry_a.fingerprint().await,
registry_b.fingerprint().await,
"Same logical tools must produce same fingerprint regardless of HashMap insertion order"
);
}
#[tokio::test]
async fn test_sync_from_storage_detects_newer_tools() {
let storage = test_storage();
let registry1 = ToolRegistry::new(test_tools(), test_session_manager(), storage.clone());
registry1.sync_from_storage().await.unwrap();
let old_fp = storage.get_fingerprint("tools").await.unwrap().unwrap();
let mut fewer_tools: HashMap<String, Arc<dyn McpTool>> = HashMap::new();
fewer_tools.insert(
"alpha".to_string(),
Arc::new(TestDynTool { tool_name: "alpha" }),
);
fewer_tools.insert(
"beta".to_string(),
Arc::new(TestDynTool { tool_name: "beta" }),
);
let registry2 = ToolRegistry::new(fewer_tools, test_session_manager(), storage.clone());
let result = registry2.sync_from_storage().await.unwrap();
match result {
SyncResult::UpdatedStorage { old_fingerprint } => {
assert_eq!(old_fingerprint, old_fp);
}
other => panic!("Expected UpdatedStorage, got {:?}", other),
}
let new_fp = storage.get_fingerprint("tools").await.unwrap().unwrap();
assert_eq!(new_fp, registry2.fingerprint().await);
assert_ne!(new_fp, old_fp);
}
#[tokio::test]
async fn test_activate_persists_to_storage() {
let storage = test_storage();
let registry = ToolRegistry::new(test_tools(), test_session_manager(), storage.clone());
registry.deactivate_tool("beta").await.unwrap();
registry.activate_tool("beta").await.unwrap();
let state = storage.get_entity_state("tools", "beta").await.unwrap();
assert!(state.is_some());
assert!(state.unwrap().active);
let stored_fp = storage.get_fingerprint("tools").await.unwrap();
assert_eq!(stored_fp, Some(registry.fingerprint().await));
}
#[tokio::test]
async fn test_polling_detects_external_fingerprint_change() {
let storage = test_storage();
let registry = Arc::new(ToolRegistry::new(
test_tools(),
test_session_manager(),
storage.clone(),
));
registry.sync_from_storage().await.unwrap();
let initial_fp = registry.fingerprint().await;
let entity = turul_mcp_server_state_storage::EntityState {
entity_id: "gamma".to_string(),
active: false,
metadata: None,
updated_at: chrono::Utc::now().to_rfc3339(),
};
storage
.set_entity_state("tools", "gamma", entity)
.await
.unwrap();
storage
.set_fingerprint("tools", "external_change".to_string())
.await
.unwrap();
let handle = registry.start_polling(std::time::Duration::from_millis(50));
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let new_fp = registry.fingerprint().await;
assert_ne!(
new_fp, initial_fp,
"Polling should detect external fingerprint change and reload state"
);
let active = registry.list_active_tools().await;
assert_eq!(
active.len(),
2,
"gamma should have been deactivated by external change"
);
assert!(
active.iter().all(|t| t.name != "gamma"),
"gamma should not be in the active tool list"
);
handle.abort();
}
#[tokio::test]
async fn test_polling_noop_when_fingerprints_match() {
let storage = test_storage();
let registry = Arc::new(ToolRegistry::new(
test_tools(),
test_session_manager(),
storage.clone(),
));
registry.sync_from_storage().await.unwrap();
let initial_fp = registry.fingerprint().await;
let handle = registry.start_polling(std::time::Duration::from_millis(50));
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let fp_after = registry.fingerprint().await;
assert_eq!(
fp_after, initial_fp,
"Fingerprint should remain unchanged when storage matches"
);
assert_eq!(registry.list_active_tools().await.len(), 3);
handle.abort();
}
#[tokio::test]
async fn test_deactivate_persists_to_storage() {
let storage = test_storage();
let registry = ToolRegistry::new(test_tools(), test_session_manager(), storage.clone());
registry.deactivate_tool("gamma").await.unwrap();
let state = storage.get_entity_state("tools", "gamma").await.unwrap();
assert!(state.is_some());
assert!(!state.unwrap().active);
}
#[tokio::test]
async fn test_check_for_changes_detects_external_change() {
let storage = test_storage();
let registry = ToolRegistry::new(test_tools(), test_session_manager(), storage.clone());
registry.sync_from_storage().await.unwrap();
let initial_fp = registry.fingerprint().await;
let entity = turul_mcp_server_state_storage::EntityState {
entity_id: "gamma".to_string(),
active: false,
metadata: None,
updated_at: chrono::Utc::now().to_rfc3339(),
};
storage
.set_entity_state("tools", "gamma", entity)
.await
.unwrap();
storage
.set_fingerprint("tools", "external_change".to_string())
.await
.unwrap();
let changed = registry.check_for_changes().await.unwrap();
assert!(
changed,
"check_for_changes must return true when storage fingerprint differs"
);
let new_fp = registry.fingerprint().await;
assert_ne!(new_fp, initial_fp, "Fingerprint must change after reload");
let active = registry.list_active_tools().await;
assert_eq!(
active.len(),
2,
"gamma should have been deactivated by external change"
);
assert!(
active.iter().all(|t| t.name != "gamma"),
"gamma should not be in the active tool list"
);
}
#[tokio::test]
async fn test_check_for_changes_noop_when_matching() {
let storage = test_storage();
let registry = ToolRegistry::new(test_tools(), test_session_manager(), storage.clone());
registry.sync_from_storage().await.unwrap();
let initial_fp = registry.fingerprint().await;
let changed = registry.check_for_changes().await.unwrap();
assert!(
!changed,
"check_for_changes must return false when fingerprints match"
);
assert_eq!(registry.fingerprint().await, initial_fp);
assert_eq!(registry.list_active_tools().await.len(), 3);
}
struct RecordingDispatcher {
events: tokio::sync::Mutex<Vec<(String, String, serde_json::Value)>>,
}
impl RecordingDispatcher {
fn new() -> Self {
Self {
events: tokio::sync::Mutex::new(Vec::new()),
}
}
async fn event_count(&self) -> usize {
self.events.lock().await.len()
}
async fn events_for_type(
&self,
event_type: &str,
) -> Vec<(String, String, serde_json::Value)> {
self.events
.lock()
.await
.iter()
.filter(|(_, et, _)| et == event_type)
.cloned()
.collect()
}
}
#[async_trait]
impl crate::session::SessionEventDispatcher for RecordingDispatcher {
async fn dispatch_to_session(
&self,
session_id: &str,
event_type: String,
data: serde_json::Value,
) -> std::result::Result<(), String> {
self.events
.lock()
.await
.push((session_id.to_string(), event_type, data));
Ok(())
}
}
fn test_session_manager_with_dispatcher() -> (Arc<SessionManager>, Arc<RecordingDispatcher>) {
let sm = Arc::new(SessionManager::new(
turul_mcp_protocol::ServerCapabilities::default(),
));
let dispatcher = Arc::new(RecordingDispatcher::new());
(sm, dispatcher)
}
#[tokio::test]
async fn test_deactivate_stores_exactly_one_event() {
let (sm, dispatcher) = test_session_manager_with_dispatcher();
sm.set_event_dispatcher(dispatcher.clone()).await;
let _session_id = sm.create_session().await;
let registry = ToolRegistry::new(test_tools(), sm, test_storage());
registry.deactivate_tool("beta").await.unwrap();
let events = dispatcher
.events_for_type("notifications/tools/list_changed")
.await;
assert_eq!(
events.len(),
1,
"deactivate_tool must persist exactly 1 notification, got {}",
events.len()
);
}
#[tokio::test]
async fn test_activate_stores_exactly_one_event() {
let (sm, dispatcher) = test_session_manager_with_dispatcher();
sm.set_event_dispatcher(dispatcher.clone()).await;
let _session_id = sm.create_session().await;
let registry = ToolRegistry::new(test_tools(), sm, test_storage());
registry.deactivate_tool("beta").await.unwrap();
let count_after_deactivate = dispatcher.event_count().await;
registry.activate_tool("beta").await.unwrap();
let count_after_activate = dispatcher.event_count().await;
assert_eq!(
count_after_activate - count_after_deactivate,
1,
"activate_tool must persist exactly 1 additional notification"
);
}
#[tokio::test]
async fn test_check_for_changes_stores_event_before_return() {
let storage = test_storage();
let (sm, dispatcher) = test_session_manager_with_dispatcher();
sm.set_event_dispatcher(dispatcher.clone()).await;
let _session_id = sm.create_session().await;
let registry_a = ToolRegistry::new(test_tools(), test_session_manager(), storage.clone());
registry_a.sync_from_storage().await.unwrap();
registry_a.deactivate_tool("gamma").await.unwrap();
registry_a.sync_from_storage().await.unwrap();
let registry_b = ToolRegistry::new(test_tools(), sm, storage.clone());
let changed = registry_b.check_for_changes().await.unwrap();
assert!(changed, "Should detect fingerprint mismatch");
let events = dispatcher
.events_for_type("notifications/tools/list_changed")
.await;
assert_eq!(
events.len(),
1,
"check_for_changes must persist exactly 1 notification before returning, got {}",
events.len()
);
}
#[tokio::test]
async fn test_new_session_after_mutation_gets_live_fingerprint() {
let (sm, dispatcher) = test_session_manager_with_dispatcher();
sm.set_event_dispatcher(dispatcher.clone()).await;
let registry = ToolRegistry::new(test_tools(), sm.clone(), test_storage());
registry.deactivate_tool("gamma").await.unwrap();
let live_fp = registry.fingerprint().await;
let session_id = sm.create_session().await;
sm.set_session_state(
&session_id,
"mcp:tool_fingerprint",
serde_json::json!(live_fp),
)
.await;
let stored_fp = sm
.get_session_state(&session_id, "mcp:tool_fingerprint")
.await
.and_then(|v| v.as_str().map(|s| s.to_string()))
.expect("fingerprint should be stored");
assert_eq!(
stored_fp, live_fp,
"New session must get live registry fingerprint, not build-time"
);
let events_before = dispatcher.event_count().await;
let session_fp = stored_fp;
let current_fp = registry.fingerprint().await;
assert_eq!(
session_fp, current_fp,
"Session fingerprint must match current — no mismatch, no spurious notification"
);
assert_eq!(
dispatcher.event_count().await,
events_before,
"No spurious notification should be emitted for a correctly initialized session"
);
}
#[tokio::test]
async fn test_static_mode_initialize_stores_no_fingerprint() {
let sm = test_session_manager();
let session_id = sm.create_session().await;
let static_fingerprint = String::new();
if !static_fingerprint.is_empty() {
sm.set_session_state(
&session_id,
"mcp:tool_fingerprint",
serde_json::json!(static_fingerprint),
)
.await;
}
let stored = sm
.get_session_state(&session_id, "mcp:tool_fingerprint")
.await;
assert!(
stored.is_none(),
"Static mode must NOT store mcp:tool_fingerprint, got {:?}",
stored
);
}
#[tokio::test]
async fn test_dispatcher_targets_all_sessions() {
let (sm, dispatcher) = test_session_manager_with_dispatcher();
sm.set_event_dispatcher(dispatcher.clone()).await;
let session_a = sm.create_session().await;
let session_b = sm.create_session().await;
let registry = ToolRegistry::new(test_tools(), sm, test_storage());
registry.deactivate_tool("alpha").await.unwrap();
let events = dispatcher
.events_for_type("notifications/tools/list_changed")
.await;
assert_eq!(
events.len(),
2,
"Should dispatch to both sessions, got {}",
events.len()
);
let session_ids: Vec<&str> = events.iter().map(|(s, _, _)| s.as_str()).collect();
assert!(
session_ids.contains(&session_a.as_str()),
"Should dispatch to session A"
);
assert!(
session_ids.contains(&session_b.as_str()),
"Should dispatch to session B"
);
}
}