use crate::error::{ClaudeSDKError, Result};
use crate::transport::subprocess::SubprocessTransport;
use crate::types::{HookContext, HookInput, HookOutput, PermissionResult, ToolPermissionContext};
use futures::Stream;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::process::ChildStdin;
use tokio::sync::{Mutex, mpsc, oneshot};
pub type PendingResponses = Arc<Mutex<HashMap<String, oneshot::Sender<Result<Value>>>>>;
pub type HookCallbackFn = Box<
dyn Fn(
HookInput,
Option<String>,
HookContext,
) -> Pin<Box<dyn Future<Output = Result<HookOutput>> + Send>>
+ Send
+ Sync,
>;
pub type CanUseToolFn = Box<
dyn Fn(
String,
Value,
ToolPermissionContext,
) -> Pin<Box<dyn Future<Output = Result<PermissionResult>> + Send>>
+ Send
+ Sync,
>;
pub struct Query {
transport: SubprocessTransport,
pending_responses: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Value>>>>>,
request_counter: Arc<Mutex<u64>>,
hook_callbacks: Arc<HashMap<String, Arc<HookCallbackFn>>>,
can_use_tool_callback: Arc<Option<Arc<CanUseToolFn>>>,
sdk_messages_tx: mpsc::Sender<Value>,
sdk_messages_rx: Arc<Mutex<mpsc::Receiver<Value>>>,
server_info: Arc<Mutex<Option<Value>>>,
}
impl Query {
pub fn new(transport: SubprocessTransport) -> Self {
let (tx, rx) = mpsc::channel(100);
Self {
transport,
pending_responses: Arc::new(Mutex::new(HashMap::new())),
request_counter: Arc::new(Mutex::new(0)),
hook_callbacks: Arc::new(HashMap::new()),
can_use_tool_callback: Arc::new(None),
sdk_messages_tx: tx,
sdk_messages_rx: Arc::new(Mutex::new(rx)),
server_info: Arc::new(Mutex::new(None)),
}
}
pub fn new_with_callbacks(
transport: SubprocessTransport,
hook_callbacks: HashMap<String, Arc<HookCallbackFn>>,
can_use_tool_callback: Option<Arc<CanUseToolFn>>,
) -> Self {
let (tx, rx) = mpsc::channel(100);
Self {
transport,
pending_responses: Arc::new(Mutex::new(HashMap::new())),
request_counter: Arc::new(Mutex::new(0)),
hook_callbacks: Arc::new(hook_callbacks),
can_use_tool_callback: Arc::new(can_use_tool_callback),
sdk_messages_tx: tx,
sdk_messages_rx: Arc::new(Mutex::new(rx)),
server_info: Arc::new(Mutex::new(None)),
}
}
pub fn start(mut self) -> QueryHandle {
let stdin = Arc::new(Mutex::new(self.transport.take_stdin()));
let pending_responses = self.pending_responses;
let hook_callbacks = self.hook_callbacks;
let can_use_tool_callback = self.can_use_tool_callback;
let sdk_messages_tx = self.sdk_messages_tx.clone();
let sdk_messages_rx = self.sdk_messages_rx;
let server_info = self.server_info;
let request_counter = self.request_counter;
let current_session_id = Arc::new(std::sync::OnceLock::new());
let pending_responses_task = pending_responses.clone();
let hook_callbacks_task = hook_callbacks.clone();
let can_use_tool_callback_task = can_use_tool_callback.clone();
let stdin_task = stdin.clone();
let current_session_id_task = current_session_id.clone();
tokio::spawn(async move {
let mut stream = Box::pin(self.transport.read_messages());
use futures::StreamExt;
while let Some(result) = stream.next().await {
match result {
Ok(msg) => {
match msg.get("type").and_then(|v| v.as_str()) {
Some("control_response") => {
if let Err(e) = Self::handle_control_response_static(
msg,
&pending_responses_task,
)
.await
{
eprintln!("Error handling control response: {}", e);
}
}
Some("control_request") => {
if let Err(e) = Self::handle_control_request_static(
msg,
&hook_callbacks_task,
&can_use_tool_callback_task,
&stdin_task,
)
.await
{
eprintln!("Error handling control request: {}", e);
}
}
_ => {
if let Some(session_id) =
msg.get("session_id").and_then(|v| v.as_str())
{
let _ = current_session_id_task.set(session_id.to_string());
}
if sdk_messages_tx.send(msg).await.is_err() {
break;
}
}
}
}
Err(e) => {
eprintln!("Error reading message: {}", e);
break;
}
}
}
drop(sdk_messages_tx);
});
QueryHandle {
pending_responses,
request_counter,
sdk_messages_rx,
server_info,
stdin,
current_session_id,
}
}
async fn handle_control_response_static(
msg: Value,
pending_responses: &PendingResponses,
) -> Result<()> {
let request_id = msg["response"]["request_id"].as_str().ok_or_else(|| {
ClaudeSDKError::message_parse("Missing request_id in control_response")
})?;
let mut pending = pending_responses.lock().await;
if let Some(tx) = pending.remove(request_id) {
let subtype = msg["response"]["subtype"].as_str();
if subtype == Some("error") {
let error_msg = msg["response"]["error"].as_str().unwrap_or("Unknown error");
let _ = tx.send(Err(ClaudeSDKError::other(error_msg)));
} else {
let response = msg["response"]["response"].clone();
let _ = tx.send(Ok(response));
}
}
Ok(())
}
async fn handle_control_request_static(
msg: Value,
hook_callbacks: &Arc<HashMap<String, Arc<HookCallbackFn>>>,
can_use_tool_callback: &Arc<Option<Arc<CanUseToolFn>>>,
stdin: &Arc<Mutex<Option<ChildStdin>>>,
) -> Result<()> {
let request_id = msg["request_id"]
.as_str()
.ok_or_else(|| ClaudeSDKError::message_parse("Missing request_id in control_request"))?
.to_string();
let request = &msg["request"];
let subtype = request["subtype"]
.as_str()
.ok_or_else(|| ClaudeSDKError::message_parse("Missing subtype in control_request"))?;
let response_result: Result<Value> = match subtype {
"hook_callback" => Self::handle_hook_callback(request.clone(), hook_callbacks).await,
"can_use_tool" => {
Self::handle_can_use_tool(request.clone(), can_use_tool_callback).await
}
"mcp_message" => {
Err(ClaudeSDKError::other("MCP bridging not yet implemented"))
}
_ => Err(ClaudeSDKError::message_parse(format!(
"Unknown control request subtype: {}",
subtype
))),
};
let response_msg = match response_result {
Ok(response) => {
json!({
"type": "control_response",
"response": {
"subtype": "success",
"request_id": request_id,
"response": response
}
})
}
Err(e) => {
json!({
"type": "control_response",
"response": {
"subtype": "error",
"request_id": request_id,
"error": e.to_string()
}
})
}
};
let response_str = serde_json::to_string(&response_msg)?;
Self::write_to_stdin(stdin, &response_str).await?;
Ok(())
}
async fn write_to_stdin(stdin: &Arc<Mutex<Option<ChildStdin>>>, data: &str) -> Result<()> {
let mut stdin_guard = stdin.lock().await;
let stdin_ref = stdin_guard
.as_mut()
.ok_or(ClaudeSDKError::TransportNotReady)?;
stdin_ref.write_all(data.as_bytes()).await?;
stdin_ref.write_all(b"\n").await?;
stdin_ref.flush().await?;
Ok(())
}
async fn handle_hook_callback(
request: Value,
hook_callbacks: &Arc<HashMap<String, Arc<HookCallbackFn>>>,
) -> Result<Value> {
let callback_id = request["callback_id"]
.as_str()
.ok_or_else(|| ClaudeSDKError::message_parse("Missing callback_id"))?;
let callback = hook_callbacks
.get(callback_id)
.ok_or_else(|| ClaudeSDKError::HookNotFound(callback_id.to_string()))?;
let input: HookInput = serde_json::from_value(request["input"].clone())?;
let tool_use_id = request["tool_use_id"].as_str().map(|s| s.to_string());
let context = HookContext { signal: None };
let output = callback.as_ref()(input, tool_use_id, context).await?;
let json_output = Self::convert_hook_output_for_cli(output)?;
Ok(json_output)
}
async fn handle_can_use_tool(
request: Value,
can_use_tool_callback: &Arc<Option<Arc<CanUseToolFn>>>,
) -> Result<Value> {
let callback = can_use_tool_callback
.as_ref()
.as_ref()
.ok_or(ClaudeSDKError::PermissionCallbackNotSet)?;
let tool_name = request["tool_name"]
.as_str()
.ok_or_else(|| ClaudeSDKError::message_parse("Missing tool_name"))?
.to_string();
let input = request["input"].clone();
let suggestions = request["permission_suggestions"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| serde_json::from_value(v.clone()).ok())
.collect()
})
.unwrap_or_default();
let context = ToolPermissionContext {
signal: None,
suggestions,
};
let result = callback.as_ref()(tool_name.clone(), input.clone(), context).await?;
let json_result = match result {
PermissionResult::Allow(allow) => {
let mut obj = json!({
"behavior": "allow",
"updatedInput": allow.updated_input.unwrap_or(input)
});
if let Some(updated_permissions) = allow.updated_permissions {
obj["updatedPermissions"] = serde_json::to_value(updated_permissions)?;
}
obj
}
PermissionResult::Deny(deny) => {
json!({
"behavior": "deny",
"message": deny.message,
"interrupt": deny.interrupt
})
}
};
Ok(json_result)
}
fn convert_hook_output_for_cli(output: HookOutput) -> Result<Value> {
let mut json = serde_json::to_value(&output)?;
match output {
HookOutput::Sync(_) => {
if let Some(obj) = json.as_object_mut()
&& let Some(continue_val) = obj.remove("continue_")
{
obj.insert("continue".to_string(), continue_val);
}
}
HookOutput::Async(_) => {
if let Some(obj) = json.as_object_mut()
&& let Some(async_val) = obj.remove("async_")
{
obj.insert("async".to_string(), async_val);
}
}
}
Ok(json)
}
}
pub struct QueryHandle {
pending_responses: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Value>>>>>,
request_counter: Arc<Mutex<u64>>,
sdk_messages_rx: Arc<Mutex<mpsc::Receiver<Value>>>,
server_info: Arc<Mutex<Option<Value>>>,
stdin: Arc<Mutex<Option<ChildStdin>>>,
current_session_id: Arc<std::sync::OnceLock<String>>,
}
impl QueryHandle {
pub fn read_messages(&self) -> impl Stream<Item = Result<Value>> + '_ {
let rx = self.sdk_messages_rx.clone();
async_stream::stream! {
let mut rx_guard = rx.lock().await;
while let Some(msg) = rx_guard.recv().await {
yield Ok(msg);
}
}
}
pub async fn get_server_info(&self) -> Option<Value> {
let server_info = self.server_info.lock().await;
server_info.clone()
}
pub fn get_session_id(&self) -> Option<String> {
self.current_session_id.get().cloned()
}
async fn send_control_request(&self, request: Value) -> Result<Value> {
let mut counter = self.request_counter.lock().await;
*counter += 1;
let request_id = format!("req_{}_{:08x}", *counter, rand::random::<u32>());
drop(counter);
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending_responses.lock().await;
pending.insert(request_id.clone(), tx);
}
let control_msg = json!({
"type": "control_request",
"request_id": request_id,
"request": request
});
let control_str = serde_json::to_string(&control_msg)?;
Query::write_to_stdin(&self.stdin, &control_str).await?;
tokio::time::timeout(std::time::Duration::from_secs(60), rx)
.await
.map_err(|_| {
ClaudeSDKError::control_timeout(
60,
request["subtype"].as_str().unwrap_or("unknown").to_string(),
)
})?
.map_err(|_| ClaudeSDKError::other("Response channel closed"))?
}
pub async fn initialize(&self, hooks: Option<Value>) -> Result<Value> {
let mut request = json!({
"subtype": "initialize"
});
if let Some(hooks_val) = hooks {
request["hooks"] = hooks_val;
}
let response = self.send_control_request(request).await?;
{
let mut server_info = self.server_info.lock().await;
*server_info = Some(response.clone());
}
Ok(response)
}
pub async fn interrupt(&self) -> Result<()> {
let request = json!({
"subtype": "interrupt"
});
self.send_control_request(request).await?;
Ok(())
}
pub async fn set_permission_mode(&self, mode: &str) -> Result<()> {
let request = json!({
"subtype": "set_permission_mode",
"mode": mode
});
self.send_control_request(request).await?;
Ok(())
}
pub async fn set_model(&self, model: Option<String>) -> Result<()> {
let request = json!({
"subtype": "set_model",
"model": model
});
self.send_control_request(request).await?;
Ok(())
}
pub async fn rewind_files(&self, user_message_id: &str) -> Result<()> {
let request = json!({
"subtype": "rewind_files",
"user_message_id": user_message_id
});
self.send_control_request(request).await?;
Ok(())
}
pub async fn get_mcp_status(&self) -> Result<Value> {
let request = json!({
"subtype": "mcp_status"
});
self.send_control_request(request).await
}
pub async fn send_user_message(&self, prompt: &str) -> Result<()> {
let message = json!({
"type": "user",
"message": {
"role": "user",
"content": prompt
}
});
let message_str = serde_json::to_string(&message)?;
Query::write_to_stdin(&self.stdin, &message_str).await?;
Ok(())
}
pub async fn close_stdin(&self) {
let mut stdin_guard = self.stdin.lock().await;
*stdin_guard = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{AsyncHookOutput, SyncHookOutput};
#[test]
fn test_hook_output_field_conversion() {
let sync_output = HookOutput::Sync(Box::new(SyncHookOutput {
continue_: Some(true),
suppress_output: None,
stop_reason: None,
decision: None,
system_message: None,
reason: None,
hook_specific_output: None,
}));
let json = Query::convert_hook_output_for_cli(sync_output).unwrap();
assert!(json.get("continue").is_some());
assert!(json.get("continue_").is_none());
let async_output = HookOutput::Async(AsyncHookOutput {
async_: true,
async_timeout: Some(5000),
});
let json = Query::convert_hook_output_for_cli(async_output).unwrap();
assert!(json.get("async").is_some());
assert!(json.get("async_").is_none());
}
}