cc_sdk/
internal_query.rs

1//! Internal query implementation with control protocol support
2//!
3//! This module provides the internal Query struct that handles control protocol,
4//! permissions, hooks, and MCP server integration.
5
6use crate::{
7    errors::{Result, SdkError},
8    transport::{InputMessage, Transport},
9    types::{
10        CanUseTool, HookCallback, HookContext, HookMatcher, Message, PermissionResult, PermissionUpdate,
11        SDKControlInitializeRequest, SDKControlInterruptRequest, SDKControlPermissionRequest,
12        SDKControlRequest, SDKHookCallbackRequest, SDKControlSetPermissionModeRequest,
13        ToolPermissionContext,
14    },
15};
16use futures::stream::Stream;
17use futures::StreamExt;
18use serde_json::Value as JsonValue;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use tokio::time::{timeout, Duration};
23use tracing::{debug, error, warn};
24
25/// Internal query handler with control protocol support
26pub struct Query {
27    /// Transport layer (shared with client)
28    transport: Arc<Mutex<Box<dyn Transport + Send>>>,
29    /// Whether in streaming mode
30    #[allow(dead_code)]
31    is_streaming_mode: bool,
32    /// Tool permission callback
33    can_use_tool: Option<Arc<dyn CanUseTool>>,
34    /// Hook configurations
35    hooks: Option<HashMap<String, Vec<HookMatcher>>>,
36    /// SDK MCP servers
37    sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
38    /// Message channel sender (reserved for future streaming receive support)
39    #[allow(dead_code)]
40    message_tx: mpsc::Sender<Result<Message>>,
41    /// Message channel receiver (reserved for future streaming receive support)
42    #[allow(dead_code)]
43    message_rx: Option<mpsc::Receiver<Result<Message>>>,
44    /// Initialization result
45    initialization_result: Option<JsonValue>,
46    /// Active hook callbacks
47    hook_callbacks: Arc<RwLock<HashMap<String, Arc<dyn HookCallback>>>>,
48    /// Hook callback counter
49    callback_counter: Arc<Mutex<u64>>,
50    /// Request counter for generating unique IDs
51    request_counter: Arc<Mutex<u64>>,
52    /// Pending control request responses
53    pending_responses: Arc<RwLock<HashMap<String, tokio::sync::oneshot::Sender<JsonValue>>>>,
54}
55
56impl Query {
57    /// Create a new Query handler
58    pub fn new(
59        transport: Arc<Mutex<Box<dyn Transport + Send>>>,
60        is_streaming_mode: bool,
61        can_use_tool: Option<Arc<dyn CanUseTool>>,
62        hooks: Option<HashMap<String, Vec<HookMatcher>>>,
63        sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
64    ) -> Self {
65        let (tx, rx) = mpsc::channel(100);
66
67        Self {
68            transport,
69            is_streaming_mode,
70            can_use_tool,
71            hooks,
72            sdk_mcp_servers,
73            message_tx: tx,
74            message_rx: Some(rx),
75            initialization_result: None,
76            hook_callbacks: Arc::new(RwLock::new(HashMap::new())),
77            callback_counter: Arc::new(Mutex::new(0)),
78            request_counter: Arc::new(Mutex::new(0)),
79            pending_responses: Arc::new(RwLock::new(HashMap::new())),
80        }
81    }
82
83    /// Test helper to register a hook callback with a known ID
84    ///
85    /// This is intended for E2E tests to inject a callback ID that can be
86    /// referenced by inbound `hook_callback` control messages.
87    pub async fn register_hook_callback_for_test(
88        &self,
89        callback_id: String,
90        callback: Arc<dyn HookCallback>,
91    ) {
92        let mut map = self.hook_callbacks.write().await;
93        map.insert(callback_id, callback);
94    }
95
96    /// Start the query handler
97    pub async fn start(&mut self) -> Result<()> {
98        // Start control request handler task
99        self.start_control_handler().await;
100
101        // Start SDK message forwarder task (route non-control messages to message_tx)
102        let transport = self.transport.clone();
103        let tx = self.message_tx.clone();
104        tokio::spawn(async move {
105            // Get message stream once and consume it continuously
106            let mut stream = {
107                let mut guard = transport.lock().await;
108                guard.receive_messages()
109            }; // Lock released immediately after getting stream
110
111            // Continuously consume from the same stream
112            while let Some(result) = stream.next().await {
113                match result {
114                    Ok(msg) => {
115                        if tx.send(Ok(msg)).await.is_err() {
116                            break;
117                        }
118                    }
119                    Err(e) => {
120                        let _ = tx.send(Err(e)).await;
121                        break;
122                    }
123                }
124            }
125        });
126        Ok(())
127    }
128
129    /// Initialize the control protocol
130    pub async fn initialize(&mut self) -> Result<()> {
131        // Build hooks with callback IDs (Python SDK style)
132        let hooks_with_ids = if let Some(ref hooks) = self.hooks {
133            let mut counter = self.callback_counter.lock().await;
134            let mut callbacks_map = self.hook_callbacks.write().await;
135
136            let hooks_json: HashMap<String, JsonValue> = hooks
137                .iter()
138                .map(|(event_name, matchers)| {
139                    let matchers_with_ids: Vec<JsonValue> = matchers
140                        .iter()
141                        .map(|matcher| {
142                            // Generate callback IDs for each hook in this matcher
143                            let callback_ids: Vec<String> = matcher
144                                .hooks
145                                .iter()
146                                .map(|hook_callback| {
147                                    *counter += 1;
148                                    let callback_id = format!("hook_{}_{}", *counter, uuid::Uuid::new_v4().simple());
149
150                                    // Store the callback for later use
151                                    callbacks_map.insert(callback_id.clone(), hook_callback.clone());
152
153                                    callback_id
154                                })
155                                .collect();
156
157                            serde_json::json!({
158                                "matcher": matcher.matcher.clone(),
159                                "hookCallbackIds": callback_ids
160                            })
161                        })
162                        .collect();
163
164                    (event_name.clone(), serde_json::json!(matchers_with_ids))
165                })
166                .collect();
167
168            Some(hooks_json)
169        } else {
170            None
171        };
172
173        // Send initialize request
174        let init_request = SDKControlRequest::Initialize(SDKControlInitializeRequest {
175            subtype: "initialize".to_string(),
176            hooks: hooks_with_ids,
177        });
178
179        // Send control request and save result
180        let result = self.send_control_request(init_request).await?;
181        self.initialization_result = Some(result);
182
183        debug!("Initialization request sent with hook callback IDs");
184        Ok(())
185    }
186
187    /// Send a control request and wait for response
188    async fn send_control_request(&mut self, request: SDKControlRequest) -> Result<JsonValue> {
189        // Generate unique request ID
190        let request_id = {
191            let mut counter = self.request_counter.lock().await;
192            *counter += 1;
193            format!("req_{}_{}", *counter, uuid::Uuid::new_v4().simple())
194        };
195
196        // Create oneshot channel for response
197        let (tx, rx) = tokio::sync::oneshot::channel();
198
199        // Register pending response
200        {
201            let mut pending = self.pending_responses.write().await;
202            pending.insert(request_id.clone(), tx);
203        }
204
205        // Build control request with request_id (snake_case for CLI compatibility)
206        let control_request = serde_json::json!({
207            "type": "control_request",
208            "request_id": request_id,
209            "request": request
210        });
211
212        debug!("Sending control request: {:?}", control_request);
213
214        // Send via transport
215        {
216            let mut transport = self.transport.lock().await;
217            transport.send_sdk_control_request(control_request).await?;
218        }
219
220        // Wait for response with timeout
221        match timeout(Duration::from_secs(60), rx).await {
222            Ok(Ok(response)) => {
223                debug!("Received control response for {}", request_id);
224                Ok(response)
225            }
226            Ok(Err(_)) => Err(SdkError::ControlRequestError(
227                "Response channel closed".to_string(),
228            )),
229            Err(_) => {
230                // Clean up pending response
231                let mut pending = self.pending_responses.write().await;
232                pending.remove(&request_id);
233                Err(SdkError::Timeout { seconds: 60 })
234            }
235        }
236    }
237
238    /// Handle permission request
239    #[allow(dead_code)]
240    async fn handle_permission_request(&mut self, request: SDKControlPermissionRequest) -> Result<()> {
241        if let Some(ref can_use_tool) = self.can_use_tool {
242            let context = ToolPermissionContext {
243                signal: None,
244                suggestions: request.permission_suggestions.unwrap_or_default(),
245            };
246
247            let result = can_use_tool
248                .can_use_tool(&request.tool_name, &request.input, &context)
249                .await;
250
251            // Send response back (CLI expects: { allow: bool, input?, reason? })
252            let response = match result {
253                PermissionResult::Allow(allow) => {
254                    let mut obj = serde_json::json!({ "allow": true });
255                    if let Some(updated) = allow.updated_input {
256                        obj["input"] = updated;
257                    }
258                    obj
259                }
260                PermissionResult::Deny(deny) => {
261                    let mut obj = serde_json::json!({ "allow": false });
262                    if !deny.message.is_empty() {
263                        obj["reason"] = serde_json::json!(deny.message);
264                    }
265                    obj
266                }
267            };
268
269            // Send response back through transport
270            let mut transport = self.transport.lock().await;
271            transport.send_sdk_control_response(response).await?;
272            debug!("Permission response sent");
273        }
274        
275        Ok(())
276    }
277
278    /// Extract requestId from CLI message (supports both camelCase and snake_case)
279    fn extract_request_id(msg: &JsonValue) -> Option<JsonValue> {
280        msg.get("requestId")
281            .or_else(|| msg.get("request_id"))
282            .cloned()
283    }
284
285    /// Start control request handler task
286    async fn start_control_handler(&mut self) {
287        let transport = self.transport.clone();
288        let can_use_tool = self.can_use_tool.clone();
289        let hook_callbacks = self.hook_callbacks.clone();
290        let sdk_mcp_servers = self.sdk_mcp_servers.clone();
291        let pending_responses = self.pending_responses.clone();
292        
293        // Take ownership of the SDK control receiver to avoid holding locks
294        let sdk_control_rx = {
295            let mut transport_lock = transport.lock().await;
296            transport_lock.take_sdk_control_receiver()
297        }; // Lock released here
298        
299        if let Some(mut control_rx) = sdk_control_rx {
300            tokio::spawn(async move {
301                // Now we can receive control requests without holding any locks
302                let transport_for_control = transport;
303                let can_use_tool_clone = can_use_tool;
304                let hook_callbacks_clone = hook_callbacks;
305                let sdk_mcp_servers_clone = sdk_mcp_servers;
306                let pending_responses_clone = pending_responses;
307
308                loop {
309                    // Receive control request without holding lock
310                    let control_message = control_rx.recv().await;
311
312                    if let Some(control_message) = control_message {
313                        debug!("Received control message: {:?}", control_message);
314
315                        // Check if this is a control response (from CLI to SDK)
316                        if control_message.get("type").and_then(|v| v.as_str()) == Some("control_response") {
317                            // Expected shape: {"type":"control_response", "response": {"request_id": "...", ...}}
318                            if let Some(resp_obj) = control_message.get("response") {
319                                let request_id = resp_obj
320                                    .get("request_id")
321                                    .or_else(|| resp_obj.get("requestId"))
322                                    .and_then(|v| v.as_str());
323
324                                if let Some(request_id) = request_id {
325                                    let mut pending = pending_responses_clone.write().await;
326                                    if let Some(tx) = pending.remove(request_id) {
327                                        // Deliver only the nested "response" object (matches Python SDK semantics)
328                                        let _ = tx.send(resp_obj.clone());
329                                        debug!("Control response delivered for request_id: {}", request_id);
330                                    } else {
331                                        warn!("No pending request found for request_id: {}", request_id);
332                                    }
333                                } else {
334                                    warn!("Control response missing request_id: {:?}", control_message);
335                                }
336                            } else {
337                                warn!("Control response missing 'response' payload: {:?}", control_message);
338                            }
339                            continue;
340                        }
341
342                        // Parse and handle control requests (from CLI to SDK)
343                        // Check if this is a control_request with a nested request field
344                        let request_data = if control_message.get("type").and_then(|v| v.as_str()) == Some("control_request") {
345                            control_message.get("request").cloned().unwrap_or(control_message.clone())
346                        } else {
347                            control_message.clone()
348                        };
349
350                        if let Some(subtype) = request_data.get("subtype").and_then(|v| v.as_str()) {
351                            match subtype {
352                                "can_use_tool" => {
353                                    // Handle permission request
354                                    if let Ok(request) = serde_json::from_value::<SDKControlPermissionRequest>(request_data.clone()) {
355                                        // Handle with can_use_tool callback
356                                        if let Some(ref can_use_tool) = can_use_tool_clone {
357                                            let context = ToolPermissionContext {
358                                                signal: None,
359                                                suggestions: request.permission_suggestions.unwrap_or_default(),
360                                            };
361                                                
362                                            let result = can_use_tool
363                                                .can_use_tool(&request.tool_name, &request.input, &context)
364                                                .await;
365                                                
366                                            // CLI expects: {"allow": true, "input": ...} or {"allow": false, "reason": ...}
367                                            let permission_response = match result {
368                                                PermissionResult::Allow(allow) => {
369                                                    let mut resp = serde_json::json!({
370                                                        "allow": true,
371                                                    });
372                                                    if let Some(input) = allow.updated_input {
373                                                        resp["input"] = input;
374                                                    }
375                                                    if let Some(perms) = allow.updated_permissions {
376                                                        resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default();
377                                                    }
378                                                    resp
379                                                }
380                                                PermissionResult::Deny(deny) => {
381                                                    let mut resp = serde_json::json!({
382                                                        "allow": false,
383                                                    });
384                                                    if !deny.message.is_empty() {
385                                                        resp["reason"] = serde_json::json!(deny.message);
386                                                    }
387                                                    if deny.interrupt {
388                                                        resp["interrupt"] = serde_json::json!(true);
389                                                    }
390                                                    resp
391                                                }
392                                            };
393                                                
394                                            // Wrap response with proper structure
395                                            // CLI expects "subtype": "success" for all successful responses
396                                            let response = serde_json::json!({
397                                                "subtype": "success",
398                                                "request_id": Self::extract_request_id(&control_message),
399                                                "response": permission_response
400                                            });
401                                                
402                                            // Send response
403                                            let mut transport = transport_for_control.lock().await;
404                                            if let Err(e) = transport.send_sdk_control_response(response).await {
405                                                error!("Failed to send permission response: {}", e);
406                                            }
407                                        }
408                                    } else {
409                                        // Fallback for snake_case fields (tool_name, permission_suggestions)
410                                        if let Some(tool_name) = request_data.get("tool_name").and_then(|v| v.as_str())
411                                            && let Some(input_val) = request_data.get("input").cloned()
412                                                && let Some(ref can_use_tool) = can_use_tool_clone {
413                                                    // Try to parse permission suggestions (snake_case)
414                                                    let suggestions: Vec<PermissionUpdate> = request_data
415                                                        .get("permission_suggestions")
416                                                        .cloned()
417                                                        .and_then(|v| serde_json::from_value::<Vec<PermissionUpdate>>(v).ok())
418                                                        .unwrap_or_default();
419
420                                                    let context = ToolPermissionContext { signal: None, suggestions };
421                                                    let result = can_use_tool
422                                                        .can_use_tool(tool_name, &input_val, &context)
423                                                        .await;
424
425                                                    let permission_response = match result {
426                                                        PermissionResult::Allow(allow) => {
427                                                            let mut resp = serde_json::json!({ "allow": true });
428                                                            if let Some(input) = allow.updated_input { resp["input"] = input; }
429                                                            if let Some(perms) = allow.updated_permissions { resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default(); }
430                                                            resp
431                                                        }
432                                                        PermissionResult::Deny(deny) => {
433                                                            let mut resp = serde_json::json!({ "allow": false });
434                                                            if !deny.message.is_empty() { resp["reason"] = serde_json::json!(deny.message); }
435                                                            if deny.interrupt { resp["interrupt"] = serde_json::json!(true); }
436                                                            resp
437                                                        }
438                                                    };
439
440                                                    let response = serde_json::json!({
441                                                        "subtype": "success",
442                                                        "request_id": Self::extract_request_id(&control_message),
443                                                        "response": permission_response
444                                                    });
445                                                    let mut transport = transport_for_control.lock().await;
446                                                    if let Err(e) = transport.send_sdk_control_response(response).await {
447                                                        error!("Failed to send permission response (fallback): {}", e);
448                                                    }
449                                                }
450                                    }
451                                }
452                                "hook_callback" => {
453                                    // Handle hook callback with strongly-typed inputs/outputs
454                                    if let Ok(request) = serde_json::from_value::<SDKHookCallbackRequest>(request_data.clone()) {
455                                        let callbacks = hook_callbacks_clone.read().await;
456
457                                        if let Some(callback) = callbacks.get(&request.callback_id) {
458                                            let context = HookContext { signal: None };
459
460                                            // Try to deserialize input as HookInput
461                                            let hook_result = match serde_json::from_value::<crate::types::HookInput>(request.input.clone()) {
462                                                Ok(hook_input) => {
463                                                    // Call the hook with strongly-typed input
464                                                    callback
465                                                        .execute(&hook_input, request.tool_use_id.as_deref(), &context)
466                                                        .await
467                                                }
468                                                Err(parse_err) => {
469                                                    error!("Failed to parse hook input: {}", parse_err);
470                                                    // Return error using MessageParseError
471                                                    Err(crate::errors::SdkError::MessageParseError {
472                                                        error: format!("Invalid hook input: {parse_err}"),
473                                                        raw: request.input.to_string(),
474                                                    })
475                                                }
476                                            };
477
478                                            // Handle hook result
479                                            let response_json = match hook_result {
480                                                Ok(hook_output) => {
481                                                    // Serialize HookJSONOutput to JSON
482                                                    let output_value = serde_json::to_value(&hook_output)
483                                                        .unwrap_or_else(|e| {
484                                                            error!("Failed to serialize hook output: {}", e);
485                                                            serde_json::json!({})
486                                                        });
487
488                                                    serde_json::json!({
489                                                        "subtype": "success",
490                                                        "request_id": Self::extract_request_id(&control_message),
491                                                        "response": output_value
492                                                    })
493                                                }
494                                                Err(e) => {
495                                                    error!("Hook callback failed: {}", e);
496                                                    serde_json::json!({
497                                                        "subtype": "error",
498                                                        "request_id": Self::extract_request_id(&control_message),
499                                                        "error": e.to_string()
500                                                    })
501                                                }
502                                            };
503
504                                            let mut transport = transport_for_control.lock().await;
505                                            if let Err(e) = transport.send_sdk_control_response(response_json).await {
506                                                error!("Failed to send hook callback response: {}", e);
507                                            }
508                                        } else {
509                                            warn!("No hook callback found for ID: {}", request.callback_id);
510                                            // Send error response
511                                            let error_response = serde_json::json!({
512                                                "subtype": "error",
513                                                "request_id": Self::extract_request_id(&control_message),
514                                                "error": format!("No hook callback found for ID: {}", request.callback_id)
515                                            });
516                                            let mut transport = transport_for_control.lock().await;
517                                            if let Err(e) = transport.send_sdk_control_response(error_response).await {
518                                                error!("Failed to send error response: {}", e);
519                                            }
520                                        }
521                                    } else {
522                                        // Fallback for snake_case fields (callback_id, tool_use_id)
523                                        let callback_id = request_data.get("callback_id").and_then(|v| v.as_str());
524                                        let tool_use_id = request_data.get("tool_use_id").and_then(|v| v.as_str()).map(|s| s.to_string());
525                                        let input = request_data.get("input").cloned().unwrap_or(serde_json::json!({}));
526
527                                        if let Some(callback_id) = callback_id {
528                                            let callbacks = hook_callbacks_clone.read().await;
529                                            if let Some(callback) = callbacks.get(callback_id) {
530                                                let context = HookContext { signal: None };
531
532                                                // Try to parse as HookInput
533                                                let hook_result = match serde_json::from_value::<crate::types::HookInput>(input.clone()) {
534                                                    Ok(hook_input) => {
535                                                        callback
536                                                            .execute(&hook_input, tool_use_id.as_deref(), &context)
537                                                            .await
538                                                    }
539                                                    Err(parse_err) => {
540                                                        error!("Failed to parse hook input (fallback): {}", parse_err);
541                                                        Err(crate::errors::SdkError::MessageParseError {
542                                                            error: format!("Invalid hook input: {parse_err}"),
543                                                            raw: input.to_string(),
544                                                        })
545                                                    }
546                                                };
547
548                                                let response_json = match hook_result {
549                                                    Ok(hook_output) => {
550                                                        let output_value = serde_json::to_value(&hook_output)
551                                                            .unwrap_or_else(|e| {
552                                                                error!("Failed to serialize hook output (fallback): {}", e);
553                                                                serde_json::json!({})
554                                                            });
555
556                                                        serde_json::json!({
557                                                            "subtype": "success",
558                                                            "request_id": Self::extract_request_id(&control_message),
559                                                            "response": output_value
560                                                        })
561                                                    }
562                                                    Err(e) => {
563                                                        error!("Hook callback failed (fallback): {}", e);
564                                                        serde_json::json!({
565                                                            "subtype": "error",
566                                                            "request_id": Self::extract_request_id(&control_message),
567                                                            "error": e.to_string()
568                                                        })
569                                                    }
570                                                };
571
572                                                let mut transport = transport_for_control.lock().await;
573                                                if let Err(e) = transport.send_sdk_control_response(response_json).await {
574                                                    error!("Failed to send hook callback response (fallback): {}", e);
575                                                }
576                                            } else {
577                                                warn!("No hook callback found for ID: {}", callback_id);
578                                            }
579                                        } else {
580                                            warn!("Invalid hook_callback control message: missing callback_id");
581                                        }
582                                    }
583                                }
584                                "mcp_message" => {
585                                    // Handle MCP message
586                                    if let Some(server_name) = request_data.get("server_name").and_then(|v| v.as_str())
587                                        && let Some(message) = request_data.get("message") {
588                                            debug!("Processing MCP message for SDK server: {}", server_name);
589
590                                            // Check if we have an SDK server with this name
591                                            if let Some(server_arc) = sdk_mcp_servers_clone.get(server_name) {
592                                                // Try to downcast to SdkMcpServer
593                                                if let Some(sdk_server) = server_arc.downcast_ref::<crate::sdk_mcp::SdkMcpServer>() {
594                                                    // Call the SDK MCP server
595                                                    match sdk_server.handle_message(message.clone()).await {
596                                                        Ok(mcp_result) => {
597                                                            // Wrap response with proper structure
598                                                            let response = serde_json::json!({
599                                                                "subtype": "success",
600                                                                "request_id": Self::extract_request_id(&control_message),
601                                                                "response": {
602                                                                    "mcp_response": mcp_result
603                                                                }
604                                                            });
605
606                                                            let mut transport = transport_for_control.lock().await;
607                                                            if let Err(e) = transport.send_sdk_control_response(response).await {
608                                                                error!("Failed to send MCP response: {}", e);
609                                                            }
610                                                        }
611                                                        Err(e) => {
612                                                            error!("SDK MCP server error: {}", e);
613                                                            let error_response = serde_json::json!({
614                                                                "subtype": "error",
615                                                                "request_id": Self::extract_request_id(&control_message),
616                                                                "error": format!("MCP server error: {}", e)
617                                                            });
618
619                                                            let mut transport = transport_for_control.lock().await;
620                                                            if let Err(e) = transport.send_sdk_control_response(error_response).await {
621                                                                error!("Failed to send MCP error response: {}", e);
622                                                            }
623                                                        }
624                                                    }
625                                                } else {
626                                                    warn!("SDK server '{}' is not of type SdkMcpServer", server_name);
627                                                }
628                                            } else {
629                                                warn!("No SDK MCP server found with name: {}", server_name);
630                                                let error_response = serde_json::json!({
631                                                    "subtype": "error",
632                                                    "request_id": Self::extract_request_id(&control_message),
633                                                    "error": format!("Server '{}' not found", server_name)
634                                                });
635
636                                                let mut transport = transport_for_control.lock().await;
637                                                if let Err(e) = transport.send_sdk_control_response(error_response).await {
638                                                    error!("Failed to send MCP error response: {}", e);
639                                                }
640                                            }
641                                        }
642                                }
643                                _ => {
644                                    debug!("Unknown SDK control subtype: {}", subtype);
645                                }
646                            }
647                        }
648                    }
649                }
650            });
651        }
652    }
653
654    /// Stream input messages to the CLI stdin by converting JSON values to InputMessage
655    #[allow(dead_code)]
656    pub async fn stream_input<S>(&mut self, input_stream: S) -> Result<()>
657    where
658        S: Stream<Item = JsonValue> + Send + 'static,
659    {
660        let transport = self.transport.clone();
661
662        tokio::spawn(async move {
663            use futures::StreamExt;
664            let mut stream = Box::pin(input_stream);
665
666            while let Some(value) = stream.next().await {
667                // Best-effort conversion from arbitrary JSON to InputMessage
668                let input_msg_opt = Self::json_to_input_message(value);
669                if let Some(input_msg) = input_msg_opt {
670                    let mut guard = transport.lock().await;
671                    if let Err(e) = guard.send_message(input_msg).await {
672                        warn!("Failed to send streaming input message: {}", e);
673                    }
674                } else {
675                    warn!("Invalid streaming input JSON; expected user message shape");
676                }
677            }
678
679            // After streaming all inputs, signal end of input
680            let mut guard = transport.lock().await;
681            if let Err(e) = guard.end_input().await {
682                warn!("Failed to signal end_input: {}", e);
683            }
684        });
685        Ok(())
686    }
687
688    /// Receive messages
689    #[allow(dead_code)]
690    pub async fn receive_messages(&mut self) -> mpsc::Receiver<Result<Message>> {
691        self.message_rx.take().expect("Receiver already taken")
692    }
693
694    /// Send interrupt request
695    pub async fn interrupt(&mut self) -> Result<()> {
696        let interrupt_request = SDKControlRequest::Interrupt(SDKControlInterruptRequest {
697            subtype: "interrupt".to_string(),
698        });
699
700        self.send_control_request(interrupt_request).await?;
701        Ok(())
702    }
703
704    /// Set permission mode via control protocol
705    #[allow(dead_code)]
706    pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
707        let req = SDKControlRequest::SetPermissionMode(SDKControlSetPermissionModeRequest {
708            subtype: "set_permission_mode".to_string(),
709            mode: mode.to_string(),
710        });
711        // Ignore response payload; errors propagate
712        let _ = self.send_control_request(req).await?;
713        Ok(())
714    }
715
716    /// Set the active model via control protocol
717    #[allow(dead_code)]
718    pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
719        let req = SDKControlRequest::SetModel(crate::types::SDKControlSetModelRequest {
720            subtype: "set_model".to_string(),
721            model,
722        });
723        let _ = self.send_control_request(req).await?;
724        Ok(())
725    }
726
727    /// Handle MCP message for SDK servers
728    #[allow(dead_code)]
729    async fn handle_mcp_message(&mut self, server_name: &str, message: &JsonValue) -> Result<JsonValue> {
730        // Check if we have an SDK server with this name
731        if let Some(_server) = self.sdk_mcp_servers.get(server_name) {
732            // TODO: Implement actual MCP server invocation
733            // For now, return a placeholder response
734            debug!("Handling MCP message for SDK server {}: {:?}", server_name, message);
735            Ok(serde_json::json!({
736                "jsonrpc": "2.0",
737                "id": message.get("id"),
738                "result": {
739                    "content": "MCP server response placeholder"
740                }
741            }))
742        } else {
743            Err(SdkError::InvalidState {
744                message: format!("No SDK MCP server found with name: {server_name}"),
745            })
746        }
747    }
748
749    /// Close the query handler
750    #[allow(dead_code)]
751    pub async fn close(&mut self) -> Result<()> {
752        // Clean up resources
753        let mut transport = self.transport.lock().await;
754        transport.disconnect().await?;
755        Ok(())
756    }
757
758    /// Get initialization result
759    pub fn get_initialization_result(&self) -> Option<&JsonValue> {
760        self.initialization_result.as_ref()
761    }
762
763    /// Convert arbitrary JSON value to InputMessage understood by CLI
764    #[allow(dead_code)]
765    fn json_to_input_message(v: JsonValue) -> Option<InputMessage> {
766        // 1) Already in SDK message shape
767        if let Some(obj) = v.as_object() {
768            if let (Some(t), Some(message)) = (obj.get("type"), obj.get("message"))
769                && t.as_str() == Some("user") {
770                    let parent = obj
771                        .get("parent_tool_use_id")
772                        .and_then(|p| p.as_str().map(|s| s.to_string()));
773                    let session_id = obj
774                        .get("session_id")
775                        .and_then(|s| s.as_str())
776                        .unwrap_or("default")
777                        .to_string();
778
779                    let im = InputMessage {
780                        r#type: "user".to_string(),
781                        message: message.clone(),
782                        parent_tool_use_id: parent,
783                        session_id,
784                    };
785                    return Some(im);
786                }
787
788            // 2) Simple wrapper: {"content":"...", "session_id":"..."}
789            if let Some(content) = obj.get("content").and_then(|c| c.as_str()) {
790                let session_id = obj
791                    .get("session_id")
792                    .and_then(|s| s.as_str())
793                    .unwrap_or("default")
794                    .to_string();
795                return Some(InputMessage::user(content.to_string(), session_id));
796            }
797        }
798
799        // 3) Bare string
800        if let Some(s) = v.as_str() {
801            return Some(InputMessage::user(s.to_string(), "default".to_string()));
802        }
803
804        None
805    }
806}
807
808#[cfg(test)]
809mod tests {
810    use super::*;
811
812    #[test]
813    fn test_extract_request_id_supports_both_cases() {
814        let snake = serde_json::json!({"request_id": "req_1"});
815        let camel = serde_json::json!({"requestId": "req_2"});
816        assert_eq!(Query::extract_request_id(&snake), Some(serde_json::json!("req_1")));
817        assert_eq!(Query::extract_request_id(&camel), Some(serde_json::json!("req_2")));
818    }
819
820    #[test]
821    fn test_json_to_input_message_from_string() {
822        let v = serde_json::json!("Hello");
823        let im = Query::json_to_input_message(v).expect("should convert");
824        assert_eq!(im.r#type, "user");
825        assert_eq!(im.session_id, "default");
826        assert_eq!(im.message["content"].as_str().unwrap(), "Hello");
827    }
828
829    #[test]
830    fn test_json_to_input_message_from_object_content() {
831        let v = serde_json::json!({"content":"Ping","session_id":"s1"});
832        let im = Query::json_to_input_message(v).expect("should convert");
833        assert_eq!(im.session_id, "s1");
834        assert_eq!(im.message["content"].as_str().unwrap(), "Ping");
835    }
836
837    #[test]
838    fn test_json_to_input_message_full_user_shape() {
839        let v = serde_json::json!({
840            "type":"user",
841            "message": {"role":"user","content":"Hi"},
842            "session_id": "abc",
843            "parent_tool_use_id": null
844        });
845        let im = Query::json_to_input_message(v).expect("should convert");
846        assert_eq!(im.session_id, "abc");
847        assert_eq!(im.message["role"].as_str().unwrap(), "user");
848        assert_eq!(im.message["content"].as_str().unwrap(), "Hi");
849    }
850}