cc_sdk/
internal_query.rs

1//! Internal query implementation with control protocol support
2//!
3//! This module provides the internal Query struct that handles control protocol,
4//! permissions, hooks, and MCP server integration.
5
6use crate::{
7    errors::{Result, SdkError},
8    transport::{InputMessage, Transport},
9    types::{
10        CanUseTool, HookCallback, HookContext, HookMatcher, Message, PermissionResult, PermissionUpdate,
11        SDKControlInitializeRequest, SDKControlInterruptRequest, SDKControlPermissionRequest,
12        SDKControlRequest, SDKHookCallbackRequest, SDKControlSetPermissionModeRequest,
13        ToolPermissionContext,
14    },
15};
16use futures::stream::Stream;
17use futures::StreamExt;
18use serde_json::Value as JsonValue;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use tokio::time::{timeout, Duration};
23use tracing::{debug, error, warn};
24
25/// Internal query handler with control protocol support
26pub struct Query {
27    /// Transport layer (shared with client)
28    transport: Arc<Mutex<Box<dyn Transport + Send>>>,
29    /// Whether in streaming mode
30    #[allow(dead_code)]
31    is_streaming_mode: bool,
32    /// Tool permission callback
33    can_use_tool: Option<Arc<dyn CanUseTool>>,
34    /// Hook configurations
35    hooks: Option<HashMap<String, Vec<HookMatcher>>>,
36    /// SDK MCP servers
37    sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
38    /// Message channel sender (reserved for future streaming receive support)
39    #[allow(dead_code)]
40    message_tx: mpsc::Sender<Result<Message>>,
41    /// Message channel receiver (reserved for future streaming receive support)
42    #[allow(dead_code)]
43    message_rx: Option<mpsc::Receiver<Result<Message>>>,
44    /// Initialization result
45    initialization_result: Option<JsonValue>,
46    /// Active hook callbacks
47    hook_callbacks: Arc<RwLock<HashMap<String, Arc<dyn HookCallback>>>>,
48    /// Hook callback counter
49    callback_counter: Arc<Mutex<u64>>,
50    /// Request counter for generating unique IDs
51    request_counter: Arc<Mutex<u64>>,
52    /// Pending control request responses
53    pending_responses: Arc<RwLock<HashMap<String, tokio::sync::oneshot::Sender<JsonValue>>>>,
54}
55
56impl Query {
57    /// Create a new Query handler
58    pub fn new(
59        transport: Arc<Mutex<Box<dyn Transport + Send>>>,
60        is_streaming_mode: bool,
61        can_use_tool: Option<Arc<dyn CanUseTool>>,
62        hooks: Option<HashMap<String, Vec<HookMatcher>>>,
63        sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
64    ) -> Self {
65        let (tx, rx) = mpsc::channel(100);
66
67        Self {
68            transport,
69            is_streaming_mode,
70            can_use_tool,
71            hooks,
72            sdk_mcp_servers,
73            message_tx: tx,
74            message_rx: Some(rx),
75            initialization_result: None,
76            hook_callbacks: Arc::new(RwLock::new(HashMap::new())),
77            callback_counter: Arc::new(Mutex::new(0)),
78            request_counter: Arc::new(Mutex::new(0)),
79            pending_responses: Arc::new(RwLock::new(HashMap::new())),
80        }
81    }
82
83    /// Test helper to register a hook callback with a known ID
84    ///
85    /// This is intended for E2E tests to inject a callback ID that can be
86    /// referenced by inbound `hook_callback` control messages.
87    pub async fn register_hook_callback_for_test(
88        &self,
89        callback_id: String,
90        callback: Arc<dyn HookCallback>,
91    ) {
92        let mut map = self.hook_callbacks.write().await;
93        map.insert(callback_id, callback);
94    }
95
96    /// Start the query handler
97    pub async fn start(&mut self) -> Result<()> {
98        // Start control request handler task
99        self.start_control_handler().await;
100
101        // Start SDK message forwarder task (route non-control messages to message_tx)
102        let transport = self.transport.clone();
103        let tx = self.message_tx.clone();
104        tokio::spawn(async move {
105            loop {
106                let next = {
107                    let mut guard = transport.lock().await;
108                    let mut stream = guard.receive_messages();
109                    stream.next().await
110                };
111
112                match next {
113                    Some(Ok(msg)) => {
114                        if tx.send(Ok(msg)).await.is_err() { break; }
115                    }
116                    Some(Err(e)) => {
117                        let _ = tx.send(Err(e)).await;
118                        break;
119                    }
120                    None => {
121                        // No message available; yield
122                        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
123                    }
124                }
125            }
126        });
127        Ok(())
128    }
129
130    /// Initialize the control protocol
131    pub async fn initialize(&mut self) -> Result<()> {
132        // Build hooks with callback IDs (Python SDK style)
133        let hooks_with_ids = if let Some(ref hooks) = self.hooks {
134            let mut counter = self.callback_counter.lock().await;
135            let mut callbacks_map = self.hook_callbacks.write().await;
136
137            let hooks_json: HashMap<String, JsonValue> = hooks
138                .iter()
139                .map(|(event_name, matchers)| {
140                    let matchers_with_ids: Vec<JsonValue> = matchers
141                        .iter()
142                        .map(|matcher| {
143                            // Generate callback IDs for each hook in this matcher
144                            let callback_ids: Vec<String> = matcher
145                                .hooks
146                                .iter()
147                                .map(|hook_callback| {
148                                    *counter += 1;
149                                    let callback_id = format!("hook_{}_{}", *counter, uuid::Uuid::new_v4().simple());
150
151                                    // Store the callback for later use
152                                    callbacks_map.insert(callback_id.clone(), hook_callback.clone());
153
154                                    callback_id
155                                })
156                                .collect();
157
158                            serde_json::json!({
159                                "matcher": matcher.matcher.clone(),
160                                "hookCallbackIds": callback_ids
161                            })
162                        })
163                        .collect();
164
165                    (event_name.clone(), serde_json::json!(matchers_with_ids))
166                })
167                .collect();
168
169            Some(hooks_json)
170        } else {
171            None
172        };
173
174        // Send initialize request
175        let init_request = SDKControlRequest::Initialize(SDKControlInitializeRequest {
176            subtype: "initialize".to_string(),
177            hooks: hooks_with_ids,
178        });
179
180        // Send control request and save result
181        let result = self.send_control_request(init_request).await?;
182        self.initialization_result = Some(result);
183
184        debug!("Initialization request sent with hook callback IDs");
185        Ok(())
186    }
187
188    /// Send a control request and wait for response
189    async fn send_control_request(&mut self, request: SDKControlRequest) -> Result<JsonValue> {
190        // Generate unique request ID
191        let request_id = {
192            let mut counter = self.request_counter.lock().await;
193            *counter += 1;
194            format!("req_{}_{}", *counter, uuid::Uuid::new_v4().simple())
195        };
196
197        // Create oneshot channel for response
198        let (tx, rx) = tokio::sync::oneshot::channel();
199
200        // Register pending response
201        {
202            let mut pending = self.pending_responses.write().await;
203            pending.insert(request_id.clone(), tx);
204        }
205
206        // Build control request with request_id (snake_case for CLI compatibility)
207        let control_request = serde_json::json!({
208            "type": "control_request",
209            "request_id": request_id,
210            "request": request
211        });
212
213        debug!("Sending control request: {:?}", control_request);
214
215        // Send via transport
216        {
217            let mut transport = self.transport.lock().await;
218            transport.send_sdk_control_request(control_request).await?;
219        }
220
221        // Wait for response with timeout
222        match timeout(Duration::from_secs(60), rx).await {
223            Ok(Ok(response)) => {
224                debug!("Received control response for {}", request_id);
225                Ok(response)
226            }
227            Ok(Err(_)) => Err(SdkError::ControlRequestError(
228                "Response channel closed".to_string(),
229            )),
230            Err(_) => {
231                // Clean up pending response
232                let mut pending = self.pending_responses.write().await;
233                pending.remove(&request_id);
234                Err(SdkError::Timeout { seconds: 60 })
235            }
236        }
237    }
238
239    /// Handle permission request
240    #[allow(dead_code)]
241    async fn handle_permission_request(&mut self, request: SDKControlPermissionRequest) -> Result<()> {
242        if let Some(ref can_use_tool) = self.can_use_tool {
243            let context = ToolPermissionContext {
244                signal: None,
245                suggestions: request.permission_suggestions.unwrap_or_default(),
246            };
247
248            let result = can_use_tool
249                .can_use_tool(&request.tool_name, &request.input, &context)
250                .await;
251
252            // Send response back (CLI expects: { allow: bool, input?, reason? })
253            let response = match result {
254                PermissionResult::Allow(allow) => {
255                    let mut obj = serde_json::json!({ "allow": true });
256                    if let Some(updated) = allow.updated_input {
257                        obj["input"] = updated;
258                    }
259                    obj
260                }
261                PermissionResult::Deny(deny) => {
262                    let mut obj = serde_json::json!({ "allow": false });
263                    if !deny.message.is_empty() {
264                        obj["reason"] = serde_json::json!(deny.message);
265                    }
266                    obj
267                }
268            };
269
270            // Send response back through transport
271            let mut transport = self.transport.lock().await;
272            transport.send_sdk_control_response(response).await?;
273            debug!("Permission response sent");
274        }
275        
276        Ok(())
277    }
278
279    /// Extract requestId from CLI message (supports both camelCase and snake_case)
280    fn extract_request_id(msg: &JsonValue) -> Option<JsonValue> {
281        msg.get("requestId")
282            .or_else(|| msg.get("request_id"))
283            .cloned()
284    }
285
286    /// Start control request handler task
287    async fn start_control_handler(&mut self) {
288        let transport = self.transport.clone();
289        let can_use_tool = self.can_use_tool.clone();
290        let hook_callbacks = self.hook_callbacks.clone();
291        let sdk_mcp_servers = self.sdk_mcp_servers.clone();
292        let pending_responses = self.pending_responses.clone();
293        
294        // Take ownership of the SDK control receiver to avoid holding locks
295        let sdk_control_rx = {
296            let mut transport_lock = transport.lock().await;
297            transport_lock.take_sdk_control_receiver()
298        }; // Lock released here
299        
300        if let Some(mut control_rx) = sdk_control_rx {
301            tokio::spawn(async move {
302                // Now we can receive control requests without holding any locks
303                let transport_for_control = transport;
304                let can_use_tool_clone = can_use_tool;
305                let hook_callbacks_clone = hook_callbacks;
306                let sdk_mcp_servers_clone = sdk_mcp_servers;
307                let pending_responses_clone = pending_responses;
308
309                loop {
310                    // Receive control request without holding lock
311                    let control_message = control_rx.recv().await;
312
313                    if let Some(control_message) = control_message {
314                        debug!("Received control message: {:?}", control_message);
315
316                        // Check if this is a control response (from CLI to SDK)
317                        if control_message.get("type").and_then(|v| v.as_str()) == Some("control_response") {
318                            // Expected shape: {"type":"control_response", "response": {"request_id": "...", ...}}
319                            if let Some(resp_obj) = control_message.get("response") {
320                                let request_id = resp_obj
321                                    .get("request_id")
322                                    .or_else(|| resp_obj.get("requestId"))
323                                    .and_then(|v| v.as_str());
324
325                                if let Some(request_id) = request_id {
326                                    let mut pending = pending_responses_clone.write().await;
327                                    if let Some(tx) = pending.remove(request_id) {
328                                        // Deliver only the nested "response" object (matches Python SDK semantics)
329                                        let _ = tx.send(resp_obj.clone());
330                                        debug!("Control response delivered for request_id: {}", request_id);
331                                    } else {
332                                        warn!("No pending request found for request_id: {}", request_id);
333                                    }
334                                } else {
335                                    warn!("Control response missing request_id: {:?}", control_message);
336                                }
337                            } else {
338                                warn!("Control response missing 'response' payload: {:?}", control_message);
339                            }
340                            continue;
341                        }
342
343                        // Parse and handle control requests (from CLI to SDK)
344                        // Check if this is a control_request with a nested request field
345                        let request_data = if control_message.get("type").and_then(|v| v.as_str()) == Some("control_request") {
346                            control_message.get("request").cloned().unwrap_or(control_message.clone())
347                        } else {
348                            control_message.clone()
349                        };
350
351                        if let Some(subtype) = request_data.get("subtype").and_then(|v| v.as_str()) {
352                            match subtype {
353                                "can_use_tool" => {
354                                    // Handle permission request
355                                    if let Ok(request) = serde_json::from_value::<SDKControlPermissionRequest>(request_data.clone()) {
356                                        // Handle with can_use_tool callback
357                                        if let Some(ref can_use_tool) = can_use_tool_clone {
358                                            let context = ToolPermissionContext {
359                                                signal: None,
360                                                suggestions: request.permission_suggestions.unwrap_or_default(),
361                                            };
362                                                
363                                            let result = can_use_tool
364                                                .can_use_tool(&request.tool_name, &request.input, &context)
365                                                .await;
366                                                
367                                            // CLI expects: {"allow": true, "input": ...} or {"allow": false, "reason": ...}
368                                            let permission_response = match result {
369                                                PermissionResult::Allow(allow) => {
370                                                    let mut resp = serde_json::json!({
371                                                        "allow": true,
372                                                    });
373                                                    if let Some(input) = allow.updated_input {
374                                                        resp["input"] = input;
375                                                    }
376                                                    if let Some(perms) = allow.updated_permissions {
377                                                        resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default();
378                                                    }
379                                                    resp
380                                                }
381                                                PermissionResult::Deny(deny) => {
382                                                    let mut resp = serde_json::json!({
383                                                        "allow": false,
384                                                    });
385                                                    if !deny.message.is_empty() {
386                                                        resp["reason"] = serde_json::json!(deny.message);
387                                                    }
388                                                    if deny.interrupt {
389                                                        resp["interrupt"] = serde_json::json!(true);
390                                                    }
391                                                    resp
392                                                }
393                                            };
394                                                
395                                            // Wrap response with proper structure
396                                            // CLI expects "subtype": "success" for all successful responses
397                                            let response = serde_json::json!({
398                                                "subtype": "success",
399                                                "request_id": Self::extract_request_id(&control_message),
400                                                "response": permission_response
401                                            });
402                                                
403                                            // Send response
404                                            let mut transport = transport_for_control.lock().await;
405                                            if let Err(e) = transport.send_sdk_control_response(response).await {
406                                                error!("Failed to send permission response: {}", e);
407                                            }
408                                        }
409                                    } else {
410                                        // Fallback for snake_case fields (tool_name, permission_suggestions)
411                                        if let Some(tool_name) = request_data.get("tool_name").and_then(|v| v.as_str()) {
412                                            if let Some(input_val) = request_data.get("input").cloned() {
413                                                if let Some(ref can_use_tool) = can_use_tool_clone {
414                                                    // Try to parse permission suggestions (snake_case)
415                                                    let suggestions: Vec<PermissionUpdate> = request_data
416                                                        .get("permission_suggestions")
417                                                        .cloned()
418                                                        .and_then(|v| serde_json::from_value::<Vec<PermissionUpdate>>(v).ok())
419                                                        .unwrap_or_default();
420
421                                                    let context = ToolPermissionContext { signal: None, suggestions };
422                                                    let result = can_use_tool
423                                                        .can_use_tool(tool_name, &input_val, &context)
424                                                        .await;
425
426                                                    let permission_response = match result {
427                                                        PermissionResult::Allow(allow) => {
428                                                            let mut resp = serde_json::json!({ "allow": true });
429                                                            if let Some(input) = allow.updated_input { resp["input"] = input; }
430                                                            if let Some(perms) = allow.updated_permissions { resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default(); }
431                                                            resp
432                                                        }
433                                                        PermissionResult::Deny(deny) => {
434                                                            let mut resp = serde_json::json!({ "allow": false });
435                                                            if !deny.message.is_empty() { resp["reason"] = serde_json::json!(deny.message); }
436                                                            if deny.interrupt { resp["interrupt"] = serde_json::json!(true); }
437                                                            resp
438                                                        }
439                                                    };
440
441                                                    let response = serde_json::json!({
442                                                        "subtype": "success",
443                                                        "request_id": Self::extract_request_id(&control_message),
444                                                        "response": permission_response
445                                                    });
446                                                    let mut transport = transport_for_control.lock().await;
447                                                    if let Err(e) = transport.send_sdk_control_response(response).await {
448                                                        error!("Failed to send permission response (fallback): {}", e);
449                                                    }
450                                                }
451                                            }
452                                        }
453                                    }
454                                }
455                                "hook_callback" => {
456                                    // Handle hook callback
457                                    if let Ok(request) = serde_json::from_value::<SDKHookCallbackRequest>(request_data.clone()) {
458                                        let callbacks = hook_callbacks_clone.read().await;
459                                        
460                                        if let Some(callback) = callbacks.get(&request.callback_id) {
461                                            let context = HookContext { signal: None };
462                                            
463                                            let response = callback
464                                                .execute(&request.input, request.tool_use_id.as_deref(), &context)
465                                                .await;
466                                            
467                                            // Send response back through transport
468                                            let response_json = serde_json::json!({
469                                                "subtype": "success",
470                                                "request_id": Self::extract_request_id(&control_message),
471                                                "response": response
472                                            });
473                                            
474                                            let mut transport = transport_for_control.lock().await;
475                                            if let Err(e) = transport.send_sdk_control_response(response_json).await {
476                                                error!("Failed to send hook callback response: {}", e);
477                                            }
478                                        } else {
479                                            warn!("No hook callback found for ID: {}", request.callback_id);
480                                        }
481                                    } else {
482                                        // Fallback for snake_case fields (callback_id, tool_use_id)
483                                        let callback_id = request_data.get("callback_id").and_then(|v| v.as_str());
484                                        let tool_use_id = request_data.get("tool_use_id").and_then(|v| v.as_str()).map(|s| s.to_string());
485                                        let input = request_data.get("input").cloned().unwrap_or(serde_json::json!({}));
486
487                                        if let Some(callback_id) = callback_id {
488                                            let callbacks = hook_callbacks_clone.read().await;
489                                            if let Some(callback) = callbacks.get(callback_id) {
490                                                let context = HookContext { signal: None };
491                                                let response = callback
492                                                    .execute(&input, tool_use_id.as_deref(), &context)
493                                                    .await;
494
495                                                let response_json = serde_json::json!({
496                                                    "subtype": "success",
497                                                    "request_id": Self::extract_request_id(&control_message),
498                                                    "response": response
499                                                });
500                                                let mut transport = transport_for_control.lock().await;
501                                                if let Err(e) = transport.send_sdk_control_response(response_json).await {
502                                                    error!("Failed to send hook callback response (fallback): {}", e);
503                                                }
504                                            } else {
505                                                warn!("No hook callback found for ID: {}", callback_id);
506                                            }
507                                        } else {
508                                            warn!("Invalid hook_callback control message: missing callback_id");
509                                        }
510                                    }
511                                }
512                                "mcp_message" => {
513                                    // Handle MCP message
514                                    if let Some(server_name) = request_data.get("server_name").and_then(|v| v.as_str()) {
515                                        if let Some(message) = request_data.get("message") {
516                                            debug!("Processing MCP message for SDK server: {}", server_name);
517
518                                            // Check if we have an SDK server with this name
519                                            if let Some(server_arc) = sdk_mcp_servers_clone.get(server_name) {
520                                                // Try to downcast to SdkMcpServer
521                                                if let Some(sdk_server) = server_arc.downcast_ref::<crate::sdk_mcp::SdkMcpServer>() {
522                                                    // Call the SDK MCP server
523                                                    match sdk_server.handle_message(message.clone()).await {
524                                                        Ok(mcp_result) => {
525                                                            // Wrap response with proper structure
526                                                            let response = serde_json::json!({
527                                                                "subtype": "success",
528                                                                "request_id": Self::extract_request_id(&control_message),
529                                                                "response": {
530                                                                    "mcp_response": mcp_result
531                                                                }
532                                                            });
533
534                                                            let mut transport = transport_for_control.lock().await;
535                                                            if let Err(e) = transport.send_sdk_control_response(response).await {
536                                                                error!("Failed to send MCP response: {}", e);
537                                                            }
538                                                        }
539                                                        Err(e) => {
540                                                            error!("SDK MCP server error: {}", e);
541                                                            let error_response = serde_json::json!({
542                                                                "subtype": "error",
543                                                                "request_id": Self::extract_request_id(&control_message),
544                                                                "error": format!("MCP server error: {}", e)
545                                                            });
546
547                                                            let mut transport = transport_for_control.lock().await;
548                                                            if let Err(e) = transport.send_sdk_control_response(error_response).await {
549                                                                error!("Failed to send MCP error response: {}", e);
550                                                            }
551                                                        }
552                                                    }
553                                                } else {
554                                                    warn!("SDK server '{}' is not of type SdkMcpServer", server_name);
555                                                }
556                                            } else {
557                                                warn!("No SDK MCP server found with name: {}", server_name);
558                                                let error_response = serde_json::json!({
559                                                    "subtype": "error",
560                                                    "request_id": Self::extract_request_id(&control_message),
561                                                    "error": format!("Server '{}' not found", server_name)
562                                                });
563
564                                                let mut transport = transport_for_control.lock().await;
565                                                if let Err(e) = transport.send_sdk_control_response(error_response).await {
566                                                    error!("Failed to send MCP error response: {}", e);
567                                                }
568                                            }
569                                        }
570                                    }
571                                }
572                                _ => {
573                                    debug!("Unknown SDK control subtype: {}", subtype);
574                                }
575                            }
576                        }
577                    }
578                }
579            });
580        }
581    }
582
583    /// Stream input messages to the CLI stdin by converting JSON values to InputMessage
584    #[allow(dead_code)]
585    pub async fn stream_input<S>(&mut self, input_stream: S) -> Result<()>
586    where
587        S: Stream<Item = JsonValue> + Send + 'static,
588    {
589        let transport = self.transport.clone();
590
591        tokio::spawn(async move {
592            use futures::StreamExt;
593            let mut stream = Box::pin(input_stream);
594
595            while let Some(value) = stream.next().await {
596                // Best-effort conversion from arbitrary JSON to InputMessage
597                let input_msg_opt = Self::json_to_input_message(value);
598                if let Some(input_msg) = input_msg_opt {
599                    let mut guard = transport.lock().await;
600                    if let Err(e) = guard.send_message(input_msg).await {
601                        warn!("Failed to send streaming input message: {}", e);
602                    }
603                } else {
604                    warn!("Invalid streaming input JSON; expected user message shape");
605                }
606            }
607
608            // After streaming all inputs, signal end of input
609            let mut guard = transport.lock().await;
610            if let Err(e) = guard.end_input().await {
611                warn!("Failed to signal end_input: {}", e);
612            }
613        });
614        Ok(())
615    }
616
617    /// Receive messages
618    #[allow(dead_code)]
619    pub async fn receive_messages(&mut self) -> mpsc::Receiver<Result<Message>> {
620        self.message_rx.take().expect("Receiver already taken")
621    }
622
623    /// Send interrupt request
624    pub async fn interrupt(&mut self) -> Result<()> {
625        let interrupt_request = SDKControlRequest::Interrupt(SDKControlInterruptRequest {
626            subtype: "interrupt".to_string(),
627        });
628
629        self.send_control_request(interrupt_request).await?;
630        Ok(())
631    }
632
633    /// Set permission mode via control protocol
634    #[allow(dead_code)]
635    pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
636        let req = SDKControlRequest::SetPermissionMode(SDKControlSetPermissionModeRequest {
637            subtype: "set_permission_mode".to_string(),
638            mode: mode.to_string(),
639        });
640        // Ignore response payload; errors propagate
641        let _ = self.send_control_request(req).await?;
642        Ok(())
643    }
644
645    /// Set the active model via control protocol
646    #[allow(dead_code)]
647    pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
648        let req = SDKControlRequest::SetModel(crate::types::SDKControlSetModelRequest {
649            subtype: "set_model".to_string(),
650            model,
651        });
652        let _ = self.send_control_request(req).await?;
653        Ok(())
654    }
655
656    /// Handle MCP message for SDK servers
657    #[allow(dead_code)]
658    async fn handle_mcp_message(&mut self, server_name: &str, message: &JsonValue) -> Result<JsonValue> {
659        // Check if we have an SDK server with this name
660        if let Some(_server) = self.sdk_mcp_servers.get(server_name) {
661            // TODO: Implement actual MCP server invocation
662            // For now, return a placeholder response
663            debug!("Handling MCP message for SDK server {}: {:?}", server_name, message);
664            Ok(serde_json::json!({
665                "jsonrpc": "2.0",
666                "id": message.get("id"),
667                "result": {
668                    "content": "MCP server response placeholder"
669                }
670            }))
671        } else {
672            Err(SdkError::InvalidState {
673                message: format!("No SDK MCP server found with name: {}", server_name),
674            })
675        }
676    }
677
678    /// Close the query handler
679    #[allow(dead_code)]
680    pub async fn close(&mut self) -> Result<()> {
681        // Clean up resources
682        let mut transport = self.transport.lock().await;
683        transport.disconnect().await?;
684        Ok(())
685    }
686
687    /// Get initialization result
688    pub fn get_initialization_result(&self) -> Option<&JsonValue> {
689        self.initialization_result.as_ref()
690    }
691
692    /// Convert arbitrary JSON value to InputMessage understood by CLI
693    #[allow(dead_code)]
694    fn json_to_input_message(v: JsonValue) -> Option<InputMessage> {
695        // 1) Already in SDK message shape
696        if let Some(obj) = v.as_object() {
697            if let (Some(t), Some(message)) = (obj.get("type"), obj.get("message")) {
698                if t.as_str() == Some("user") {
699                    let parent = obj
700                        .get("parent_tool_use_id")
701                        .and_then(|p| p.as_str().map(|s| s.to_string()));
702                    let session_id = obj
703                        .get("session_id")
704                        .and_then(|s| s.as_str())
705                        .unwrap_or("default")
706                        .to_string();
707
708                    let im = InputMessage {
709                        r#type: "user".to_string(),
710                        message: message.clone(),
711                        parent_tool_use_id: parent,
712                        session_id,
713                    };
714                    return Some(im);
715                }
716            }
717
718            // 2) Simple wrapper: {"content":"...", "session_id":"..."}
719            if let Some(content) = obj.get("content").and_then(|c| c.as_str()) {
720                let session_id = obj
721                    .get("session_id")
722                    .and_then(|s| s.as_str())
723                    .unwrap_or("default")
724                    .to_string();
725                return Some(InputMessage::user(content.to_string(), session_id));
726            }
727        }
728
729        // 3) Bare string
730        if let Some(s) = v.as_str() {
731            return Some(InputMessage::user(s.to_string(), "default".to_string()));
732        }
733
734        None
735    }
736}
737
738#[cfg(test)]
739mod tests {
740    use super::*;
741
742    #[test]
743    fn test_extract_request_id_supports_both_cases() {
744        let snake = serde_json::json!({"request_id": "req_1"});
745        let camel = serde_json::json!({"requestId": "req_2"});
746        assert_eq!(Query::extract_request_id(&snake), Some(serde_json::json!("req_1")));
747        assert_eq!(Query::extract_request_id(&camel), Some(serde_json::json!("req_2")));
748    }
749
750    #[test]
751    fn test_json_to_input_message_from_string() {
752        let v = serde_json::json!("Hello");
753        let im = Query::json_to_input_message(v).expect("should convert");
754        assert_eq!(im.r#type, "user");
755        assert_eq!(im.session_id, "default");
756        assert_eq!(im.message["content"].as_str().unwrap(), "Hello");
757    }
758
759    #[test]
760    fn test_json_to_input_message_from_object_content() {
761        let v = serde_json::json!({"content":"Ping","session_id":"s1"});
762        let im = Query::json_to_input_message(v).expect("should convert");
763        assert_eq!(im.session_id, "s1");
764        assert_eq!(im.message["content"].as_str().unwrap(), "Ping");
765    }
766
767    #[test]
768    fn test_json_to_input_message_full_user_shape() {
769        let v = serde_json::json!({
770            "type":"user",
771            "message": {"role":"user","content":"Hi"},
772            "session_id": "abc",
773            "parent_tool_use_id": null
774        });
775        let im = Query::json_to_input_message(v).expect("should convert");
776        assert_eq!(im.session_id, "abc");
777        assert_eq!(im.message["role"].as_str().unwrap(), "user");
778        assert_eq!(im.message["content"].as_str().unwrap(), "Hi");
779    }
780}