1use crate::{
7 errors::{Result, SdkError},
8 transport::{InputMessage, Transport},
9 types::{
10 CanUseTool, HookCallback, HookContext, HookMatcher, Message, PermissionResult, PermissionUpdate,
11 SDKControlInitializeRequest, SDKControlInterruptRequest, SDKControlPermissionRequest,
12 SDKControlRequest, SDKHookCallbackRequest, SDKControlSetPermissionModeRequest,
13 ToolPermissionContext,
14 },
15};
16use futures::stream::Stream;
17use futures::StreamExt;
18use serde_json::Value as JsonValue;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use tokio::time::{timeout, Duration};
23use tracing::{debug, error, warn};
24
25pub struct Query {
27 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
29 #[allow(dead_code)]
31 is_streaming_mode: bool,
32 can_use_tool: Option<Arc<dyn CanUseTool>>,
34 hooks: Option<HashMap<String, Vec<HookMatcher>>>,
36 sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
38 #[allow(dead_code)]
40 message_tx: mpsc::Sender<Result<Message>>,
41 #[allow(dead_code)]
43 message_rx: Option<mpsc::Receiver<Result<Message>>>,
44 initialization_result: Option<JsonValue>,
46 hook_callbacks: Arc<RwLock<HashMap<String, Arc<dyn HookCallback>>>>,
48 callback_counter: Arc<Mutex<u64>>,
50 request_counter: Arc<Mutex<u64>>,
52 pending_responses: Arc<RwLock<HashMap<String, tokio::sync::oneshot::Sender<JsonValue>>>>,
54}
55
56impl Query {
57 pub fn new(
59 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
60 is_streaming_mode: bool,
61 can_use_tool: Option<Arc<dyn CanUseTool>>,
62 hooks: Option<HashMap<String, Vec<HookMatcher>>>,
63 sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
64 ) -> Self {
65 let (tx, rx) = mpsc::channel(100);
66
67 Self {
68 transport,
69 is_streaming_mode,
70 can_use_tool,
71 hooks,
72 sdk_mcp_servers,
73 message_tx: tx,
74 message_rx: Some(rx),
75 initialization_result: None,
76 hook_callbacks: Arc::new(RwLock::new(HashMap::new())),
77 callback_counter: Arc::new(Mutex::new(0)),
78 request_counter: Arc::new(Mutex::new(0)),
79 pending_responses: Arc::new(RwLock::new(HashMap::new())),
80 }
81 }
82
83 pub async fn register_hook_callback_for_test(
88 &self,
89 callback_id: String,
90 callback: Arc<dyn HookCallback>,
91 ) {
92 let mut map = self.hook_callbacks.write().await;
93 map.insert(callback_id, callback);
94 }
95
96 pub async fn start(&mut self) -> Result<()> {
98 self.start_control_handler().await;
100
101 let transport = self.transport.clone();
103 let tx = self.message_tx.clone();
104 tokio::spawn(async move {
105 loop {
106 let next = {
107 let mut guard = transport.lock().await;
108 let mut stream = guard.receive_messages();
109 stream.next().await
110 };
111
112 match next {
113 Some(Ok(msg)) => {
114 if tx.send(Ok(msg)).await.is_err() { break; }
115 }
116 Some(Err(e)) => {
117 let _ = tx.send(Err(e)).await;
118 break;
119 }
120 None => {
121 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
123 }
124 }
125 }
126 });
127 Ok(())
128 }
129
130 pub async fn initialize(&mut self) -> Result<()> {
132 let hooks_with_ids = if let Some(ref hooks) = self.hooks {
134 let mut counter = self.callback_counter.lock().await;
135 let mut callbacks_map = self.hook_callbacks.write().await;
136
137 let hooks_json: HashMap<String, JsonValue> = hooks
138 .iter()
139 .map(|(event_name, matchers)| {
140 let matchers_with_ids: Vec<JsonValue> = matchers
141 .iter()
142 .map(|matcher| {
143 let callback_ids: Vec<String> = matcher
145 .hooks
146 .iter()
147 .map(|hook_callback| {
148 *counter += 1;
149 let callback_id = format!("hook_{}_{}", *counter, uuid::Uuid::new_v4().simple());
150
151 callbacks_map.insert(callback_id.clone(), hook_callback.clone());
153
154 callback_id
155 })
156 .collect();
157
158 serde_json::json!({
159 "matcher": matcher.matcher.clone(),
160 "hookCallbackIds": callback_ids
161 })
162 })
163 .collect();
164
165 (event_name.clone(), serde_json::json!(matchers_with_ids))
166 })
167 .collect();
168
169 Some(hooks_json)
170 } else {
171 None
172 };
173
174 let init_request = SDKControlRequest::Initialize(SDKControlInitializeRequest {
176 subtype: "initialize".to_string(),
177 hooks: hooks_with_ids,
178 });
179
180 let result = self.send_control_request(init_request).await?;
182 self.initialization_result = Some(result);
183
184 debug!("Initialization request sent with hook callback IDs");
185 Ok(())
186 }
187
188 async fn send_control_request(&mut self, request: SDKControlRequest) -> Result<JsonValue> {
190 let request_id = {
192 let mut counter = self.request_counter.lock().await;
193 *counter += 1;
194 format!("req_{}_{}", *counter, uuid::Uuid::new_v4().simple())
195 };
196
197 let (tx, rx) = tokio::sync::oneshot::channel();
199
200 {
202 let mut pending = self.pending_responses.write().await;
203 pending.insert(request_id.clone(), tx);
204 }
205
206 let control_request = serde_json::json!({
208 "type": "control_request",
209 "request_id": request_id,
210 "request": request
211 });
212
213 debug!("Sending control request: {:?}", control_request);
214
215 {
217 let mut transport = self.transport.lock().await;
218 transport.send_sdk_control_request(control_request).await?;
219 }
220
221 match timeout(Duration::from_secs(60), rx).await {
223 Ok(Ok(response)) => {
224 debug!("Received control response for {}", request_id);
225 Ok(response)
226 }
227 Ok(Err(_)) => Err(SdkError::ControlRequestError(
228 "Response channel closed".to_string(),
229 )),
230 Err(_) => {
231 let mut pending = self.pending_responses.write().await;
233 pending.remove(&request_id);
234 Err(SdkError::Timeout { seconds: 60 })
235 }
236 }
237 }
238
239 #[allow(dead_code)]
241 async fn handle_permission_request(&mut self, request: SDKControlPermissionRequest) -> Result<()> {
242 if let Some(ref can_use_tool) = self.can_use_tool {
243 let context = ToolPermissionContext {
244 signal: None,
245 suggestions: request.permission_suggestions.unwrap_or_default(),
246 };
247
248 let result = can_use_tool
249 .can_use_tool(&request.tool_name, &request.input, &context)
250 .await;
251
252 let response = match result {
254 PermissionResult::Allow(allow) => {
255 let mut obj = serde_json::json!({ "allow": true });
256 if let Some(updated) = allow.updated_input {
257 obj["input"] = updated;
258 }
259 obj
260 }
261 PermissionResult::Deny(deny) => {
262 let mut obj = serde_json::json!({ "allow": false });
263 if !deny.message.is_empty() {
264 obj["reason"] = serde_json::json!(deny.message);
265 }
266 obj
267 }
268 };
269
270 let mut transport = self.transport.lock().await;
272 transport.send_sdk_control_response(response).await?;
273 debug!("Permission response sent");
274 }
275
276 Ok(())
277 }
278
279 fn extract_request_id(msg: &JsonValue) -> Option<JsonValue> {
281 msg.get("requestId")
282 .or_else(|| msg.get("request_id"))
283 .cloned()
284 }
285
286 async fn start_control_handler(&mut self) {
288 let transport = self.transport.clone();
289 let can_use_tool = self.can_use_tool.clone();
290 let hook_callbacks = self.hook_callbacks.clone();
291 let sdk_mcp_servers = self.sdk_mcp_servers.clone();
292 let pending_responses = self.pending_responses.clone();
293
294 let sdk_control_rx = {
296 let mut transport_lock = transport.lock().await;
297 transport_lock.take_sdk_control_receiver()
298 }; if let Some(mut control_rx) = sdk_control_rx {
301 tokio::spawn(async move {
302 let transport_for_control = transport;
304 let can_use_tool_clone = can_use_tool;
305 let hook_callbacks_clone = hook_callbacks;
306 let sdk_mcp_servers_clone = sdk_mcp_servers;
307 let pending_responses_clone = pending_responses;
308
309 loop {
310 let control_message = control_rx.recv().await;
312
313 if let Some(control_message) = control_message {
314 debug!("Received control message: {:?}", control_message);
315
316 if control_message.get("type").and_then(|v| v.as_str()) == Some("control_response") {
318 if let Some(resp_obj) = control_message.get("response") {
320 let request_id = resp_obj
321 .get("request_id")
322 .or_else(|| resp_obj.get("requestId"))
323 .and_then(|v| v.as_str());
324
325 if let Some(request_id) = request_id {
326 let mut pending = pending_responses_clone.write().await;
327 if let Some(tx) = pending.remove(request_id) {
328 let _ = tx.send(resp_obj.clone());
330 debug!("Control response delivered for request_id: {}", request_id);
331 } else {
332 warn!("No pending request found for request_id: {}", request_id);
333 }
334 } else {
335 warn!("Control response missing request_id: {:?}", control_message);
336 }
337 } else {
338 warn!("Control response missing 'response' payload: {:?}", control_message);
339 }
340 continue;
341 }
342
343 let request_data = if control_message.get("type").and_then(|v| v.as_str()) == Some("control_request") {
346 control_message.get("request").cloned().unwrap_or(control_message.clone())
347 } else {
348 control_message.clone()
349 };
350
351 if let Some(subtype) = request_data.get("subtype").and_then(|v| v.as_str()) {
352 match subtype {
353 "can_use_tool" => {
354 if let Ok(request) = serde_json::from_value::<SDKControlPermissionRequest>(request_data.clone()) {
356 if let Some(ref can_use_tool) = can_use_tool_clone {
358 let context = ToolPermissionContext {
359 signal: None,
360 suggestions: request.permission_suggestions.unwrap_or_default(),
361 };
362
363 let result = can_use_tool
364 .can_use_tool(&request.tool_name, &request.input, &context)
365 .await;
366
367 let permission_response = match result {
369 PermissionResult::Allow(allow) => {
370 let mut resp = serde_json::json!({
371 "allow": true,
372 });
373 if let Some(input) = allow.updated_input {
374 resp["input"] = input;
375 }
376 if let Some(perms) = allow.updated_permissions {
377 resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default();
378 }
379 resp
380 }
381 PermissionResult::Deny(deny) => {
382 let mut resp = serde_json::json!({
383 "allow": false,
384 });
385 if !deny.message.is_empty() {
386 resp["reason"] = serde_json::json!(deny.message);
387 }
388 if deny.interrupt {
389 resp["interrupt"] = serde_json::json!(true);
390 }
391 resp
392 }
393 };
394
395 let response = serde_json::json!({
398 "subtype": "success",
399 "request_id": Self::extract_request_id(&control_message),
400 "response": permission_response
401 });
402
403 let mut transport = transport_for_control.lock().await;
405 if let Err(e) = transport.send_sdk_control_response(response).await {
406 error!("Failed to send permission response: {}", e);
407 }
408 }
409 } else {
410 if let Some(tool_name) = request_data.get("tool_name").and_then(|v| v.as_str()) {
412 if let Some(input_val) = request_data.get("input").cloned() {
413 if let Some(ref can_use_tool) = can_use_tool_clone {
414 let suggestions: Vec<PermissionUpdate> = request_data
416 .get("permission_suggestions")
417 .cloned()
418 .and_then(|v| serde_json::from_value::<Vec<PermissionUpdate>>(v).ok())
419 .unwrap_or_default();
420
421 let context = ToolPermissionContext { signal: None, suggestions };
422 let result = can_use_tool
423 .can_use_tool(tool_name, &input_val, &context)
424 .await;
425
426 let permission_response = match result {
427 PermissionResult::Allow(allow) => {
428 let mut resp = serde_json::json!({ "allow": true });
429 if let Some(input) = allow.updated_input { resp["input"] = input; }
430 if let Some(perms) = allow.updated_permissions { resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default(); }
431 resp
432 }
433 PermissionResult::Deny(deny) => {
434 let mut resp = serde_json::json!({ "allow": false });
435 if !deny.message.is_empty() { resp["reason"] = serde_json::json!(deny.message); }
436 if deny.interrupt { resp["interrupt"] = serde_json::json!(true); }
437 resp
438 }
439 };
440
441 let response = serde_json::json!({
442 "subtype": "success",
443 "request_id": Self::extract_request_id(&control_message),
444 "response": permission_response
445 });
446 let mut transport = transport_for_control.lock().await;
447 if let Err(e) = transport.send_sdk_control_response(response).await {
448 error!("Failed to send permission response (fallback): {}", e);
449 }
450 }
451 }
452 }
453 }
454 }
455 "hook_callback" => {
456 if let Ok(request) = serde_json::from_value::<SDKHookCallbackRequest>(request_data.clone()) {
458 let callbacks = hook_callbacks_clone.read().await;
459
460 if let Some(callback) = callbacks.get(&request.callback_id) {
461 let context = HookContext { signal: None };
462
463 let response = callback
464 .execute(&request.input, request.tool_use_id.as_deref(), &context)
465 .await;
466
467 let response_json = serde_json::json!({
469 "subtype": "success",
470 "request_id": Self::extract_request_id(&control_message),
471 "response": response
472 });
473
474 let mut transport = transport_for_control.lock().await;
475 if let Err(e) = transport.send_sdk_control_response(response_json).await {
476 error!("Failed to send hook callback response: {}", e);
477 }
478 } else {
479 warn!("No hook callback found for ID: {}", request.callback_id);
480 }
481 } else {
482 let callback_id = request_data.get("callback_id").and_then(|v| v.as_str());
484 let tool_use_id = request_data.get("tool_use_id").and_then(|v| v.as_str()).map(|s| s.to_string());
485 let input = request_data.get("input").cloned().unwrap_or(serde_json::json!({}));
486
487 if let Some(callback_id) = callback_id {
488 let callbacks = hook_callbacks_clone.read().await;
489 if let Some(callback) = callbacks.get(callback_id) {
490 let context = HookContext { signal: None };
491 let response = callback
492 .execute(&input, tool_use_id.as_deref(), &context)
493 .await;
494
495 let response_json = serde_json::json!({
496 "subtype": "success",
497 "request_id": Self::extract_request_id(&control_message),
498 "response": response
499 });
500 let mut transport = transport_for_control.lock().await;
501 if let Err(e) = transport.send_sdk_control_response(response_json).await {
502 error!("Failed to send hook callback response (fallback): {}", e);
503 }
504 } else {
505 warn!("No hook callback found for ID: {}", callback_id);
506 }
507 } else {
508 warn!("Invalid hook_callback control message: missing callback_id");
509 }
510 }
511 }
512 "mcp_message" => {
513 if let Some(server_name) = request_data.get("server_name").and_then(|v| v.as_str()) {
515 if let Some(message) = request_data.get("message") {
516 debug!("Processing MCP message for SDK server: {}", server_name);
517
518 if let Some(server_arc) = sdk_mcp_servers_clone.get(server_name) {
520 if let Some(sdk_server) = server_arc.downcast_ref::<crate::sdk_mcp::SdkMcpServer>() {
522 match sdk_server.handle_message(message.clone()).await {
524 Ok(mcp_result) => {
525 let response = serde_json::json!({
527 "subtype": "success",
528 "request_id": Self::extract_request_id(&control_message),
529 "response": {
530 "mcp_response": mcp_result
531 }
532 });
533
534 let mut transport = transport_for_control.lock().await;
535 if let Err(e) = transport.send_sdk_control_response(response).await {
536 error!("Failed to send MCP response: {}", e);
537 }
538 }
539 Err(e) => {
540 error!("SDK MCP server error: {}", e);
541 let error_response = serde_json::json!({
542 "subtype": "error",
543 "request_id": Self::extract_request_id(&control_message),
544 "error": format!("MCP server error: {}", e)
545 });
546
547 let mut transport = transport_for_control.lock().await;
548 if let Err(e) = transport.send_sdk_control_response(error_response).await {
549 error!("Failed to send MCP error response: {}", e);
550 }
551 }
552 }
553 } else {
554 warn!("SDK server '{}' is not of type SdkMcpServer", server_name);
555 }
556 } else {
557 warn!("No SDK MCP server found with name: {}", server_name);
558 let error_response = serde_json::json!({
559 "subtype": "error",
560 "request_id": Self::extract_request_id(&control_message),
561 "error": format!("Server '{}' not found", server_name)
562 });
563
564 let mut transport = transport_for_control.lock().await;
565 if let Err(e) = transport.send_sdk_control_response(error_response).await {
566 error!("Failed to send MCP error response: {}", e);
567 }
568 }
569 }
570 }
571 }
572 _ => {
573 debug!("Unknown SDK control subtype: {}", subtype);
574 }
575 }
576 }
577 }
578 }
579 });
580 }
581 }
582
583 #[allow(dead_code)]
585 pub async fn stream_input<S>(&mut self, input_stream: S) -> Result<()>
586 where
587 S: Stream<Item = JsonValue> + Send + 'static,
588 {
589 let transport = self.transport.clone();
590
591 tokio::spawn(async move {
592 use futures::StreamExt;
593 let mut stream = Box::pin(input_stream);
594
595 while let Some(value) = stream.next().await {
596 let input_msg_opt = Self::json_to_input_message(value);
598 if let Some(input_msg) = input_msg_opt {
599 let mut guard = transport.lock().await;
600 if let Err(e) = guard.send_message(input_msg).await {
601 warn!("Failed to send streaming input message: {}", e);
602 }
603 } else {
604 warn!("Invalid streaming input JSON; expected user message shape");
605 }
606 }
607
608 let mut guard = transport.lock().await;
610 if let Err(e) = guard.end_input().await {
611 warn!("Failed to signal end_input: {}", e);
612 }
613 });
614 Ok(())
615 }
616
617 #[allow(dead_code)]
619 pub async fn receive_messages(&mut self) -> mpsc::Receiver<Result<Message>> {
620 self.message_rx.take().expect("Receiver already taken")
621 }
622
623 pub async fn interrupt(&mut self) -> Result<()> {
625 let interrupt_request = SDKControlRequest::Interrupt(SDKControlInterruptRequest {
626 subtype: "interrupt".to_string(),
627 });
628
629 self.send_control_request(interrupt_request).await?;
630 Ok(())
631 }
632
633 #[allow(dead_code)]
635 pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
636 let req = SDKControlRequest::SetPermissionMode(SDKControlSetPermissionModeRequest {
637 subtype: "set_permission_mode".to_string(),
638 mode: mode.to_string(),
639 });
640 let _ = self.send_control_request(req).await?;
642 Ok(())
643 }
644
645 #[allow(dead_code)]
647 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
648 let req = SDKControlRequest::SetModel(crate::types::SDKControlSetModelRequest {
649 subtype: "set_model".to_string(),
650 model,
651 });
652 let _ = self.send_control_request(req).await?;
653 Ok(())
654 }
655
656 #[allow(dead_code)]
658 async fn handle_mcp_message(&mut self, server_name: &str, message: &JsonValue) -> Result<JsonValue> {
659 if let Some(_server) = self.sdk_mcp_servers.get(server_name) {
661 debug!("Handling MCP message for SDK server {}: {:?}", server_name, message);
664 Ok(serde_json::json!({
665 "jsonrpc": "2.0",
666 "id": message.get("id"),
667 "result": {
668 "content": "MCP server response placeholder"
669 }
670 }))
671 } else {
672 Err(SdkError::InvalidState {
673 message: format!("No SDK MCP server found with name: {}", server_name),
674 })
675 }
676 }
677
678 #[allow(dead_code)]
680 pub async fn close(&mut self) -> Result<()> {
681 let mut transport = self.transport.lock().await;
683 transport.disconnect().await?;
684 Ok(())
685 }
686
687 pub fn get_initialization_result(&self) -> Option<&JsonValue> {
689 self.initialization_result.as_ref()
690 }
691
692 #[allow(dead_code)]
694 fn json_to_input_message(v: JsonValue) -> Option<InputMessage> {
695 if let Some(obj) = v.as_object() {
697 if let (Some(t), Some(message)) = (obj.get("type"), obj.get("message")) {
698 if t.as_str() == Some("user") {
699 let parent = obj
700 .get("parent_tool_use_id")
701 .and_then(|p| p.as_str().map(|s| s.to_string()));
702 let session_id = obj
703 .get("session_id")
704 .and_then(|s| s.as_str())
705 .unwrap_or("default")
706 .to_string();
707
708 let im = InputMessage {
709 r#type: "user".to_string(),
710 message: message.clone(),
711 parent_tool_use_id: parent,
712 session_id,
713 };
714 return Some(im);
715 }
716 }
717
718 if let Some(content) = obj.get("content").and_then(|c| c.as_str()) {
720 let session_id = obj
721 .get("session_id")
722 .and_then(|s| s.as_str())
723 .unwrap_or("default")
724 .to_string();
725 return Some(InputMessage::user(content.to_string(), session_id));
726 }
727 }
728
729 if let Some(s) = v.as_str() {
731 return Some(InputMessage::user(s.to_string(), "default".to_string()));
732 }
733
734 None
735 }
736}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741
742 #[test]
743 fn test_extract_request_id_supports_both_cases() {
744 let snake = serde_json::json!({"request_id": "req_1"});
745 let camel = serde_json::json!({"requestId": "req_2"});
746 assert_eq!(Query::extract_request_id(&snake), Some(serde_json::json!("req_1")));
747 assert_eq!(Query::extract_request_id(&camel), Some(serde_json::json!("req_2")));
748 }
749
750 #[test]
751 fn test_json_to_input_message_from_string() {
752 let v = serde_json::json!("Hello");
753 let im = Query::json_to_input_message(v).expect("should convert");
754 assert_eq!(im.r#type, "user");
755 assert_eq!(im.session_id, "default");
756 assert_eq!(im.message["content"].as_str().unwrap(), "Hello");
757 }
758
759 #[test]
760 fn test_json_to_input_message_from_object_content() {
761 let v = serde_json::json!({"content":"Ping","session_id":"s1"});
762 let im = Query::json_to_input_message(v).expect("should convert");
763 assert_eq!(im.session_id, "s1");
764 assert_eq!(im.message["content"].as_str().unwrap(), "Ping");
765 }
766
767 #[test]
768 fn test_json_to_input_message_full_user_shape() {
769 let v = serde_json::json!({
770 "type":"user",
771 "message": {"role":"user","content":"Hi"},
772 "session_id": "abc",
773 "parent_tool_use_id": null
774 });
775 let im = Query::json_to_input_message(v).expect("should convert");
776 assert_eq!(im.session_id, "abc");
777 assert_eq!(im.message["role"].as_str().unwrap(), "user");
778 assert_eq!(im.message["content"].as_str().unwrap(), "Hi");
779 }
780}