use futures::stream::StreamExt;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::AsyncWriteExt;
use tokio::sync::{Mutex, mpsc, oneshot};
use crate::errors::{ClaudeError, Result};
use crate::types::hooks::{HookCallback, HookContext, HookInput, HookMatcher};
use crate::types::mcp::McpSdkServerConfig;
use super::transport::Transport;
#[allow(dead_code)]
#[derive(Debug, serde::Serialize)]
struct ControlRequest {
#[serde(rename = "type")]
type_: String,
request_id: String,
request: serde_json::Value,
}
#[derive(Debug, serde::Deserialize)]
struct ControlResponse {
#[serde(rename = "type")]
#[allow(dead_code)]
type_: String,
response: ControlResponseData,
}
#[derive(Debug, serde::Deserialize)]
struct ControlResponseData {
#[allow(dead_code)]
subtype: String,
request_id: String,
#[serde(flatten)]
data: serde_json::Value,
}
#[derive(Debug, serde::Deserialize)]
struct IncomingControlRequest {
#[serde(rename = "type")]
#[allow(dead_code)]
type_: String,
request_id: String,
request: serde_json::Value,
}
pub struct QueryFull {
pub(crate) transport: Arc<Mutex<Box<dyn Transport>>>,
hook_callbacks: Arc<Mutex<HashMap<String, HookCallback>>>,
sdk_mcp_servers: Arc<Mutex<HashMap<String, McpSdkServerConfig>>>,
next_callback_id: Arc<AtomicU64>,
request_counter: Arc<AtomicU64>,
pending_responses: Arc<Mutex<HashMap<String, oneshot::Sender<serde_json::Value>>>>,
message_tx: mpsc::UnboundedSender<serde_json::Value>,
pub(crate) message_rx: Arc<Mutex<mpsc::UnboundedReceiver<serde_json::Value>>>,
pub(crate) stdin: Option<Arc<Mutex<Option<tokio::process::ChildStdin>>>>,
initialization_result: Arc<Mutex<Option<serde_json::Value>>>,
}
impl QueryFull {
pub fn new(transport: Box<dyn Transport>) -> Self {
let (message_tx, message_rx) = mpsc::unbounded_channel();
Self {
transport: Arc::new(Mutex::new(transport)),
hook_callbacks: Arc::new(Mutex::new(HashMap::new())),
sdk_mcp_servers: Arc::new(Mutex::new(HashMap::new())),
next_callback_id: Arc::new(AtomicU64::new(0)),
request_counter: Arc::new(AtomicU64::new(0)),
pending_responses: Arc::new(Mutex::new(HashMap::new())),
message_tx,
message_rx: Arc::new(Mutex::new(message_rx)),
stdin: None,
initialization_result: Arc::new(Mutex::new(None)),
}
}
pub fn set_stdin(&mut self, stdin: Arc<Mutex<Option<tokio::process::ChildStdin>>>) {
self.stdin = Some(stdin);
}
pub async fn set_sdk_mcp_servers(&mut self, servers: HashMap<String, McpSdkServerConfig>) {
*self.sdk_mcp_servers.lock().await = servers;
}
pub async fn initialize(
&self,
hooks: Option<HashMap<String, Vec<HookMatcher>>>,
) -> Result<serde_json::Value> {
let mut hooks_config: HashMap<String, Vec<serde_json::Value>> = HashMap::new();
if let Some(hooks_map) = hooks {
for (event, matchers) in hooks_map {
let mut event_matchers = Vec::new();
for matcher in matchers {
let mut callback_ids = Vec::new();
for callback in matcher.hooks {
let callback_id = format!(
"hook_{}",
self.next_callback_id.fetch_add(1, Ordering::SeqCst)
);
self.hook_callbacks
.lock()
.await
.insert(callback_id.clone(), callback);
callback_ids.push(callback_id);
}
let mut matcher_json = json!({
"matcher": matcher.matcher,
"hookCallbackIds": callback_ids
});
if let Some(timeout) = matcher.timeout {
matcher_json["timeout"] = json!(timeout);
}
event_matchers.push(matcher_json);
}
hooks_config.insert(event, event_matchers);
}
}
let request = json!({
"subtype": "initialize",
"hooks": if hooks_config.is_empty() { json!(null) } else { json!(hooks_config) }
});
let response = self.send_control_request(request).await?;
*self.initialization_result.lock().await = Some(response.clone());
Ok(response)
}
pub async fn start(&self) -> Result<()> {
let transport = Arc::clone(&self.transport);
let hook_callbacks = Arc::clone(&self.hook_callbacks);
let sdk_mcp_servers = Arc::clone(&self.sdk_mcp_servers);
let pending_responses = Arc::clone(&self.pending_responses);
let message_tx = self.message_tx.clone();
let stdin = self.stdin.clone();
let (ready_tx, ready_rx) = oneshot::channel();
tokio::spawn(async move {
let mut transport_guard = transport.lock().await;
let mut stream = transport_guard.read_messages();
let _ = ready_tx.send(());
while let Some(result) = stream.next().await {
match result {
Ok(message) => {
let msg_type = message.get("type").and_then(|v| v.as_str());
match msg_type {
Some("control_response") => {
if let Ok(response) =
serde_json::from_value::<ControlResponse>(message.clone())
{
let mut pending = pending_responses.lock().await;
if let Some(tx) = pending.remove(&response.response.request_id)
{
let _ = tx.send(response.response.data);
}
}
},
Some("control_request") => {
if let Ok(request) = serde_json::from_value::<IncomingControlRequest>(
message.clone(),
) {
let stdin_clone = stdin.clone();
let hook_callbacks_clone = Arc::clone(&hook_callbacks);
let sdk_mcp_servers_clone = Arc::clone(&sdk_mcp_servers);
tokio::spawn(async move {
if let Err(e) = Self::handle_control_request_with_stdin(
request,
stdin_clone,
hook_callbacks_clone,
sdk_mcp_servers_clone,
)
.await
{
eprintln!("Error handling control request: {}", e);
}
});
}
},
_ => {
let _ = message_tx.send(message);
},
}
},
Err(_) => break,
}
}
});
ready_rx
.await
.map_err(|_| ClaudeError::Transport("Background task failed to start".to_string()))?;
Ok(())
}
async fn handle_control_request_with_stdin(
request: IncomingControlRequest,
stdin: Option<Arc<Mutex<Option<tokio::process::ChildStdin>>>>,
hook_callbacks: Arc<Mutex<HashMap<String, HookCallback>>>,
sdk_mcp_servers: Arc<Mutex<HashMap<String, McpSdkServerConfig>>>,
) -> Result<()> {
let request_id = request.request_id;
let request_data = request.request;
let subtype = request_data
.get("subtype")
.and_then(|v| v.as_str())
.ok_or_else(|| ClaudeError::ControlProtocol("Missing subtype".to_string()))?;
let response_data: serde_json::Value = match subtype {
"hook_callback" => {
let callback_id = request_data
.get("callback_id")
.and_then(|v| v.as_str())
.ok_or_else(|| {
ClaudeError::ControlProtocol("Missing callback_id".to_string())
})?;
let callbacks = hook_callbacks.lock().await;
let callback = callbacks.get(callback_id).ok_or_else(|| {
ClaudeError::ControlProtocol(format!(
"Hook callback not found: {}",
callback_id
))
})?;
let input_json = request_data.get("input").cloned().unwrap_or(json!({}));
let hook_input: HookInput = serde_json::from_value(input_json).map_err(|e| {
ClaudeError::ControlProtocol(format!("Failed to parse hook input: {}", e))
})?;
let tool_use_id = request_data
.get("tool_use_id")
.and_then(|v| v.as_str())
.map(String::from);
let context = HookContext::default();
let hook_output = callback(hook_input, tool_use_id, context).await;
serde_json::to_value(&hook_output).map_err(|e| {
ClaudeError::ControlProtocol(format!("Failed to serialize hook output: {}", e))
})?
},
"mcp_message" => {
let server_name = request_data
.get("server_name")
.and_then(|v| v.as_str())
.ok_or_else(|| {
ClaudeError::ControlProtocol(
"Missing server_name for mcp_message".to_string(),
)
})?;
let mcp_message = request_data.get("message").ok_or_else(|| {
ClaudeError::ControlProtocol("Missing message for mcp_message".to_string())
})?;
let mcp_response =
Self::handle_sdk_mcp_request(sdk_mcp_servers, server_name, mcp_message.clone())
.await?;
json!({"mcp_response": mcp_response})
},
_ => {
return Err(ClaudeError::ControlProtocol(format!(
"Unsupported control request subtype: {}",
subtype
)));
},
};
let response = json!({
"type": "control_response",
"response": {
"subtype": "success",
"request_id": request_id,
"response": response_data
}
});
let response_str = serde_json::to_string(&response)
.map_err(|e| ClaudeError::Transport(format!("Failed to serialize response: {}", e)))?;
if let Some(ref stdin_arc) = stdin {
let mut stdin_guard = stdin_arc.lock().await;
if let Some(ref mut stdin_stream) = *stdin_guard {
use tokio::io::AsyncWriteExt;
stdin_stream
.write_all(response_str.as_bytes())
.await
.map_err(|e| {
ClaudeError::Transport(format!("Failed to write control response: {}", e))
})?;
stdin_stream.write_all(b"\n").await.map_err(|e| {
ClaudeError::Transport(format!("Failed to write newline: {}", e))
})?;
stdin_stream
.flush()
.await
.map_err(|e| ClaudeError::Transport(format!("Failed to flush: {}", e)))?;
} else {
return Err(ClaudeError::Transport("stdin not available".to_string()));
}
} else {
return Err(ClaudeError::Transport("stdin not set".to_string()));
}
Ok(())
}
async fn send_control_request(&self, request: serde_json::Value) -> Result<serde_json::Value> {
let request_id = format!(
"req_{}_{}",
self.request_counter.fetch_add(1, Ordering::SeqCst),
uuid::Uuid::new_v4().simple()
);
let (tx, rx) = oneshot::channel();
self.pending_responses
.lock()
.await
.insert(request_id.clone(), tx);
let control_request = json!({
"type": "control_request",
"request_id": request_id,
"request": request
});
let request_str = serde_json::to_string(&control_request)
.map_err(|e| ClaudeError::Transport(format!("Failed to serialize request: {}", e)))?;
if let Some(ref stdin) = self.stdin {
let mut stdin_guard = stdin.lock().await;
if let Some(ref mut stdin_stream) = *stdin_guard {
stdin_stream
.write_all(request_str.as_bytes())
.await
.map_err(|e| {
ClaudeError::Transport(format!("Failed to write control request: {}", e))
})?;
stdin_stream.write_all(b"\n").await.map_err(|e| {
ClaudeError::Transport(format!("Failed to write newline: {}", e))
})?;
stdin_stream
.flush()
.await
.map_err(|e| ClaudeError::Transport(format!("Failed to flush: {}", e)))?;
} else {
return Err(ClaudeError::Transport("stdin not available".to_string()));
}
} else {
return Err(ClaudeError::Transport("stdin not set".to_string()));
}
let response = rx.await.map_err(|_| {
ClaudeError::ControlProtocol("Control request response channel closed".to_string())
})?;
Ok(response)
}
#[allow(dead_code)]
pub async fn receive_messages(&self) -> Vec<serde_json::Value> {
let mut messages = Vec::new();
let mut rx = self.message_rx.lock().await;
while let Some(message) = rx.recv().await {
messages.push(message);
}
messages
}
pub async fn interrupt(&self) -> Result<()> {
let request = json!({
"subtype": "interrupt"
});
self.send_control_request(request).await?;
Ok(())
}
pub async fn set_permission_mode(
&self,
mode: crate::types::config::PermissionMode,
) -> Result<()> {
let mode_str = match mode {
crate::types::config::PermissionMode::Default => "default",
crate::types::config::PermissionMode::AcceptEdits => "acceptEdits",
crate::types::config::PermissionMode::Plan => "plan",
crate::types::config::PermissionMode::BypassPermissions => "bypassPermissions",
};
let request = json!({
"subtype": "set_permission_mode",
"mode": mode_str
});
self.send_control_request(request).await?;
Ok(())
}
pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
let request = json!({
"subtype": "set_model",
"model": model
});
self.send_control_request(request).await?;
Ok(())
}
pub async fn rewind_files(&self, user_message_id: &str) -> Result<()> {
let request = json!({
"subtype": "rewind_files",
"user_message_id": user_message_id
});
self.send_control_request(request).await?;
Ok(())
}
pub async fn get_initialization_result(&self) -> Option<serde_json::Value> {
self.initialization_result.lock().await.clone()
}
async fn handle_sdk_mcp_request(
sdk_mcp_servers: Arc<Mutex<HashMap<String, McpSdkServerConfig>>>,
server_name: &str,
message: serde_json::Value,
) -> Result<serde_json::Value> {
let servers = sdk_mcp_servers.lock().await;
let server_config = servers.get(server_name).ok_or_else(|| {
ClaudeError::ControlProtocol(format!("SDK MCP server not found: {}", server_name))
})?;
server_config
.instance
.handle_message(message)
.await
.map_err(|e| ClaudeError::ControlProtocol(format!("MCP server error: {}", e)))
}
pub async fn get_buffer_metrics(&self) -> Option<crate::internal::transport::BufferMetricsSnapshot> {
let transport = self.transport.lock().await;
transport.get_buffer_metrics()
}
}