Skip to main content

cc_sdk/
internal_query.rs

1//! Internal query implementation with control protocol support
2//!
3//! This module provides the internal Query struct that handles control protocol,
4//! permissions, hooks, and MCP server integration.
5
6use crate::{
7    errors::{Result, SdkError},
8    transport::{InputMessage, Transport},
9    types::{
10        CanUseTool, HookCallback, HookContext, HookMatcher, Message, PermissionResult, PermissionUpdate,
11        SDKControlInitializeRequest, SDKControlInterruptRequest, SDKControlPermissionRequest,
12        SDKControlRequest, SDKHookCallbackRequest, SDKControlSetPermissionModeRequest,
13        ToolPermissionContext,
14    },
15};
16use futures::stream::Stream;
17use futures::StreamExt;
18use serde_json::Value as JsonValue;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use tokio::time::{timeout, Duration};
23use tracing::{debug, error, warn};
24
25/// Internal query handler with control protocol support
26pub struct Query {
27    /// Transport layer (shared with client)
28    transport: Arc<Mutex<Box<dyn Transport + Send>>>,
29    /// Whether in streaming mode
30    #[allow(dead_code)]
31    is_streaming_mode: bool,
32    /// Tool permission callback
33    can_use_tool: Option<Arc<dyn CanUseTool>>,
34    /// Hook configurations
35    hooks: Option<HashMap<String, Vec<HookMatcher>>>,
36    /// SDK MCP servers
37    sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
38    /// Message channel sender (reserved for future streaming receive support)
39    #[allow(dead_code)]
40    message_tx: mpsc::Sender<Result<Message>>,
41    /// Message channel receiver (reserved for future streaming receive support)
42    #[allow(dead_code)]
43    message_rx: Option<mpsc::Receiver<Result<Message>>>,
44    /// Initialization result
45    initialization_result: Option<JsonValue>,
46    /// Active hook callbacks
47    hook_callbacks: Arc<RwLock<HashMap<String, Arc<dyn HookCallback>>>>,
48    /// Hook callback counter
49    callback_counter: Arc<Mutex<u64>>,
50    /// Request counter for generating unique IDs
51    request_counter: Arc<Mutex<u64>>,
52    /// Pending control request responses
53    pending_responses: Arc<RwLock<HashMap<String, tokio::sync::oneshot::Sender<JsonValue>>>>,
54}
55
56impl Query {
57    /// Create a new Query handler
58    pub fn new(
59        transport: Arc<Mutex<Box<dyn Transport + Send>>>,
60        is_streaming_mode: bool,
61        can_use_tool: Option<Arc<dyn CanUseTool>>,
62        hooks: Option<HashMap<String, Vec<HookMatcher>>>,
63        sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
64    ) -> Self {
65        let (tx, rx) = mpsc::channel(100);
66
67        Self {
68            transport,
69            is_streaming_mode,
70            can_use_tool,
71            hooks,
72            sdk_mcp_servers,
73            message_tx: tx,
74            message_rx: Some(rx),
75            initialization_result: None,
76            hook_callbacks: Arc::new(RwLock::new(HashMap::new())),
77            callback_counter: Arc::new(Mutex::new(0)),
78            request_counter: Arc::new(Mutex::new(0)),
79            pending_responses: Arc::new(RwLock::new(HashMap::new())),
80        }
81    }
82
83    /// Test helper to register a hook callback with a known ID
84    ///
85    /// This is intended for E2E tests to inject a callback ID that can be
86    /// referenced by inbound `hook_callback` control messages.
87    pub async fn register_hook_callback_for_test(
88        &self,
89        callback_id: String,
90        callback: Arc<dyn HookCallback>,
91    ) {
92        let mut map = self.hook_callbacks.write().await;
93        map.insert(callback_id, callback);
94    }
95
96    /// Start the query handler
97    pub async fn start(&mut self) -> Result<()> {
98        // Start control request handler task
99        self.start_control_handler().await;
100
101        // Start SDK message forwarder task (route non-control messages to message_tx)
102        let transport = self.transport.clone();
103        let tx = self.message_tx.clone();
104        tokio::spawn(async move {
105            // Get message stream once and consume it continuously
106            let mut stream = {
107                let mut guard = transport.lock().await;
108                guard.receive_messages()
109            }; // Lock released immediately after getting stream
110
111            // Continuously consume from the same stream
112            while let Some(result) = stream.next().await {
113                match result {
114                    Ok(msg) => {
115                        if tx.send(Ok(msg)).await.is_err() {
116                            break;
117                        }
118                    }
119                    Err(e) => {
120                        let _ = tx.send(Err(e)).await;
121                        break;
122                    }
123                }
124            }
125        });
126        Ok(())
127    }
128
129    /// Initialize the control protocol
130    pub async fn initialize(&mut self) -> Result<()> {
131        // Build hooks with callback IDs (Python SDK style)
132        let hooks_with_ids = if let Some(ref hooks) = self.hooks {
133            let mut counter = self.callback_counter.lock().await;
134            let mut callbacks_map = self.hook_callbacks.write().await;
135
136            let hooks_json: HashMap<String, JsonValue> = hooks
137                .iter()
138                .map(|(event_name, matchers)| {
139                    let matchers_with_ids: Vec<JsonValue> = matchers
140                        .iter()
141                        .map(|matcher| {
142                            // Generate callback IDs for each hook in this matcher
143                            let callback_ids: Vec<String> = matcher
144                                .hooks
145                                .iter()
146                                .map(|hook_callback| {
147                                    *counter += 1;
148                                    let callback_id = format!("hook_{}_{}", *counter, uuid::Uuid::new_v4().simple());
149
150                                    // Store the callback for later use
151                                    callbacks_map.insert(callback_id.clone(), hook_callback.clone());
152
153                                    callback_id
154                                })
155                                .collect();
156
157                            serde_json::json!({
158                                "matcher": matcher.matcher.clone(),
159                                "hookCallbackIds": callback_ids
160                            })
161                        })
162                        .collect();
163
164                    (event_name.clone(), serde_json::json!(matchers_with_ids))
165                })
166                .collect();
167
168            Some(hooks_json)
169        } else {
170            None
171        };
172
173        // Send initialize request
174        let init_request = SDKControlRequest::Initialize(SDKControlInitializeRequest {
175            subtype: "initialize".to_string(),
176            hooks: hooks_with_ids,
177        });
178
179        // Send control request and save result
180        let result = self.send_control_request(init_request).await?;
181        self.initialization_result = Some(result);
182
183        debug!("Initialization request sent with hook callback IDs");
184        Ok(())
185    }
186
187    /// Send a control request and wait for response
188    async fn send_control_request(&mut self, request: SDKControlRequest) -> Result<JsonValue> {
189        // Generate unique request ID
190        let request_id = {
191            let mut counter = self.request_counter.lock().await;
192            *counter += 1;
193            format!("req_{}_{}", *counter, uuid::Uuid::new_v4().simple())
194        };
195
196        // Create oneshot channel for response
197        let (tx, rx) = tokio::sync::oneshot::channel();
198
199        // Register pending response
200        {
201            let mut pending = self.pending_responses.write().await;
202            pending.insert(request_id.clone(), tx);
203        }
204
205        // Build control request with request_id (snake_case for CLI compatibility)
206        let control_request = serde_json::json!({
207            "type": "control_request",
208            "request_id": request_id,
209            "request": request
210        });
211
212        debug!("Sending control request: {:?}", control_request);
213
214        // Send via transport
215        {
216            let mut transport = self.transport.lock().await;
217            transport.send_sdk_control_request(control_request).await?;
218        }
219
220        // Wait for response with timeout
221        match timeout(Duration::from_secs(60), rx).await {
222            Ok(Ok(response)) => {
223                debug!("Received control response for {}", request_id);
224
225                // Python parity: treat subtype=error as an error, and return only
226                // the payload from `response` (or legacy `data`) on success.
227                if response.get("subtype").and_then(|v| v.as_str()) == Some("error") {
228                    let msg = response
229                        .get("error")
230                        .and_then(|v| v.as_str())
231                        .unwrap_or("Unknown control request error");
232                    return Err(SdkError::ControlRequestError(msg.to_string()));
233                }
234
235                Ok(response
236                    .get("response")
237                    .or_else(|| response.get("data"))
238                    .cloned()
239                    .unwrap_or_else(|| serde_json::json!({})))
240            }
241            Ok(Err(_)) => Err(SdkError::ControlRequestError(
242                "Response channel closed".to_string(),
243            )),
244            Err(_) => {
245                // Clean up pending response
246                let mut pending = self.pending_responses.write().await;
247                pending.remove(&request_id);
248                Err(SdkError::Timeout { seconds: 60 })
249            }
250        }
251    }
252
253    /// Handle permission request
254    #[allow(dead_code)]
255    async fn handle_permission_request(&mut self, request: SDKControlPermissionRequest) -> Result<()> {
256        if let Some(ref can_use_tool) = self.can_use_tool {
257            let context = ToolPermissionContext {
258                signal: None,
259                suggestions: request.permission_suggestions.unwrap_or_default(),
260            };
261
262            let result = can_use_tool
263                .can_use_tool(&request.tool_name, &request.input, &context)
264                .await;
265
266            // Send response back (CLI expects: { allow: bool, input?, reason? })
267            let response = match result {
268                PermissionResult::Allow(allow) => {
269                    let mut obj = serde_json::json!({ "allow": true });
270                    if let Some(updated) = allow.updated_input {
271                        obj["input"] = updated;
272                    }
273                    obj
274                }
275                PermissionResult::Deny(deny) => {
276                    let mut obj = serde_json::json!({ "allow": false });
277                    if !deny.message.is_empty() {
278                        obj["reason"] = serde_json::json!(deny.message);
279                    }
280                    obj
281                }
282            };
283
284            // Send response back through transport
285            let mut transport = self.transport.lock().await;
286            transport.send_sdk_control_response(response).await?;
287            debug!("Permission response sent");
288        }
289        
290        Ok(())
291    }
292
293    /// Extract requestId from CLI message (supports both camelCase and snake_case)
294    fn extract_request_id(msg: &JsonValue) -> Option<JsonValue> {
295        msg.get("requestId")
296            .or_else(|| msg.get("request_id"))
297            .cloned()
298    }
299
300    /// Start control request handler task
301    async fn start_control_handler(&mut self) {
302        let transport = self.transport.clone();
303        let can_use_tool = self.can_use_tool.clone();
304        let hook_callbacks = self.hook_callbacks.clone();
305        let sdk_mcp_servers = self.sdk_mcp_servers.clone();
306        let pending_responses = self.pending_responses.clone();
307        
308        // Take ownership of the SDK control receiver to avoid holding locks
309        let sdk_control_rx = {
310            let mut transport_lock = transport.lock().await;
311            transport_lock.take_sdk_control_receiver()
312        }; // Lock released here
313        
314        if let Some(mut control_rx) = sdk_control_rx {
315            tokio::spawn(async move {
316                // Now we can receive control requests without holding any locks
317                let transport_for_control = transport;
318                let can_use_tool_clone = can_use_tool;
319                let hook_callbacks_clone = hook_callbacks;
320                let sdk_mcp_servers_clone = sdk_mcp_servers;
321                let pending_responses_clone = pending_responses;
322
323                loop {
324                    // Receive control request without holding lock
325                    let control_message = control_rx.recv().await;
326
327                    if let Some(control_message) = control_message {
328                        debug!("Received control message: {:?}", control_message);
329
330                        // Check if this is a control response (from CLI to SDK)
331                        if control_message.get("type").and_then(|v| v.as_str()) == Some("control_response") {
332                            // Expected shape: {"type":"control_response", "response": {"request_id": "...", ...}}
333                            if let Some(resp_obj) = control_message.get("response") {
334                                let request_id = resp_obj
335                                    .get("request_id")
336                                    .or_else(|| resp_obj.get("requestId"))
337                                    .and_then(|v| v.as_str());
338
339                                if let Some(request_id) = request_id {
340                                    let mut pending = pending_responses_clone.write().await;
341                                    if let Some(tx) = pending.remove(request_id) {
342                                        // Deliver the nested control response object; send_control_request will
343                                        // extract the `response` (or legacy `data`) payload for callers.
344                                        let _ = tx.send(resp_obj.clone());
345                                        debug!("Control response delivered for request_id: {}", request_id);
346                                    } else {
347                                        warn!("No pending request found for request_id: {}", request_id);
348                                    }
349                                } else {
350                                    warn!("Control response missing request_id: {:?}", control_message);
351                                }
352                            } else {
353                                warn!("Control response missing 'response' payload: {:?}", control_message);
354                            }
355                            continue;
356                        }
357
358                        // Parse and handle control requests (from CLI to SDK)
359                        // Check if this is a control_request with a nested request field
360                        let request_data = if control_message.get("type").and_then(|v| v.as_str()) == Some("control_request") {
361                            control_message.get("request").cloned().unwrap_or(control_message.clone())
362                        } else {
363                            control_message.clone()
364                        };
365
366                        if let Some(subtype) = request_data.get("subtype").and_then(|v| v.as_str()) {
367                            match subtype {
368                                "can_use_tool" => {
369                                    // Handle permission request
370                                    if let Ok(request) = serde_json::from_value::<SDKControlPermissionRequest>(request_data.clone()) {
371                                        // Handle with can_use_tool callback
372                                        if let Some(ref can_use_tool) = can_use_tool_clone {
373                                            let context = ToolPermissionContext {
374                                                signal: None,
375                                                suggestions: request.permission_suggestions.unwrap_or_default(),
376                                            };
377                                                
378                                            // Save original input for fallback (Python SDK always sends updatedInput)
379                                            let original_input = request.input.clone();
380
381                                            let result = can_use_tool
382                                                .can_use_tool(&request.tool_name, &request.input, &context)
383                                                .await;
384
385                                            // Match Python SDK response format:
386                                            // Allow: {"behavior": "allow", "updatedInput": ..., "updatedPermissions": ...}
387                                            // Deny: {"behavior": "deny", "message": "...", "interrupt": false}
388                                            // NOTE: updatedInput is ALWAYS required for allow (CLI Zod schema expects it)
389                                            let permission_response = match result {
390                                                PermissionResult::Allow(allow) => {
391                                                    let mut resp = serde_json::json!({
392                                                        "behavior": "allow",
393                                                        "updatedInput": allow.updated_input.unwrap_or(original_input),
394                                                    });
395                                                    if let Some(perms) = allow.updated_permissions {
396                                                        resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default();
397                                                    }
398                                                    resp
399                                                }
400                                                PermissionResult::Deny(deny) => {
401                                                    let mut resp = serde_json::json!({
402                                                        "behavior": "deny",
403                                                    });
404                                                    if !deny.message.is_empty() {
405                                                        resp["message"] = serde_json::json!(deny.message);
406                                                    }
407                                                    if deny.interrupt {
408                                                        resp["interrupt"] = serde_json::json!(true);
409                                                    }
410                                                    resp
411                                                }
412                                            };
413                                                
414                                            // Wrap response with proper structure
415                                            // CLI expects "subtype": "success" for all successful responses
416                                            let response = serde_json::json!({
417                                                "subtype": "success",
418                                                "request_id": Self::extract_request_id(&control_message),
419                                                "response": permission_response
420                                            });
421                                                
422                                            // Send response
423                                            let mut transport = transport_for_control.lock().await;
424                                            if let Err(e) = transport.send_sdk_control_response(response).await {
425                                                error!("Failed to send permission response: {}", e);
426                                            }
427                                        }
428                                    } else {
429                                        // Fallback for snake_case fields (tool_name, permission_suggestions)
430                                        if let Some(tool_name) = request_data.get("tool_name").and_then(|v| v.as_str())
431                                            && let Some(input_val) = request_data.get("input").cloned()
432                                                && let Some(ref can_use_tool) = can_use_tool_clone {
433                                                    // Try to parse permission suggestions (snake_case)
434                                                    let suggestions: Vec<PermissionUpdate> = request_data
435                                                        .get("permission_suggestions")
436                                                        .cloned()
437                                                        .and_then(|v| serde_json::from_value::<Vec<PermissionUpdate>>(v).ok())
438                                                        .unwrap_or_default();
439
440                                                    let context = ToolPermissionContext { signal: None, suggestions };
441                                                    let original_input = input_val.clone();
442                                                    let result = can_use_tool
443                                                        .can_use_tool(tool_name, &input_val, &context)
444                                                        .await;
445
446                                                    let permission_response = match result {
447                                                        PermissionResult::Allow(allow) => {
448                                                            let mut resp = serde_json::json!({
449                                                                "behavior": "allow",
450                                                                "updatedInput": allow.updated_input.unwrap_or(original_input),
451                                                            });
452                                                            if let Some(perms) = allow.updated_permissions { resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default(); }
453                                                            resp
454                                                        }
455                                                        PermissionResult::Deny(deny) => {
456                                                            let mut resp = serde_json::json!({ "behavior": "deny" });
457                                                            if !deny.message.is_empty() { resp["message"] = serde_json::json!(deny.message); }
458                                                            if deny.interrupt { resp["interrupt"] = serde_json::json!(true); }
459                                                            resp
460                                                        }
461                                                    };
462
463                                                    let response = serde_json::json!({
464                                                        "subtype": "success",
465                                                        "request_id": Self::extract_request_id(&control_message),
466                                                        "response": permission_response
467                                                    });
468                                                    let mut transport = transport_for_control.lock().await;
469                                                    if let Err(e) = transport.send_sdk_control_response(response).await {
470                                                        error!("Failed to send permission response (fallback): {}", e);
471                                                    }
472                                                }
473                                    }
474                                }
475                                "hook_callback" => {
476                                    // Handle hook callback with strongly-typed inputs/outputs
477                                    if let Ok(request) = serde_json::from_value::<SDKHookCallbackRequest>(request_data.clone()) {
478                                        let callbacks = hook_callbacks_clone.read().await;
479
480                                        if let Some(callback) = callbacks.get(&request.callback_id) {
481                                            let context = HookContext { signal: None };
482
483                                            // Try to deserialize input as HookInput
484                                            let hook_result = match serde_json::from_value::<crate::types::HookInput>(request.input.clone()) {
485                                                Ok(hook_input) => {
486                                                    // Call the hook with strongly-typed input
487                                                    callback
488                                                        .execute(&hook_input, request.tool_use_id.as_deref(), &context)
489                                                        .await
490                                                }
491                                                Err(parse_err) => {
492                                                    error!("Failed to parse hook input: {}", parse_err);
493                                                    // Return error using MessageParseError
494                                                    Err(crate::errors::SdkError::MessageParseError {
495                                                        error: format!("Invalid hook input: {parse_err}"),
496                                                        raw: request.input.to_string(),
497                                                    })
498                                                }
499                                            };
500
501                                            // Handle hook result
502                                            let response_json = match hook_result {
503                                                Ok(hook_output) => {
504                                                    // Serialize HookJSONOutput to JSON
505                                                    let output_value = serde_json::to_value(&hook_output)
506                                                        .unwrap_or_else(|e| {
507                                                            error!("Failed to serialize hook output: {}", e);
508                                                            serde_json::json!({})
509                                                        });
510
511                                                    serde_json::json!({
512                                                        "subtype": "success",
513                                                        "request_id": Self::extract_request_id(&control_message),
514                                                        "response": output_value
515                                                    })
516                                                }
517                                                Err(e) => {
518                                                    error!("Hook callback failed: {}", e);
519                                                    serde_json::json!({
520                                                        "subtype": "error",
521                                                        "request_id": Self::extract_request_id(&control_message),
522                                                        "error": e.to_string()
523                                                    })
524                                                }
525                                            };
526
527                                            let mut transport = transport_for_control.lock().await;
528                                            if let Err(e) = transport.send_sdk_control_response(response_json).await {
529                                                error!("Failed to send hook callback response: {}", e);
530                                            }
531                                        } else {
532                                            warn!("No hook callback found for ID: {}", request.callback_id);
533                                            // Send error response
534                                            let error_response = serde_json::json!({
535                                                "subtype": "error",
536                                                "request_id": Self::extract_request_id(&control_message),
537                                                "error": format!("No hook callback found for ID: {}", request.callback_id)
538                                            });
539                                            let mut transport = transport_for_control.lock().await;
540                                            if let Err(e) = transport.send_sdk_control_response(error_response).await {
541                                                error!("Failed to send error response: {}", e);
542                                            }
543                                        }
544                                    } else {
545                                        // Fallback for snake_case fields (callback_id, tool_use_id)
546                                        let callback_id = request_data.get("callback_id").and_then(|v| v.as_str());
547                                        let tool_use_id = request_data.get("tool_use_id").and_then(|v| v.as_str()).map(|s| s.to_string());
548                                        let input = request_data.get("input").cloned().unwrap_or(serde_json::json!({}));
549
550                                        if let Some(callback_id) = callback_id {
551                                            let callbacks = hook_callbacks_clone.read().await;
552                                            if let Some(callback) = callbacks.get(callback_id) {
553                                                let context = HookContext { signal: None };
554
555                                                // Try to parse as HookInput
556                                                let hook_result = match serde_json::from_value::<crate::types::HookInput>(input.clone()) {
557                                                    Ok(hook_input) => {
558                                                        callback
559                                                            .execute(&hook_input, tool_use_id.as_deref(), &context)
560                                                            .await
561                                                    }
562                                                    Err(parse_err) => {
563                                                        error!("Failed to parse hook input (fallback): {}", parse_err);
564                                                        Err(crate::errors::SdkError::MessageParseError {
565                                                            error: format!("Invalid hook input: {parse_err}"),
566                                                            raw: input.to_string(),
567                                                        })
568                                                    }
569                                                };
570
571                                                let response_json = match hook_result {
572                                                    Ok(hook_output) => {
573                                                        let output_value = serde_json::to_value(&hook_output)
574                                                            .unwrap_or_else(|e| {
575                                                                error!("Failed to serialize hook output (fallback): {}", e);
576                                                                serde_json::json!({})
577                                                            });
578
579                                                        serde_json::json!({
580                                                            "subtype": "success",
581                                                            "request_id": Self::extract_request_id(&control_message),
582                                                            "response": output_value
583                                                        })
584                                                    }
585                                                    Err(e) => {
586                                                        error!("Hook callback failed (fallback): {}", e);
587                                                        serde_json::json!({
588                                                            "subtype": "error",
589                                                            "request_id": Self::extract_request_id(&control_message),
590                                                            "error": e.to_string()
591                                                        })
592                                                    }
593                                                };
594
595                                                let mut transport = transport_for_control.lock().await;
596                                                if let Err(e) = transport.send_sdk_control_response(response_json).await {
597                                                    error!("Failed to send hook callback response (fallback): {}", e);
598                                                }
599                                            } else {
600                                                warn!("No hook callback found for ID: {}", callback_id);
601                                            }
602                                        } else {
603                                            warn!("Invalid hook_callback control message: missing callback_id");
604                                        }
605                                    }
606                                }
607                                "mcp_message" => {
608                                    // Handle MCP message
609                                    if let Some(server_name) = request_data.get("server_name").and_then(|v| v.as_str())
610                                        && let Some(message) = request_data.get("message") {
611                                            debug!("Processing MCP message for SDK server: {}", server_name);
612
613                                            // Check if we have an SDK server with this name
614                                            if let Some(server_arc) = sdk_mcp_servers_clone.get(server_name) {
615                                                // Try to downcast to SdkMcpServer
616                                                if let Some(sdk_server) = server_arc.downcast_ref::<crate::sdk_mcp::SdkMcpServer>() {
617                                                    // Call the SDK MCP server
618                                                    match sdk_server.handle_message(message.clone()).await {
619                                                        Ok(mcp_result) => {
620                                                            // Wrap response with proper structure
621                                                            let response = serde_json::json!({
622                                                                "subtype": "success",
623                                                                "request_id": Self::extract_request_id(&control_message),
624                                                                "response": {
625                                                                    "mcp_response": mcp_result
626                                                                }
627                                                            });
628
629                                                            let mut transport = transport_for_control.lock().await;
630                                                            if let Err(e) = transport.send_sdk_control_response(response).await {
631                                                                error!("Failed to send MCP response: {}", e);
632                                                            }
633                                                        }
634                                                        Err(e) => {
635                                                            error!("SDK MCP server error: {}", e);
636                                                            let error_response = serde_json::json!({
637                                                                "subtype": "error",
638                                                                "request_id": Self::extract_request_id(&control_message),
639                                                                "error": format!("MCP server error: {}", e)
640                                                            });
641
642                                                            let mut transport = transport_for_control.lock().await;
643                                                            if let Err(e) = transport.send_sdk_control_response(error_response).await {
644                                                                error!("Failed to send MCP error response: {}", e);
645                                                            }
646                                                        }
647                                                    }
648                                                } else {
649                                                    warn!("SDK server '{}' is not of type SdkMcpServer", server_name);
650                                                }
651                                            } else {
652                                                warn!("No SDK MCP server found with name: {}", server_name);
653                                                let error_response = serde_json::json!({
654                                                    "subtype": "error",
655                                                    "request_id": Self::extract_request_id(&control_message),
656                                                    "error": format!("Server '{}' not found", server_name)
657                                                });
658
659                                                let mut transport = transport_for_control.lock().await;
660                                                if let Err(e) = transport.send_sdk_control_response(error_response).await {
661                                                    error!("Failed to send MCP error response: {}", e);
662                                                }
663                                            }
664                                        }
665                                }
666                                _ => {
667                                    debug!("Unknown SDK control subtype: {}", subtype);
668                                }
669                            }
670                        }
671                    }
672                }
673            });
674        }
675    }
676
677    /// Stream input messages to the CLI stdin by converting JSON values to InputMessage
678    #[allow(dead_code)]
679    pub async fn stream_input<S>(&mut self, input_stream: S) -> Result<()>
680    where
681        S: Stream<Item = JsonValue> + Send + 'static,
682    {
683        let transport = self.transport.clone();
684
685        tokio::spawn(async move {
686            use futures::StreamExt;
687            let mut stream = Box::pin(input_stream);
688
689            while let Some(value) = stream.next().await {
690                // Best-effort conversion from arbitrary JSON to InputMessage
691                let input_msg_opt = Self::json_to_input_message(value);
692                if let Some(input_msg) = input_msg_opt {
693                    let mut guard = transport.lock().await;
694                    if let Err(e) = guard.send_message(input_msg).await {
695                        warn!("Failed to send streaming input message: {}", e);
696                    }
697                } else {
698                    warn!("Invalid streaming input JSON; expected user message shape");
699                }
700            }
701
702            // After streaming all inputs, signal end of input
703            let mut guard = transport.lock().await;
704            if let Err(e) = guard.end_input().await {
705                warn!("Failed to signal end_input: {}", e);
706            }
707        });
708        Ok(())
709    }
710
711    /// Receive messages
712    #[allow(dead_code)]
713    pub async fn receive_messages(&mut self) -> mpsc::Receiver<Result<Message>> {
714        self.message_rx.take().expect("Receiver already taken")
715    }
716
717    /// Send interrupt request
718    pub async fn interrupt(&mut self) -> Result<()> {
719        let interrupt_request = SDKControlRequest::Interrupt(SDKControlInterruptRequest {
720            subtype: "interrupt".to_string(),
721        });
722
723        self.send_control_request(interrupt_request).await?;
724        Ok(())
725    }
726
727    /// Set permission mode via control protocol
728    #[allow(dead_code)]
729    pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
730        let req = SDKControlRequest::SetPermissionMode(SDKControlSetPermissionModeRequest {
731            subtype: "set_permission_mode".to_string(),
732            mode: mode.to_string(),
733        });
734        // Ignore response payload; errors propagate
735        let _ = self.send_control_request(req).await?;
736        Ok(())
737    }
738
739    /// Set the active model via control protocol
740    #[allow(dead_code)]
741    pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
742        let req = SDKControlRequest::SetModel(crate::types::SDKControlSetModelRequest {
743            subtype: "set_model".to_string(),
744            model,
745        });
746        let _ = self.send_control_request(req).await?;
747        Ok(())
748    }
749
750    /// Rewind tracked files to their state at a specific user message
751    ///
752    /// Requires `enable_file_checkpointing` to be enabled in `ClaudeCodeOptions`.
753    ///
754    /// # Arguments
755    ///
756    /// * `user_message_id` - UUID of the user message to rewind to
757    ///
758    /// # Example
759    ///
760    /// ```rust,no_run
761    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
762    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
763    /// let options = ClaudeCodeOptions::builder()
764    ///     .enable_file_checkpointing(true)
765    ///     .build();
766    /// let mut client = ClaudeSDKClient::new(options);
767    /// client.connect(None).await?;
768    ///
769    /// // Later, rewind to a checkpoint
770    /// // client.rewind_files("user-message-uuid-here").await?;
771    /// # Ok(())
772    /// # }
773    /// ```
774    pub async fn rewind_files(&mut self, user_message_id: &str) -> Result<()> {
775        let req = SDKControlRequest::RewindFiles(crate::types::SDKControlRewindFilesRequest::new(user_message_id));
776        let _ = self.send_control_request(req).await?;
777        Ok(())
778    }
779
780    /// Get context usage information
781    pub async fn get_context_usage(&mut self) -> Result<serde_json::Value> {
782        let req = SDKControlRequest::GetContextUsage(crate::types::SDKControlGetContextUsageRequest::new());
783        self.send_control_request(req).await
784    }
785
786    /// Stop a background task
787    pub async fn stop_task(&mut self, task_id: &str) -> Result<()> {
788        let req = SDKControlRequest::StopTask(crate::types::SDKControlStopTaskRequest::new(task_id));
789        let _ = self.send_control_request(req).await?;
790        Ok(())
791    }
792
793    /// Get MCP server status
794    pub async fn get_mcp_status(&mut self) -> Result<serde_json::Value> {
795        let req = SDKControlRequest::McpStatus(crate::types::SDKControlMcpStatusRequest::new());
796        self.send_control_request(req).await
797    }
798
799    /// Reconnect an MCP server
800    pub async fn reconnect_mcp_server(&mut self, server_name: &str) -> Result<()> {
801        let req = SDKControlRequest::McpReconnect(crate::types::SDKControlMcpReconnectRequest::new(server_name));
802        let _ = self.send_control_request(req).await?;
803        Ok(())
804    }
805
806    /// Toggle an MCP server on/off
807    pub async fn toggle_mcp_server(&mut self, server_name: &str, enabled: bool) -> Result<()> {
808        let req = SDKControlRequest::McpToggle(crate::types::SDKControlMcpToggleRequest::new(server_name, enabled));
809        let _ = self.send_control_request(req).await?;
810        Ok(())
811    }
812
813    /// Handle MCP message for SDK servers
814    #[allow(dead_code)]
815    async fn handle_mcp_message(&mut self, server_name: &str, message: &JsonValue) -> Result<JsonValue> {
816        // Check if we have an SDK server with this name
817        if let Some(_server) = self.sdk_mcp_servers.get(server_name) {
818            // TODO: Implement actual MCP server invocation
819            // For now, return a placeholder response
820            debug!("Handling MCP message for SDK server {}: {:?}", server_name, message);
821            Ok(serde_json::json!({
822                "jsonrpc": "2.0",
823                "id": message.get("id"),
824                "result": {
825                    "content": "MCP server response placeholder"
826                }
827            }))
828        } else {
829            Err(SdkError::InvalidState {
830                message: format!("No SDK MCP server found with name: {server_name}"),
831            })
832        }
833    }
834
835    /// Close the query handler
836    #[allow(dead_code)]
837    pub async fn close(&mut self) -> Result<()> {
838        // Clean up resources
839        let mut transport = self.transport.lock().await;
840        transport.disconnect().await?;
841        Ok(())
842    }
843
844    /// Get initialization result
845    pub fn get_initialization_result(&self) -> Option<&JsonValue> {
846        self.initialization_result.as_ref()
847    }
848
849    /// Convert arbitrary JSON value to InputMessage understood by CLI
850    #[allow(dead_code)]
851    fn json_to_input_message(v: JsonValue) -> Option<InputMessage> {
852        // 1) Already in SDK message shape
853        if let Some(obj) = v.as_object() {
854            if let (Some(t), Some(message)) = (obj.get("type"), obj.get("message"))
855                && t.as_str() == Some("user") {
856                    let parent = obj
857                        .get("parent_tool_use_id")
858                        .and_then(|p| p.as_str().map(|s| s.to_string()));
859                    let session_id = obj
860                        .get("session_id")
861                        .and_then(|s| s.as_str())
862                        .unwrap_or("default")
863                        .to_string();
864
865                    let im = InputMessage {
866                        r#type: "user".to_string(),
867                        message: message.clone(),
868                        parent_tool_use_id: parent,
869                        session_id,
870                    };
871                    return Some(im);
872                }
873
874            // 2) Simple wrapper: {"content":"...", "session_id":"..."}
875            if let Some(content) = obj.get("content").and_then(|c| c.as_str()) {
876                let session_id = obj
877                    .get("session_id")
878                    .and_then(|s| s.as_str())
879                    .unwrap_or("default")
880                    .to_string();
881                return Some(InputMessage::user(content.to_string(), session_id));
882            }
883        }
884
885        // 3) Bare string
886        if let Some(s) = v.as_str() {
887            return Some(InputMessage::user(s.to_string(), "default".to_string()));
888        }
889
890        None
891    }
892}
893
894#[cfg(test)]
895mod tests {
896    use super::*;
897
898    #[test]
899    fn test_extract_request_id_supports_both_cases() {
900        let snake = serde_json::json!({"request_id": "req_1"});
901        let camel = serde_json::json!({"requestId": "req_2"});
902        assert_eq!(Query::extract_request_id(&snake), Some(serde_json::json!("req_1")));
903        assert_eq!(Query::extract_request_id(&camel), Some(serde_json::json!("req_2")));
904    }
905
906    #[test]
907    fn test_json_to_input_message_from_string() {
908        let v = serde_json::json!("Hello");
909        let im = Query::json_to_input_message(v).expect("should convert");
910        assert_eq!(im.r#type, "user");
911        assert_eq!(im.session_id, "default");
912        assert_eq!(im.message["content"].as_str().unwrap(), "Hello");
913    }
914
915    #[test]
916    fn test_json_to_input_message_from_object_content() {
917        let v = serde_json::json!({"content":"Ping","session_id":"s1"});
918        let im = Query::json_to_input_message(v).expect("should convert");
919        assert_eq!(im.session_id, "s1");
920        assert_eq!(im.message["content"].as_str().unwrap(), "Ping");
921    }
922
923    #[test]
924    fn test_json_to_input_message_full_user_shape() {
925        let v = serde_json::json!({
926            "type":"user",
927            "message": {"role":"user","content":"Hi"},
928            "session_id": "abc",
929            "parent_tool_use_id": null
930        });
931        let im = Query::json_to_input_message(v).expect("should convert");
932        assert_eq!(im.session_id, "abc");
933        assert_eq!(im.message["role"].as_str().unwrap(), "user");
934        assert_eq!(im.message["content"].as_str().unwrap(), "Hi");
935    }
936}