cc_sdk/
internal_query.rs

1//! Internal query implementation with control protocol support
2//!
3//! This module provides the internal Query struct that handles control protocol,
4//! permissions, hooks, and MCP server integration.
5
6use crate::{
7    errors::{Result, SdkError},
8    transport::{InputMessage, Transport},
9    types::{
10        CanUseTool, HookCallback, HookContext, HookMatcher, Message, PermissionResult, PermissionUpdate,
11        SDKControlInitializeRequest, SDKControlInterruptRequest, SDKControlPermissionRequest,
12        SDKControlRequest, SDKHookCallbackRequest, SDKControlSetPermissionModeRequest,
13        ToolPermissionContext,
14    },
15};
16use futures::stream::Stream;
17use futures::StreamExt;
18use serde_json::Value as JsonValue;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use tokio::time::{timeout, Duration};
23use tracing::{debug, error, warn};
24
25/// Internal query handler with control protocol support
26pub struct Query {
27    /// Transport layer (shared with client)
28    transport: Arc<Mutex<Box<dyn Transport + Send>>>,
29    /// Whether in streaming mode
30    #[allow(dead_code)]
31    is_streaming_mode: bool,
32    /// Tool permission callback
33    can_use_tool: Option<Arc<dyn CanUseTool>>,
34    /// Hook configurations
35    hooks: Option<HashMap<String, Vec<HookMatcher>>>,
36    /// SDK MCP servers
37    sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
38    /// Message channel sender (reserved for future streaming receive support)
39    #[allow(dead_code)]
40    message_tx: mpsc::Sender<Result<Message>>,
41    /// Message channel receiver (reserved for future streaming receive support)
42    #[allow(dead_code)]
43    message_rx: Option<mpsc::Receiver<Result<Message>>>,
44    /// Initialization result
45    initialization_result: Option<JsonValue>,
46    /// Active hook callbacks
47    hook_callbacks: Arc<RwLock<HashMap<String, Arc<dyn HookCallback>>>>,
48    /// Hook callback counter
49    callback_counter: Arc<Mutex<u64>>,
50    /// Request counter for generating unique IDs
51    request_counter: Arc<Mutex<u64>>,
52    /// Pending control request responses
53    pending_responses: Arc<RwLock<HashMap<String, tokio::sync::oneshot::Sender<JsonValue>>>>,
54}
55
56impl Query {
57    /// Create a new Query handler
58    pub fn new(
59        transport: Arc<Mutex<Box<dyn Transport + Send>>>,
60        is_streaming_mode: bool,
61        can_use_tool: Option<Arc<dyn CanUseTool>>,
62        hooks: Option<HashMap<String, Vec<HookMatcher>>>,
63        sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
64    ) -> Self {
65        let (tx, rx) = mpsc::channel(100);
66
67        Self {
68            transport,
69            is_streaming_mode,
70            can_use_tool,
71            hooks,
72            sdk_mcp_servers,
73            message_tx: tx,
74            message_rx: Some(rx),
75            initialization_result: None,
76            hook_callbacks: Arc::new(RwLock::new(HashMap::new())),
77            callback_counter: Arc::new(Mutex::new(0)),
78            request_counter: Arc::new(Mutex::new(0)),
79            pending_responses: Arc::new(RwLock::new(HashMap::new())),
80        }
81    }
82
83    /// Test helper to register a hook callback with a known ID
84    ///
85    /// This is intended for E2E tests to inject a callback ID that can be
86    /// referenced by inbound `hook_callback` control messages.
87    pub async fn register_hook_callback_for_test(
88        &self,
89        callback_id: String,
90        callback: Arc<dyn HookCallback>,
91    ) {
92        let mut map = self.hook_callbacks.write().await;
93        map.insert(callback_id, callback);
94    }
95
96    /// Start the query handler
97    pub async fn start(&mut self) -> Result<()> {
98        // Start control request handler task
99        self.start_control_handler().await;
100
101        // Start SDK message forwarder task (route non-control messages to message_tx)
102        let transport = self.transport.clone();
103        let tx = self.message_tx.clone();
104        tokio::spawn(async move {
105            // Get message stream once and consume it continuously
106            let mut stream = {
107                let mut guard = transport.lock().await;
108                guard.receive_messages()
109            }; // Lock released immediately after getting stream
110
111            // Continuously consume from the same stream
112            while let Some(result) = stream.next().await {
113                match result {
114                    Ok(msg) => {
115                        if tx.send(Ok(msg)).await.is_err() {
116                            break;
117                        }
118                    }
119                    Err(e) => {
120                        let _ = tx.send(Err(e)).await;
121                        break;
122                    }
123                }
124            }
125        });
126        Ok(())
127    }
128
129    /// Initialize the control protocol
130    pub async fn initialize(&mut self) -> Result<()> {
131        // Build hooks with callback IDs (Python SDK style)
132        let hooks_with_ids = if let Some(ref hooks) = self.hooks {
133            let mut counter = self.callback_counter.lock().await;
134            let mut callbacks_map = self.hook_callbacks.write().await;
135
136            let hooks_json: HashMap<String, JsonValue> = hooks
137                .iter()
138                .map(|(event_name, matchers)| {
139                    let matchers_with_ids: Vec<JsonValue> = matchers
140                        .iter()
141                        .map(|matcher| {
142                            // Generate callback IDs for each hook in this matcher
143                            let callback_ids: Vec<String> = matcher
144                                .hooks
145                                .iter()
146                                .map(|hook_callback| {
147                                    *counter += 1;
148                                    let callback_id = format!("hook_{}_{}", *counter, uuid::Uuid::new_v4().simple());
149
150                                    // Store the callback for later use
151                                    callbacks_map.insert(callback_id.clone(), hook_callback.clone());
152
153                                    callback_id
154                                })
155                                .collect();
156
157                            serde_json::json!({
158                                "matcher": matcher.matcher.clone(),
159                                "hookCallbackIds": callback_ids
160                            })
161                        })
162                        .collect();
163
164                    (event_name.clone(), serde_json::json!(matchers_with_ids))
165                })
166                .collect();
167
168            Some(hooks_json)
169        } else {
170            None
171        };
172
173        // Send initialize request
174        let init_request = SDKControlRequest::Initialize(SDKControlInitializeRequest {
175            subtype: "initialize".to_string(),
176            hooks: hooks_with_ids,
177        });
178
179        // Send control request and save result
180        let result = self.send_control_request(init_request).await?;
181        self.initialization_result = Some(result);
182
183        debug!("Initialization request sent with hook callback IDs");
184        Ok(())
185    }
186
187    /// Send a control request and wait for response
188    async fn send_control_request(&mut self, request: SDKControlRequest) -> Result<JsonValue> {
189        // Generate unique request ID
190        let request_id = {
191            let mut counter = self.request_counter.lock().await;
192            *counter += 1;
193            format!("req_{}_{}", *counter, uuid::Uuid::new_v4().simple())
194        };
195
196        // Create oneshot channel for response
197        let (tx, rx) = tokio::sync::oneshot::channel();
198
199        // Register pending response
200        {
201            let mut pending = self.pending_responses.write().await;
202            pending.insert(request_id.clone(), tx);
203        }
204
205        // Build control request with request_id (snake_case for CLI compatibility)
206        let control_request = serde_json::json!({
207            "type": "control_request",
208            "request_id": request_id,
209            "request": request
210        });
211
212        debug!("Sending control request: {:?}", control_request);
213
214        // Send via transport
215        {
216            let mut transport = self.transport.lock().await;
217            transport.send_sdk_control_request(control_request).await?;
218        }
219
220        // Wait for response with timeout
221        match timeout(Duration::from_secs(60), rx).await {
222            Ok(Ok(response)) => {
223                debug!("Received control response for {}", request_id);
224
225                // Python parity: treat subtype=error as an error, and return only
226                // the payload from `response` (or legacy `data`) on success.
227                if response.get("subtype").and_then(|v| v.as_str()) == Some("error") {
228                    let msg = response
229                        .get("error")
230                        .and_then(|v| v.as_str())
231                        .unwrap_or("Unknown control request error");
232                    return Err(SdkError::ControlRequestError(msg.to_string()));
233                }
234
235                Ok(response
236                    .get("response")
237                    .or_else(|| response.get("data"))
238                    .cloned()
239                    .unwrap_or_else(|| serde_json::json!({})))
240            }
241            Ok(Err(_)) => Err(SdkError::ControlRequestError(
242                "Response channel closed".to_string(),
243            )),
244            Err(_) => {
245                // Clean up pending response
246                let mut pending = self.pending_responses.write().await;
247                pending.remove(&request_id);
248                Err(SdkError::Timeout { seconds: 60 })
249            }
250        }
251    }
252
253    /// Handle permission request
254    #[allow(dead_code)]
255    async fn handle_permission_request(&mut self, request: SDKControlPermissionRequest) -> Result<()> {
256        if let Some(ref can_use_tool) = self.can_use_tool {
257            let context = ToolPermissionContext {
258                signal: None,
259                suggestions: request.permission_suggestions.unwrap_or_default(),
260            };
261
262            let result = can_use_tool
263                .can_use_tool(&request.tool_name, &request.input, &context)
264                .await;
265
266            // Send response back (CLI expects: { allow: bool, input?, reason? })
267            let response = match result {
268                PermissionResult::Allow(allow) => {
269                    let mut obj = serde_json::json!({ "allow": true });
270                    if let Some(updated) = allow.updated_input {
271                        obj["input"] = updated;
272                    }
273                    obj
274                }
275                PermissionResult::Deny(deny) => {
276                    let mut obj = serde_json::json!({ "allow": false });
277                    if !deny.message.is_empty() {
278                        obj["reason"] = serde_json::json!(deny.message);
279                    }
280                    obj
281                }
282            };
283
284            // Send response back through transport
285            let mut transport = self.transport.lock().await;
286            transport.send_sdk_control_response(response).await?;
287            debug!("Permission response sent");
288        }
289        
290        Ok(())
291    }
292
293    /// Extract requestId from CLI message (supports both camelCase and snake_case)
294    fn extract_request_id(msg: &JsonValue) -> Option<JsonValue> {
295        msg.get("requestId")
296            .or_else(|| msg.get("request_id"))
297            .cloned()
298    }
299
300    /// Start control request handler task
301    async fn start_control_handler(&mut self) {
302        let transport = self.transport.clone();
303        let can_use_tool = self.can_use_tool.clone();
304        let hook_callbacks = self.hook_callbacks.clone();
305        let sdk_mcp_servers = self.sdk_mcp_servers.clone();
306        let pending_responses = self.pending_responses.clone();
307        
308        // Take ownership of the SDK control receiver to avoid holding locks
309        let sdk_control_rx = {
310            let mut transport_lock = transport.lock().await;
311            transport_lock.take_sdk_control_receiver()
312        }; // Lock released here
313        
314        if let Some(mut control_rx) = sdk_control_rx {
315            tokio::spawn(async move {
316                // Now we can receive control requests without holding any locks
317                let transport_for_control = transport;
318                let can_use_tool_clone = can_use_tool;
319                let hook_callbacks_clone = hook_callbacks;
320                let sdk_mcp_servers_clone = sdk_mcp_servers;
321                let pending_responses_clone = pending_responses;
322
323                loop {
324                    // Receive control request without holding lock
325                    let control_message = control_rx.recv().await;
326
327                    if let Some(control_message) = control_message {
328                        debug!("Received control message: {:?}", control_message);
329
330                        // Check if this is a control response (from CLI to SDK)
331                        if control_message.get("type").and_then(|v| v.as_str()) == Some("control_response") {
332                            // Expected shape: {"type":"control_response", "response": {"request_id": "...", ...}}
333                            if let Some(resp_obj) = control_message.get("response") {
334                                let request_id = resp_obj
335                                    .get("request_id")
336                                    .or_else(|| resp_obj.get("requestId"))
337                                    .and_then(|v| v.as_str());
338
339                                if let Some(request_id) = request_id {
340                                    let mut pending = pending_responses_clone.write().await;
341                                    if let Some(tx) = pending.remove(request_id) {
342                                        // Deliver the nested control response object; send_control_request will
343                                        // extract the `response` (or legacy `data`) payload for callers.
344                                        let _ = tx.send(resp_obj.clone());
345                                        debug!("Control response delivered for request_id: {}", request_id);
346                                    } else {
347                                        warn!("No pending request found for request_id: {}", request_id);
348                                    }
349                                } else {
350                                    warn!("Control response missing request_id: {:?}", control_message);
351                                }
352                            } else {
353                                warn!("Control response missing 'response' payload: {:?}", control_message);
354                            }
355                            continue;
356                        }
357
358                        // Parse and handle control requests (from CLI to SDK)
359                        // Check if this is a control_request with a nested request field
360                        let request_data = if control_message.get("type").and_then(|v| v.as_str()) == Some("control_request") {
361                            control_message.get("request").cloned().unwrap_or(control_message.clone())
362                        } else {
363                            control_message.clone()
364                        };
365
366                        if let Some(subtype) = request_data.get("subtype").and_then(|v| v.as_str()) {
367                            match subtype {
368                                "can_use_tool" => {
369                                    // Handle permission request
370                                    if let Ok(request) = serde_json::from_value::<SDKControlPermissionRequest>(request_data.clone()) {
371                                        // Handle with can_use_tool callback
372                                        if let Some(ref can_use_tool) = can_use_tool_clone {
373                                            let context = ToolPermissionContext {
374                                                signal: None,
375                                                suggestions: request.permission_suggestions.unwrap_or_default(),
376                                            };
377                                                
378                                            let result = can_use_tool
379                                                .can_use_tool(&request.tool_name, &request.input, &context)
380                                                .await;
381                                                
382                                            // CLI expects: {"allow": true, "input": ...} or {"allow": false, "reason": ...}
383                                            let permission_response = match result {
384                                                PermissionResult::Allow(allow) => {
385                                                    let mut resp = serde_json::json!({
386                                                        "allow": true,
387                                                    });
388                                                    if let Some(input) = allow.updated_input {
389                                                        resp["input"] = input;
390                                                    }
391                                                    if let Some(perms) = allow.updated_permissions {
392                                                        resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default();
393                                                    }
394                                                    resp
395                                                }
396                                                PermissionResult::Deny(deny) => {
397                                                    let mut resp = serde_json::json!({
398                                                        "allow": false,
399                                                    });
400                                                    if !deny.message.is_empty() {
401                                                        resp["reason"] = serde_json::json!(deny.message);
402                                                    }
403                                                    if deny.interrupt {
404                                                        resp["interrupt"] = serde_json::json!(true);
405                                                    }
406                                                    resp
407                                                }
408                                            };
409                                                
410                                            // Wrap response with proper structure
411                                            // CLI expects "subtype": "success" for all successful responses
412                                            let response = serde_json::json!({
413                                                "subtype": "success",
414                                                "request_id": Self::extract_request_id(&control_message),
415                                                "response": permission_response
416                                            });
417                                                
418                                            // Send response
419                                            let mut transport = transport_for_control.lock().await;
420                                            if let Err(e) = transport.send_sdk_control_response(response).await {
421                                                error!("Failed to send permission response: {}", e);
422                                            }
423                                        }
424                                    } else {
425                                        // Fallback for snake_case fields (tool_name, permission_suggestions)
426                                        if let Some(tool_name) = request_data.get("tool_name").and_then(|v| v.as_str())
427                                            && let Some(input_val) = request_data.get("input").cloned()
428                                                && let Some(ref can_use_tool) = can_use_tool_clone {
429                                                    // Try to parse permission suggestions (snake_case)
430                                                    let suggestions: Vec<PermissionUpdate> = request_data
431                                                        .get("permission_suggestions")
432                                                        .cloned()
433                                                        .and_then(|v| serde_json::from_value::<Vec<PermissionUpdate>>(v).ok())
434                                                        .unwrap_or_default();
435
436                                                    let context = ToolPermissionContext { signal: None, suggestions };
437                                                    let result = can_use_tool
438                                                        .can_use_tool(tool_name, &input_val, &context)
439                                                        .await;
440
441                                                    let permission_response = match result {
442                                                        PermissionResult::Allow(allow) => {
443                                                            let mut resp = serde_json::json!({ "allow": true });
444                                                            if let Some(input) = allow.updated_input { resp["input"] = input; }
445                                                            if let Some(perms) = allow.updated_permissions { resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default(); }
446                                                            resp
447                                                        }
448                                                        PermissionResult::Deny(deny) => {
449                                                            let mut resp = serde_json::json!({ "allow": false });
450                                                            if !deny.message.is_empty() { resp["reason"] = serde_json::json!(deny.message); }
451                                                            if deny.interrupt { resp["interrupt"] = serde_json::json!(true); }
452                                                            resp
453                                                        }
454                                                    };
455
456                                                    let response = serde_json::json!({
457                                                        "subtype": "success",
458                                                        "request_id": Self::extract_request_id(&control_message),
459                                                        "response": permission_response
460                                                    });
461                                                    let mut transport = transport_for_control.lock().await;
462                                                    if let Err(e) = transport.send_sdk_control_response(response).await {
463                                                        error!("Failed to send permission response (fallback): {}", e);
464                                                    }
465                                                }
466                                    }
467                                }
468                                "hook_callback" => {
469                                    // Handle hook callback with strongly-typed inputs/outputs
470                                    if let Ok(request) = serde_json::from_value::<SDKHookCallbackRequest>(request_data.clone()) {
471                                        let callbacks = hook_callbacks_clone.read().await;
472
473                                        if let Some(callback) = callbacks.get(&request.callback_id) {
474                                            let context = HookContext { signal: None };
475
476                                            // Try to deserialize input as HookInput
477                                            let hook_result = match serde_json::from_value::<crate::types::HookInput>(request.input.clone()) {
478                                                Ok(hook_input) => {
479                                                    // Call the hook with strongly-typed input
480                                                    callback
481                                                        .execute(&hook_input, request.tool_use_id.as_deref(), &context)
482                                                        .await
483                                                }
484                                                Err(parse_err) => {
485                                                    error!("Failed to parse hook input: {}", parse_err);
486                                                    // Return error using MessageParseError
487                                                    Err(crate::errors::SdkError::MessageParseError {
488                                                        error: format!("Invalid hook input: {parse_err}"),
489                                                        raw: request.input.to_string(),
490                                                    })
491                                                }
492                                            };
493
494                                            // Handle hook result
495                                            let response_json = match hook_result {
496                                                Ok(hook_output) => {
497                                                    // Serialize HookJSONOutput to JSON
498                                                    let output_value = serde_json::to_value(&hook_output)
499                                                        .unwrap_or_else(|e| {
500                                                            error!("Failed to serialize hook output: {}", e);
501                                                            serde_json::json!({})
502                                                        });
503
504                                                    serde_json::json!({
505                                                        "subtype": "success",
506                                                        "request_id": Self::extract_request_id(&control_message),
507                                                        "response": output_value
508                                                    })
509                                                }
510                                                Err(e) => {
511                                                    error!("Hook callback failed: {}", e);
512                                                    serde_json::json!({
513                                                        "subtype": "error",
514                                                        "request_id": Self::extract_request_id(&control_message),
515                                                        "error": e.to_string()
516                                                    })
517                                                }
518                                            };
519
520                                            let mut transport = transport_for_control.lock().await;
521                                            if let Err(e) = transport.send_sdk_control_response(response_json).await {
522                                                error!("Failed to send hook callback response: {}", e);
523                                            }
524                                        } else {
525                                            warn!("No hook callback found for ID: {}", request.callback_id);
526                                            // Send error response
527                                            let error_response = serde_json::json!({
528                                                "subtype": "error",
529                                                "request_id": Self::extract_request_id(&control_message),
530                                                "error": format!("No hook callback found for ID: {}", request.callback_id)
531                                            });
532                                            let mut transport = transport_for_control.lock().await;
533                                            if let Err(e) = transport.send_sdk_control_response(error_response).await {
534                                                error!("Failed to send error response: {}", e);
535                                            }
536                                        }
537                                    } else {
538                                        // Fallback for snake_case fields (callback_id, tool_use_id)
539                                        let callback_id = request_data.get("callback_id").and_then(|v| v.as_str());
540                                        let tool_use_id = request_data.get("tool_use_id").and_then(|v| v.as_str()).map(|s| s.to_string());
541                                        let input = request_data.get("input").cloned().unwrap_or(serde_json::json!({}));
542
543                                        if let Some(callback_id) = callback_id {
544                                            let callbacks = hook_callbacks_clone.read().await;
545                                            if let Some(callback) = callbacks.get(callback_id) {
546                                                let context = HookContext { signal: None };
547
548                                                // Try to parse as HookInput
549                                                let hook_result = match serde_json::from_value::<crate::types::HookInput>(input.clone()) {
550                                                    Ok(hook_input) => {
551                                                        callback
552                                                            .execute(&hook_input, tool_use_id.as_deref(), &context)
553                                                            .await
554                                                    }
555                                                    Err(parse_err) => {
556                                                        error!("Failed to parse hook input (fallback): {}", parse_err);
557                                                        Err(crate::errors::SdkError::MessageParseError {
558                                                            error: format!("Invalid hook input: {parse_err}"),
559                                                            raw: input.to_string(),
560                                                        })
561                                                    }
562                                                };
563
564                                                let response_json = match hook_result {
565                                                    Ok(hook_output) => {
566                                                        let output_value = serde_json::to_value(&hook_output)
567                                                            .unwrap_or_else(|e| {
568                                                                error!("Failed to serialize hook output (fallback): {}", e);
569                                                                serde_json::json!({})
570                                                            });
571
572                                                        serde_json::json!({
573                                                            "subtype": "success",
574                                                            "request_id": Self::extract_request_id(&control_message),
575                                                            "response": output_value
576                                                        })
577                                                    }
578                                                    Err(e) => {
579                                                        error!("Hook callback failed (fallback): {}", e);
580                                                        serde_json::json!({
581                                                            "subtype": "error",
582                                                            "request_id": Self::extract_request_id(&control_message),
583                                                            "error": e.to_string()
584                                                        })
585                                                    }
586                                                };
587
588                                                let mut transport = transport_for_control.lock().await;
589                                                if let Err(e) = transport.send_sdk_control_response(response_json).await {
590                                                    error!("Failed to send hook callback response (fallback): {}", e);
591                                                }
592                                            } else {
593                                                warn!("No hook callback found for ID: {}", callback_id);
594                                            }
595                                        } else {
596                                            warn!("Invalid hook_callback control message: missing callback_id");
597                                        }
598                                    }
599                                }
600                                "mcp_message" => {
601                                    // Handle MCP message
602                                    if let Some(server_name) = request_data.get("server_name").and_then(|v| v.as_str())
603                                        && let Some(message) = request_data.get("message") {
604                                            debug!("Processing MCP message for SDK server: {}", server_name);
605
606                                            // Check if we have an SDK server with this name
607                                            if let Some(server_arc) = sdk_mcp_servers_clone.get(server_name) {
608                                                // Try to downcast to SdkMcpServer
609                                                if let Some(sdk_server) = server_arc.downcast_ref::<crate::sdk_mcp::SdkMcpServer>() {
610                                                    // Call the SDK MCP server
611                                                    match sdk_server.handle_message(message.clone()).await {
612                                                        Ok(mcp_result) => {
613                                                            // Wrap response with proper structure
614                                                            let response = serde_json::json!({
615                                                                "subtype": "success",
616                                                                "request_id": Self::extract_request_id(&control_message),
617                                                                "response": {
618                                                                    "mcp_response": mcp_result
619                                                                }
620                                                            });
621
622                                                            let mut transport = transport_for_control.lock().await;
623                                                            if let Err(e) = transport.send_sdk_control_response(response).await {
624                                                                error!("Failed to send MCP response: {}", e);
625                                                            }
626                                                        }
627                                                        Err(e) => {
628                                                            error!("SDK MCP server error: {}", e);
629                                                            let error_response = serde_json::json!({
630                                                                "subtype": "error",
631                                                                "request_id": Self::extract_request_id(&control_message),
632                                                                "error": format!("MCP server error: {}", e)
633                                                            });
634
635                                                            let mut transport = transport_for_control.lock().await;
636                                                            if let Err(e) = transport.send_sdk_control_response(error_response).await {
637                                                                error!("Failed to send MCP error response: {}", e);
638                                                            }
639                                                        }
640                                                    }
641                                                } else {
642                                                    warn!("SDK server '{}' is not of type SdkMcpServer", server_name);
643                                                }
644                                            } else {
645                                                warn!("No SDK MCP server found with name: {}", server_name);
646                                                let error_response = serde_json::json!({
647                                                    "subtype": "error",
648                                                    "request_id": Self::extract_request_id(&control_message),
649                                                    "error": format!("Server '{}' not found", server_name)
650                                                });
651
652                                                let mut transport = transport_for_control.lock().await;
653                                                if let Err(e) = transport.send_sdk_control_response(error_response).await {
654                                                    error!("Failed to send MCP error response: {}", e);
655                                                }
656                                            }
657                                        }
658                                }
659                                _ => {
660                                    debug!("Unknown SDK control subtype: {}", subtype);
661                                }
662                            }
663                        }
664                    }
665                }
666            });
667        }
668    }
669
670    /// Stream input messages to the CLI stdin by converting JSON values to InputMessage
671    #[allow(dead_code)]
672    pub async fn stream_input<S>(&mut self, input_stream: S) -> Result<()>
673    where
674        S: Stream<Item = JsonValue> + Send + 'static,
675    {
676        let transport = self.transport.clone();
677
678        tokio::spawn(async move {
679            use futures::StreamExt;
680            let mut stream = Box::pin(input_stream);
681
682            while let Some(value) = stream.next().await {
683                // Best-effort conversion from arbitrary JSON to InputMessage
684                let input_msg_opt = Self::json_to_input_message(value);
685                if let Some(input_msg) = input_msg_opt {
686                    let mut guard = transport.lock().await;
687                    if let Err(e) = guard.send_message(input_msg).await {
688                        warn!("Failed to send streaming input message: {}", e);
689                    }
690                } else {
691                    warn!("Invalid streaming input JSON; expected user message shape");
692                }
693            }
694
695            // After streaming all inputs, signal end of input
696            let mut guard = transport.lock().await;
697            if let Err(e) = guard.end_input().await {
698                warn!("Failed to signal end_input: {}", e);
699            }
700        });
701        Ok(())
702    }
703
704    /// Receive messages
705    #[allow(dead_code)]
706    pub async fn receive_messages(&mut self) -> mpsc::Receiver<Result<Message>> {
707        self.message_rx.take().expect("Receiver already taken")
708    }
709
710    /// Send interrupt request
711    pub async fn interrupt(&mut self) -> Result<()> {
712        let interrupt_request = SDKControlRequest::Interrupt(SDKControlInterruptRequest {
713            subtype: "interrupt".to_string(),
714        });
715
716        self.send_control_request(interrupt_request).await?;
717        Ok(())
718    }
719
720    /// Set permission mode via control protocol
721    #[allow(dead_code)]
722    pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
723        let req = SDKControlRequest::SetPermissionMode(SDKControlSetPermissionModeRequest {
724            subtype: "set_permission_mode".to_string(),
725            mode: mode.to_string(),
726        });
727        // Ignore response payload; errors propagate
728        let _ = self.send_control_request(req).await?;
729        Ok(())
730    }
731
732    /// Set the active model via control protocol
733    #[allow(dead_code)]
734    pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
735        let req = SDKControlRequest::SetModel(crate::types::SDKControlSetModelRequest {
736            subtype: "set_model".to_string(),
737            model,
738        });
739        let _ = self.send_control_request(req).await?;
740        Ok(())
741    }
742
743    /// Rewind tracked files to their state at a specific user message
744    ///
745    /// Requires `enable_file_checkpointing` to be enabled in `ClaudeCodeOptions`.
746    ///
747    /// # Arguments
748    ///
749    /// * `user_message_id` - UUID of the user message to rewind to
750    ///
751    /// # Example
752    ///
753    /// ```rust,no_run
754    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
755    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
756    /// let options = ClaudeCodeOptions::builder()
757    ///     .enable_file_checkpointing(true)
758    ///     .build();
759    /// let mut client = ClaudeSDKClient::new(options);
760    /// client.connect(None).await?;
761    ///
762    /// // Later, rewind to a checkpoint
763    /// // client.rewind_files("user-message-uuid-here").await?;
764    /// # Ok(())
765    /// # }
766    /// ```
767    pub async fn rewind_files(&mut self, user_message_id: &str) -> Result<()> {
768        let req = SDKControlRequest::RewindFiles(crate::types::SDKControlRewindFilesRequest::new(user_message_id));
769        let _ = self.send_control_request(req).await?;
770        Ok(())
771    }
772
773    /// Handle MCP message for SDK servers
774    #[allow(dead_code)]
775    async fn handle_mcp_message(&mut self, server_name: &str, message: &JsonValue) -> Result<JsonValue> {
776        // Check if we have an SDK server with this name
777        if let Some(_server) = self.sdk_mcp_servers.get(server_name) {
778            // TODO: Implement actual MCP server invocation
779            // For now, return a placeholder response
780            debug!("Handling MCP message for SDK server {}: {:?}", server_name, message);
781            Ok(serde_json::json!({
782                "jsonrpc": "2.0",
783                "id": message.get("id"),
784                "result": {
785                    "content": "MCP server response placeholder"
786                }
787            }))
788        } else {
789            Err(SdkError::InvalidState {
790                message: format!("No SDK MCP server found with name: {server_name}"),
791            })
792        }
793    }
794
795    /// Close the query handler
796    #[allow(dead_code)]
797    pub async fn close(&mut self) -> Result<()> {
798        // Clean up resources
799        let mut transport = self.transport.lock().await;
800        transport.disconnect().await?;
801        Ok(())
802    }
803
804    /// Get initialization result
805    pub fn get_initialization_result(&self) -> Option<&JsonValue> {
806        self.initialization_result.as_ref()
807    }
808
809    /// Convert arbitrary JSON value to InputMessage understood by CLI
810    #[allow(dead_code)]
811    fn json_to_input_message(v: JsonValue) -> Option<InputMessage> {
812        // 1) Already in SDK message shape
813        if let Some(obj) = v.as_object() {
814            if let (Some(t), Some(message)) = (obj.get("type"), obj.get("message"))
815                && t.as_str() == Some("user") {
816                    let parent = obj
817                        .get("parent_tool_use_id")
818                        .and_then(|p| p.as_str().map(|s| s.to_string()));
819                    let session_id = obj
820                        .get("session_id")
821                        .and_then(|s| s.as_str())
822                        .unwrap_or("default")
823                        .to_string();
824
825                    let im = InputMessage {
826                        r#type: "user".to_string(),
827                        message: message.clone(),
828                        parent_tool_use_id: parent,
829                        session_id,
830                    };
831                    return Some(im);
832                }
833
834            // 2) Simple wrapper: {"content":"...", "session_id":"..."}
835            if let Some(content) = obj.get("content").and_then(|c| c.as_str()) {
836                let session_id = obj
837                    .get("session_id")
838                    .and_then(|s| s.as_str())
839                    .unwrap_or("default")
840                    .to_string();
841                return Some(InputMessage::user(content.to_string(), session_id));
842            }
843        }
844
845        // 3) Bare string
846        if let Some(s) = v.as_str() {
847            return Some(InputMessage::user(s.to_string(), "default".to_string()));
848        }
849
850        None
851    }
852}
853
854#[cfg(test)]
855mod tests {
856    use super::*;
857
858    #[test]
859    fn test_extract_request_id_supports_both_cases() {
860        let snake = serde_json::json!({"request_id": "req_1"});
861        let camel = serde_json::json!({"requestId": "req_2"});
862        assert_eq!(Query::extract_request_id(&snake), Some(serde_json::json!("req_1")));
863        assert_eq!(Query::extract_request_id(&camel), Some(serde_json::json!("req_2")));
864    }
865
866    #[test]
867    fn test_json_to_input_message_from_string() {
868        let v = serde_json::json!("Hello");
869        let im = Query::json_to_input_message(v).expect("should convert");
870        assert_eq!(im.r#type, "user");
871        assert_eq!(im.session_id, "default");
872        assert_eq!(im.message["content"].as_str().unwrap(), "Hello");
873    }
874
875    #[test]
876    fn test_json_to_input_message_from_object_content() {
877        let v = serde_json::json!({"content":"Ping","session_id":"s1"});
878        let im = Query::json_to_input_message(v).expect("should convert");
879        assert_eq!(im.session_id, "s1");
880        assert_eq!(im.message["content"].as_str().unwrap(), "Ping");
881    }
882
883    #[test]
884    fn test_json_to_input_message_full_user_shape() {
885        let v = serde_json::json!({
886            "type":"user",
887            "message": {"role":"user","content":"Hi"},
888            "session_id": "abc",
889            "parent_tool_use_id": null
890        });
891        let im = Query::json_to_input_message(v).expect("should convert");
892        assert_eq!(im.session_id, "abc");
893        assert_eq!(im.message["role"].as_str().unwrap(), "user");
894        assert_eq!(im.message["content"].as_str().unwrap(), "Hi");
895    }
896}