use crate::errors::{ClaudeError, Result};
use crate::internal::query_full::QueryFull;
use crate::internal::transport::Transport;
use crate::types::mcp::McpSdkServerConfig;
use crate::types::permissions::CanUseToolCallback;
use dashmap::DashMap;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::{debug, info};
use uuid::Uuid;
pub struct QueryManager {
queries: DashMap<String, Arc<QueryFull>>,
session_queries: DashMap<String, String>,
transport_factory: Arc<Mutex<dyn Fn() -> Result<Box<dyn Transport>> + Send + Sync + 'static>>,
hooks: Arc<Mutex<Option<HashMap<String, Vec<crate::types::hooks::HookMatcher>>>>>,
sdk_mcp_servers: Arc<Mutex<HashMap<String, McpSdkServerConfig>>>,
can_use_tool: Arc<Mutex<Option<CanUseToolCallback>>>,
control_request_timeout: Option<Duration>,
agents: Arc<Mutex<Option<serde_json::Value>>>,
}
impl QueryManager {
pub fn get_all_queries(&self) -> Vec<(String, Arc<QueryFull>)> {
self.queries
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect()
}
pub fn new<F>(transport_factory: F) -> Self
where
F: Fn() -> Result<Box<dyn Transport>> + Send + Sync + 'static,
{
Self {
queries: DashMap::new(),
session_queries: DashMap::new(),
transport_factory: Arc::new(Mutex::new(transport_factory)),
hooks: Arc::new(Mutex::new(None)),
sdk_mcp_servers: Arc::new(Mutex::new(HashMap::new())),
can_use_tool: Arc::new(Mutex::new(None)),
control_request_timeout: None,
agents: Arc::new(Mutex::new(None)),
}
}
pub async fn set_hooks(
&self,
hooks: Option<HashMap<String, Vec<crate::types::hooks::HookMatcher>>>,
) {
*self.hooks.lock().await = hooks;
}
pub async fn set_sdk_mcp_servers(&self, servers: HashMap<String, McpSdkServerConfig>) {
*self.sdk_mcp_servers.lock().await = servers;
}
pub async fn set_can_use_tool(&self, callback: Option<CanUseToolCallback>) {
*self.can_use_tool.lock().await = callback;
}
pub fn set_control_request_timeout(&mut self, timeout: Option<Duration>) {
self.control_request_timeout = timeout;
}
pub async fn set_agents(&self, agents: Option<serde_json::Value>) {
*self.agents.lock().await = agents;
}
pub async fn create_query(&self) -> Result<String> {
let query_id = format!(
"query_{}_{}",
Uuid::new_v4().simple(),
chrono::Utc::now().timestamp_micros()
);
let mut transport = {
let factory = self.transport_factory.lock().await;
(factory)()
.map_err(|e| ClaudeError::Transport(format!("Failed to create transport: {}", e)))?
};
transport
.connect()
.await
.map_err(|e| ClaudeError::Transport(format!("Failed to connect transport: {}", e)))?;
let stdin = transport.get_stdin();
let mut query = QueryFull::new(transport);
if let Some(stdin_ref) = stdin {
query.set_stdin(stdin_ref);
}
if let Some(timeout) = self.control_request_timeout {
query.set_control_request_timeout(Some(timeout));
}
let servers = self.sdk_mcp_servers.lock().await.clone();
query.set_sdk_mcp_servers(servers).await;
let callback = self.can_use_tool.lock().await.clone();
if let Some(cb) = callback {
query.set_can_use_tool(Some(cb)).await;
}
query.start().await?;
let hooks = self.hooks.lock().await.clone();
let agents = self.agents.lock().await.clone();
query.initialize(hooks, agents).await?;
self.queries.insert(query_id.clone(), Arc::new(query));
info!(query_id = %query_id, "Created new isolated query");
Ok(query_id)
}
pub fn get_query(&self, query_id: &str) -> Result<Arc<QueryFull>> {
self.queries
.get(query_id)
.map(|v| v.value().clone())
.ok_or_else(|| ClaudeError::InvalidConfig(format!("Query not found: {}", query_id)))
}
pub async fn get_or_create_session_query(&self, session_id: &str) -> Result<String> {
if let Some(entry) = self.session_queries.get(session_id) {
let query_id = entry.value().clone();
if let Some(query) = self.queries.get(&query_id) {
if query.is_active() {
tracing::debug!(
session_id = %session_id,
query_id = %query_id,
"Reusing existing query for session"
);
return Ok(query_id);
} else {
tracing::warn!(
session_id = %session_id,
query_id = %query_id,
"Existing query for session is inactive, removing"
);
self.session_queries.remove(session_id);
self.queries.remove(&query_id);
}
} else {
tracing::warn!(
session_id = %session_id,
query_id = %query_id,
"Query not found for session, removing stale mapping"
);
self.session_queries.remove(session_id);
}
}
let query_id = self.create_query().await?;
self.session_queries
.insert(session_id.to_string(), query_id.clone());
tracing::info!(
session_id = %session_id,
query_id = %query_id,
"Created new query for session"
);
Ok(query_id)
}
pub fn remove_query(&self, query_id: &str) -> Result<()> {
self.queries
.remove(query_id)
.ok_or_else(|| ClaudeError::InvalidConfig(format!("Query not found: {}", query_id)))?;
info!(query_id = %query_id, "Removed query");
Ok(())
}
pub fn active_count(&self) -> usize {
self.queries.len()
}
pub fn contains_query(&self, query_id: &str) -> bool {
self.queries.contains_key(query_id)
}
pub fn start_cleanup_task(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
interval.tick().await;
let mut removed = 0;
let mut active = Vec::new();
self.queries.retain(|query_id, query| {
if query.is_active() {
active.push(query_id.clone());
true
} else {
removed += 1;
false
}
});
if removed > 0 {
debug!(
removed,
active_count = active.len(),
"Cleaned up inactive queries"
);
}
if cfg!(debug_assertions) && !active.is_empty() {
debug!(active_queries = ?active, "Currently active queries");
}
}
})
}
}