cc_sdk/internal_query.rs
1//! Internal query implementation with control protocol support
2//!
3//! This module provides the internal Query struct that handles control protocol,
4//! permissions, hooks, and MCP server integration.
5
6use crate::{
7 errors::{Result, SdkError},
8 transport::{InputMessage, Transport},
9 types::{
10 CanUseTool, HookCallback, HookContext, HookMatcher, Message, PermissionResult, PermissionUpdate,
11 SDKControlInitializeRequest, SDKControlInterruptRequest, SDKControlPermissionRequest,
12 SDKControlRequest, SDKHookCallbackRequest, SDKControlSetPermissionModeRequest,
13 ToolPermissionContext,
14 },
15};
16use futures::stream::Stream;
17use futures::StreamExt;
18use serde_json::Value as JsonValue;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use tokio::time::{timeout, Duration};
23use tracing::{debug, error, warn};
24
25/// Internal query handler with control protocol support
26pub struct Query {
27 /// Transport layer (shared with client)
28 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
29 /// Whether in streaming mode
30 #[allow(dead_code)]
31 is_streaming_mode: bool,
32 /// Tool permission callback
33 can_use_tool: Option<Arc<dyn CanUseTool>>,
34 /// Hook configurations
35 hooks: Option<HashMap<String, Vec<HookMatcher>>>,
36 /// SDK MCP servers
37 sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
38 /// Message channel sender (reserved for future streaming receive support)
39 #[allow(dead_code)]
40 message_tx: mpsc::Sender<Result<Message>>,
41 /// Message channel receiver (reserved for future streaming receive support)
42 #[allow(dead_code)]
43 message_rx: Option<mpsc::Receiver<Result<Message>>>,
44 /// Initialization result
45 initialization_result: Option<JsonValue>,
46 /// Active hook callbacks
47 hook_callbacks: Arc<RwLock<HashMap<String, Arc<dyn HookCallback>>>>,
48 /// Hook callback counter
49 callback_counter: Arc<Mutex<u64>>,
50 /// Request counter for generating unique IDs
51 request_counter: Arc<Mutex<u64>>,
52 /// Pending control request responses
53 pending_responses: Arc<RwLock<HashMap<String, tokio::sync::oneshot::Sender<JsonValue>>>>,
54}
55
56impl Query {
57 /// Create a new Query handler
58 pub fn new(
59 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
60 is_streaming_mode: bool,
61 can_use_tool: Option<Arc<dyn CanUseTool>>,
62 hooks: Option<HashMap<String, Vec<HookMatcher>>>,
63 sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
64 ) -> Self {
65 let (tx, rx) = mpsc::channel(100);
66
67 Self {
68 transport,
69 is_streaming_mode,
70 can_use_tool,
71 hooks,
72 sdk_mcp_servers,
73 message_tx: tx,
74 message_rx: Some(rx),
75 initialization_result: None,
76 hook_callbacks: Arc::new(RwLock::new(HashMap::new())),
77 callback_counter: Arc::new(Mutex::new(0)),
78 request_counter: Arc::new(Mutex::new(0)),
79 pending_responses: Arc::new(RwLock::new(HashMap::new())),
80 }
81 }
82
83 /// Test helper to register a hook callback with a known ID
84 ///
85 /// This is intended for E2E tests to inject a callback ID that can be
86 /// referenced by inbound `hook_callback` control messages.
87 pub async fn register_hook_callback_for_test(
88 &self,
89 callback_id: String,
90 callback: Arc<dyn HookCallback>,
91 ) {
92 let mut map = self.hook_callbacks.write().await;
93 map.insert(callback_id, callback);
94 }
95
96 /// Start the query handler
97 pub async fn start(&mut self) -> Result<()> {
98 // Start control request handler task
99 self.start_control_handler().await;
100
101 // Start SDK message forwarder task (route non-control messages to message_tx)
102 let transport = self.transport.clone();
103 let tx = self.message_tx.clone();
104 tokio::spawn(async move {
105 // Get message stream once and consume it continuously
106 let mut stream = {
107 let mut guard = transport.lock().await;
108 guard.receive_messages()
109 }; // Lock released immediately after getting stream
110
111 // Continuously consume from the same stream
112 while let Some(result) = stream.next().await {
113 match result {
114 Ok(msg) => {
115 if tx.send(Ok(msg)).await.is_err() {
116 break;
117 }
118 }
119 Err(e) => {
120 let _ = tx.send(Err(e)).await;
121 break;
122 }
123 }
124 }
125 });
126 Ok(())
127 }
128
129 /// Initialize the control protocol
130 pub async fn initialize(&mut self) -> Result<()> {
131 // Build hooks with callback IDs (Python SDK style)
132 let hooks_with_ids = if let Some(ref hooks) = self.hooks {
133 let mut counter = self.callback_counter.lock().await;
134 let mut callbacks_map = self.hook_callbacks.write().await;
135
136 let hooks_json: HashMap<String, JsonValue> = hooks
137 .iter()
138 .map(|(event_name, matchers)| {
139 let matchers_with_ids: Vec<JsonValue> = matchers
140 .iter()
141 .map(|matcher| {
142 // Generate callback IDs for each hook in this matcher
143 let callback_ids: Vec<String> = matcher
144 .hooks
145 .iter()
146 .map(|hook_callback| {
147 *counter += 1;
148 let callback_id = format!("hook_{}_{}", *counter, uuid::Uuid::new_v4().simple());
149
150 // Store the callback for later use
151 callbacks_map.insert(callback_id.clone(), hook_callback.clone());
152
153 callback_id
154 })
155 .collect();
156
157 serde_json::json!({
158 "matcher": matcher.matcher.clone(),
159 "hookCallbackIds": callback_ids
160 })
161 })
162 .collect();
163
164 (event_name.clone(), serde_json::json!(matchers_with_ids))
165 })
166 .collect();
167
168 Some(hooks_json)
169 } else {
170 None
171 };
172
173 // Send initialize request
174 let init_request = SDKControlRequest::Initialize(SDKControlInitializeRequest {
175 subtype: "initialize".to_string(),
176 hooks: hooks_with_ids,
177 });
178
179 // Send control request and save result
180 let result = self.send_control_request(init_request).await?;
181 self.initialization_result = Some(result);
182
183 debug!("Initialization request sent with hook callback IDs");
184 Ok(())
185 }
186
187 /// Send a control request and wait for response
188 async fn send_control_request(&mut self, request: SDKControlRequest) -> Result<JsonValue> {
189 // Generate unique request ID
190 let request_id = {
191 let mut counter = self.request_counter.lock().await;
192 *counter += 1;
193 format!("req_{}_{}", *counter, uuid::Uuid::new_v4().simple())
194 };
195
196 // Create oneshot channel for response
197 let (tx, rx) = tokio::sync::oneshot::channel();
198
199 // Register pending response
200 {
201 let mut pending = self.pending_responses.write().await;
202 pending.insert(request_id.clone(), tx);
203 }
204
205 // Build control request with request_id (snake_case for CLI compatibility)
206 let control_request = serde_json::json!({
207 "type": "control_request",
208 "request_id": request_id,
209 "request": request
210 });
211
212 debug!("Sending control request: {:?}", control_request);
213
214 // Send via transport
215 {
216 let mut transport = self.transport.lock().await;
217 transport.send_sdk_control_request(control_request).await?;
218 }
219
220 // Wait for response with timeout
221 match timeout(Duration::from_secs(60), rx).await {
222 Ok(Ok(response)) => {
223 debug!("Received control response for {}", request_id);
224
225 // Python parity: treat subtype=error as an error, and return only
226 // the payload from `response` (or legacy `data`) on success.
227 if response.get("subtype").and_then(|v| v.as_str()) == Some("error") {
228 let msg = response
229 .get("error")
230 .and_then(|v| v.as_str())
231 .unwrap_or("Unknown control request error");
232 return Err(SdkError::ControlRequestError(msg.to_string()));
233 }
234
235 Ok(response
236 .get("response")
237 .or_else(|| response.get("data"))
238 .cloned()
239 .unwrap_or_else(|| serde_json::json!({})))
240 }
241 Ok(Err(_)) => Err(SdkError::ControlRequestError(
242 "Response channel closed".to_string(),
243 )),
244 Err(_) => {
245 // Clean up pending response
246 let mut pending = self.pending_responses.write().await;
247 pending.remove(&request_id);
248 Err(SdkError::Timeout { seconds: 60 })
249 }
250 }
251 }
252
253 /// Handle permission request
254 #[allow(dead_code)]
255 async fn handle_permission_request(&mut self, request: SDKControlPermissionRequest) -> Result<()> {
256 if let Some(ref can_use_tool) = self.can_use_tool {
257 let context = ToolPermissionContext {
258 signal: None,
259 suggestions: request.permission_suggestions.unwrap_or_default(),
260 };
261
262 let result = can_use_tool
263 .can_use_tool(&request.tool_name, &request.input, &context)
264 .await;
265
266 // Send response back (CLI expects: { allow: bool, input?, reason? })
267 let response = match result {
268 PermissionResult::Allow(allow) => {
269 let mut obj = serde_json::json!({ "allow": true });
270 if let Some(updated) = allow.updated_input {
271 obj["input"] = updated;
272 }
273 obj
274 }
275 PermissionResult::Deny(deny) => {
276 let mut obj = serde_json::json!({ "allow": false });
277 if !deny.message.is_empty() {
278 obj["reason"] = serde_json::json!(deny.message);
279 }
280 obj
281 }
282 };
283
284 // Send response back through transport
285 let mut transport = self.transport.lock().await;
286 transport.send_sdk_control_response(response).await?;
287 debug!("Permission response sent");
288 }
289
290 Ok(())
291 }
292
293 /// Extract requestId from CLI message (supports both camelCase and snake_case)
294 fn extract_request_id(msg: &JsonValue) -> Option<JsonValue> {
295 msg.get("requestId")
296 .or_else(|| msg.get("request_id"))
297 .cloned()
298 }
299
300 /// Start control request handler task
301 async fn start_control_handler(&mut self) {
302 let transport = self.transport.clone();
303 let can_use_tool = self.can_use_tool.clone();
304 let hook_callbacks = self.hook_callbacks.clone();
305 let sdk_mcp_servers = self.sdk_mcp_servers.clone();
306 let pending_responses = self.pending_responses.clone();
307
308 // Take ownership of the SDK control receiver to avoid holding locks
309 let sdk_control_rx = {
310 let mut transport_lock = transport.lock().await;
311 transport_lock.take_sdk_control_receiver()
312 }; // Lock released here
313
314 if let Some(mut control_rx) = sdk_control_rx {
315 tokio::spawn(async move {
316 // Now we can receive control requests without holding any locks
317 let transport_for_control = transport;
318 let can_use_tool_clone = can_use_tool;
319 let hook_callbacks_clone = hook_callbacks;
320 let sdk_mcp_servers_clone = sdk_mcp_servers;
321 let pending_responses_clone = pending_responses;
322
323 loop {
324 // Receive control request without holding lock
325 let control_message = control_rx.recv().await;
326
327 // If channel closed (sender dropped), exit the loop
328 let Some(control_message) = control_message else {
329 debug!("Control channel closed, exiting control handler");
330 break;
331 };
332
333 debug!("Received control message: {:?}", control_message);
334
335 // Check if this is a control response (from CLI to SDK)
336 if control_message.get("type").and_then(|v| v.as_str()) == Some("control_response") {
337 // Expected shape: {"type":"control_response", "response": {"request_id": "...", ...}}
338 if let Some(resp_obj) = control_message.get("response") {
339 let request_id = resp_obj
340 .get("request_id")
341 .or_else(|| resp_obj.get("requestId"))
342 .and_then(|v| v.as_str());
343
344 if let Some(request_id) = request_id {
345 let mut pending = pending_responses_clone.write().await;
346 if let Some(tx) = pending.remove(request_id) {
347 // Deliver the nested control response object; send_control_request will
348 // extract the `response` (or legacy `data`) payload for callers.
349 let _ = tx.send(resp_obj.clone());
350 debug!("Control response delivered for request_id: {}", request_id);
351 } else {
352 warn!("No pending request found for request_id: {}", request_id);
353 }
354 } else {
355 warn!("Control response missing request_id: {:?}", control_message);
356 }
357 } else {
358 warn!("Control response missing 'response' payload: {:?}", control_message);
359 }
360 continue;
361 }
362
363 // Parse and handle control requests (from CLI to SDK)
364 // Check if this is a control_request with a nested request field
365 let request_data = if control_message.get("type").and_then(|v| v.as_str()) == Some("control_request") {
366 control_message.get("request").cloned().unwrap_or(control_message.clone())
367 } else {
368 control_message.clone()
369 };
370
371 if let Some(subtype) = request_data.get("subtype").and_then(|v| v.as_str()) {
372 match subtype {
373 "can_use_tool" => {
374 // Handle permission request
375 if let Ok(request) = serde_json::from_value::<SDKControlPermissionRequest>(request_data.clone()) {
376 // Handle with can_use_tool callback
377 if let Some(ref can_use_tool) = can_use_tool_clone {
378 let context = ToolPermissionContext {
379 signal: None,
380 suggestions: request.permission_suggestions.unwrap_or_default(),
381 };
382
383 // Save original input for fallback (Python SDK always sends updatedInput)
384 let original_input = request.input.clone();
385
386 let result = can_use_tool
387 .can_use_tool(&request.tool_name, &request.input, &context)
388 .await;
389
390 // Match Python SDK response format:
391 // Allow: {"behavior": "allow", "updatedInput": ..., "updatedPermissions": ...}
392 // Deny: {"behavior": "deny", "message": "...", "interrupt": false}
393 // NOTE: updatedInput is ALWAYS required for allow (CLI Zod schema expects it)
394 let permission_response = match result {
395 PermissionResult::Allow(allow) => {
396 let mut resp = serde_json::json!({
397 "behavior": "allow",
398 "updatedInput": allow.updated_input.unwrap_or(original_input),
399 });
400 if let Some(perms) = allow.updated_permissions {
401 resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default();
402 }
403 resp
404 }
405 PermissionResult::Deny(deny) => {
406 let mut resp = serde_json::json!({
407 "behavior": "deny",
408 });
409 if !deny.message.is_empty() {
410 resp["message"] = serde_json::json!(deny.message);
411 }
412 if deny.interrupt {
413 resp["interrupt"] = serde_json::json!(true);
414 }
415 resp
416 }
417 };
418
419 // Wrap response with proper structure
420 // CLI expects "subtype": "success" for all successful responses
421 let response = serde_json::json!({
422 "subtype": "success",
423 "request_id": Self::extract_request_id(&control_message),
424 "response": permission_response
425 });
426
427 // Send response
428 let mut transport = transport_for_control.lock().await;
429 if let Err(e) = transport.send_sdk_control_response(response).await {
430 error!("Failed to send permission response: {}", e);
431 }
432 }
433 } else {
434 // Fallback for snake_case fields (tool_name, permission_suggestions)
435 if let Some(tool_name) = request_data.get("tool_name").and_then(|v| v.as_str())
436 && let Some(input_val) = request_data.get("input").cloned()
437 && let Some(ref can_use_tool) = can_use_tool_clone {
438 // Try to parse permission suggestions (snake_case)
439 let suggestions: Vec<PermissionUpdate> = request_data
440 .get("permission_suggestions")
441 .cloned()
442 .and_then(|v| serde_json::from_value::<Vec<PermissionUpdate>>(v).ok())
443 .unwrap_or_default();
444
445 let context = ToolPermissionContext { signal: None, suggestions };
446 let original_input = input_val.clone();
447 let result = can_use_tool
448 .can_use_tool(tool_name, &input_val, &context)
449 .await;
450
451 let permission_response = match result {
452 PermissionResult::Allow(allow) => {
453 let mut resp = serde_json::json!({
454 "behavior": "allow",
455 "updatedInput": allow.updated_input.unwrap_or(original_input),
456 });
457 if let Some(perms) = allow.updated_permissions { resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default(); }
458 resp
459 }
460 PermissionResult::Deny(deny) => {
461 let mut resp = serde_json::json!({ "behavior": "deny" });
462 if !deny.message.is_empty() { resp["message"] = serde_json::json!(deny.message); }
463 if deny.interrupt { resp["interrupt"] = serde_json::json!(true); }
464 resp
465 }
466 };
467
468 let response = serde_json::json!({
469 "subtype": "success",
470 "request_id": Self::extract_request_id(&control_message),
471 "response": permission_response
472 });
473 let mut transport = transport_for_control.lock().await;
474 if let Err(e) = transport.send_sdk_control_response(response).await {
475 error!("Failed to send permission response (fallback): {}", e);
476 }
477 }
478 }
479 }
480 "hook_callback" => {
481 // Handle hook callback with strongly-typed inputs/outputs
482 if let Ok(request) = serde_json::from_value::<SDKHookCallbackRequest>(request_data.clone()) {
483 let callbacks = hook_callbacks_clone.read().await;
484
485 if let Some(callback) = callbacks.get(&request.callback_id) {
486 let context = HookContext { signal: None };
487
488 // Try to deserialize input as HookInput
489 let hook_result = match serde_json::from_value::<crate::types::HookInput>(request.input.clone()) {
490 Ok(hook_input) => {
491 // Call the hook with strongly-typed input
492 callback
493 .execute(&hook_input, request.tool_use_id.as_deref(), &context)
494 .await
495 }
496 Err(parse_err) => {
497 error!("Failed to parse hook input: {}", parse_err);
498 // Return error using MessageParseError
499 Err(crate::errors::SdkError::MessageParseError {
500 error: format!("Invalid hook input: {parse_err}"),
501 raw: request.input.to_string(),
502 })
503 }
504 };
505
506 // Handle hook result
507 let response_json = match hook_result {
508 Ok(hook_output) => {
509 // Serialize HookJSONOutput to JSON
510 let output_value = serde_json::to_value(&hook_output)
511 .unwrap_or_else(|e| {
512 error!("Failed to serialize hook output: {}", e);
513 serde_json::json!({})
514 });
515
516 serde_json::json!({
517 "subtype": "success",
518 "request_id": Self::extract_request_id(&control_message),
519 "response": output_value
520 })
521 }
522 Err(e) => {
523 error!("Hook callback failed: {}", e);
524 serde_json::json!({
525 "subtype": "error",
526 "request_id": Self::extract_request_id(&control_message),
527 "error": e.to_string()
528 })
529 }
530 };
531
532 let mut transport = transport_for_control.lock().await;
533 if let Err(e) = transport.send_sdk_control_response(response_json).await {
534 error!("Failed to send hook callback response: {}", e);
535 }
536 } else {
537 warn!("No hook callback found for ID: {}", request.callback_id);
538 // Send error response
539 let error_response = serde_json::json!({
540 "subtype": "error",
541 "request_id": Self::extract_request_id(&control_message),
542 "error": format!("No hook callback found for ID: {}", request.callback_id)
543 });
544 let mut transport = transport_for_control.lock().await;
545 if let Err(e) = transport.send_sdk_control_response(error_response).await {
546 error!("Failed to send error response: {}", e);
547 }
548 }
549 } else {
550 // Fallback for snake_case fields (callback_id, tool_use_id)
551 let callback_id = request_data.get("callback_id").and_then(|v| v.as_str());
552 let tool_use_id = request_data.get("tool_use_id").and_then(|v| v.as_str()).map(|s| s.to_string());
553 let input = request_data.get("input").cloned().unwrap_or(serde_json::json!({}));
554
555 if let Some(callback_id) = callback_id {
556 let callbacks = hook_callbacks_clone.read().await;
557 if let Some(callback) = callbacks.get(callback_id) {
558 let context = HookContext { signal: None };
559
560 // Try to parse as HookInput
561 let hook_result = match serde_json::from_value::<crate::types::HookInput>(input.clone()) {
562 Ok(hook_input) => {
563 callback
564 .execute(&hook_input, tool_use_id.as_deref(), &context)
565 .await
566 }
567 Err(parse_err) => {
568 error!("Failed to parse hook input (fallback): {}", parse_err);
569 Err(crate::errors::SdkError::MessageParseError {
570 error: format!("Invalid hook input: {parse_err}"),
571 raw: input.to_string(),
572 })
573 }
574 };
575
576 let response_json = match hook_result {
577 Ok(hook_output) => {
578 let output_value = serde_json::to_value(&hook_output)
579 .unwrap_or_else(|e| {
580 error!("Failed to serialize hook output (fallback): {}", e);
581 serde_json::json!({})
582 });
583
584 serde_json::json!({
585 "subtype": "success",
586 "request_id": Self::extract_request_id(&control_message),
587 "response": output_value
588 })
589 }
590 Err(e) => {
591 error!("Hook callback failed (fallback): {}", e);
592 serde_json::json!({
593 "subtype": "error",
594 "request_id": Self::extract_request_id(&control_message),
595 "error": e.to_string()
596 })
597 }
598 };
599
600 let mut transport = transport_for_control.lock().await;
601 if let Err(e) = transport.send_sdk_control_response(response_json).await {
602 error!("Failed to send hook callback response (fallback): {}", e);
603 }
604 } else {
605 warn!("No hook callback found for ID: {}", callback_id);
606 }
607 } else {
608 warn!("Invalid hook_callback control message: missing callback_id");
609 }
610 }
611 }
612 "mcp_message" => {
613 // Handle MCP message
614 if let Some(server_name) = request_data.get("server_name").and_then(|v| v.as_str())
615 && let Some(message) = request_data.get("message") {
616 debug!("Processing MCP message for SDK server: {}", server_name);
617
618 // Check if we have an SDK server with this name
619 if let Some(server_arc) = sdk_mcp_servers_clone.get(server_name) {
620 // Try to downcast to SdkMcpServer
621 if let Some(sdk_server) = server_arc.downcast_ref::<crate::sdk_mcp::SdkMcpServer>() {
622 // Call the SDK MCP server
623 match sdk_server.handle_message(message.clone()).await {
624 Ok(mcp_result) => {
625 // Wrap response with proper structure
626 let response = serde_json::json!({
627 "subtype": "success",
628 "request_id": Self::extract_request_id(&control_message),
629 "response": {
630 "mcp_response": mcp_result
631 }
632 });
633
634 let mut transport = transport_for_control.lock().await;
635 if let Err(e) = transport.send_sdk_control_response(response).await {
636 error!("Failed to send MCP response: {}", e);
637 }
638 }
639 Err(e) => {
640 error!("SDK MCP server error: {}", e);
641 let error_response = serde_json::json!({
642 "subtype": "error",
643 "request_id": Self::extract_request_id(&control_message),
644 "error": format!("MCP server error: {}", e)
645 });
646
647 let mut transport = transport_for_control.lock().await;
648 if let Err(e) = transport.send_sdk_control_response(error_response).await {
649 error!("Failed to send MCP error response: {}", e);
650 }
651 }
652 }
653 } else {
654 warn!("SDK server '{}' is not of type SdkMcpServer", server_name);
655 }
656 } else {
657 warn!("No SDK MCP server found with name: {}", server_name);
658 let error_response = serde_json::json!({
659 "subtype": "error",
660 "request_id": Self::extract_request_id(&control_message),
661 "error": format!("Server '{}' not found", server_name)
662 });
663
664 let mut transport = transport_for_control.lock().await;
665 if let Err(e) = transport.send_sdk_control_response(error_response).await {
666 error!("Failed to send MCP error response: {}", e);
667 }
668 }
669 }
670 }
671 _ => {
672 debug!("Unknown SDK control subtype: {}", subtype);
673 }
674 }
675 }
676 }
677 });
678 }
679 }
680
681 /// Stream input messages to the CLI stdin by converting JSON values to InputMessage
682 #[allow(dead_code)]
683 pub async fn stream_input<S>(&mut self, input_stream: S) -> Result<()>
684 where
685 S: Stream<Item = JsonValue> + Send + 'static,
686 {
687 let transport = self.transport.clone();
688
689 tokio::spawn(async move {
690 use futures::StreamExt;
691 let mut stream = Box::pin(input_stream);
692
693 while let Some(value) = stream.next().await {
694 // Best-effort conversion from arbitrary JSON to InputMessage
695 let input_msg_opt = Self::json_to_input_message(value);
696 if let Some(input_msg) = input_msg_opt {
697 let mut guard = transport.lock().await;
698 if let Err(e) = guard.send_message(input_msg).await {
699 warn!("Failed to send streaming input message: {}", e);
700 }
701 } else {
702 warn!("Invalid streaming input JSON; expected user message shape");
703 }
704 }
705
706 // After streaming all inputs, signal end of input
707 let mut guard = transport.lock().await;
708 if let Err(e) = guard.end_input().await {
709 warn!("Failed to signal end_input: {}", e);
710 }
711 });
712 Ok(())
713 }
714
715 /// Receive messages
716 #[allow(dead_code)]
717 pub async fn receive_messages(&mut self) -> mpsc::Receiver<Result<Message>> {
718 self.message_rx.take().expect("Receiver already taken")
719 }
720
721 /// Send interrupt request
722 pub async fn interrupt(&mut self) -> Result<()> {
723 let interrupt_request = SDKControlRequest::Interrupt(SDKControlInterruptRequest {
724 subtype: "interrupt".to_string(),
725 });
726
727 self.send_control_request(interrupt_request).await?;
728 Ok(())
729 }
730
731 /// Set permission mode via control protocol
732 #[allow(dead_code)]
733 pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
734 let req = SDKControlRequest::SetPermissionMode(SDKControlSetPermissionModeRequest {
735 subtype: "set_permission_mode".to_string(),
736 mode: mode.to_string(),
737 });
738 // Ignore response payload; errors propagate
739 let _ = self.send_control_request(req).await?;
740 Ok(())
741 }
742
743 /// Set the active model via control protocol
744 #[allow(dead_code)]
745 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
746 let req = SDKControlRequest::SetModel(crate::types::SDKControlSetModelRequest {
747 subtype: "set_model".to_string(),
748 model,
749 });
750 let _ = self.send_control_request(req).await?;
751 Ok(())
752 }
753
754 /// Rewind tracked files to their state at a specific user message
755 ///
756 /// Requires `enable_file_checkpointing` to be enabled in `ClaudeCodeOptions`.
757 ///
758 /// # Arguments
759 ///
760 /// * `user_message_id` - UUID of the user message to rewind to
761 ///
762 /// # Example
763 ///
764 /// ```rust,no_run
765 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
766 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
767 /// let options = ClaudeCodeOptions::builder()
768 /// .enable_file_checkpointing(true)
769 /// .build();
770 /// let mut client = ClaudeSDKClient::new(options);
771 /// client.connect(None).await?;
772 ///
773 /// // Later, rewind to a checkpoint
774 /// // client.rewind_files("user-message-uuid-here").await?;
775 /// # Ok(())
776 /// # }
777 /// ```
778 pub async fn rewind_files(&mut self, user_message_id: &str) -> Result<()> {
779 let req = SDKControlRequest::RewindFiles(crate::types::SDKControlRewindFilesRequest::new(user_message_id));
780 let _ = self.send_control_request(req).await?;
781 Ok(())
782 }
783
784 /// Get context usage information
785 pub async fn get_context_usage(&mut self) -> Result<serde_json::Value> {
786 let req = SDKControlRequest::GetContextUsage(crate::types::SDKControlGetContextUsageRequest::new());
787 self.send_control_request(req).await
788 }
789
790 /// Stop a background task
791 pub async fn stop_task(&mut self, task_id: &str) -> Result<()> {
792 let req = SDKControlRequest::StopTask(crate::types::SDKControlStopTaskRequest::new(task_id));
793 let _ = self.send_control_request(req).await?;
794 Ok(())
795 }
796
797 /// Get MCP server status
798 pub async fn get_mcp_status(&mut self) -> Result<serde_json::Value> {
799 let req = SDKControlRequest::McpStatus(crate::types::SDKControlMcpStatusRequest::new());
800 self.send_control_request(req).await
801 }
802
803 /// Reconnect an MCP server
804 pub async fn reconnect_mcp_server(&mut self, server_name: &str) -> Result<()> {
805 let req = SDKControlRequest::McpReconnect(crate::types::SDKControlMcpReconnectRequest::new(server_name));
806 let _ = self.send_control_request(req).await?;
807 Ok(())
808 }
809
810 /// Toggle an MCP server on/off
811 pub async fn toggle_mcp_server(&mut self, server_name: &str, enabled: bool) -> Result<()> {
812 let req = SDKControlRequest::McpToggle(crate::types::SDKControlMcpToggleRequest::new(server_name, enabled));
813 let _ = self.send_control_request(req).await?;
814 Ok(())
815 }
816
817 /// Handle MCP message for SDK servers
818 #[allow(dead_code)]
819 async fn handle_mcp_message(&mut self, server_name: &str, message: &JsonValue) -> Result<JsonValue> {
820 // Check if we have an SDK server with this name
821 if let Some(_server) = self.sdk_mcp_servers.get(server_name) {
822 // TODO: Implement actual MCP server invocation
823 // For now, return a placeholder response
824 debug!("Handling MCP message for SDK server {}: {:?}", server_name, message);
825 Ok(serde_json::json!({
826 "jsonrpc": "2.0",
827 "id": message.get("id"),
828 "result": {
829 "content": "MCP server response placeholder"
830 }
831 }))
832 } else {
833 Err(SdkError::InvalidState {
834 message: format!("No SDK MCP server found with name: {server_name}"),
835 })
836 }
837 }
838
839 /// Close the query handler
840 #[allow(dead_code)]
841 pub async fn close(&mut self) -> Result<()> {
842 // Clean up resources
843 let mut transport = self.transport.lock().await;
844 transport.disconnect().await?;
845 Ok(())
846 }
847
848 /// Get initialization result
849 pub fn get_initialization_result(&self) -> Option<&JsonValue> {
850 self.initialization_result.as_ref()
851 }
852
853 /// Convert arbitrary JSON value to InputMessage understood by CLI
854 #[allow(dead_code)]
855 fn json_to_input_message(v: JsonValue) -> Option<InputMessage> {
856 // 1) Already in SDK message shape
857 if let Some(obj) = v.as_object() {
858 if let (Some(t), Some(message)) = (obj.get("type"), obj.get("message"))
859 && t.as_str() == Some("user") {
860 let parent = obj
861 .get("parent_tool_use_id")
862 .and_then(|p| p.as_str().map(|s| s.to_string()));
863 let session_id = obj
864 .get("session_id")
865 .and_then(|s| s.as_str())
866 .unwrap_or("default")
867 .to_string();
868
869 let im = InputMessage {
870 r#type: "user".to_string(),
871 message: message.clone(),
872 parent_tool_use_id: parent,
873 session_id,
874 };
875 return Some(im);
876 }
877
878 // 2) Simple wrapper: {"content":"...", "session_id":"..."}
879 if let Some(content) = obj.get("content").and_then(|c| c.as_str()) {
880 let session_id = obj
881 .get("session_id")
882 .and_then(|s| s.as_str())
883 .unwrap_or("default")
884 .to_string();
885 return Some(InputMessage::user(content.to_string(), session_id));
886 }
887 }
888
889 // 3) Bare string
890 if let Some(s) = v.as_str() {
891 return Some(InputMessage::user(s.to_string(), "default".to_string()));
892 }
893
894 None
895 }
896}
897
898#[cfg(test)]
899mod tests {
900 use super::*;
901
902 #[test]
903 fn test_extract_request_id_supports_both_cases() {
904 let snake = serde_json::json!({"request_id": "req_1"});
905 let camel = serde_json::json!({"requestId": "req_2"});
906 assert_eq!(Query::extract_request_id(&snake), Some(serde_json::json!("req_1")));
907 assert_eq!(Query::extract_request_id(&camel), Some(serde_json::json!("req_2")));
908 }
909
910 #[test]
911 fn test_json_to_input_message_from_string() {
912 let v = serde_json::json!("Hello");
913 let im = Query::json_to_input_message(v).expect("should convert");
914 assert_eq!(im.r#type, "user");
915 assert_eq!(im.session_id, "default");
916 assert_eq!(im.message["content"].as_str().unwrap(), "Hello");
917 }
918
919 #[test]
920 fn test_json_to_input_message_from_object_content() {
921 let v = serde_json::json!({"content":"Ping","session_id":"s1"});
922 let im = Query::json_to_input_message(v).expect("should convert");
923 assert_eq!(im.session_id, "s1");
924 assert_eq!(im.message["content"].as_str().unwrap(), "Ping");
925 }
926
927 #[test]
928 fn test_json_to_input_message_full_user_shape() {
929 let v = serde_json::json!({
930 "type":"user",
931 "message": {"role":"user","content":"Hi"},
932 "session_id": "abc",
933 "parent_tool_use_id": null
934 });
935 let im = Query::json_to_input_message(v).expect("should convert");
936 assert_eq!(im.session_id, "abc");
937 assert_eq!(im.message["role"].as_str().unwrap(), "user");
938 assert_eq!(im.message["content"].as_str().unwrap(), "Hi");
939 }
940}