cc_sdk/internal_query.rs
1//! Internal query implementation with control protocol support
2//!
3//! This module provides the internal Query struct that handles control protocol,
4//! permissions, hooks, and MCP server integration.
5
6use crate::{
7 errors::{Result, SdkError},
8 transport::{InputMessage, Transport},
9 types::{
10 CanUseTool, HookCallback, HookContext, HookMatcher, Message, PermissionResult, PermissionUpdate,
11 SDKControlInitializeRequest, SDKControlInterruptRequest, SDKControlPermissionRequest,
12 SDKControlRequest, SDKHookCallbackRequest, SDKControlSetPermissionModeRequest,
13 ToolPermissionContext,
14 },
15};
16use futures::stream::Stream;
17use futures::StreamExt;
18use serde_json::Value as JsonValue;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use tokio::time::{timeout, Duration};
23use tracing::{debug, error, warn};
24
25/// Internal query handler with control protocol support
26pub struct Query {
27 /// Transport layer (shared with client)
28 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
29 /// Whether in streaming mode
30 #[allow(dead_code)]
31 is_streaming_mode: bool,
32 /// Tool permission callback
33 can_use_tool: Option<Arc<dyn CanUseTool>>,
34 /// Hook configurations
35 hooks: Option<HashMap<String, Vec<HookMatcher>>>,
36 /// SDK MCP servers
37 sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
38 /// Message channel sender (reserved for future streaming receive support)
39 #[allow(dead_code)]
40 message_tx: mpsc::Sender<Result<Message>>,
41 /// Message channel receiver (reserved for future streaming receive support)
42 #[allow(dead_code)]
43 message_rx: Option<mpsc::Receiver<Result<Message>>>,
44 /// Initialization result
45 initialization_result: Option<JsonValue>,
46 /// Active hook callbacks
47 hook_callbacks: Arc<RwLock<HashMap<String, Arc<dyn HookCallback>>>>,
48 /// Hook callback counter
49 callback_counter: Arc<Mutex<u64>>,
50 /// Request counter for generating unique IDs
51 request_counter: Arc<Mutex<u64>>,
52 /// Pending control request responses
53 pending_responses: Arc<RwLock<HashMap<String, tokio::sync::oneshot::Sender<JsonValue>>>>,
54}
55
56impl Query {
57 /// Create a new Query handler
58 pub fn new(
59 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
60 is_streaming_mode: bool,
61 can_use_tool: Option<Arc<dyn CanUseTool>>,
62 hooks: Option<HashMap<String, Vec<HookMatcher>>>,
63 sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
64 ) -> Self {
65 let (tx, rx) = mpsc::channel(100);
66
67 Self {
68 transport,
69 is_streaming_mode,
70 can_use_tool,
71 hooks,
72 sdk_mcp_servers,
73 message_tx: tx,
74 message_rx: Some(rx),
75 initialization_result: None,
76 hook_callbacks: Arc::new(RwLock::new(HashMap::new())),
77 callback_counter: Arc::new(Mutex::new(0)),
78 request_counter: Arc::new(Mutex::new(0)),
79 pending_responses: Arc::new(RwLock::new(HashMap::new())),
80 }
81 }
82
83 /// Test helper to register a hook callback with a known ID
84 ///
85 /// This is intended for E2E tests to inject a callback ID that can be
86 /// referenced by inbound `hook_callback` control messages.
87 pub async fn register_hook_callback_for_test(
88 &self,
89 callback_id: String,
90 callback: Arc<dyn HookCallback>,
91 ) {
92 let mut map = self.hook_callbacks.write().await;
93 map.insert(callback_id, callback);
94 }
95
96 /// Start the query handler
97 pub async fn start(&mut self) -> Result<()> {
98 // Start control request handler task
99 self.start_control_handler().await;
100
101 // Start SDK message forwarder task (route non-control messages to message_tx)
102 let transport = self.transport.clone();
103 let tx = self.message_tx.clone();
104 tokio::spawn(async move {
105 // Get message stream once and consume it continuously
106 let mut stream = {
107 let mut guard = transport.lock().await;
108 guard.receive_messages()
109 }; // Lock released immediately after getting stream
110
111 // Continuously consume from the same stream
112 while let Some(result) = stream.next().await {
113 match result {
114 Ok(msg) => {
115 if tx.send(Ok(msg)).await.is_err() {
116 break;
117 }
118 }
119 Err(e) => {
120 let _ = tx.send(Err(e)).await;
121 break;
122 }
123 }
124 }
125 });
126 Ok(())
127 }
128
129 /// Initialize the control protocol
130 pub async fn initialize(&mut self) -> Result<()> {
131 // Build hooks with callback IDs (Python SDK style)
132 let hooks_with_ids = if let Some(ref hooks) = self.hooks {
133 let mut counter = self.callback_counter.lock().await;
134 let mut callbacks_map = self.hook_callbacks.write().await;
135
136 let hooks_json: HashMap<String, JsonValue> = hooks
137 .iter()
138 .map(|(event_name, matchers)| {
139 let matchers_with_ids: Vec<JsonValue> = matchers
140 .iter()
141 .map(|matcher| {
142 // Generate callback IDs for each hook in this matcher
143 let callback_ids: Vec<String> = matcher
144 .hooks
145 .iter()
146 .map(|hook_callback| {
147 *counter += 1;
148 let callback_id = format!("hook_{}_{}", *counter, uuid::Uuid::new_v4().simple());
149
150 // Store the callback for later use
151 callbacks_map.insert(callback_id.clone(), hook_callback.clone());
152
153 callback_id
154 })
155 .collect();
156
157 serde_json::json!({
158 "matcher": matcher.matcher.clone(),
159 "hookCallbackIds": callback_ids
160 })
161 })
162 .collect();
163
164 (event_name.clone(), serde_json::json!(matchers_with_ids))
165 })
166 .collect();
167
168 Some(hooks_json)
169 } else {
170 None
171 };
172
173 // Send initialize request
174 let init_request = SDKControlRequest::Initialize(SDKControlInitializeRequest {
175 subtype: "initialize".to_string(),
176 hooks: hooks_with_ids,
177 });
178
179 // Send control request and save result
180 let result = self.send_control_request(init_request).await?;
181 self.initialization_result = Some(result);
182
183 debug!("Initialization request sent with hook callback IDs");
184 Ok(())
185 }
186
187 /// Send a control request and wait for response
188 async fn send_control_request(&mut self, request: SDKControlRequest) -> Result<JsonValue> {
189 // Generate unique request ID
190 let request_id = {
191 let mut counter = self.request_counter.lock().await;
192 *counter += 1;
193 format!("req_{}_{}", *counter, uuid::Uuid::new_v4().simple())
194 };
195
196 // Create oneshot channel for response
197 let (tx, rx) = tokio::sync::oneshot::channel();
198
199 // Register pending response
200 {
201 let mut pending = self.pending_responses.write().await;
202 pending.insert(request_id.clone(), tx);
203 }
204
205 // Build control request with request_id (snake_case for CLI compatibility)
206 let control_request = serde_json::json!({
207 "type": "control_request",
208 "request_id": request_id,
209 "request": request
210 });
211
212 debug!("Sending control request: {:?}", control_request);
213
214 // Send via transport
215 {
216 let mut transport = self.transport.lock().await;
217 transport.send_sdk_control_request(control_request).await?;
218 }
219
220 // Wait for response with timeout
221 match timeout(Duration::from_secs(60), rx).await {
222 Ok(Ok(response)) => {
223 debug!("Received control response for {}", request_id);
224
225 // Python parity: treat subtype=error as an error, and return only
226 // the payload from `response` (or legacy `data`) on success.
227 if response.get("subtype").and_then(|v| v.as_str()) == Some("error") {
228 let msg = response
229 .get("error")
230 .and_then(|v| v.as_str())
231 .unwrap_or("Unknown control request error");
232 return Err(SdkError::ControlRequestError(msg.to_string()));
233 }
234
235 Ok(response
236 .get("response")
237 .or_else(|| response.get("data"))
238 .cloned()
239 .unwrap_or_else(|| serde_json::json!({})))
240 }
241 Ok(Err(_)) => Err(SdkError::ControlRequestError(
242 "Response channel closed".to_string(),
243 )),
244 Err(_) => {
245 // Clean up pending response
246 let mut pending = self.pending_responses.write().await;
247 pending.remove(&request_id);
248 Err(SdkError::Timeout { seconds: 60 })
249 }
250 }
251 }
252
253 /// Handle permission request
254 #[allow(dead_code)]
255 async fn handle_permission_request(&mut self, request: SDKControlPermissionRequest) -> Result<()> {
256 if let Some(ref can_use_tool) = self.can_use_tool {
257 let context = ToolPermissionContext {
258 signal: None,
259 suggestions: request.permission_suggestions.unwrap_or_default(),
260 };
261
262 let result = can_use_tool
263 .can_use_tool(&request.tool_name, &request.input, &context)
264 .await;
265
266 // Send response back (CLI expects: { allow: bool, input?, reason? })
267 let response = match result {
268 PermissionResult::Allow(allow) => {
269 let mut obj = serde_json::json!({ "allow": true });
270 if let Some(updated) = allow.updated_input {
271 obj["input"] = updated;
272 }
273 obj
274 }
275 PermissionResult::Deny(deny) => {
276 let mut obj = serde_json::json!({ "allow": false });
277 if !deny.message.is_empty() {
278 obj["reason"] = serde_json::json!(deny.message);
279 }
280 obj
281 }
282 };
283
284 // Send response back through transport
285 let mut transport = self.transport.lock().await;
286 transport.send_sdk_control_response(response).await?;
287 debug!("Permission response sent");
288 }
289
290 Ok(())
291 }
292
293 /// Extract requestId from CLI message (supports both camelCase and snake_case)
294 fn extract_request_id(msg: &JsonValue) -> Option<JsonValue> {
295 msg.get("requestId")
296 .or_else(|| msg.get("request_id"))
297 .cloned()
298 }
299
300 /// Start control request handler task
301 async fn start_control_handler(&mut self) {
302 let transport = self.transport.clone();
303 let can_use_tool = self.can_use_tool.clone();
304 let hook_callbacks = self.hook_callbacks.clone();
305 let sdk_mcp_servers = self.sdk_mcp_servers.clone();
306 let pending_responses = self.pending_responses.clone();
307
308 // Take ownership of the SDK control receiver to avoid holding locks
309 let sdk_control_rx = {
310 let mut transport_lock = transport.lock().await;
311 transport_lock.take_sdk_control_receiver()
312 }; // Lock released here
313
314 if let Some(mut control_rx) = sdk_control_rx {
315 tokio::spawn(async move {
316 // Now we can receive control requests without holding any locks
317 let transport_for_control = transport;
318 let can_use_tool_clone = can_use_tool;
319 let hook_callbacks_clone = hook_callbacks;
320 let sdk_mcp_servers_clone = sdk_mcp_servers;
321 let pending_responses_clone = pending_responses;
322
323 loop {
324 // Receive control request without holding lock
325 let control_message = control_rx.recv().await;
326
327 if let Some(control_message) = control_message {
328 debug!("Received control message: {:?}", control_message);
329
330 // Check if this is a control response (from CLI to SDK)
331 if control_message.get("type").and_then(|v| v.as_str()) == Some("control_response") {
332 // Expected shape: {"type":"control_response", "response": {"request_id": "...", ...}}
333 if let Some(resp_obj) = control_message.get("response") {
334 let request_id = resp_obj
335 .get("request_id")
336 .or_else(|| resp_obj.get("requestId"))
337 .and_then(|v| v.as_str());
338
339 if let Some(request_id) = request_id {
340 let mut pending = pending_responses_clone.write().await;
341 if let Some(tx) = pending.remove(request_id) {
342 // Deliver the nested control response object; send_control_request will
343 // extract the `response` (or legacy `data`) payload for callers.
344 let _ = tx.send(resp_obj.clone());
345 debug!("Control response delivered for request_id: {}", request_id);
346 } else {
347 warn!("No pending request found for request_id: {}", request_id);
348 }
349 } else {
350 warn!("Control response missing request_id: {:?}", control_message);
351 }
352 } else {
353 warn!("Control response missing 'response' payload: {:?}", control_message);
354 }
355 continue;
356 }
357
358 // Parse and handle control requests (from CLI to SDK)
359 // Check if this is a control_request with a nested request field
360 let request_data = if control_message.get("type").and_then(|v| v.as_str()) == Some("control_request") {
361 control_message.get("request").cloned().unwrap_or(control_message.clone())
362 } else {
363 control_message.clone()
364 };
365
366 if let Some(subtype) = request_data.get("subtype").and_then(|v| v.as_str()) {
367 match subtype {
368 "can_use_tool" => {
369 // Handle permission request
370 if let Ok(request) = serde_json::from_value::<SDKControlPermissionRequest>(request_data.clone()) {
371 // Handle with can_use_tool callback
372 if let Some(ref can_use_tool) = can_use_tool_clone {
373 let context = ToolPermissionContext {
374 signal: None,
375 suggestions: request.permission_suggestions.unwrap_or_default(),
376 };
377
378 let result = can_use_tool
379 .can_use_tool(&request.tool_name, &request.input, &context)
380 .await;
381
382 // CLI expects: {"allow": true, "input": ...} or {"allow": false, "reason": ...}
383 let permission_response = match result {
384 PermissionResult::Allow(allow) => {
385 let mut resp = serde_json::json!({
386 "allow": true,
387 });
388 if let Some(input) = allow.updated_input {
389 resp["input"] = input;
390 }
391 if let Some(perms) = allow.updated_permissions {
392 resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default();
393 }
394 resp
395 }
396 PermissionResult::Deny(deny) => {
397 let mut resp = serde_json::json!({
398 "allow": false,
399 });
400 if !deny.message.is_empty() {
401 resp["reason"] = serde_json::json!(deny.message);
402 }
403 if deny.interrupt {
404 resp["interrupt"] = serde_json::json!(true);
405 }
406 resp
407 }
408 };
409
410 // Wrap response with proper structure
411 // CLI expects "subtype": "success" for all successful responses
412 let response = serde_json::json!({
413 "subtype": "success",
414 "request_id": Self::extract_request_id(&control_message),
415 "response": permission_response
416 });
417
418 // Send response
419 let mut transport = transport_for_control.lock().await;
420 if let Err(e) = transport.send_sdk_control_response(response).await {
421 error!("Failed to send permission response: {}", e);
422 }
423 }
424 } else {
425 // Fallback for snake_case fields (tool_name, permission_suggestions)
426 if let Some(tool_name) = request_data.get("tool_name").and_then(|v| v.as_str())
427 && let Some(input_val) = request_data.get("input").cloned()
428 && let Some(ref can_use_tool) = can_use_tool_clone {
429 // Try to parse permission suggestions (snake_case)
430 let suggestions: Vec<PermissionUpdate> = request_data
431 .get("permission_suggestions")
432 .cloned()
433 .and_then(|v| serde_json::from_value::<Vec<PermissionUpdate>>(v).ok())
434 .unwrap_or_default();
435
436 let context = ToolPermissionContext { signal: None, suggestions };
437 let result = can_use_tool
438 .can_use_tool(tool_name, &input_val, &context)
439 .await;
440
441 let permission_response = match result {
442 PermissionResult::Allow(allow) => {
443 let mut resp = serde_json::json!({ "allow": true });
444 if let Some(input) = allow.updated_input { resp["input"] = input; }
445 if let Some(perms) = allow.updated_permissions { resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default(); }
446 resp
447 }
448 PermissionResult::Deny(deny) => {
449 let mut resp = serde_json::json!({ "allow": false });
450 if !deny.message.is_empty() { resp["reason"] = serde_json::json!(deny.message); }
451 if deny.interrupt { resp["interrupt"] = serde_json::json!(true); }
452 resp
453 }
454 };
455
456 let response = serde_json::json!({
457 "subtype": "success",
458 "request_id": Self::extract_request_id(&control_message),
459 "response": permission_response
460 });
461 let mut transport = transport_for_control.lock().await;
462 if let Err(e) = transport.send_sdk_control_response(response).await {
463 error!("Failed to send permission response (fallback): {}", e);
464 }
465 }
466 }
467 }
468 "hook_callback" => {
469 // Handle hook callback with strongly-typed inputs/outputs
470 if let Ok(request) = serde_json::from_value::<SDKHookCallbackRequest>(request_data.clone()) {
471 let callbacks = hook_callbacks_clone.read().await;
472
473 if let Some(callback) = callbacks.get(&request.callback_id) {
474 let context = HookContext { signal: None };
475
476 // Try to deserialize input as HookInput
477 let hook_result = match serde_json::from_value::<crate::types::HookInput>(request.input.clone()) {
478 Ok(hook_input) => {
479 // Call the hook with strongly-typed input
480 callback
481 .execute(&hook_input, request.tool_use_id.as_deref(), &context)
482 .await
483 }
484 Err(parse_err) => {
485 error!("Failed to parse hook input: {}", parse_err);
486 // Return error using MessageParseError
487 Err(crate::errors::SdkError::MessageParseError {
488 error: format!("Invalid hook input: {parse_err}"),
489 raw: request.input.to_string(),
490 })
491 }
492 };
493
494 // Handle hook result
495 let response_json = match hook_result {
496 Ok(hook_output) => {
497 // Serialize HookJSONOutput to JSON
498 let output_value = serde_json::to_value(&hook_output)
499 .unwrap_or_else(|e| {
500 error!("Failed to serialize hook output: {}", e);
501 serde_json::json!({})
502 });
503
504 serde_json::json!({
505 "subtype": "success",
506 "request_id": Self::extract_request_id(&control_message),
507 "response": output_value
508 })
509 }
510 Err(e) => {
511 error!("Hook callback failed: {}", e);
512 serde_json::json!({
513 "subtype": "error",
514 "request_id": Self::extract_request_id(&control_message),
515 "error": e.to_string()
516 })
517 }
518 };
519
520 let mut transport = transport_for_control.lock().await;
521 if let Err(e) = transport.send_sdk_control_response(response_json).await {
522 error!("Failed to send hook callback response: {}", e);
523 }
524 } else {
525 warn!("No hook callback found for ID: {}", request.callback_id);
526 // Send error response
527 let error_response = serde_json::json!({
528 "subtype": "error",
529 "request_id": Self::extract_request_id(&control_message),
530 "error": format!("No hook callback found for ID: {}", request.callback_id)
531 });
532 let mut transport = transport_for_control.lock().await;
533 if let Err(e) = transport.send_sdk_control_response(error_response).await {
534 error!("Failed to send error response: {}", e);
535 }
536 }
537 } else {
538 // Fallback for snake_case fields (callback_id, tool_use_id)
539 let callback_id = request_data.get("callback_id").and_then(|v| v.as_str());
540 let tool_use_id = request_data.get("tool_use_id").and_then(|v| v.as_str()).map(|s| s.to_string());
541 let input = request_data.get("input").cloned().unwrap_or(serde_json::json!({}));
542
543 if let Some(callback_id) = callback_id {
544 let callbacks = hook_callbacks_clone.read().await;
545 if let Some(callback) = callbacks.get(callback_id) {
546 let context = HookContext { signal: None };
547
548 // Try to parse as HookInput
549 let hook_result = match serde_json::from_value::<crate::types::HookInput>(input.clone()) {
550 Ok(hook_input) => {
551 callback
552 .execute(&hook_input, tool_use_id.as_deref(), &context)
553 .await
554 }
555 Err(parse_err) => {
556 error!("Failed to parse hook input (fallback): {}", parse_err);
557 Err(crate::errors::SdkError::MessageParseError {
558 error: format!("Invalid hook input: {parse_err}"),
559 raw: input.to_string(),
560 })
561 }
562 };
563
564 let response_json = match hook_result {
565 Ok(hook_output) => {
566 let output_value = serde_json::to_value(&hook_output)
567 .unwrap_or_else(|e| {
568 error!("Failed to serialize hook output (fallback): {}", e);
569 serde_json::json!({})
570 });
571
572 serde_json::json!({
573 "subtype": "success",
574 "request_id": Self::extract_request_id(&control_message),
575 "response": output_value
576 })
577 }
578 Err(e) => {
579 error!("Hook callback failed (fallback): {}", e);
580 serde_json::json!({
581 "subtype": "error",
582 "request_id": Self::extract_request_id(&control_message),
583 "error": e.to_string()
584 })
585 }
586 };
587
588 let mut transport = transport_for_control.lock().await;
589 if let Err(e) = transport.send_sdk_control_response(response_json).await {
590 error!("Failed to send hook callback response (fallback): {}", e);
591 }
592 } else {
593 warn!("No hook callback found for ID: {}", callback_id);
594 }
595 } else {
596 warn!("Invalid hook_callback control message: missing callback_id");
597 }
598 }
599 }
600 "mcp_message" => {
601 // Handle MCP message
602 if let Some(server_name) = request_data.get("server_name").and_then(|v| v.as_str())
603 && let Some(message) = request_data.get("message") {
604 debug!("Processing MCP message for SDK server: {}", server_name);
605
606 // Check if we have an SDK server with this name
607 if let Some(server_arc) = sdk_mcp_servers_clone.get(server_name) {
608 // Try to downcast to SdkMcpServer
609 if let Some(sdk_server) = server_arc.downcast_ref::<crate::sdk_mcp::SdkMcpServer>() {
610 // Call the SDK MCP server
611 match sdk_server.handle_message(message.clone()).await {
612 Ok(mcp_result) => {
613 // Wrap response with proper structure
614 let response = serde_json::json!({
615 "subtype": "success",
616 "request_id": Self::extract_request_id(&control_message),
617 "response": {
618 "mcp_response": mcp_result
619 }
620 });
621
622 let mut transport = transport_for_control.lock().await;
623 if let Err(e) = transport.send_sdk_control_response(response).await {
624 error!("Failed to send MCP response: {}", e);
625 }
626 }
627 Err(e) => {
628 error!("SDK MCP server error: {}", e);
629 let error_response = serde_json::json!({
630 "subtype": "error",
631 "request_id": Self::extract_request_id(&control_message),
632 "error": format!("MCP server error: {}", e)
633 });
634
635 let mut transport = transport_for_control.lock().await;
636 if let Err(e) = transport.send_sdk_control_response(error_response).await {
637 error!("Failed to send MCP error response: {}", e);
638 }
639 }
640 }
641 } else {
642 warn!("SDK server '{}' is not of type SdkMcpServer", server_name);
643 }
644 } else {
645 warn!("No SDK MCP server found with name: {}", server_name);
646 let error_response = serde_json::json!({
647 "subtype": "error",
648 "request_id": Self::extract_request_id(&control_message),
649 "error": format!("Server '{}' not found", server_name)
650 });
651
652 let mut transport = transport_for_control.lock().await;
653 if let Err(e) = transport.send_sdk_control_response(error_response).await {
654 error!("Failed to send MCP error response: {}", e);
655 }
656 }
657 }
658 }
659 _ => {
660 debug!("Unknown SDK control subtype: {}", subtype);
661 }
662 }
663 }
664 }
665 }
666 });
667 }
668 }
669
670 /// Stream input messages to the CLI stdin by converting JSON values to InputMessage
671 #[allow(dead_code)]
672 pub async fn stream_input<S>(&mut self, input_stream: S) -> Result<()>
673 where
674 S: Stream<Item = JsonValue> + Send + 'static,
675 {
676 let transport = self.transport.clone();
677
678 tokio::spawn(async move {
679 use futures::StreamExt;
680 let mut stream = Box::pin(input_stream);
681
682 while let Some(value) = stream.next().await {
683 // Best-effort conversion from arbitrary JSON to InputMessage
684 let input_msg_opt = Self::json_to_input_message(value);
685 if let Some(input_msg) = input_msg_opt {
686 let mut guard = transport.lock().await;
687 if let Err(e) = guard.send_message(input_msg).await {
688 warn!("Failed to send streaming input message: {}", e);
689 }
690 } else {
691 warn!("Invalid streaming input JSON; expected user message shape");
692 }
693 }
694
695 // After streaming all inputs, signal end of input
696 let mut guard = transport.lock().await;
697 if let Err(e) = guard.end_input().await {
698 warn!("Failed to signal end_input: {}", e);
699 }
700 });
701 Ok(())
702 }
703
704 /// Receive messages
705 #[allow(dead_code)]
706 pub async fn receive_messages(&mut self) -> mpsc::Receiver<Result<Message>> {
707 self.message_rx.take().expect("Receiver already taken")
708 }
709
710 /// Send interrupt request
711 pub async fn interrupt(&mut self) -> Result<()> {
712 let interrupt_request = SDKControlRequest::Interrupt(SDKControlInterruptRequest {
713 subtype: "interrupt".to_string(),
714 });
715
716 self.send_control_request(interrupt_request).await?;
717 Ok(())
718 }
719
720 /// Set permission mode via control protocol
721 #[allow(dead_code)]
722 pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
723 let req = SDKControlRequest::SetPermissionMode(SDKControlSetPermissionModeRequest {
724 subtype: "set_permission_mode".to_string(),
725 mode: mode.to_string(),
726 });
727 // Ignore response payload; errors propagate
728 let _ = self.send_control_request(req).await?;
729 Ok(())
730 }
731
732 /// Set the active model via control protocol
733 #[allow(dead_code)]
734 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
735 let req = SDKControlRequest::SetModel(crate::types::SDKControlSetModelRequest {
736 subtype: "set_model".to_string(),
737 model,
738 });
739 let _ = self.send_control_request(req).await?;
740 Ok(())
741 }
742
743 /// Rewind tracked files to their state at a specific user message
744 ///
745 /// Requires `enable_file_checkpointing` to be enabled in `ClaudeCodeOptions`.
746 ///
747 /// # Arguments
748 ///
749 /// * `user_message_id` - UUID of the user message to rewind to
750 ///
751 /// # Example
752 ///
753 /// ```rust,no_run
754 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
755 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
756 /// let options = ClaudeCodeOptions::builder()
757 /// .enable_file_checkpointing(true)
758 /// .build();
759 /// let mut client = ClaudeSDKClient::new(options);
760 /// client.connect(None).await?;
761 ///
762 /// // Later, rewind to a checkpoint
763 /// // client.rewind_files("user-message-uuid-here").await?;
764 /// # Ok(())
765 /// # }
766 /// ```
767 pub async fn rewind_files(&mut self, user_message_id: &str) -> Result<()> {
768 let req = SDKControlRequest::RewindFiles(crate::types::SDKControlRewindFilesRequest::new(user_message_id));
769 let _ = self.send_control_request(req).await?;
770 Ok(())
771 }
772
773 /// Handle MCP message for SDK servers
774 #[allow(dead_code)]
775 async fn handle_mcp_message(&mut self, server_name: &str, message: &JsonValue) -> Result<JsonValue> {
776 // Check if we have an SDK server with this name
777 if let Some(_server) = self.sdk_mcp_servers.get(server_name) {
778 // TODO: Implement actual MCP server invocation
779 // For now, return a placeholder response
780 debug!("Handling MCP message for SDK server {}: {:?}", server_name, message);
781 Ok(serde_json::json!({
782 "jsonrpc": "2.0",
783 "id": message.get("id"),
784 "result": {
785 "content": "MCP server response placeholder"
786 }
787 }))
788 } else {
789 Err(SdkError::InvalidState {
790 message: format!("No SDK MCP server found with name: {server_name}"),
791 })
792 }
793 }
794
795 /// Close the query handler
796 #[allow(dead_code)]
797 pub async fn close(&mut self) -> Result<()> {
798 // Clean up resources
799 let mut transport = self.transport.lock().await;
800 transport.disconnect().await?;
801 Ok(())
802 }
803
804 /// Get initialization result
805 pub fn get_initialization_result(&self) -> Option<&JsonValue> {
806 self.initialization_result.as_ref()
807 }
808
809 /// Convert arbitrary JSON value to InputMessage understood by CLI
810 #[allow(dead_code)]
811 fn json_to_input_message(v: JsonValue) -> Option<InputMessage> {
812 // 1) Already in SDK message shape
813 if let Some(obj) = v.as_object() {
814 if let (Some(t), Some(message)) = (obj.get("type"), obj.get("message"))
815 && t.as_str() == Some("user") {
816 let parent = obj
817 .get("parent_tool_use_id")
818 .and_then(|p| p.as_str().map(|s| s.to_string()));
819 let session_id = obj
820 .get("session_id")
821 .and_then(|s| s.as_str())
822 .unwrap_or("default")
823 .to_string();
824
825 let im = InputMessage {
826 r#type: "user".to_string(),
827 message: message.clone(),
828 parent_tool_use_id: parent,
829 session_id,
830 };
831 return Some(im);
832 }
833
834 // 2) Simple wrapper: {"content":"...", "session_id":"..."}
835 if let Some(content) = obj.get("content").and_then(|c| c.as_str()) {
836 let session_id = obj
837 .get("session_id")
838 .and_then(|s| s.as_str())
839 .unwrap_or("default")
840 .to_string();
841 return Some(InputMessage::user(content.to_string(), session_id));
842 }
843 }
844
845 // 3) Bare string
846 if let Some(s) = v.as_str() {
847 return Some(InputMessage::user(s.to_string(), "default".to_string()));
848 }
849
850 None
851 }
852}
853
854#[cfg(test)]
855mod tests {
856 use super::*;
857
858 #[test]
859 fn test_extract_request_id_supports_both_cases() {
860 let snake = serde_json::json!({"request_id": "req_1"});
861 let camel = serde_json::json!({"requestId": "req_2"});
862 assert_eq!(Query::extract_request_id(&snake), Some(serde_json::json!("req_1")));
863 assert_eq!(Query::extract_request_id(&camel), Some(serde_json::json!("req_2")));
864 }
865
866 #[test]
867 fn test_json_to_input_message_from_string() {
868 let v = serde_json::json!("Hello");
869 let im = Query::json_to_input_message(v).expect("should convert");
870 assert_eq!(im.r#type, "user");
871 assert_eq!(im.session_id, "default");
872 assert_eq!(im.message["content"].as_str().unwrap(), "Hello");
873 }
874
875 #[test]
876 fn test_json_to_input_message_from_object_content() {
877 let v = serde_json::json!({"content":"Ping","session_id":"s1"});
878 let im = Query::json_to_input_message(v).expect("should convert");
879 assert_eq!(im.session_id, "s1");
880 assert_eq!(im.message["content"].as_str().unwrap(), "Ping");
881 }
882
883 #[test]
884 fn test_json_to_input_message_full_user_shape() {
885 let v = serde_json::json!({
886 "type":"user",
887 "message": {"role":"user","content":"Hi"},
888 "session_id": "abc",
889 "parent_tool_use_id": null
890 });
891 let im = Query::json_to_input_message(v).expect("should convert");
892 assert_eq!(im.session_id, "abc");
893 assert_eq!(im.message["role"].as_str().unwrap(), "user");
894 assert_eq!(im.message["content"].as_str().unwrap(), "Hi");
895 }
896}