1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde_json::Value;
6use tokio::sync::{mpsc, oneshot, Mutex};
7use tokio_util::sync::CancellationToken;
8
9use crate::error::{Error, Result};
10use crate::message_parser::parse_message;
11use crate::types::control::{SDKCapabilities, SDKControlCommand, SDKInitMessage};
12use crate::types::hooks::{
13 HookDecision, HookDefinition, HookEvent, HookInput, NotificationInput, PostToolUseInput,
14 PreToolUseInput, StopInput,
15};
16use crate::types::messages::Message;
17use crate::types::permissions::{CanUseToolCallback, CanUseToolInput};
18use crate::transport::{Transport, TransportWriter};
19
20const DEFAULT_CONTROL_TIMEOUT: Duration = Duration::from_secs(30);
21
22pub type McpMessageHandler = Arc<
24 dyn Fn(String, Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = Value> + Send>>
25 + Send
26 + Sync,
27>;
28
29pub struct Query {
34 transport: Box<dyn Transport>,
35 writer: Option<TransportWriter>,
36 hooks: Vec<HookDefinition>,
37 can_use_tool: Option<CanUseToolCallback>,
38 mcp_handler: Option<McpMessageHandler>,
39 pending_responses: Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
40 cancel: CancellationToken,
41 control_timeout: Duration,
42 server_info: Arc<Mutex<Option<Value>>>,
43}
44
45impl Query {
46 pub fn new(
47 transport: Box<dyn Transport>,
48 hooks: Vec<HookDefinition>,
49 can_use_tool: Option<CanUseToolCallback>,
50 mcp_handler: Option<McpMessageHandler>,
51 control_timeout: Option<Duration>,
52 ) -> Self {
53 Self {
54 transport,
55 writer: None,
56 hooks,
57 can_use_tool,
58 mcp_handler,
59 pending_responses: Arc::new(Mutex::new(HashMap::new())),
60 cancel: CancellationToken::new(),
61 control_timeout: control_timeout.unwrap_or(DEFAULT_CONTROL_TIMEOUT),
62 server_info: Arc::new(Mutex::new(None)),
63 }
64 }
65
66 pub async fn connect(&mut self) -> Result<mpsc::Receiver<Result<Message>>> {
68 let (raw_rx, writer) = self.transport.connect().await?;
69 self.writer = Some(writer.clone());
70
71 let (consumer_tx, consumer_rx) = mpsc::channel::<Result<Message>>(256);
72
73 self.spawn_router(raw_rx, consumer_tx, writer.clone());
75
76 self.initialize().await?;
78
79 Ok(consumer_rx)
80 }
81
82 pub async fn send_message(&self, prompt: &str, session_id: Option<&str>) -> Result<()> {
84 let writer = self.writer.as_ref().ok_or(Error::NotConnected)?;
85 let msg = serde_json::json!({
86 "type": "user",
87 "message": {
88 "role": "user",
89 "content": prompt
90 },
91 "session_id": session_id.unwrap_or(""),
92 "parent_tool_use_id": null
93 });
94 writer.write(msg).await
95 }
96
97 pub async fn send_control_command(&self, command: SDKControlCommand) -> Result<Value> {
99 let writer = self.writer.as_ref().ok_or(Error::NotConnected)?;
100 let request_id = generate_request_id();
101
102 let mut request = serde_json::json!({
103 "type": "control_request",
104 "request_id": request_id,
105 "request": {
106 "subtype": command.command_type,
107 }
108 });
109
110 if let Value::Object(params) = command.params {
111 if let Value::Object(ref mut req) = request["request"] {
112 for (k, v) in params {
113 req.insert(k, v);
114 }
115 }
116 }
117
118 let (tx, rx) = oneshot::channel();
119 {
120 let mut pending = self.pending_responses.lock().await;
121 pending.insert(request_id.clone(), tx);
122 }
123
124 writer.write(request).await?;
125
126 let response = tokio::time::timeout(self.control_timeout, rx)
127 .await
128 .map_err(|_| Error::ControlTimeout(self.control_timeout))?
129 .map_err(|_| Error::ControlProtocol("response channel dropped".into()))?;
130
131 Ok(response)
132 }
133
134 pub async fn interrupt(&self) -> Result<Value> {
135 self.send_control_command(SDKControlCommand::interrupt())
136 .await
137 }
138
139 pub async fn set_permission_mode(&self, mode: &str) -> Result<Value> {
140 self.send_control_command(SDKControlCommand::set_permission_mode(mode))
141 .await
142 }
143
144 pub async fn set_model(&self, model: &str) -> Result<Value> {
145 self.send_control_command(SDKControlCommand::set_model(model))
146 .await
147 }
148
149 pub async fn rewind_files(&self, user_message_id: &str) -> Result<Value> {
150 self.send_control_command(SDKControlCommand::rewind_files(user_message_id))
151 .await
152 }
153
154 pub async fn get_mcp_status(&self) -> Result<Value> {
155 self.send_control_command(SDKControlCommand::get_mcp_status())
156 .await
157 }
158
159 pub async fn get_server_info(&self) -> Option<Value> {
160 self.server_info.lock().await.clone()
161 }
162
163 pub async fn end_input(&self) -> Result<()> {
164 self.transport.end_input().await
165 }
166
167 pub async fn close(&mut self) -> Result<()> {
168 self.cancel.cancel();
169 self.writer = None;
170 self.transport.close().await
171 }
172
173 async fn initialize(&self) -> Result<()> {
174 let writer = self.writer.as_ref().ok_or(Error::NotConnected)?;
175
176 let capabilities = SDKCapabilities {
177 hooks: !self.hooks.is_empty(),
178 permissions: self.can_use_tool.is_some(),
179 mcp: self.mcp_handler.is_some(),
180 agent_definitions: vec![],
181 mcp_servers: vec![],
182 };
183
184 let init_msg = SDKInitMessage::new(capabilities);
185 let init_value = serde_json::to_value(&init_msg)?;
186
187 let request_id = generate_request_id();
188 let request = serde_json::json!({
189 "type": "control_request",
190 "request_id": request_id,
191 "request": {
192 "subtype": "initialize",
193 "protocol_version": "1",
194 "capabilities": init_value.get("capabilities"),
195 }
196 });
197
198 let (tx, rx) = oneshot::channel();
199 {
200 let mut pending = self.pending_responses.lock().await;
201 pending.insert(request_id.clone(), tx);
202 }
203
204 writer.write(request).await?;
205
206 let response = tokio::time::timeout(self.control_timeout, rx)
207 .await
208 .map_err(|_| Error::ControlTimeout(self.control_timeout))?
209 .map_err(|_| Error::ControlProtocol("init response channel dropped".into()))?;
210
211 {
212 let mut info = self.server_info.lock().await;
213 *info = Some(response);
214 }
215
216 Ok(())
217 }
218
219 fn spawn_router(
220 &self,
221 mut raw_rx: mpsc::Receiver<Result<Value>>,
222 consumer_tx: mpsc::Sender<Result<Message>>,
223 writer: TransportWriter,
224 ) {
225 let pending = self.pending_responses.clone();
226 let hooks = self.hooks.clone();
227 let can_use_tool = self.can_use_tool.clone();
228 let mcp_handler = self.mcp_handler.clone();
229 let cancel = self.cancel.clone();
230
231 tokio::spawn(async move {
232 loop {
233 tokio::select! {
234 _ = cancel.cancelled() => break,
235 msg = raw_rx.recv() => {
236 match msg {
237 Some(Ok(value)) => {
238 let msg_type = value.get("type")
239 .and_then(|v| v.as_str())
240 .unwrap_or("");
241
242 match msg_type {
243 "control_response" => {
244 route_control_response(&pending, &value).await;
245 }
246 "control_request" => {
247 dispatch_control_request(
248 &value,
249 &hooks,
250 &can_use_tool,
251 &mcp_handler,
252 &writer,
253 ).await;
254 }
255 _ => {
256 let parsed = parse_message(value);
257 if consumer_tx.send(parsed).await.is_err() {
258 break;
259 }
260 }
261 }
262 }
263 Some(Err(e)) => {
264 let _ = consumer_tx.send(Err(e)).await;
265 break;
266 }
267 None => break,
268 }
269 }
270 }
271 }
272 });
273 }
274}
275
276async fn route_control_response(
277 pending: &Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
278 value: &Value,
279) {
280 let response = value.get("response").cloned().unwrap_or(value.clone());
281 let request_id = response
282 .get("request_id")
283 .and_then(|v| v.as_str())
284 .unwrap_or("");
285
286 let mut pending = pending.lock().await;
287 if let Some(tx) = pending.remove(request_id) {
288 let _ = tx.send(response);
289 } else {
290 tracing::warn!(request_id, "control response for unknown request");
291 }
292}
293
294async fn dispatch_control_request(
295 value: &Value,
296 hooks: &[HookDefinition],
297 can_use_tool: &Option<CanUseToolCallback>,
298 mcp_handler: &Option<McpMessageHandler>,
299 writer: &TransportWriter,
300) {
301 let request_id = value
302 .get("request_id")
303 .and_then(|v| v.as_str())
304 .unwrap_or("")
305 .to_string();
306
307 let request = match value.get("request") {
308 Some(r) => r,
309 None => {
310 tracing::warn!("control request missing 'request' field");
311 return;
312 }
313 };
314
315 let subtype = request
316 .get("subtype")
317 .and_then(|v| v.as_str())
318 .unwrap_or("");
319
320 let response_body = match subtype {
321 "can_use_tool" => handle_can_use_tool(request, can_use_tool).await,
322 "hook_callback" => handle_hook_callback(request, hooks).await,
323 "mcp_message" => handle_mcp_message(request, mcp_handler).await,
324 other => {
325 tracing::warn!(subtype = other, "unknown control request subtype");
326 serde_json::json!({"error": format!("unknown subtype: {other}")})
327 }
328 };
329
330 let control_response = serde_json::json!({
331 "type": "control_response",
332 "response": {
333 "subtype": "success",
334 "request_id": request_id,
335 "response": response_body,
336 }
337 });
338
339 if let Err(e) = writer.write(control_response).await {
340 tracing::error!("failed to send control response: {e}");
341 }
342}
343
344async fn handle_can_use_tool(request: &Value, callback: &Option<CanUseToolCallback>) -> Value {
345 let tool_name = request
346 .get("tool_name")
347 .and_then(|v| v.as_str())
348 .unwrap_or("")
349 .to_string();
350 let input = request.get("input").cloned().unwrap_or(Value::Null);
351
352 if let Some(cb) = callback {
353 let result = cb(CanUseToolInput { tool_name, input }).await;
354 if result.allowed {
355 serde_json::json!({"behavior": "allow"})
356 } else {
357 serde_json::json!({
358 "behavior": "deny",
359 "message": result.reason.unwrap_or_default()
360 })
361 }
362 } else {
363 serde_json::json!({"behavior": "allow"})
364 }
365}
366
367async fn handle_hook_callback(request: &Value, hooks: &[HookDefinition]) -> Value {
368 let callback_id = request
369 .get("callback_id")
370 .and_then(|v| v.as_str())
371 .unwrap_or("");
372 let hook_input = request.get("input").cloned().unwrap_or(Value::Null);
373
374 let hook_index: Option<usize> = callback_id
375 .strip_prefix("hook_")
376 .and_then(|s| s.parse().ok());
377
378 let hook = hook_index.and_then(|i| hooks.get(i));
379
380 if let Some(hook) = hook {
381 let typed_input = match hook.event {
382 HookEvent::PreToolUse => {
383 let pre: PreToolUseInput =
384 serde_json::from_value(hook_input).unwrap_or(PreToolUseInput {
385 tool_name: String::new(),
386 tool_input: Value::Null,
387 });
388 HookInput::PreToolUse(pre)
389 }
390 HookEvent::PostToolUse => {
391 let post: PostToolUseInput =
392 serde_json::from_value(hook_input).unwrap_or(PostToolUseInput {
393 tool_name: String::new(),
394 tool_input: Value::Null,
395 tool_output: Value::Null,
396 });
397 HookInput::PostToolUse(post)
398 }
399 HookEvent::Notification => {
400 let notif: NotificationInput =
401 serde_json::from_value(hook_input).unwrap_or(NotificationInput {
402 title: String::new(),
403 message: None,
404 });
405 HookInput::Notification(notif)
406 }
407 HookEvent::Stop | HookEvent::SubagentStop => {
408 let stop: StopInput =
409 serde_json::from_value(hook_input).unwrap_or(StopInput { reason: None });
410 HookInput::Stop(stop)
411 }
412 };
413
414 let output = (hook.callback)(typed_input).await;
415 let mut result = serde_json::json!({"continue": true});
416 if let Some(decision) = &output.decision {
417 let hook_specific = serde_json::json!({
418 "hookEventName": match hook.event {
419 HookEvent::PreToolUse => "PreToolUse",
420 HookEvent::PostToolUse => "PostToolUse",
421 HookEvent::Notification => "Notification",
422 HookEvent::Stop => "Stop",
423 HookEvent::SubagentStop => "SubagentStop",
424 },
425 "permissionDecision": match decision {
426 HookDecision::Approve => "approve",
427 HookDecision::Block => "deny",
428 HookDecision::Ignore => "ignore",
429 },
430 "permissionDecisionReason": output.reason.as_deref().unwrap_or(""),
431 });
432 result["hookSpecificOutput"] = hook_specific;
433
434 if *decision == HookDecision::Block {
435 result["continue"] = Value::Bool(false);
436 }
437 }
438 result
439 } else {
440 tracing::warn!(callback_id, "hook callback not found");
441 serde_json::json!({"continue": true})
442 }
443}
444
445async fn handle_mcp_message(request: &Value, handler: &Option<McpMessageHandler>) -> Value {
446 let server_name = request
447 .get("server_name")
448 .and_then(|v| v.as_str())
449 .unwrap_or("")
450 .to_string();
451 let message = request.get("message").cloned().unwrap_or(Value::Null);
452
453 if let Some(handler) = handler {
454 handler(server_name, message).await
455 } else {
456 serde_json::json!({"error": "no MCP handler registered"})
457 }
458}
459
460fn generate_request_id() -> String {
461 use rand::Rng;
462 let mut rng = rand::rng();
463 let suffix: u64 = rng.random();
464 format!("req_{suffix:016x}")
465}
466
467impl Drop for Query {
468 fn drop(&mut self) {
469 self.cancel.cancel();
470 }
471}