Skip to main content

cc_sdk/
internal_query.rs

1//! Internal query implementation with control protocol support
2//!
3//! This module provides the internal Query struct that handles control protocol,
4//! permissions, hooks, and MCP server integration.
5
6use crate::{
7    errors::{Result, SdkError},
8    transport::{InputMessage, Transport},
9    types::{
10        CanUseTool, HookCallback, HookContext, HookMatcher, Message, PermissionResult, PermissionUpdate,
11        SDKControlInitializeRequest, SDKControlInterruptRequest, SDKControlPermissionRequest,
12        SDKControlRequest, SDKHookCallbackRequest, SDKControlSetPermissionModeRequest,
13        ToolPermissionContext,
14    },
15};
16use futures::stream::Stream;
17use futures::StreamExt;
18use serde_json::Value as JsonValue;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use tokio::time::{timeout, Duration};
23use tracing::{debug, error, warn};
24
25/// Internal query handler with control protocol support
26pub struct Query {
27    /// Transport layer (shared with client)
28    transport: Arc<Mutex<Box<dyn Transport + Send>>>,
29    /// Whether in streaming mode
30    #[allow(dead_code)]
31    is_streaming_mode: bool,
32    /// Tool permission callback
33    can_use_tool: Option<Arc<dyn CanUseTool>>,
34    /// Hook configurations
35    hooks: Option<HashMap<String, Vec<HookMatcher>>>,
36    /// SDK MCP servers
37    sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
38    /// Message channel sender (reserved for future streaming receive support)
39    #[allow(dead_code)]
40    message_tx: mpsc::Sender<Result<Message>>,
41    /// Message channel receiver (reserved for future streaming receive support)
42    #[allow(dead_code)]
43    message_rx: Option<mpsc::Receiver<Result<Message>>>,
44    /// Initialization result
45    initialization_result: Option<JsonValue>,
46    /// Active hook callbacks
47    hook_callbacks: Arc<RwLock<HashMap<String, Arc<dyn HookCallback>>>>,
48    /// Hook callback counter
49    callback_counter: Arc<Mutex<u64>>,
50    /// Request counter for generating unique IDs
51    request_counter: Arc<Mutex<u64>>,
52    /// Pending control request responses
53    pending_responses: Arc<RwLock<HashMap<String, tokio::sync::oneshot::Sender<JsonValue>>>>,
54}
55
56impl Query {
57    /// Create a new Query handler
58    pub fn new(
59        transport: Arc<Mutex<Box<dyn Transport + Send>>>,
60        is_streaming_mode: bool,
61        can_use_tool: Option<Arc<dyn CanUseTool>>,
62        hooks: Option<HashMap<String, Vec<HookMatcher>>>,
63        sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
64    ) -> Self {
65        let (tx, rx) = mpsc::channel(100);
66
67        Self {
68            transport,
69            is_streaming_mode,
70            can_use_tool,
71            hooks,
72            sdk_mcp_servers,
73            message_tx: tx,
74            message_rx: Some(rx),
75            initialization_result: None,
76            hook_callbacks: Arc::new(RwLock::new(HashMap::new())),
77            callback_counter: Arc::new(Mutex::new(0)),
78            request_counter: Arc::new(Mutex::new(0)),
79            pending_responses: Arc::new(RwLock::new(HashMap::new())),
80        }
81    }
82
83    /// Test helper to register a hook callback with a known ID
84    ///
85    /// This is intended for E2E tests to inject a callback ID that can be
86    /// referenced by inbound `hook_callback` control messages.
87    pub async fn register_hook_callback_for_test(
88        &self,
89        callback_id: String,
90        callback: Arc<dyn HookCallback>,
91    ) {
92        let mut map = self.hook_callbacks.write().await;
93        map.insert(callback_id, callback);
94    }
95
96    /// Start the query handler
97    pub async fn start(&mut self) -> Result<()> {
98        // Start control request handler task
99        self.start_control_handler().await;
100
101        // Start SDK message forwarder task (route non-control messages to message_tx)
102        let transport = self.transport.clone();
103        let tx = self.message_tx.clone();
104        tokio::spawn(async move {
105            // Get message stream once and consume it continuously
106            let mut stream = {
107                let mut guard = transport.lock().await;
108                guard.receive_messages()
109            }; // Lock released immediately after getting stream
110
111            // Continuously consume from the same stream
112            while let Some(result) = stream.next().await {
113                match result {
114                    Ok(msg) => {
115                        if tx.send(Ok(msg)).await.is_err() {
116                            break;
117                        }
118                    }
119                    Err(e) => {
120                        let _ = tx.send(Err(e)).await;
121                        break;
122                    }
123                }
124            }
125        });
126        Ok(())
127    }
128
129    /// Initialize the control protocol
130    pub async fn initialize(&mut self) -> Result<()> {
131        // Build hooks with callback IDs (Python SDK style)
132        let hooks_with_ids = if let Some(ref hooks) = self.hooks {
133            let mut counter = self.callback_counter.lock().await;
134            let mut callbacks_map = self.hook_callbacks.write().await;
135
136            let hooks_json: HashMap<String, JsonValue> = hooks
137                .iter()
138                .map(|(event_name, matchers)| {
139                    let matchers_with_ids: Vec<JsonValue> = matchers
140                        .iter()
141                        .map(|matcher| {
142                            // Generate callback IDs for each hook in this matcher
143                            let callback_ids: Vec<String> = matcher
144                                .hooks
145                                .iter()
146                                .map(|hook_callback| {
147                                    *counter += 1;
148                                    let callback_id = format!("hook_{}_{}", *counter, uuid::Uuid::new_v4().simple());
149
150                                    // Store the callback for later use
151                                    callbacks_map.insert(callback_id.clone(), hook_callback.clone());
152
153                                    callback_id
154                                })
155                                .collect();
156
157                            serde_json::json!({
158                                "matcher": matcher.matcher.clone(),
159                                "hookCallbackIds": callback_ids
160                            })
161                        })
162                        .collect();
163
164                    (event_name.clone(), serde_json::json!(matchers_with_ids))
165                })
166                .collect();
167
168            Some(hooks_json)
169        } else {
170            None
171        };
172
173        // Send initialize request
174        let init_request = SDKControlRequest::Initialize(SDKControlInitializeRequest {
175            subtype: "initialize".to_string(),
176            hooks: hooks_with_ids,
177        });
178
179        // Send control request and save result
180        let result = self.send_control_request(init_request).await?;
181        self.initialization_result = Some(result);
182
183        debug!("Initialization request sent with hook callback IDs");
184        Ok(())
185    }
186
187    /// Send a control request and wait for response
188    async fn send_control_request(&mut self, request: SDKControlRequest) -> Result<JsonValue> {
189        // Generate unique request ID
190        let request_id = {
191            let mut counter = self.request_counter.lock().await;
192            *counter += 1;
193            format!("req_{}_{}", *counter, uuid::Uuid::new_v4().simple())
194        };
195
196        // Create oneshot channel for response
197        let (tx, rx) = tokio::sync::oneshot::channel();
198
199        // Register pending response
200        {
201            let mut pending = self.pending_responses.write().await;
202            pending.insert(request_id.clone(), tx);
203        }
204
205        // Build control request with request_id (snake_case for CLI compatibility)
206        let control_request = serde_json::json!({
207            "type": "control_request",
208            "request_id": request_id,
209            "request": request
210        });
211
212        debug!("Sending control request: {:?}", control_request);
213
214        // Send via transport
215        {
216            let mut transport = self.transport.lock().await;
217            transport.send_sdk_control_request(control_request).await?;
218        }
219
220        // Wait for response with timeout
221        match timeout(Duration::from_secs(60), rx).await {
222            Ok(Ok(response)) => {
223                debug!("Received control response for {}", request_id);
224
225                // Python parity: treat subtype=error as an error, and return only
226                // the payload from `response` (or legacy `data`) on success.
227                if response.get("subtype").and_then(|v| v.as_str()) == Some("error") {
228                    let msg = response
229                        .get("error")
230                        .and_then(|v| v.as_str())
231                        .unwrap_or("Unknown control request error");
232                    return Err(SdkError::ControlRequestError(msg.to_string()));
233                }
234
235                Ok(response
236                    .get("response")
237                    .or_else(|| response.get("data"))
238                    .cloned()
239                    .unwrap_or_else(|| serde_json::json!({})))
240            }
241            Ok(Err(_)) => Err(SdkError::ControlRequestError(
242                "Response channel closed".to_string(),
243            )),
244            Err(_) => {
245                // Clean up pending response
246                let mut pending = self.pending_responses.write().await;
247                pending.remove(&request_id);
248                Err(SdkError::Timeout { seconds: 60 })
249            }
250        }
251    }
252
253    /// Handle permission request
254    #[allow(dead_code)]
255    async fn handle_permission_request(&mut self, request: SDKControlPermissionRequest) -> Result<()> {
256        if let Some(ref can_use_tool) = self.can_use_tool {
257            let context = ToolPermissionContext {
258                signal: None,
259                suggestions: request.permission_suggestions.unwrap_or_default(),
260            };
261
262            let result = can_use_tool
263                .can_use_tool(&request.tool_name, &request.input, &context)
264                .await;
265
266            // Send response back (CLI expects: { allow: bool, input?, reason? })
267            let response = match result {
268                PermissionResult::Allow(allow) => {
269                    let mut obj = serde_json::json!({ "allow": true });
270                    if let Some(updated) = allow.updated_input {
271                        obj["input"] = updated;
272                    }
273                    obj
274                }
275                PermissionResult::Deny(deny) => {
276                    let mut obj = serde_json::json!({ "allow": false });
277                    if !deny.message.is_empty() {
278                        obj["reason"] = serde_json::json!(deny.message);
279                    }
280                    obj
281                }
282            };
283
284            // Send response back through transport
285            let mut transport = self.transport.lock().await;
286            transport.send_sdk_control_response(response).await?;
287            debug!("Permission response sent");
288        }
289        
290        Ok(())
291    }
292
293    /// Extract requestId from CLI message (supports both camelCase and snake_case)
294    fn extract_request_id(msg: &JsonValue) -> Option<JsonValue> {
295        msg.get("requestId")
296            .or_else(|| msg.get("request_id"))
297            .cloned()
298    }
299
300    /// Start control request handler task
301    async fn start_control_handler(&mut self) {
302        let transport = self.transport.clone();
303        let can_use_tool = self.can_use_tool.clone();
304        let hook_callbacks = self.hook_callbacks.clone();
305        let sdk_mcp_servers = self.sdk_mcp_servers.clone();
306        let pending_responses = self.pending_responses.clone();
307        
308        // Take ownership of the SDK control receiver to avoid holding locks
309        let sdk_control_rx = {
310            let mut transport_lock = transport.lock().await;
311            transport_lock.take_sdk_control_receiver()
312        }; // Lock released here
313        
314        if let Some(mut control_rx) = sdk_control_rx {
315            tokio::spawn(async move {
316                // Now we can receive control requests without holding any locks
317                let transport_for_control = transport;
318                let can_use_tool_clone = can_use_tool;
319                let hook_callbacks_clone = hook_callbacks;
320                let sdk_mcp_servers_clone = sdk_mcp_servers;
321                let pending_responses_clone = pending_responses;
322
323                loop {
324                    // Receive control request without holding lock
325                    let control_message = control_rx.recv().await;
326
327                    // If channel closed (sender dropped), exit the loop
328                    let Some(control_message) = control_message else {
329                        debug!("Control channel closed, exiting control handler");
330                        break;
331                    };
332
333                    debug!("Received control message: {:?}", control_message);
334
335                        // Check if this is a control response (from CLI to SDK)
336                        if control_message.get("type").and_then(|v| v.as_str()) == Some("control_response") {
337                            // Expected shape: {"type":"control_response", "response": {"request_id": "...", ...}}
338                            if let Some(resp_obj) = control_message.get("response") {
339                                let request_id = resp_obj
340                                    .get("request_id")
341                                    .or_else(|| resp_obj.get("requestId"))
342                                    .and_then(|v| v.as_str());
343
344                                if let Some(request_id) = request_id {
345                                    let mut pending = pending_responses_clone.write().await;
346                                    if let Some(tx) = pending.remove(request_id) {
347                                        // Deliver the nested control response object; send_control_request will
348                                        // extract the `response` (or legacy `data`) payload for callers.
349                                        let _ = tx.send(resp_obj.clone());
350                                        debug!("Control response delivered for request_id: {}", request_id);
351                                    } else {
352                                        warn!("No pending request found for request_id: {}", request_id);
353                                    }
354                                } else {
355                                    warn!("Control response missing request_id: {:?}", control_message);
356                                }
357                            } else {
358                                warn!("Control response missing 'response' payload: {:?}", control_message);
359                            }
360                            continue;
361                        }
362
363                        // Parse and handle control requests (from CLI to SDK)
364                        // Check if this is a control_request with a nested request field
365                        let request_data = if control_message.get("type").and_then(|v| v.as_str()) == Some("control_request") {
366                            control_message.get("request").cloned().unwrap_or(control_message.clone())
367                        } else {
368                            control_message.clone()
369                        };
370
371                        if let Some(subtype) = request_data.get("subtype").and_then(|v| v.as_str()) {
372                            match subtype {
373                                "can_use_tool" => {
374                                    // Handle permission request
375                                    if let Ok(request) = serde_json::from_value::<SDKControlPermissionRequest>(request_data.clone()) {
376                                        // Handle with can_use_tool callback
377                                        if let Some(ref can_use_tool) = can_use_tool_clone {
378                                            let context = ToolPermissionContext {
379                                                signal: None,
380                                                suggestions: request.permission_suggestions.unwrap_or_default(),
381                                            };
382                                                
383                                            // Save original input for fallback (Python SDK always sends updatedInput)
384                                            let original_input = request.input.clone();
385
386                                            let result = can_use_tool
387                                                .can_use_tool(&request.tool_name, &request.input, &context)
388                                                .await;
389
390                                            // Match Python SDK response format:
391                                            // Allow: {"behavior": "allow", "updatedInput": ..., "updatedPermissions": ...}
392                                            // Deny: {"behavior": "deny", "message": "...", "interrupt": false}
393                                            // NOTE: updatedInput is ALWAYS required for allow (CLI Zod schema expects it)
394                                            let permission_response = match result {
395                                                PermissionResult::Allow(allow) => {
396                                                    let mut resp = serde_json::json!({
397                                                        "behavior": "allow",
398                                                        "updatedInput": allow.updated_input.unwrap_or(original_input),
399                                                    });
400                                                    if let Some(perms) = allow.updated_permissions {
401                                                        resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default();
402                                                    }
403                                                    resp
404                                                }
405                                                PermissionResult::Deny(deny) => {
406                                                    let mut resp = serde_json::json!({
407                                                        "behavior": "deny",
408                                                    });
409                                                    if !deny.message.is_empty() {
410                                                        resp["message"] = serde_json::json!(deny.message);
411                                                    }
412                                                    if deny.interrupt {
413                                                        resp["interrupt"] = serde_json::json!(true);
414                                                    }
415                                                    resp
416                                                }
417                                            };
418                                                
419                                            // Wrap response with proper structure
420                                            // CLI expects "subtype": "success" for all successful responses
421                                            let response = serde_json::json!({
422                                                "subtype": "success",
423                                                "request_id": Self::extract_request_id(&control_message),
424                                                "response": permission_response
425                                            });
426                                                
427                                            // Send response
428                                            let mut transport = transport_for_control.lock().await;
429                                            if let Err(e) = transport.send_sdk_control_response(response).await {
430                                                error!("Failed to send permission response: {}", e);
431                                            }
432                                        }
433                                    } else {
434                                        // Fallback for snake_case fields (tool_name, permission_suggestions)
435                                        if let Some(tool_name) = request_data.get("tool_name").and_then(|v| v.as_str())
436                                            && let Some(input_val) = request_data.get("input").cloned()
437                                                && let Some(ref can_use_tool) = can_use_tool_clone {
438                                                    // Try to parse permission suggestions (snake_case)
439                                                    let suggestions: Vec<PermissionUpdate> = request_data
440                                                        .get("permission_suggestions")
441                                                        .cloned()
442                                                        .and_then(|v| serde_json::from_value::<Vec<PermissionUpdate>>(v).ok())
443                                                        .unwrap_or_default();
444
445                                                    let context = ToolPermissionContext { signal: None, suggestions };
446                                                    let original_input = input_val.clone();
447                                                    let result = can_use_tool
448                                                        .can_use_tool(tool_name, &input_val, &context)
449                                                        .await;
450
451                                                    let permission_response = match result {
452                                                        PermissionResult::Allow(allow) => {
453                                                            let mut resp = serde_json::json!({
454                                                                "behavior": "allow",
455                                                                "updatedInput": allow.updated_input.unwrap_or(original_input),
456                                                            });
457                                                            if let Some(perms) = allow.updated_permissions { resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default(); }
458                                                            resp
459                                                        }
460                                                        PermissionResult::Deny(deny) => {
461                                                            let mut resp = serde_json::json!({ "behavior": "deny" });
462                                                            if !deny.message.is_empty() { resp["message"] = serde_json::json!(deny.message); }
463                                                            if deny.interrupt { resp["interrupt"] = serde_json::json!(true); }
464                                                            resp
465                                                        }
466                                                    };
467
468                                                    let response = serde_json::json!({
469                                                        "subtype": "success",
470                                                        "request_id": Self::extract_request_id(&control_message),
471                                                        "response": permission_response
472                                                    });
473                                                    let mut transport = transport_for_control.lock().await;
474                                                    if let Err(e) = transport.send_sdk_control_response(response).await {
475                                                        error!("Failed to send permission response (fallback): {}", e);
476                                                    }
477                                                }
478                                    }
479                                }
480                                "hook_callback" => {
481                                    // Handle hook callback with strongly-typed inputs/outputs
482                                    if let Ok(request) = serde_json::from_value::<SDKHookCallbackRequest>(request_data.clone()) {
483                                        let callbacks = hook_callbacks_clone.read().await;
484
485                                        if let Some(callback) = callbacks.get(&request.callback_id) {
486                                            let context = HookContext { signal: None };
487
488                                            // Try to deserialize input as HookInput
489                                            let hook_result = match serde_json::from_value::<crate::types::HookInput>(request.input.clone()) {
490                                                Ok(hook_input) => {
491                                                    // Call the hook with strongly-typed input
492                                                    callback
493                                                        .execute(&hook_input, request.tool_use_id.as_deref(), &context)
494                                                        .await
495                                                }
496                                                Err(parse_err) => {
497                                                    error!("Failed to parse hook input: {}", parse_err);
498                                                    // Return error using MessageParseError
499                                                    Err(crate::errors::SdkError::MessageParseError {
500                                                        error: format!("Invalid hook input: {parse_err}"),
501                                                        raw: request.input.to_string(),
502                                                    })
503                                                }
504                                            };
505
506                                            // Handle hook result
507                                            let response_json = match hook_result {
508                                                Ok(hook_output) => {
509                                                    // Serialize HookJSONOutput to JSON
510                                                    let output_value = serde_json::to_value(&hook_output)
511                                                        .unwrap_or_else(|e| {
512                                                            error!("Failed to serialize hook output: {}", e);
513                                                            serde_json::json!({})
514                                                        });
515
516                                                    serde_json::json!({
517                                                        "subtype": "success",
518                                                        "request_id": Self::extract_request_id(&control_message),
519                                                        "response": output_value
520                                                    })
521                                                }
522                                                Err(e) => {
523                                                    error!("Hook callback failed: {}", e);
524                                                    serde_json::json!({
525                                                        "subtype": "error",
526                                                        "request_id": Self::extract_request_id(&control_message),
527                                                        "error": e.to_string()
528                                                    })
529                                                }
530                                            };
531
532                                            let mut transport = transport_for_control.lock().await;
533                                            if let Err(e) = transport.send_sdk_control_response(response_json).await {
534                                                error!("Failed to send hook callback response: {}", e);
535                                            }
536                                        } else {
537                                            warn!("No hook callback found for ID: {}", request.callback_id);
538                                            // Send error response
539                                            let error_response = serde_json::json!({
540                                                "subtype": "error",
541                                                "request_id": Self::extract_request_id(&control_message),
542                                                "error": format!("No hook callback found for ID: {}", request.callback_id)
543                                            });
544                                            let mut transport = transport_for_control.lock().await;
545                                            if let Err(e) = transport.send_sdk_control_response(error_response).await {
546                                                error!("Failed to send error response: {}", e);
547                                            }
548                                        }
549                                    } else {
550                                        // Fallback for snake_case fields (callback_id, tool_use_id)
551                                        let callback_id = request_data.get("callback_id").and_then(|v| v.as_str());
552                                        let tool_use_id = request_data.get("tool_use_id").and_then(|v| v.as_str()).map(|s| s.to_string());
553                                        let input = request_data.get("input").cloned().unwrap_or(serde_json::json!({}));
554
555                                        if let Some(callback_id) = callback_id {
556                                            let callbacks = hook_callbacks_clone.read().await;
557                                            if let Some(callback) = callbacks.get(callback_id) {
558                                                let context = HookContext { signal: None };
559
560                                                // Try to parse as HookInput
561                                                let hook_result = match serde_json::from_value::<crate::types::HookInput>(input.clone()) {
562                                                    Ok(hook_input) => {
563                                                        callback
564                                                            .execute(&hook_input, tool_use_id.as_deref(), &context)
565                                                            .await
566                                                    }
567                                                    Err(parse_err) => {
568                                                        error!("Failed to parse hook input (fallback): {}", parse_err);
569                                                        Err(crate::errors::SdkError::MessageParseError {
570                                                            error: format!("Invalid hook input: {parse_err}"),
571                                                            raw: input.to_string(),
572                                                        })
573                                                    }
574                                                };
575
576                                                let response_json = match hook_result {
577                                                    Ok(hook_output) => {
578                                                        let output_value = serde_json::to_value(&hook_output)
579                                                            .unwrap_or_else(|e| {
580                                                                error!("Failed to serialize hook output (fallback): {}", e);
581                                                                serde_json::json!({})
582                                                            });
583
584                                                        serde_json::json!({
585                                                            "subtype": "success",
586                                                            "request_id": Self::extract_request_id(&control_message),
587                                                            "response": output_value
588                                                        })
589                                                    }
590                                                    Err(e) => {
591                                                        error!("Hook callback failed (fallback): {}", e);
592                                                        serde_json::json!({
593                                                            "subtype": "error",
594                                                            "request_id": Self::extract_request_id(&control_message),
595                                                            "error": e.to_string()
596                                                        })
597                                                    }
598                                                };
599
600                                                let mut transport = transport_for_control.lock().await;
601                                                if let Err(e) = transport.send_sdk_control_response(response_json).await {
602                                                    error!("Failed to send hook callback response (fallback): {}", e);
603                                                }
604                                            } else {
605                                                warn!("No hook callback found for ID: {}", callback_id);
606                                            }
607                                        } else {
608                                            warn!("Invalid hook_callback control message: missing callback_id");
609                                        }
610                                    }
611                                }
612                                "mcp_message" => {
613                                    // Handle MCP message
614                                    if let Some(server_name) = request_data.get("server_name").and_then(|v| v.as_str())
615                                        && let Some(message) = request_data.get("message") {
616                                            debug!("Processing MCP message for SDK server: {}", server_name);
617
618                                            // Check if we have an SDK server with this name
619                                            if let Some(server_arc) = sdk_mcp_servers_clone.get(server_name) {
620                                                // Try to downcast to SdkMcpServer
621                                                if let Some(sdk_server) = server_arc.downcast_ref::<crate::sdk_mcp::SdkMcpServer>() {
622                                                    // Call the SDK MCP server
623                                                    match sdk_server.handle_message(message.clone()).await {
624                                                        Ok(mcp_result) => {
625                                                            // Wrap response with proper structure
626                                                            let response = serde_json::json!({
627                                                                "subtype": "success",
628                                                                "request_id": Self::extract_request_id(&control_message),
629                                                                "response": {
630                                                                    "mcp_response": mcp_result
631                                                                }
632                                                            });
633
634                                                            let mut transport = transport_for_control.lock().await;
635                                                            if let Err(e) = transport.send_sdk_control_response(response).await {
636                                                                error!("Failed to send MCP response: {}", e);
637                                                            }
638                                                        }
639                                                        Err(e) => {
640                                                            error!("SDK MCP server error: {}", e);
641                                                            let error_response = serde_json::json!({
642                                                                "subtype": "error",
643                                                                "request_id": Self::extract_request_id(&control_message),
644                                                                "error": format!("MCP server error: {}", e)
645                                                            });
646
647                                                            let mut transport = transport_for_control.lock().await;
648                                                            if let Err(e) = transport.send_sdk_control_response(error_response).await {
649                                                                error!("Failed to send MCP error response: {}", e);
650                                                            }
651                                                        }
652                                                    }
653                                                } else {
654                                                    warn!("SDK server '{}' is not of type SdkMcpServer", server_name);
655                                                }
656                                            } else {
657                                                warn!("No SDK MCP server found with name: {}", server_name);
658                                                let error_response = serde_json::json!({
659                                                    "subtype": "error",
660                                                    "request_id": Self::extract_request_id(&control_message),
661                                                    "error": format!("Server '{}' not found", server_name)
662                                                });
663
664                                                let mut transport = transport_for_control.lock().await;
665                                                if let Err(e) = transport.send_sdk_control_response(error_response).await {
666                                                    error!("Failed to send MCP error response: {}", e);
667                                                }
668                                            }
669                                        }
670                                }
671                                _ => {
672                                    debug!("Unknown SDK control subtype: {}", subtype);
673                                }
674                            }
675                        }
676                }
677            });
678        }
679    }
680
681    /// Stream input messages to the CLI stdin by converting JSON values to InputMessage
682    #[allow(dead_code)]
683    pub async fn stream_input<S>(&mut self, input_stream: S) -> Result<()>
684    where
685        S: Stream<Item = JsonValue> + Send + 'static,
686    {
687        let transport = self.transport.clone();
688
689        tokio::spawn(async move {
690            use futures::StreamExt;
691            let mut stream = Box::pin(input_stream);
692
693            while let Some(value) = stream.next().await {
694                // Best-effort conversion from arbitrary JSON to InputMessage
695                let input_msg_opt = Self::json_to_input_message(value);
696                if let Some(input_msg) = input_msg_opt {
697                    let mut guard = transport.lock().await;
698                    if let Err(e) = guard.send_message(input_msg).await {
699                        warn!("Failed to send streaming input message: {}", e);
700                    }
701                } else {
702                    warn!("Invalid streaming input JSON; expected user message shape");
703                }
704            }
705
706            // After streaming all inputs, signal end of input
707            let mut guard = transport.lock().await;
708            if let Err(e) = guard.end_input().await {
709                warn!("Failed to signal end_input: {}", e);
710            }
711        });
712        Ok(())
713    }
714
715    /// Receive messages
716    #[allow(dead_code)]
717    pub async fn receive_messages(&mut self) -> mpsc::Receiver<Result<Message>> {
718        self.message_rx.take().expect("Receiver already taken")
719    }
720
721    /// Send interrupt request
722    pub async fn interrupt(&mut self) -> Result<()> {
723        let interrupt_request = SDKControlRequest::Interrupt(SDKControlInterruptRequest {
724            subtype: "interrupt".to_string(),
725        });
726
727        self.send_control_request(interrupt_request).await?;
728        Ok(())
729    }
730
731    /// Set permission mode via control protocol
732    #[allow(dead_code)]
733    pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
734        let req = SDKControlRequest::SetPermissionMode(SDKControlSetPermissionModeRequest {
735            subtype: "set_permission_mode".to_string(),
736            mode: mode.to_string(),
737        });
738        // Ignore response payload; errors propagate
739        let _ = self.send_control_request(req).await?;
740        Ok(())
741    }
742
743    /// Set the active model via control protocol
744    #[allow(dead_code)]
745    pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
746        let req = SDKControlRequest::SetModel(crate::types::SDKControlSetModelRequest {
747            subtype: "set_model".to_string(),
748            model,
749        });
750        let _ = self.send_control_request(req).await?;
751        Ok(())
752    }
753
754    /// Rewind tracked files to their state at a specific user message
755    ///
756    /// Requires `enable_file_checkpointing` to be enabled in `ClaudeCodeOptions`.
757    ///
758    /// # Arguments
759    ///
760    /// * `user_message_id` - UUID of the user message to rewind to
761    ///
762    /// # Example
763    ///
764    /// ```rust,no_run
765    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
766    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
767    /// let options = ClaudeCodeOptions::builder()
768    ///     .enable_file_checkpointing(true)
769    ///     .build();
770    /// let mut client = ClaudeSDKClient::new(options);
771    /// client.connect(None).await?;
772    ///
773    /// // Later, rewind to a checkpoint
774    /// // client.rewind_files("user-message-uuid-here").await?;
775    /// # Ok(())
776    /// # }
777    /// ```
778    pub async fn rewind_files(&mut self, user_message_id: &str) -> Result<()> {
779        let req = SDKControlRequest::RewindFiles(crate::types::SDKControlRewindFilesRequest::new(user_message_id));
780        let _ = self.send_control_request(req).await?;
781        Ok(())
782    }
783
784    /// Get context usage information
785    pub async fn get_context_usage(&mut self) -> Result<serde_json::Value> {
786        let req = SDKControlRequest::GetContextUsage(crate::types::SDKControlGetContextUsageRequest::new());
787        self.send_control_request(req).await
788    }
789
790    /// Stop a background task
791    pub async fn stop_task(&mut self, task_id: &str) -> Result<()> {
792        let req = SDKControlRequest::StopTask(crate::types::SDKControlStopTaskRequest::new(task_id));
793        let _ = self.send_control_request(req).await?;
794        Ok(())
795    }
796
797    /// Get MCP server status
798    pub async fn get_mcp_status(&mut self) -> Result<serde_json::Value> {
799        let req = SDKControlRequest::McpStatus(crate::types::SDKControlMcpStatusRequest::new());
800        self.send_control_request(req).await
801    }
802
803    /// Reconnect an MCP server
804    pub async fn reconnect_mcp_server(&mut self, server_name: &str) -> Result<()> {
805        let req = SDKControlRequest::McpReconnect(crate::types::SDKControlMcpReconnectRequest::new(server_name));
806        let _ = self.send_control_request(req).await?;
807        Ok(())
808    }
809
810    /// Toggle an MCP server on/off
811    pub async fn toggle_mcp_server(&mut self, server_name: &str, enabled: bool) -> Result<()> {
812        let req = SDKControlRequest::McpToggle(crate::types::SDKControlMcpToggleRequest::new(server_name, enabled));
813        let _ = self.send_control_request(req).await?;
814        Ok(())
815    }
816
817    /// Handle MCP message for SDK servers
818    #[allow(dead_code)]
819    async fn handle_mcp_message(&mut self, server_name: &str, message: &JsonValue) -> Result<JsonValue> {
820        // Check if we have an SDK server with this name
821        if let Some(_server) = self.sdk_mcp_servers.get(server_name) {
822            // TODO: Implement actual MCP server invocation
823            // For now, return a placeholder response
824            debug!("Handling MCP message for SDK server {}: {:?}", server_name, message);
825            Ok(serde_json::json!({
826                "jsonrpc": "2.0",
827                "id": message.get("id"),
828                "result": {
829                    "content": "MCP server response placeholder"
830                }
831            }))
832        } else {
833            Err(SdkError::InvalidState {
834                message: format!("No SDK MCP server found with name: {server_name}"),
835            })
836        }
837    }
838
839    /// Close the query handler
840    #[allow(dead_code)]
841    pub async fn close(&mut self) -> Result<()> {
842        // Clean up resources
843        let mut transport = self.transport.lock().await;
844        transport.disconnect().await?;
845        Ok(())
846    }
847
848    /// Get initialization result
849    pub fn get_initialization_result(&self) -> Option<&JsonValue> {
850        self.initialization_result.as_ref()
851    }
852
853    /// Convert arbitrary JSON value to InputMessage understood by CLI
854    #[allow(dead_code)]
855    fn json_to_input_message(v: JsonValue) -> Option<InputMessage> {
856        // 1) Already in SDK message shape
857        if let Some(obj) = v.as_object() {
858            if let (Some(t), Some(message)) = (obj.get("type"), obj.get("message"))
859                && t.as_str() == Some("user") {
860                    let parent = obj
861                        .get("parent_tool_use_id")
862                        .and_then(|p| p.as_str().map(|s| s.to_string()));
863                    let session_id = obj
864                        .get("session_id")
865                        .and_then(|s| s.as_str())
866                        .unwrap_or("default")
867                        .to_string();
868
869                    let im = InputMessage {
870                        r#type: "user".to_string(),
871                        message: message.clone(),
872                        parent_tool_use_id: parent,
873                        session_id,
874                    };
875                    return Some(im);
876                }
877
878            // 2) Simple wrapper: {"content":"...", "session_id":"..."}
879            if let Some(content) = obj.get("content").and_then(|c| c.as_str()) {
880                let session_id = obj
881                    .get("session_id")
882                    .and_then(|s| s.as_str())
883                    .unwrap_or("default")
884                    .to_string();
885                return Some(InputMessage::user(content.to_string(), session_id));
886            }
887        }
888
889        // 3) Bare string
890        if let Some(s) = v.as_str() {
891            return Some(InputMessage::user(s.to_string(), "default".to_string()));
892        }
893
894        None
895    }
896}
897
898#[cfg(test)]
899mod tests {
900    use super::*;
901
902    #[test]
903    fn test_extract_request_id_supports_both_cases() {
904        let snake = serde_json::json!({"request_id": "req_1"});
905        let camel = serde_json::json!({"requestId": "req_2"});
906        assert_eq!(Query::extract_request_id(&snake), Some(serde_json::json!("req_1")));
907        assert_eq!(Query::extract_request_id(&camel), Some(serde_json::json!("req_2")));
908    }
909
910    #[test]
911    fn test_json_to_input_message_from_string() {
912        let v = serde_json::json!("Hello");
913        let im = Query::json_to_input_message(v).expect("should convert");
914        assert_eq!(im.r#type, "user");
915        assert_eq!(im.session_id, "default");
916        assert_eq!(im.message["content"].as_str().unwrap(), "Hello");
917    }
918
919    #[test]
920    fn test_json_to_input_message_from_object_content() {
921        let v = serde_json::json!({"content":"Ping","session_id":"s1"});
922        let im = Query::json_to_input_message(v).expect("should convert");
923        assert_eq!(im.session_id, "s1");
924        assert_eq!(im.message["content"].as_str().unwrap(), "Ping");
925    }
926
927    #[test]
928    fn test_json_to_input_message_full_user_shape() {
929        let v = serde_json::json!({
930            "type":"user",
931            "message": {"role":"user","content":"Hi"},
932            "session_id": "abc",
933            "parent_tool_use_id": null
934        });
935        let im = Query::json_to_input_message(v).expect("should convert");
936        assert_eq!(im.session_id, "abc");
937        assert_eq!(im.message["role"].as_str().unwrap(), "user");
938        assert_eq!(im.message["content"].as_str().unwrap(), "Hi");
939    }
940}