1use std::collections::HashMap;
15use std::future::Future;
16use std::panic::{self, AssertUnwindSafe};
17use std::sync::Arc;
18use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
19use std::time::Duration;
20
21use futures::{FutureExt, Stream, StreamExt};
22use serde_json::{Map, Value, json};
23use tokio::sync::{Mutex, mpsc, oneshot};
24use tokio::task::JoinHandle;
25use tracing::{debug, warn};
26
27use crate::errors::{Error, Result};
28use crate::message_parser::parse_message;
29use crate::sdk_mcp::McpSdkServer;
30use crate::transport::{TransportCloseHandle, TransportReader, TransportWriter};
31use crate::types::{
32 AgentDefinition, CanUseToolCallback, HookCallback, HookMatcher, McpStatusResponse, Message,
33 PermissionResult, ToolPermissionContext,
34};
35
36const MESSAGE_CHANNEL_BUFFER: usize = 100;
38
39fn convert_hook_output_for_cli(output: Value) -> Value {
44 let Some(obj) = output.as_object() else {
45 return output;
46 };
47
48 let mut converted = Map::new();
49 for (key, value) in obj {
50 match key.as_str() {
51 "async_" => {
52 converted.insert("async".to_string(), value.clone());
53 }
54 "continue_" => {
55 converted.insert("continue".to_string(), value.clone());
56 }
57 _ => {
58 converted.insert(key.clone(), value.clone());
59 }
60 }
61 }
62 Value::Object(converted)
63}
64
65fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
66 if let Some(msg) = payload.downcast_ref::<&str>() {
67 (*msg).to_string()
68 } else if let Some(msg) = payload.downcast_ref::<String>() {
69 msg.clone()
70 } else {
71 "unknown panic payload".to_string()
72 }
73}
74
75fn callback_panic_error(callback_type: &str, payload: Box<dyn std::any::Any + Send>) -> Error {
76 let panic_message = panic_payload_to_string(payload);
77 warn!(
78 callback_type,
79 panic_message, "Caught panic in callback invocation"
80 );
81 Error::Other(format!(
82 "{callback_type} callback panicked: {panic_message}"
83 ))
84}
85
86async fn await_callback_with_panic_isolation<T, F>(
87 callback_type: &str,
88 callback_future: F,
89) -> Result<T>
90where
91 F: Future<Output = Result<T>>,
92{
93 match AssertUnwindSafe(callback_future).catch_unwind().await {
94 Ok(result) => result,
95 Err(payload) => Err(callback_panic_error(callback_type, payload)),
96 }
97}
98
99struct PendingControlsState {
105 senders: HashMap<String, oneshot::Sender<Result<Value>>>,
106 buffered: HashMap<String, Result<Value>>,
107}
108
109struct QuerySharedState {
111 can_use_tool: Option<CanUseToolCallback>,
112 hook_callbacks: Mutex<HashMap<String, HookCallback>>,
113 sdk_mcp_servers: HashMap<String, Arc<McpSdkServer>>,
114 pending_controls: Mutex<PendingControlsState>,
116 writer: Arc<Mutex<Box<dyn TransportWriter>>>,
118 pending_stdin_close: AtomicBool,
120 stream_close_timeout: Duration,
122 reader_terminated: AtomicBool,
124 reader_termination_reason: Mutex<Option<String>>,
126}
127
128pub struct Query {
141 state: Option<Arc<QuerySharedState>>,
143
144 message_rx: Option<mpsc::Receiver<Result<Message>>>,
146
147 reader_task: Option<JoinHandle<()>>,
149
150 close_handle: Option<Box<dyn TransportCloseHandle>>,
152
153 request_counter: Arc<AtomicUsize>,
155
156 is_streaming_mode: bool,
158
159 agents: Option<HashMap<String, AgentDefinition>>,
161
162 initialized: bool,
164
165 initialization_result: Option<Value>,
167
168 initialize_timeout: Duration,
170
171 has_hooks_or_mcp: bool,
173}
174
175impl Query {
176 #[allow(clippy::too_many_arguments)]
193 pub(crate) fn start(
194 reader: Box<dyn TransportReader>,
195 writer: Box<dyn TransportWriter>,
196 close_handle: Box<dyn TransportCloseHandle>,
197 is_streaming_mode: bool,
198 can_use_tool: Option<CanUseToolCallback>,
199 hook_callbacks: HashMap<String, HookCallback>,
200 sdk_mcp_servers: HashMap<String, Arc<McpSdkServer>>,
201 agents: Option<HashMap<String, AgentDefinition>>,
202 initialize_timeout: Duration,
203 ) -> Self {
204 let stream_close_timeout_ms: u64 = std::env::var("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT")
205 .ok()
206 .and_then(|v| v.parse().ok())
207 .unwrap_or(60_000);
208 let stream_close_timeout =
209 Duration::from_millis(stream_close_timeout_ms).max(Duration::from_secs(60));
210
211 let has_hooks_or_mcp = !hook_callbacks.is_empty() || !sdk_mcp_servers.is_empty();
212 let writer = Arc::new(Mutex::new(writer));
213
214 let state = Arc::new(QuerySharedState {
215 can_use_tool,
216 hook_callbacks: Mutex::new(hook_callbacks),
217 sdk_mcp_servers,
218 pending_controls: Mutex::new(PendingControlsState {
219 senders: HashMap::new(),
220 buffered: HashMap::new(),
221 }),
222 writer: writer.clone(),
223 pending_stdin_close: AtomicBool::new(false),
224 stream_close_timeout,
225 reader_terminated: AtomicBool::new(false),
226 reader_termination_reason: Mutex::new(None),
227 });
228
229 let (message_tx, message_rx) = mpsc::channel(MESSAGE_CHANNEL_BUFFER);
230
231 let reader_state = state.clone();
232 let reader_task = tokio::spawn(async move {
233 background_reader_task(reader, reader_state, message_tx).await;
234 });
235
236 Self {
237 state: Some(state),
238 message_rx: Some(message_rx),
239 reader_task: Some(reader_task),
240 close_handle: Some(close_handle),
241 request_counter: Arc::new(AtomicUsize::new(0)),
242 is_streaming_mode,
243 agents,
244 initialized: false,
245 initialization_result: None,
246 initialize_timeout,
247 has_hooks_or_mcp,
248 }
249 }
250
251 pub async fn initialize(&mut self, hooks_config: Map<String, Value>) -> Result<Option<Value>> {
271 if !self.is_streaming_mode {
272 return Ok(None);
273 }
274
275 let mut request = Map::new();
276 request.insert(
277 "subtype".to_string(),
278 Value::String("initialize".to_string()),
279 );
280 request.insert(
281 "hooks".to_string(),
282 if hooks_config.is_empty() {
283 Value::Null
284 } else {
285 Value::Object(hooks_config)
286 },
287 );
288
289 if let Some(agents) = &self.agents {
290 request.insert(
291 "agents".to_string(),
292 serde_json::to_value(agents).unwrap_or(Value::Null),
293 );
294 }
295
296 let response = self
297 .send_control_request(Value::Object(request), self.initialize_timeout)
298 .await?;
299 self.initialized = true;
300 self.initialization_result = Some(response.clone());
301 Ok(Some(response))
302 }
303
304 pub fn initialization_result(&self) -> Option<Value> {
318 self.initialization_result.clone()
319 }
320
321 async fn send_control_request(&self, request: Value, timeout: Duration) -> Result<Value> {
326 if !self.is_streaming_mode {
327 return Err(Error::Other(
328 "Control requests require streaming mode".to_string(),
329 ));
330 }
331
332 let state = self
333 .state
334 .as_ref()
335 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
336
337 let request_id = format!(
338 "req_{}",
339 self.request_counter.fetch_add(1, Ordering::SeqCst) + 1
340 );
341
342 let control_request = json!({
344 "type": "control_request",
345 "request_id": request_id,
346 "request": request,
347 });
348 state
349 .writer
350 .lock()
351 .await
352 .write(&(control_request.to_string() + "\n"))
353 .await?;
354
355 let (tx, rx) = oneshot::channel();
359 {
360 let mut controls = state.pending_controls.lock().await;
361 if let Some(result) = controls.buffered.remove(&request_id) {
362 return result;
363 }
364 controls.senders.insert(request_id.clone(), tx);
365 }
366 if state.reader_terminated.load(Ordering::SeqCst) {
367 state
368 .pending_controls
369 .lock()
370 .await
371 .senders
372 .remove(&request_id);
373 let reason = reader_termination_reason(state).await;
374 return Err(Error::Other(format!(
375 "Background reader task terminated: {reason}"
376 )));
377 }
378
379 let result = tokio::time::timeout(timeout, rx).await;
381 match result {
382 Ok(Ok(value)) => value,
383 Ok(Err(_)) => {
384 Err(Error::Other(
386 "Background reader task terminated while waiting for control response"
387 .to_string(),
388 ))
389 }
390 Err(_) => {
391 let subtype = request
393 .get("subtype")
394 .and_then(Value::as_str)
395 .unwrap_or("unknown");
396 state
397 .pending_controls
398 .lock()
399 .await
400 .senders
401 .remove(&request_id);
402 Err(Error::Other(format!("Control request timeout: {subtype}")))
403 }
404 }
405 }
406
407 pub async fn send_user_message(&self, prompt: &str, session_id: &str) -> Result<()> {
420 let message = json!({
421 "type": "user",
422 "message": {"role": "user", "content": prompt},
423 "parent_tool_use_id": Value::Null,
424 "session_id": session_id
425 });
426 self.write_message(&message).await
427 }
428
429 pub async fn send_raw_message(&self, message: Value) -> Result<()> {
443 self.write_message(&message).await
444 }
445
446 async fn write_message(&self, message: &Value) -> Result<()> {
448 let state = self
449 .state
450 .as_ref()
451 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
452 state
453 .writer
454 .lock()
455 .await
456 .write(&(message.to_string() + "\n"))
457 .await
458 }
459
460 pub async fn send_input_messages(&self, messages: Vec<Value>) -> Result<()> {
476 for message in messages {
477 self.send_raw_message(message).await?;
478 }
479 Ok(())
480 }
481
482 pub async fn send_input_from_stream<S>(&self, mut messages: S) -> Result<()>
499 where
500 S: Stream<Item = Value> + Unpin,
501 {
502 while let Some(message) = messages.next().await {
503 self.send_raw_message(message).await?;
504 }
505 Ok(())
506 }
507
508 pub fn spawn_input_from_stream<S>(&self, mut messages: S) -> Result<JoinHandle<Result<()>>>
532 where
533 S: Stream<Item = Value> + Send + Unpin + 'static,
534 {
535 let state = self
536 .state
537 .as_ref()
538 .cloned()
539 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
540
541 Ok(tokio::spawn(async move {
542 while let Some(message) = messages.next().await {
543 state
544 .writer
545 .lock()
546 .await
547 .write(&(message.to_string() + "\n"))
548 .await?;
549 }
550 Ok(())
551 }))
552 }
553
554 pub async fn stream_input(&self, messages: Vec<Value>) -> Result<()> {
573 self.send_input_messages(messages).await?;
574 self.finalize_stream_input().await
575 }
576
577 pub async fn stream_input_from_stream<S>(&self, mut messages: S) -> Result<()>
594 where
595 S: Stream<Item = Value> + Unpin,
596 {
597 self.send_input_from_stream(&mut messages).await?;
598 self.finalize_stream_input().await
599 }
600
601 async fn finalize_stream_input(&self) -> Result<()> {
602 let state = self
603 .state
604 .as_ref()
605 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
606
607 if self.has_hooks_or_mcp {
608 debug!(
609 has_hooks_or_mcp = self.has_hooks_or_mcp,
610 "Deferring stdin close until first result"
611 );
612 state.pending_stdin_close.store(true, Ordering::SeqCst);
613 } else {
614 state.writer.lock().await.end_input().await?;
615 }
616 Ok(())
617 }
618
619 pub async fn end_input(&self) -> Result<()> {
632 let state = self
633 .state
634 .as_ref()
635 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
636 state.writer.lock().await.end_input().await
637 }
638
639 pub async fn receive_next_message(&mut self) -> Result<Option<Message>> {
659 let rx = self
660 .message_rx
661 .as_mut()
662 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
663
664 match rx.recv().await {
665 Some(Ok(message)) => Ok(Some(message)),
666 Some(Err(err)) => Err(err),
667 None => Ok(None),
668 }
669 }
670
671 pub async fn get_mcp_status(&self) -> Result<McpStatusResponse> {
684 let raw = self
685 .send_control_request(json!({ "subtype": "mcp_status" }), Duration::from_secs(60))
686 .await?;
687 serde_json::from_value(raw).map_err(|err| {
688 Error::Other(format!("Failed to decode typed MCP status response: {err}"))
689 })
690 }
691
692 pub async fn interrupt(&self) -> Result<()> {
705 self.send_control_request(json!({ "subtype": "interrupt" }), Duration::from_secs(60))
706 .await?;
707 Ok(())
708 }
709
710 pub async fn set_permission_mode(&self, mode: &str) -> Result<()> {
723 self.send_control_request(
724 json!({ "subtype": "set_permission_mode", "mode": mode }),
725 Duration::from_secs(60),
726 )
727 .await?;
728 Ok(())
729 }
730
731 pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
744 self.send_control_request(
745 json!({ "subtype": "set_model", "model": model }),
746 Duration::from_secs(60),
747 )
748 .await?;
749 Ok(())
750 }
751
752 pub async fn rewind_files(&self, user_message_id: &str) -> Result<()> {
765 self.send_control_request(
766 json!({ "subtype": "rewind_files", "user_message_id": user_message_id }),
767 Duration::from_secs(60),
768 )
769 .await?;
770 Ok(())
771 }
772
773 pub async fn reconnect_mcp_server(&self, server_name: &str) -> Result<()> {
775 self.send_control_request(
776 json!({ "subtype": "mcp_reconnect", "serverName": server_name }),
777 Duration::from_secs(60),
778 )
779 .await?;
780 Ok(())
781 }
782
783 pub async fn toggle_mcp_server(&self, server_name: &str, enabled: bool) -> Result<()> {
785 self.send_control_request(
786 json!({ "subtype": "mcp_toggle", "serverName": server_name, "enabled": enabled }),
787 Duration::from_secs(60),
788 )
789 .await?;
790 Ok(())
791 }
792
793 pub async fn stop_task(&self, task_id: &str) -> Result<()> {
795 self.send_control_request(
796 json!({ "subtype": "stop_task", "task_id": task_id }),
797 Duration::from_secs(60),
798 )
799 .await?;
800 Ok(())
801 }
802
803 pub async fn close(mut self) -> Result<()> {
816 self.shutdown().await
817 }
818
819 async fn shutdown(&mut self) -> Result<()> {
821 self.message_rx.take();
822 self.state.take();
823
824 if let Some(task) = self.reader_task.take() {
825 task.abort();
826 let _ = task.await;
827 }
828
829 if let Some(close_handle) = self.close_handle.take() {
830 close_handle.close().await?;
831 }
832
833 Ok(())
834 }
835
836 pub(crate) fn take_message_receiver(&mut self) -> Option<mpsc::Receiver<Result<Message>>> {
838 self.message_rx.take()
839 }
840}
841
842impl Drop for Query {
843 fn drop(&mut self) {
844 if let Some(task) = self.reader_task.take() {
845 task.abort();
846 }
847
848 if let Some(close_handle) = self.close_handle.take() {
849 if let Ok(handle) = tokio::runtime::Handle::try_current() {
853 handle.spawn(async move {
854 let _ = close_handle.close().await;
855 });
856 } else if let Ok(runtime) = tokio::runtime::Builder::new_current_thread()
857 .enable_all()
858 .build()
859 {
860 let _ = runtime.block_on(async move { close_handle.close().await });
861 }
862 }
863 }
864}
865
866async fn background_reader_task(
873 mut reader: Box<dyn TransportReader>,
874 state: Arc<QuerySharedState>,
875 message_tx: mpsc::Sender<Result<Message>>,
876) {
877 loop {
878 let read_result = if state.pending_stdin_close.load(Ordering::SeqCst) {
880 let timeout_dur = state.stream_close_timeout;
881 match tokio::time::timeout(timeout_dur, reader.read_next_message()).await {
882 Ok(result) => result,
883 Err(_) => {
884 debug!("Timed out waiting for first result, closing input stream");
885 try_close_deferred_stdin(&state).await;
886 continue;
887 }
888 }
889 } else {
890 reader.read_next_message().await
891 };
892
893 let raw = match read_result {
894 Ok(Some(raw)) => raw,
895 Ok(None) => {
896 try_close_deferred_stdin(&state).await;
897 break;
898 }
899 Err(err) => {
900 mark_reader_terminated(&state, err.to_string()).await;
901 let _ = message_tx.send(Err(err)).await;
902 break;
903 }
904 };
905
906 let msg_type = raw.get("type").and_then(Value::as_str).unwrap_or_default();
907
908 if msg_type == "control_response" {
909 handle_control_response(&state, &raw).await;
910 continue;
911 }
912
913 if msg_type == "control_request" {
914 if let Err(err) = handle_control_request(&state, raw).await {
915 debug!("Error handling control request: {err}");
916 }
917 continue;
918 }
919
920 if msg_type == "control_cancel_request" {
921 continue;
922 }
923
924 match parse_message(&raw) {
926 Ok(Some(msg)) => {
927 if matches!(msg, Message::Result(_))
928 && state.pending_stdin_close.load(Ordering::SeqCst)
929 {
930 debug!("Received first result, closing input stream");
931 try_close_deferred_stdin(&state).await;
932 }
933
934 if message_tx.send(Ok(msg)).await.is_err() {
935 break;
936 }
937 }
938 Ok(None) => {}
939 Err(err) => {
940 if message_tx
941 .send(Err(Error::MessageParse(err)))
942 .await
943 .is_err()
944 {
945 break;
946 }
947 }
948 }
949 }
950}
951
952async fn mark_reader_terminated(state: &QuerySharedState, reason: String) {
954 state.reader_terminated.store(true, Ordering::SeqCst);
955 let stored_reason = {
956 let mut termination_reason = state.reader_termination_reason.lock().await;
957 if termination_reason.is_none() {
958 *termination_reason = Some(reason);
959 }
960 termination_reason
961 .clone()
962 .unwrap_or_else(|| "Unknown reason".to_string())
963 };
964
965 let mut controls = state.pending_controls.lock().await;
966 for (_, sender) in controls.senders.drain() {
967 let _ = sender.send(Err(Error::Other(format!(
968 "Background reader task terminated: {stored_reason}"
969 ))));
970 }
971}
972
973async fn reader_termination_reason(state: &QuerySharedState) -> String {
975 state
976 .reader_termination_reason
977 .lock()
978 .await
979 .clone()
980 .unwrap_or_else(|| "Unknown reason".to_string())
981}
982
983async fn try_close_deferred_stdin(state: &QuerySharedState) {
985 if state
986 .pending_stdin_close
987 .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
988 .is_ok()
989 {
990 if let Err(e) = state.writer.lock().await.end_input().await {
991 debug!("Error closing deferred stdin: {e}");
992 }
993 }
994}
995
996async fn handle_control_response(state: &QuerySharedState, raw: &Value) {
1001 let Some(response) = raw.get("response").and_then(Value::as_object) else {
1002 return;
1003 };
1004 let response_request_id = response
1005 .get("request_id")
1006 .and_then(Value::as_str)
1007 .unwrap_or_default();
1008
1009 let subtype = response
1010 .get("subtype")
1011 .and_then(Value::as_str)
1012 .unwrap_or_default();
1013
1014 let result: Result<Value> = if subtype == "error" {
1015 let error = response
1016 .get("error")
1017 .and_then(Value::as_str)
1018 .unwrap_or("Unknown error");
1019 Err(Error::Other(error.to_string()))
1020 } else {
1021 Ok(response
1022 .get("response")
1023 .cloned()
1024 .unwrap_or_else(|| json!({})))
1025 };
1026
1027 let mut controls = state.pending_controls.lock().await;
1028 if let Some(sender) = controls.senders.remove(response_request_id) {
1029 let _ = sender.send(result);
1030 } else {
1031 controls
1033 .buffered
1034 .insert(response_request_id.to_string(), result);
1035 }
1036}
1037
1038async fn handle_can_use_tool_request(
1039 state: &QuerySharedState,
1040 request_data: &Map<String, Value>,
1041) -> Result<Value> {
1042 let callback = state
1043 .can_use_tool
1044 .clone()
1045 .ok_or_else(|| Error::Other("canUseTool callback is not provided".to_string()))?;
1046 let tool_name = request_data
1047 .get("tool_name")
1048 .and_then(Value::as_str)
1049 .unwrap_or_default()
1050 .to_string();
1051 let input = request_data
1052 .get("input")
1053 .cloned()
1054 .unwrap_or_else(|| json!({}));
1055 let suggestions = request_data
1056 .get("permission_suggestions")
1057 .and_then(Value::as_array)
1058 .cloned()
1059 .unwrap_or_default()
1060 .into_iter()
1061 .filter_map(|value| serde_json::from_value(value).ok())
1062 .collect();
1063 let blocked_path = request_data
1064 .get("blocked_path")
1065 .and_then(Value::as_str)
1066 .map(ToString::to_string);
1067 let context = ToolPermissionContext {
1068 suggestions,
1069 blocked_path,
1070 signal: None,
1071 };
1072
1073 let callback_future = panic::catch_unwind(AssertUnwindSafe(|| {
1074 callback(tool_name, input.clone(), context)
1075 }))
1076 .map_err(|payload| callback_panic_error("can_use_tool", payload))?;
1077 let callback_result =
1078 await_callback_with_panic_isolation("can_use_tool", callback_future).await?;
1079 let output = match callback_result {
1080 PermissionResult::Allow(allow) => {
1081 let mut obj = Map::new();
1082 obj.insert("behavior".to_string(), Value::String("allow".to_string()));
1083 obj.insert(
1084 "updatedInput".to_string(),
1085 allow.updated_input.unwrap_or(input),
1086 );
1087 if let Some(updated_permissions) = allow.updated_permissions {
1088 let permissions_json: Vec<Value> = updated_permissions
1089 .into_iter()
1090 .map(|permission| permission.to_cli_dict())
1091 .collect();
1092 obj.insert(
1093 "updatedPermissions".to_string(),
1094 Value::Array(permissions_json),
1095 );
1096 }
1097 Value::Object(obj)
1098 }
1099 PermissionResult::Deny(deny) => {
1100 let mut obj = Map::new();
1101 obj.insert("behavior".to_string(), Value::String("deny".to_string()));
1102 obj.insert("message".to_string(), Value::String(deny.message));
1103 if deny.interrupt {
1104 obj.insert("interrupt".to_string(), Value::Bool(true));
1105 }
1106 Value::Object(obj)
1107 }
1108 };
1109 Ok(output)
1110}
1111
1112async fn handle_hook_callback_request(
1113 state: &QuerySharedState,
1114 request_data: &Map<String, Value>,
1115) -> Result<Value> {
1116 let callback_id = request_data
1117 .get("callback_id")
1118 .and_then(Value::as_str)
1119 .ok_or_else(|| Error::Other("Missing callback_id in hook_callback".to_string()))?;
1120 let callback = state
1121 .hook_callbacks
1122 .lock()
1123 .await
1124 .get(callback_id)
1125 .cloned()
1126 .ok_or_else(|| Error::Other(format!("No hook callback found for ID: {callback_id}")))?;
1127 let input = request_data.get("input").cloned().unwrap_or(Value::Null);
1128 let tool_use_id = request_data
1129 .get("tool_use_id")
1130 .and_then(Value::as_str)
1131 .map(ToString::to_string);
1132 let callback_future = panic::catch_unwind(AssertUnwindSafe(|| {
1133 callback(input, tool_use_id, Default::default())
1134 }))
1135 .map_err(|payload| callback_panic_error("hook", payload))?;
1136 let output = await_callback_with_panic_isolation("hook", callback_future).await?;
1137 Ok(convert_hook_output_for_cli(output))
1138}
1139
1140async fn handle_mcp_message_request(
1141 state: &QuerySharedState,
1142 request_data: &Map<String, Value>,
1143) -> Result<Value> {
1144 let server_name = request_data
1145 .get("server_name")
1146 .and_then(Value::as_str)
1147 .ok_or_else(|| Error::Other("Missing server_name in mcp_message".to_string()))?;
1148 let message = request_data
1149 .get("message")
1150 .cloned()
1151 .ok_or_else(|| Error::Other("Missing message in mcp_message".to_string()))?;
1152 let response = handle_sdk_mcp_request(&state.sdk_mcp_servers, server_name, &message).await;
1153 Ok(json!({ "mcp_response": response }))
1154}
1155
1156async fn handle_control_request(state: &QuerySharedState, request: Value) -> Result<()> {
1158 let Some(request_obj) = request.as_object() else {
1159 return Err(Error::Other("Invalid control request format".to_string()));
1160 };
1161 let request_id = request_obj
1162 .get("request_id")
1163 .and_then(Value::as_str)
1164 .ok_or_else(|| Error::Other("Missing request_id in control request".to_string()))?
1165 .to_string();
1166 let request_data = request_obj
1167 .get("request")
1168 .and_then(Value::as_object)
1169 .ok_or_else(|| Error::Other("Missing request payload".to_string()))?;
1170 let subtype = request_data
1171 .get("subtype")
1172 .and_then(Value::as_str)
1173 .ok_or_else(|| Error::Other("Missing request subtype".to_string()))?;
1174
1175 let result: Result<Value> = match subtype {
1176 "can_use_tool" => handle_can_use_tool_request(state, request_data).await,
1177 "hook_callback" => handle_hook_callback_request(state, request_data).await,
1178 "mcp_message" => handle_mcp_message_request(state, request_data).await,
1179 _ => Err(Error::Other(format!(
1180 "Unsupported control request subtype: {subtype}"
1181 ))),
1182 };
1183
1184 let response_json = match result {
1185 Ok(payload) => json!({
1186 "type": "control_response",
1187 "response": {
1188 "subtype": "success",
1189 "request_id": request_id,
1190 "response": payload
1191 }
1192 }),
1193 Err(err) => json!({
1194 "type": "control_response",
1195 "response": {
1196 "subtype": "error",
1197 "request_id": request_id,
1198 "error": err.to_string()
1199 }
1200 }),
1201 };
1202
1203 state
1204 .writer
1205 .lock()
1206 .await
1207 .write(&(response_json.to_string() + "\n"))
1208 .await
1209}
1210
1211pub async fn handle_sdk_mcp_request(
1248 sdk_mcp_servers: &HashMap<String, Arc<McpSdkServer>>,
1249 server_name: &str,
1250 message: &Value,
1251) -> Value {
1252 let Some(server) = sdk_mcp_servers.get(server_name) else {
1253 return json!({
1254 "jsonrpc": "2.0",
1255 "id": message.get("id").cloned().unwrap_or(Value::Null),
1256 "error": {
1257 "code": -32601,
1258 "message": format!("Server '{server_name}' not found")
1259 }
1260 });
1261 };
1262
1263 let method = message
1264 .get("method")
1265 .and_then(Value::as_str)
1266 .unwrap_or_default();
1267 let id = message.get("id").cloned().unwrap_or(Value::Null);
1268 let params = message.get("params").cloned().unwrap_or_else(|| json!({}));
1269
1270 match method {
1271 "initialize" => json!({
1272 "jsonrpc": "2.0",
1273 "id": id,
1274 "result": {
1275 "protocolVersion": "2024-11-05",
1276 "capabilities": {"tools": {}},
1277 "serverInfo": {
1278 "name": server.name,
1279 "version": server.version
1280 }
1281 }
1282 }),
1283 "tools/list" => json!({
1284 "jsonrpc": "2.0",
1285 "id": id,
1286 "result": {
1287 "tools": server.list_tools_json()
1288 }
1289 }),
1290 "tools/call" => {
1291 let tool_name = params
1292 .get("name")
1293 .and_then(Value::as_str)
1294 .unwrap_or_default();
1295 let arguments = params
1296 .get("arguments")
1297 .cloned()
1298 .unwrap_or_else(|| json!({}));
1299 let result = server.call_tool_json(tool_name, arguments).await;
1300 json!({
1301 "jsonrpc": "2.0",
1302 "id": id,
1303 "result": result
1304 })
1305 }
1306 "notifications/initialized" => json!({
1307 "jsonrpc": "2.0",
1308 "result": {}
1309 }),
1310 _ => json!({
1311 "jsonrpc": "2.0",
1312 "id": id,
1313 "error": {
1314 "code": -32601,
1315 "message": format!("Method '{method}' not found")
1316 }
1317 }),
1318 }
1319}
1320
1321pub(crate) fn build_hooks_config(
1328 hooks: &HashMap<String, Vec<HookMatcher>>,
1329) -> (Map<String, Value>, HashMap<String, HookCallback>) {
1330 let mut hooks_config = Map::new();
1331 let mut hook_callbacks = HashMap::new();
1332 let mut next_callback_id: usize = 0;
1333
1334 for (event, matchers) in hooks {
1335 if matchers.is_empty() {
1336 continue;
1337 }
1338 let mut event_matchers = Vec::new();
1339 for matcher in matchers {
1340 let mut callback_ids = Vec::new();
1341 for callback in &matcher.hooks {
1342 let callback_id = format!("hook_{}", next_callback_id);
1343 next_callback_id += 1;
1344 hook_callbacks.insert(callback_id.clone(), callback.clone());
1345 callback_ids.push(callback_id);
1346 }
1347
1348 let mut matcher_obj = Map::new();
1349 matcher_obj.insert(
1350 "matcher".to_string(),
1351 matcher
1352 .matcher
1353 .as_ref()
1354 .map(|m| Value::String(m.clone()))
1355 .unwrap_or(Value::Null),
1356 );
1357 matcher_obj.insert("hookCallbackIds".to_string(), json!(callback_ids));
1358 if let Some(timeout) = matcher.timeout {
1359 matcher_obj.insert("timeout".to_string(), json!(timeout));
1360 }
1361 event_matchers.push(Value::Object(matcher_obj));
1362 }
1363 hooks_config.insert(event.clone(), Value::Array(event_matchers));
1364 }
1365
1366 (hooks_config, hook_callbacks)
1367}