1use std::collections::HashMap;
15use std::future::Future;
16use std::panic::{self, AssertUnwindSafe};
17use std::sync::Arc;
18use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
19use std::time::Duration;
20
21use futures::{FutureExt, Stream, StreamExt};
22use serde_json::{Map, Value, json};
23use tokio::sync::{Mutex, mpsc, oneshot};
24use tokio::task::JoinHandle;
25use tracing::{debug, warn};
26
27use crate::errors::{Error, Result};
28use crate::message_parser::parse_message;
29use crate::sdk_mcp::McpSdkServer;
30use crate::transport::{TransportCloseHandle, TransportReader, TransportWriter};
31use crate::types::{
32 AgentDefinition, CanUseToolCallback, HookCallback, HookMatcher, Message, PermissionResult,
33 ToolPermissionContext,
34};
35
36const MESSAGE_CHANNEL_BUFFER: usize = 100;
38
39fn convert_hook_output_for_cli(output: Value) -> Value {
44 let Some(obj) = output.as_object() else {
45 return output;
46 };
47
48 let mut converted = Map::new();
49 for (key, value) in obj {
50 match key.as_str() {
51 "async_" => {
52 converted.insert("async".to_string(), value.clone());
53 }
54 "continue_" => {
55 converted.insert("continue".to_string(), value.clone());
56 }
57 _ => {
58 converted.insert(key.clone(), value.clone());
59 }
60 }
61 }
62 Value::Object(converted)
63}
64
65fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
66 if let Some(msg) = payload.downcast_ref::<&str>() {
67 (*msg).to_string()
68 } else if let Some(msg) = payload.downcast_ref::<String>() {
69 msg.clone()
70 } else {
71 "unknown panic payload".to_string()
72 }
73}
74
75fn callback_panic_error(callback_type: &str, payload: Box<dyn std::any::Any + Send>) -> Error {
76 let panic_message = panic_payload_to_string(payload);
77 warn!(
78 callback_type,
79 panic_message, "Caught panic in callback invocation"
80 );
81 Error::Other(format!(
82 "{callback_type} callback panicked: {panic_message}"
83 ))
84}
85
86async fn await_callback_with_panic_isolation<T, F>(
87 callback_type: &str,
88 callback_future: F,
89) -> Result<T>
90where
91 F: Future<Output = Result<T>>,
92{
93 match AssertUnwindSafe(callback_future).catch_unwind().await {
94 Ok(result) => result,
95 Err(payload) => Err(callback_panic_error(callback_type, payload)),
96 }
97}
98
99struct PendingControlsState {
105 senders: HashMap<String, oneshot::Sender<Result<Value>>>,
106 buffered: HashMap<String, Result<Value>>,
107}
108
109struct QuerySharedState {
111 can_use_tool: Option<CanUseToolCallback>,
112 hook_callbacks: Mutex<HashMap<String, HookCallback>>,
113 sdk_mcp_servers: HashMap<String, Arc<McpSdkServer>>,
114 pending_controls: Mutex<PendingControlsState>,
116 writer: Arc<Mutex<Box<dyn TransportWriter>>>,
118 pending_stdin_close: AtomicBool,
120 stream_close_timeout: Duration,
122 reader_terminated: AtomicBool,
124 reader_termination_reason: Mutex<Option<String>>,
126}
127
128pub struct Query {
141 state: Option<Arc<QuerySharedState>>,
143
144 message_rx: Option<mpsc::Receiver<Result<Message>>>,
146
147 reader_task: Option<JoinHandle<()>>,
149
150 close_handle: Option<Box<dyn TransportCloseHandle>>,
152
153 request_counter: Arc<AtomicUsize>,
155
156 is_streaming_mode: bool,
158
159 agents: Option<HashMap<String, AgentDefinition>>,
161
162 initialized: bool,
164
165 initialization_result: Option<Value>,
167
168 initialize_timeout: Duration,
170
171 has_hooks_or_mcp: bool,
173}
174
175impl Query {
176 #[allow(clippy::too_many_arguments)]
193 pub(crate) fn start(
194 reader: Box<dyn TransportReader>,
195 writer: Box<dyn TransportWriter>,
196 close_handle: Box<dyn TransportCloseHandle>,
197 is_streaming_mode: bool,
198 can_use_tool: Option<CanUseToolCallback>,
199 hook_callbacks: HashMap<String, HookCallback>,
200 sdk_mcp_servers: HashMap<String, Arc<McpSdkServer>>,
201 agents: Option<HashMap<String, AgentDefinition>>,
202 initialize_timeout: Duration,
203 ) -> Self {
204 let stream_close_timeout_ms: u64 = std::env::var("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT")
205 .ok()
206 .and_then(|v| v.parse().ok())
207 .unwrap_or(60_000);
208 let stream_close_timeout =
209 Duration::from_millis(stream_close_timeout_ms).max(Duration::from_secs(60));
210
211 let has_hooks_or_mcp = !hook_callbacks.is_empty() || !sdk_mcp_servers.is_empty();
212 let writer = Arc::new(Mutex::new(writer));
213
214 let state = Arc::new(QuerySharedState {
215 can_use_tool,
216 hook_callbacks: Mutex::new(hook_callbacks),
217 sdk_mcp_servers,
218 pending_controls: Mutex::new(PendingControlsState {
219 senders: HashMap::new(),
220 buffered: HashMap::new(),
221 }),
222 writer: writer.clone(),
223 pending_stdin_close: AtomicBool::new(false),
224 stream_close_timeout,
225 reader_terminated: AtomicBool::new(false),
226 reader_termination_reason: Mutex::new(None),
227 });
228
229 let (message_tx, message_rx) = mpsc::channel(MESSAGE_CHANNEL_BUFFER);
230
231 let reader_state = state.clone();
232 let reader_task = tokio::spawn(async move {
233 background_reader_task(reader, reader_state, message_tx).await;
234 });
235
236 Self {
237 state: Some(state),
238 message_rx: Some(message_rx),
239 reader_task: Some(reader_task),
240 close_handle: Some(close_handle),
241 request_counter: Arc::new(AtomicUsize::new(0)),
242 is_streaming_mode,
243 agents,
244 initialized: false,
245 initialization_result: None,
246 initialize_timeout,
247 has_hooks_or_mcp,
248 }
249 }
250
251 pub async fn initialize(&mut self, hooks_config: Map<String, Value>) -> Result<Option<Value>> {
271 if !self.is_streaming_mode {
272 return Ok(None);
273 }
274
275 let mut request = Map::new();
276 request.insert(
277 "subtype".to_string(),
278 Value::String("initialize".to_string()),
279 );
280 request.insert(
281 "hooks".to_string(),
282 if hooks_config.is_empty() {
283 Value::Null
284 } else {
285 Value::Object(hooks_config)
286 },
287 );
288
289 if let Some(agents) = &self.agents {
290 request.insert(
291 "agents".to_string(),
292 serde_json::to_value(agents).unwrap_or(Value::Null),
293 );
294 }
295
296 let response = self
297 .send_control_request(Value::Object(request), self.initialize_timeout)
298 .await?;
299 self.initialized = true;
300 self.initialization_result = Some(response.clone());
301 Ok(Some(response))
302 }
303
304 pub fn initialization_result(&self) -> Option<Value> {
318 self.initialization_result.clone()
319 }
320
321 async fn send_control_request(&self, request: Value, timeout: Duration) -> Result<Value> {
326 if !self.is_streaming_mode {
327 return Err(Error::Other(
328 "Control requests require streaming mode".to_string(),
329 ));
330 }
331
332 let state = self
333 .state
334 .as_ref()
335 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
336
337 let request_id = format!(
338 "req_{}",
339 self.request_counter.fetch_add(1, Ordering::SeqCst) + 1
340 );
341
342 let control_request = json!({
344 "type": "control_request",
345 "request_id": request_id,
346 "request": request,
347 });
348 state
349 .writer
350 .lock()
351 .await
352 .write(&(control_request.to_string() + "\n"))
353 .await?;
354
355 let (tx, rx) = oneshot::channel();
359 {
360 let mut controls = state.pending_controls.lock().await;
361 if let Some(result) = controls.buffered.remove(&request_id) {
362 return result;
363 }
364 controls.senders.insert(request_id.clone(), tx);
365 }
366 if state.reader_terminated.load(Ordering::SeqCst) {
367 state
368 .pending_controls
369 .lock()
370 .await
371 .senders
372 .remove(&request_id);
373 let reason = reader_termination_reason(state).await;
374 return Err(Error::Other(format!(
375 "Background reader task terminated: {reason}"
376 )));
377 }
378
379 let result = tokio::time::timeout(timeout, rx).await;
381 match result {
382 Ok(Ok(value)) => value,
383 Ok(Err(_)) => {
384 Err(Error::Other(
386 "Background reader task terminated while waiting for control response"
387 .to_string(),
388 ))
389 }
390 Err(_) => {
391 let subtype = request
393 .get("subtype")
394 .and_then(Value::as_str)
395 .unwrap_or("unknown");
396 state
397 .pending_controls
398 .lock()
399 .await
400 .senders
401 .remove(&request_id);
402 Err(Error::Other(format!("Control request timeout: {subtype}")))
403 }
404 }
405 }
406
407 pub async fn send_user_message(&self, prompt: &str, session_id: &str) -> Result<()> {
420 let message = json!({
421 "type": "user",
422 "message": {"role": "user", "content": prompt},
423 "parent_tool_use_id": Value::Null,
424 "session_id": session_id
425 });
426 self.write_message(&message).await
427 }
428
429 pub async fn send_raw_message(&self, message: Value) -> Result<()> {
443 self.write_message(&message).await
444 }
445
446 async fn write_message(&self, message: &Value) -> Result<()> {
448 let state = self
449 .state
450 .as_ref()
451 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
452 state
453 .writer
454 .lock()
455 .await
456 .write(&(message.to_string() + "\n"))
457 .await
458 }
459
460 pub async fn send_input_messages(&self, messages: Vec<Value>) -> Result<()> {
476 for message in messages {
477 self.send_raw_message(message).await?;
478 }
479 Ok(())
480 }
481
482 pub async fn send_input_from_stream<S>(&self, mut messages: S) -> Result<()>
499 where
500 S: Stream<Item = Value> + Unpin,
501 {
502 while let Some(message) = messages.next().await {
503 self.send_raw_message(message).await?;
504 }
505 Ok(())
506 }
507
508 pub fn spawn_input_from_stream<S>(&self, mut messages: S) -> Result<JoinHandle<Result<()>>>
532 where
533 S: Stream<Item = Value> + Send + Unpin + 'static,
534 {
535 let state = self
536 .state
537 .as_ref()
538 .cloned()
539 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
540
541 Ok(tokio::spawn(async move {
542 while let Some(message) = messages.next().await {
543 state
544 .writer
545 .lock()
546 .await
547 .write(&(message.to_string() + "\n"))
548 .await?;
549 }
550 Ok(())
551 }))
552 }
553
554 pub async fn stream_input(&self, messages: Vec<Value>) -> Result<()> {
573 self.send_input_messages(messages).await?;
574 self.finalize_stream_input().await
575 }
576
577 pub async fn stream_input_from_stream<S>(&self, mut messages: S) -> Result<()>
594 where
595 S: Stream<Item = Value> + Unpin,
596 {
597 self.send_input_from_stream(&mut messages).await?;
598 self.finalize_stream_input().await
599 }
600
601 async fn finalize_stream_input(&self) -> Result<()> {
602 let state = self
603 .state
604 .as_ref()
605 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
606
607 if self.has_hooks_or_mcp {
608 debug!(
609 has_hooks_or_mcp = self.has_hooks_or_mcp,
610 "Deferring stdin close until first result"
611 );
612 state.pending_stdin_close.store(true, Ordering::SeqCst);
613 } else {
614 state.writer.lock().await.end_input().await?;
615 }
616 Ok(())
617 }
618
619 pub async fn end_input(&self) -> Result<()> {
632 let state = self
633 .state
634 .as_ref()
635 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
636 state.writer.lock().await.end_input().await
637 }
638
639 pub async fn receive_next_message(&mut self) -> Result<Option<Message>> {
659 let rx = self
660 .message_rx
661 .as_mut()
662 .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
663
664 match rx.recv().await {
665 Some(Ok(message)) => Ok(Some(message)),
666 Some(Err(err)) => Err(err),
667 None => Ok(None),
668 }
669 }
670
671 pub async fn get_mcp_status(&self) -> Result<Value> {
684 self.send_control_request(json!({ "subtype": "mcp_status" }), Duration::from_secs(60))
685 .await
686 }
687
688 pub async fn interrupt(&self) -> Result<()> {
701 self.send_control_request(json!({ "subtype": "interrupt" }), Duration::from_secs(60))
702 .await?;
703 Ok(())
704 }
705
706 pub async fn set_permission_mode(&self, mode: &str) -> Result<()> {
719 self.send_control_request(
720 json!({ "subtype": "set_permission_mode", "mode": mode }),
721 Duration::from_secs(60),
722 )
723 .await?;
724 Ok(())
725 }
726
727 pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
740 self.send_control_request(
741 json!({ "subtype": "set_model", "model": model }),
742 Duration::from_secs(60),
743 )
744 .await?;
745 Ok(())
746 }
747
748 pub async fn rewind_files(&self, user_message_id: &str) -> Result<()> {
761 self.send_control_request(
762 json!({ "subtype": "rewind_files", "user_message_id": user_message_id }),
763 Duration::from_secs(60),
764 )
765 .await?;
766 Ok(())
767 }
768
769 pub async fn close(mut self) -> Result<()> {
782 self.shutdown().await
783 }
784
785 async fn shutdown(&mut self) -> Result<()> {
787 self.message_rx.take();
788 self.state.take();
789
790 if let Some(task) = self.reader_task.take() {
791 task.abort();
792 let _ = task.await;
793 }
794
795 if let Some(close_handle) = self.close_handle.take() {
796 close_handle.close().await?;
797 }
798
799 Ok(())
800 }
801
802 pub(crate) fn take_message_receiver(&mut self) -> Option<mpsc::Receiver<Result<Message>>> {
804 self.message_rx.take()
805 }
806}
807
808impl Drop for Query {
809 fn drop(&mut self) {
810 if let Some(task) = self.reader_task.take() {
811 task.abort();
812 }
813
814 if let Some(close_handle) = self.close_handle.take() {
815 if let Ok(handle) = tokio::runtime::Handle::try_current() {
819 handle.spawn(async move {
820 let _ = close_handle.close().await;
821 });
822 } else if let Ok(runtime) = tokio::runtime::Builder::new_current_thread()
823 .enable_all()
824 .build()
825 {
826 let _ = runtime.block_on(async move { close_handle.close().await });
827 }
828 }
829 }
830}
831
832async fn background_reader_task(
839 mut reader: Box<dyn TransportReader>,
840 state: Arc<QuerySharedState>,
841 message_tx: mpsc::Sender<Result<Message>>,
842) {
843 loop {
844 let read_result = if state.pending_stdin_close.load(Ordering::SeqCst) {
846 let timeout_dur = state.stream_close_timeout;
847 match tokio::time::timeout(timeout_dur, reader.read_next_message()).await {
848 Ok(result) => result,
849 Err(_) => {
850 debug!("Timed out waiting for first result, closing input stream");
851 try_close_deferred_stdin(&state).await;
852 continue;
853 }
854 }
855 } else {
856 reader.read_next_message().await
857 };
858
859 let raw = match read_result {
860 Ok(Some(raw)) => raw,
861 Ok(None) => {
862 try_close_deferred_stdin(&state).await;
863 break;
864 }
865 Err(err) => {
866 mark_reader_terminated(&state, err.to_string()).await;
867 let _ = message_tx.send(Err(err)).await;
868 break;
869 }
870 };
871
872 let msg_type = raw.get("type").and_then(Value::as_str).unwrap_or_default();
873
874 if msg_type == "control_response" {
875 handle_control_response(&state, &raw).await;
876 continue;
877 }
878
879 if msg_type == "control_request" {
880 if let Err(err) = handle_control_request(&state, raw).await {
881 debug!("Error handling control request: {err}");
882 }
883 continue;
884 }
885
886 if msg_type == "control_cancel_request" {
887 continue;
888 }
889
890 match parse_message(&raw) {
892 Ok(Some(msg)) => {
893 if matches!(msg, Message::Result(_))
894 && state.pending_stdin_close.load(Ordering::SeqCst)
895 {
896 debug!("Received first result, closing input stream");
897 try_close_deferred_stdin(&state).await;
898 }
899
900 if message_tx.send(Ok(msg)).await.is_err() {
901 break;
902 }
903 }
904 Ok(None) => {}
905 Err(err) => {
906 if message_tx
907 .send(Err(Error::MessageParse(err)))
908 .await
909 .is_err()
910 {
911 break;
912 }
913 }
914 }
915 }
916}
917
918async fn mark_reader_terminated(state: &QuerySharedState, reason: String) {
920 state.reader_terminated.store(true, Ordering::SeqCst);
921 let stored_reason = {
922 let mut termination_reason = state.reader_termination_reason.lock().await;
923 if termination_reason.is_none() {
924 *termination_reason = Some(reason);
925 }
926 termination_reason
927 .clone()
928 .unwrap_or_else(|| "Unknown reason".to_string())
929 };
930
931 let mut controls = state.pending_controls.lock().await;
932 for (_, sender) in controls.senders.drain() {
933 let _ = sender.send(Err(Error::Other(format!(
934 "Background reader task terminated: {stored_reason}"
935 ))));
936 }
937}
938
939async fn reader_termination_reason(state: &QuerySharedState) -> String {
941 state
942 .reader_termination_reason
943 .lock()
944 .await
945 .clone()
946 .unwrap_or_else(|| "Unknown reason".to_string())
947}
948
949async fn try_close_deferred_stdin(state: &QuerySharedState) {
951 if state
952 .pending_stdin_close
953 .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
954 .is_ok()
955 {
956 if let Err(e) = state.writer.lock().await.end_input().await {
957 debug!("Error closing deferred stdin: {e}");
958 }
959 }
960}
961
962async fn handle_control_response(state: &QuerySharedState, raw: &Value) {
967 let Some(response) = raw.get("response").and_then(Value::as_object) else {
968 return;
969 };
970 let response_request_id = response
971 .get("request_id")
972 .and_then(Value::as_str)
973 .unwrap_or_default();
974
975 let subtype = response
976 .get("subtype")
977 .and_then(Value::as_str)
978 .unwrap_or_default();
979
980 let result: Result<Value> = if subtype == "error" {
981 let error = response
982 .get("error")
983 .and_then(Value::as_str)
984 .unwrap_or("Unknown error");
985 Err(Error::Other(error.to_string()))
986 } else {
987 Ok(response
988 .get("response")
989 .cloned()
990 .unwrap_or_else(|| json!({})))
991 };
992
993 let mut controls = state.pending_controls.lock().await;
994 if let Some(sender) = controls.senders.remove(response_request_id) {
995 let _ = sender.send(result);
996 } else {
997 controls
999 .buffered
1000 .insert(response_request_id.to_string(), result);
1001 }
1002}
1003
1004async fn handle_can_use_tool_request(
1005 state: &QuerySharedState,
1006 request_data: &Map<String, Value>,
1007) -> Result<Value> {
1008 let callback = state
1009 .can_use_tool
1010 .clone()
1011 .ok_or_else(|| Error::Other("canUseTool callback is not provided".to_string()))?;
1012 let tool_name = request_data
1013 .get("tool_name")
1014 .and_then(Value::as_str)
1015 .unwrap_or_default()
1016 .to_string();
1017 let input = request_data
1018 .get("input")
1019 .cloned()
1020 .unwrap_or_else(|| json!({}));
1021 let suggestions = request_data
1022 .get("permission_suggestions")
1023 .and_then(Value::as_array)
1024 .cloned()
1025 .unwrap_or_default()
1026 .into_iter()
1027 .filter_map(|value| serde_json::from_value(value).ok())
1028 .collect();
1029 let blocked_path = request_data
1030 .get("blocked_path")
1031 .and_then(Value::as_str)
1032 .map(ToString::to_string);
1033 let context = ToolPermissionContext {
1034 suggestions,
1035 blocked_path,
1036 signal: None,
1037 };
1038
1039 let callback_future = panic::catch_unwind(AssertUnwindSafe(|| {
1040 callback(tool_name, input.clone(), context)
1041 }))
1042 .map_err(|payload| callback_panic_error("can_use_tool", payload))?;
1043 let callback_result =
1044 await_callback_with_panic_isolation("can_use_tool", callback_future).await?;
1045 let output = match callback_result {
1046 PermissionResult::Allow(allow) => {
1047 let mut obj = Map::new();
1048 obj.insert("behavior".to_string(), Value::String("allow".to_string()));
1049 obj.insert(
1050 "updatedInput".to_string(),
1051 allow.updated_input.unwrap_or(input),
1052 );
1053 if let Some(updated_permissions) = allow.updated_permissions {
1054 let permissions_json: Vec<Value> = updated_permissions
1055 .into_iter()
1056 .map(|permission| permission.to_cli_dict())
1057 .collect();
1058 obj.insert(
1059 "updatedPermissions".to_string(),
1060 Value::Array(permissions_json),
1061 );
1062 }
1063 Value::Object(obj)
1064 }
1065 PermissionResult::Deny(deny) => {
1066 let mut obj = Map::new();
1067 obj.insert("behavior".to_string(), Value::String("deny".to_string()));
1068 obj.insert("message".to_string(), Value::String(deny.message));
1069 if deny.interrupt {
1070 obj.insert("interrupt".to_string(), Value::Bool(true));
1071 }
1072 Value::Object(obj)
1073 }
1074 };
1075 Ok(output)
1076}
1077
1078async fn handle_hook_callback_request(
1079 state: &QuerySharedState,
1080 request_data: &Map<String, Value>,
1081) -> Result<Value> {
1082 let callback_id = request_data
1083 .get("callback_id")
1084 .and_then(Value::as_str)
1085 .ok_or_else(|| Error::Other("Missing callback_id in hook_callback".to_string()))?;
1086 let callback = state
1087 .hook_callbacks
1088 .lock()
1089 .await
1090 .get(callback_id)
1091 .cloned()
1092 .ok_or_else(|| Error::Other(format!("No hook callback found for ID: {callback_id}")))?;
1093 let input = request_data.get("input").cloned().unwrap_or(Value::Null);
1094 let tool_use_id = request_data
1095 .get("tool_use_id")
1096 .and_then(Value::as_str)
1097 .map(ToString::to_string);
1098 let callback_future = panic::catch_unwind(AssertUnwindSafe(|| {
1099 callback(input, tool_use_id, Default::default())
1100 }))
1101 .map_err(|payload| callback_panic_error("hook", payload))?;
1102 let output = await_callback_with_panic_isolation("hook", callback_future).await?;
1103 Ok(convert_hook_output_for_cli(output))
1104}
1105
1106async fn handle_mcp_message_request(
1107 state: &QuerySharedState,
1108 request_data: &Map<String, Value>,
1109) -> Result<Value> {
1110 let server_name = request_data
1111 .get("server_name")
1112 .and_then(Value::as_str)
1113 .ok_or_else(|| Error::Other("Missing server_name in mcp_message".to_string()))?;
1114 let message = request_data
1115 .get("message")
1116 .cloned()
1117 .ok_or_else(|| Error::Other("Missing message in mcp_message".to_string()))?;
1118 let response = handle_sdk_mcp_request(&state.sdk_mcp_servers, server_name, &message).await;
1119 Ok(json!({ "mcp_response": response }))
1120}
1121
1122async fn handle_control_request(state: &QuerySharedState, request: Value) -> Result<()> {
1124 let Some(request_obj) = request.as_object() else {
1125 return Err(Error::Other("Invalid control request format".to_string()));
1126 };
1127 let request_id = request_obj
1128 .get("request_id")
1129 .and_then(Value::as_str)
1130 .ok_or_else(|| Error::Other("Missing request_id in control request".to_string()))?
1131 .to_string();
1132 let request_data = request_obj
1133 .get("request")
1134 .and_then(Value::as_object)
1135 .ok_or_else(|| Error::Other("Missing request payload".to_string()))?;
1136 let subtype = request_data
1137 .get("subtype")
1138 .and_then(Value::as_str)
1139 .ok_or_else(|| Error::Other("Missing request subtype".to_string()))?;
1140
1141 let result: Result<Value> = match subtype {
1142 "can_use_tool" => handle_can_use_tool_request(state, request_data).await,
1143 "hook_callback" => handle_hook_callback_request(state, request_data).await,
1144 "mcp_message" => handle_mcp_message_request(state, request_data).await,
1145 _ => Err(Error::Other(format!(
1146 "Unsupported control request subtype: {subtype}"
1147 ))),
1148 };
1149
1150 let response_json = match result {
1151 Ok(payload) => json!({
1152 "type": "control_response",
1153 "response": {
1154 "subtype": "success",
1155 "request_id": request_id,
1156 "response": payload
1157 }
1158 }),
1159 Err(err) => json!({
1160 "type": "control_response",
1161 "response": {
1162 "subtype": "error",
1163 "request_id": request_id,
1164 "error": err.to_string()
1165 }
1166 }),
1167 };
1168
1169 state
1170 .writer
1171 .lock()
1172 .await
1173 .write(&(response_json.to_string() + "\n"))
1174 .await
1175}
1176
1177pub async fn handle_sdk_mcp_request(
1214 sdk_mcp_servers: &HashMap<String, Arc<McpSdkServer>>,
1215 server_name: &str,
1216 message: &Value,
1217) -> Value {
1218 let Some(server) = sdk_mcp_servers.get(server_name) else {
1219 return json!({
1220 "jsonrpc": "2.0",
1221 "id": message.get("id").cloned().unwrap_or(Value::Null),
1222 "error": {
1223 "code": -32601,
1224 "message": format!("Server '{server_name}' not found")
1225 }
1226 });
1227 };
1228
1229 let method = message
1230 .get("method")
1231 .and_then(Value::as_str)
1232 .unwrap_or_default();
1233 let id = message.get("id").cloned().unwrap_or(Value::Null);
1234 let params = message.get("params").cloned().unwrap_or_else(|| json!({}));
1235
1236 match method {
1237 "initialize" => json!({
1238 "jsonrpc": "2.0",
1239 "id": id,
1240 "result": {
1241 "protocolVersion": "2024-11-05",
1242 "capabilities": {"tools": {}},
1243 "serverInfo": {
1244 "name": server.name,
1245 "version": server.version
1246 }
1247 }
1248 }),
1249 "tools/list" => json!({
1250 "jsonrpc": "2.0",
1251 "id": id,
1252 "result": {
1253 "tools": server.list_tools_json()
1254 }
1255 }),
1256 "tools/call" => {
1257 let tool_name = params
1258 .get("name")
1259 .and_then(Value::as_str)
1260 .unwrap_or_default();
1261 let arguments = params
1262 .get("arguments")
1263 .cloned()
1264 .unwrap_or_else(|| json!({}));
1265 let result = server.call_tool_json(tool_name, arguments).await;
1266 json!({
1267 "jsonrpc": "2.0",
1268 "id": id,
1269 "result": result
1270 })
1271 }
1272 "notifications/initialized" => json!({
1273 "jsonrpc": "2.0",
1274 "result": {}
1275 }),
1276 _ => json!({
1277 "jsonrpc": "2.0",
1278 "id": id,
1279 "error": {
1280 "code": -32601,
1281 "message": format!("Method '{method}' not found")
1282 }
1283 }),
1284 }
1285}
1286
1287pub(crate) fn build_hooks_config(
1294 hooks: &HashMap<String, Vec<HookMatcher>>,
1295) -> (Map<String, Value>, HashMap<String, HookCallback>) {
1296 let mut hooks_config = Map::new();
1297 let mut hook_callbacks = HashMap::new();
1298 let mut next_callback_id: usize = 0;
1299
1300 for (event, matchers) in hooks {
1301 if matchers.is_empty() {
1302 continue;
1303 }
1304 let mut event_matchers = Vec::new();
1305 for matcher in matchers {
1306 let mut callback_ids = Vec::new();
1307 for callback in &matcher.hooks {
1308 let callback_id = format!("hook_{}", next_callback_id);
1309 next_callback_id += 1;
1310 hook_callbacks.insert(callback_id.clone(), callback.clone());
1311 callback_ids.push(callback_id);
1312 }
1313
1314 let mut matcher_obj = Map::new();
1315 matcher_obj.insert(
1316 "matcher".to_string(),
1317 matcher
1318 .matcher
1319 .as_ref()
1320 .map(|m| Value::String(m.clone()))
1321 .unwrap_or(Value::Null),
1322 );
1323 matcher_obj.insert("hookCallbackIds".to_string(), json!(callback_ids));
1324 if let Some(timeout) = matcher.timeout {
1325 matcher_obj.insert("timeout".to_string(), json!(timeout));
1326 }
1327 event_matchers.push(Value::Object(matcher_obj));
1328 }
1329 hooks_config.insert(event.clone(), Value::Array(event_matchers));
1330 }
1331
1332 (hooks_config, hook_callbacks)
1333}