cc_sdk/internal_query.rs
1//! Internal query implementation with control protocol support
2//!
3//! This module provides the internal Query struct that handles control protocol,
4//! permissions, hooks, and MCP server integration.
5
6use crate::{
7 errors::{Result, SdkError},
8 transport::{InputMessage, Transport},
9 types::{
10 CanUseTool, HookCallback, HookContext, HookMatcher, Message, PermissionResult, PermissionUpdate,
11 SDKControlInitializeRequest, SDKControlInterruptRequest, SDKControlPermissionRequest,
12 SDKControlRequest, SDKHookCallbackRequest, SDKControlSetPermissionModeRequest,
13 ToolPermissionContext,
14 },
15};
16use futures::stream::Stream;
17use futures::StreamExt;
18use serde_json::Value as JsonValue;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use tokio::time::{timeout, Duration};
23use tracing::{debug, error, warn};
24
25/// Internal query handler with control protocol support
26pub struct Query {
27 /// Transport layer (shared with client)
28 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
29 /// Whether in streaming mode
30 #[allow(dead_code)]
31 is_streaming_mode: bool,
32 /// Tool permission callback
33 can_use_tool: Option<Arc<dyn CanUseTool>>,
34 /// Hook configurations
35 hooks: Option<HashMap<String, Vec<HookMatcher>>>,
36 /// SDK MCP servers
37 sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
38 /// Message channel sender (reserved for future streaming receive support)
39 #[allow(dead_code)]
40 message_tx: mpsc::Sender<Result<Message>>,
41 /// Message channel receiver (reserved for future streaming receive support)
42 #[allow(dead_code)]
43 message_rx: Option<mpsc::Receiver<Result<Message>>>,
44 /// Initialization result
45 initialization_result: Option<JsonValue>,
46 /// Active hook callbacks
47 hook_callbacks: Arc<RwLock<HashMap<String, Arc<dyn HookCallback>>>>,
48 /// Hook callback counter
49 callback_counter: Arc<Mutex<u64>>,
50 /// Request counter for generating unique IDs
51 request_counter: Arc<Mutex<u64>>,
52 /// Pending control request responses
53 pending_responses: Arc<RwLock<HashMap<String, tokio::sync::oneshot::Sender<JsonValue>>>>,
54}
55
56impl Query {
57 /// Create a new Query handler
58 pub fn new(
59 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
60 is_streaming_mode: bool,
61 can_use_tool: Option<Arc<dyn CanUseTool>>,
62 hooks: Option<HashMap<String, Vec<HookMatcher>>>,
63 sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
64 ) -> Self {
65 let (tx, rx) = mpsc::channel(100);
66
67 Self {
68 transport,
69 is_streaming_mode,
70 can_use_tool,
71 hooks,
72 sdk_mcp_servers,
73 message_tx: tx,
74 message_rx: Some(rx),
75 initialization_result: None,
76 hook_callbacks: Arc::new(RwLock::new(HashMap::new())),
77 callback_counter: Arc::new(Mutex::new(0)),
78 request_counter: Arc::new(Mutex::new(0)),
79 pending_responses: Arc::new(RwLock::new(HashMap::new())),
80 }
81 }
82
83 /// Test helper to register a hook callback with a known ID
84 ///
85 /// This is intended for E2E tests to inject a callback ID that can be
86 /// referenced by inbound `hook_callback` control messages.
87 pub async fn register_hook_callback_for_test(
88 &self,
89 callback_id: String,
90 callback: Arc<dyn HookCallback>,
91 ) {
92 let mut map = self.hook_callbacks.write().await;
93 map.insert(callback_id, callback);
94 }
95
96 /// Start the query handler
97 pub async fn start(&mut self) -> Result<()> {
98 // Start control request handler task
99 self.start_control_handler().await;
100
101 // Start SDK message forwarder task (route non-control messages to message_tx)
102 let transport = self.transport.clone();
103 let tx = self.message_tx.clone();
104 tokio::spawn(async move {
105 // Get message stream once and consume it continuously
106 let mut stream = {
107 let mut guard = transport.lock().await;
108 guard.receive_messages()
109 }; // Lock released immediately after getting stream
110
111 // Continuously consume from the same stream
112 while let Some(result) = stream.next().await {
113 match result {
114 Ok(msg) => {
115 if tx.send(Ok(msg)).await.is_err() {
116 break;
117 }
118 }
119 Err(e) => {
120 let _ = tx.send(Err(e)).await;
121 break;
122 }
123 }
124 }
125 });
126 Ok(())
127 }
128
129 /// Initialize the control protocol
130 pub async fn initialize(&mut self) -> Result<()> {
131 // Build hooks with callback IDs (Python SDK style)
132 let hooks_with_ids = if let Some(ref hooks) = self.hooks {
133 let mut counter = self.callback_counter.lock().await;
134 let mut callbacks_map = self.hook_callbacks.write().await;
135
136 let hooks_json: HashMap<String, JsonValue> = hooks
137 .iter()
138 .map(|(event_name, matchers)| {
139 let matchers_with_ids: Vec<JsonValue> = matchers
140 .iter()
141 .map(|matcher| {
142 // Generate callback IDs for each hook in this matcher
143 let callback_ids: Vec<String> = matcher
144 .hooks
145 .iter()
146 .map(|hook_callback| {
147 *counter += 1;
148 let callback_id = format!("hook_{}_{}", *counter, uuid::Uuid::new_v4().simple());
149
150 // Store the callback for later use
151 callbacks_map.insert(callback_id.clone(), hook_callback.clone());
152
153 callback_id
154 })
155 .collect();
156
157 serde_json::json!({
158 "matcher": matcher.matcher.clone(),
159 "hookCallbackIds": callback_ids
160 })
161 })
162 .collect();
163
164 (event_name.clone(), serde_json::json!(matchers_with_ids))
165 })
166 .collect();
167
168 Some(hooks_json)
169 } else {
170 None
171 };
172
173 // Send initialize request
174 let init_request = SDKControlRequest::Initialize(SDKControlInitializeRequest {
175 subtype: "initialize".to_string(),
176 hooks: hooks_with_ids,
177 });
178
179 // Send control request and save result
180 let result = self.send_control_request(init_request).await?;
181 self.initialization_result = Some(result);
182
183 debug!("Initialization request sent with hook callback IDs");
184 Ok(())
185 }
186
187 /// Send a control request and wait for response
188 async fn send_control_request(&mut self, request: SDKControlRequest) -> Result<JsonValue> {
189 // Generate unique request ID
190 let request_id = {
191 let mut counter = self.request_counter.lock().await;
192 *counter += 1;
193 format!("req_{}_{}", *counter, uuid::Uuid::new_v4().simple())
194 };
195
196 // Create oneshot channel for response
197 let (tx, rx) = tokio::sync::oneshot::channel();
198
199 // Register pending response
200 {
201 let mut pending = self.pending_responses.write().await;
202 pending.insert(request_id.clone(), tx);
203 }
204
205 // Build control request with request_id (snake_case for CLI compatibility)
206 let control_request = serde_json::json!({
207 "type": "control_request",
208 "request_id": request_id,
209 "request": request
210 });
211
212 debug!("Sending control request: {:?}", control_request);
213
214 // Send via transport
215 {
216 let mut transport = self.transport.lock().await;
217 transport.send_sdk_control_request(control_request).await?;
218 }
219
220 // Wait for response with timeout
221 match timeout(Duration::from_secs(60), rx).await {
222 Ok(Ok(response)) => {
223 debug!("Received control response for {}", request_id);
224
225 // Python parity: treat subtype=error as an error, and return only
226 // the payload from `response` (or legacy `data`) on success.
227 if response.get("subtype").and_then(|v| v.as_str()) == Some("error") {
228 let msg = response
229 .get("error")
230 .and_then(|v| v.as_str())
231 .unwrap_or("Unknown control request error");
232 return Err(SdkError::ControlRequestError(msg.to_string()));
233 }
234
235 Ok(response
236 .get("response")
237 .or_else(|| response.get("data"))
238 .cloned()
239 .unwrap_or_else(|| serde_json::json!({})))
240 }
241 Ok(Err(_)) => Err(SdkError::ControlRequestError(
242 "Response channel closed".to_string(),
243 )),
244 Err(_) => {
245 // Clean up pending response
246 let mut pending = self.pending_responses.write().await;
247 pending.remove(&request_id);
248 Err(SdkError::Timeout { seconds: 60 })
249 }
250 }
251 }
252
253 /// Handle permission request
254 #[allow(dead_code)]
255 async fn handle_permission_request(&mut self, request: SDKControlPermissionRequest) -> Result<()> {
256 if let Some(ref can_use_tool) = self.can_use_tool {
257 let context = ToolPermissionContext {
258 signal: None,
259 suggestions: request.permission_suggestions.unwrap_or_default(),
260 };
261
262 let result = can_use_tool
263 .can_use_tool(&request.tool_name, &request.input, &context)
264 .await;
265
266 // Send response back (CLI expects: { allow: bool, input?, reason? })
267 let response = match result {
268 PermissionResult::Allow(allow) => {
269 let mut obj = serde_json::json!({ "allow": true });
270 if let Some(updated) = allow.updated_input {
271 obj["input"] = updated;
272 }
273 obj
274 }
275 PermissionResult::Deny(deny) => {
276 let mut obj = serde_json::json!({ "allow": false });
277 if !deny.message.is_empty() {
278 obj["reason"] = serde_json::json!(deny.message);
279 }
280 obj
281 }
282 };
283
284 // Send response back through transport
285 let mut transport = self.transport.lock().await;
286 transport.send_sdk_control_response(response).await?;
287 debug!("Permission response sent");
288 }
289
290 Ok(())
291 }
292
293 /// Extract requestId from CLI message (supports both camelCase and snake_case)
294 fn extract_request_id(msg: &JsonValue) -> Option<JsonValue> {
295 msg.get("requestId")
296 .or_else(|| msg.get("request_id"))
297 .cloned()
298 }
299
300 /// Start control request handler task
301 async fn start_control_handler(&mut self) {
302 let transport = self.transport.clone();
303 let can_use_tool = self.can_use_tool.clone();
304 let hook_callbacks = self.hook_callbacks.clone();
305 let sdk_mcp_servers = self.sdk_mcp_servers.clone();
306 let pending_responses = self.pending_responses.clone();
307
308 // Take ownership of the SDK control receiver to avoid holding locks
309 let sdk_control_rx = {
310 let mut transport_lock = transport.lock().await;
311 transport_lock.take_sdk_control_receiver()
312 }; // Lock released here
313
314 if let Some(mut control_rx) = sdk_control_rx {
315 tokio::spawn(async move {
316 // Now we can receive control requests without holding any locks
317 let transport_for_control = transport;
318 let can_use_tool_clone = can_use_tool;
319 let hook_callbacks_clone = hook_callbacks;
320 let sdk_mcp_servers_clone = sdk_mcp_servers;
321 let pending_responses_clone = pending_responses;
322
323 loop {
324 // Receive control request without holding lock
325 let control_message = control_rx.recv().await;
326
327 if let Some(control_message) = control_message {
328 debug!("Received control message: {:?}", control_message);
329
330 // Check if this is a control response (from CLI to SDK)
331 if control_message.get("type").and_then(|v| v.as_str()) == Some("control_response") {
332 // Expected shape: {"type":"control_response", "response": {"request_id": "...", ...}}
333 if let Some(resp_obj) = control_message.get("response") {
334 let request_id = resp_obj
335 .get("request_id")
336 .or_else(|| resp_obj.get("requestId"))
337 .and_then(|v| v.as_str());
338
339 if let Some(request_id) = request_id {
340 let mut pending = pending_responses_clone.write().await;
341 if let Some(tx) = pending.remove(request_id) {
342 // Deliver the nested control response object; send_control_request will
343 // extract the `response` (or legacy `data`) payload for callers.
344 let _ = tx.send(resp_obj.clone());
345 debug!("Control response delivered for request_id: {}", request_id);
346 } else {
347 warn!("No pending request found for request_id: {}", request_id);
348 }
349 } else {
350 warn!("Control response missing request_id: {:?}", control_message);
351 }
352 } else {
353 warn!("Control response missing 'response' payload: {:?}", control_message);
354 }
355 continue;
356 }
357
358 // Parse and handle control requests (from CLI to SDK)
359 // Check if this is a control_request with a nested request field
360 let request_data = if control_message.get("type").and_then(|v| v.as_str()) == Some("control_request") {
361 control_message.get("request").cloned().unwrap_or(control_message.clone())
362 } else {
363 control_message.clone()
364 };
365
366 if let Some(subtype) = request_data.get("subtype").and_then(|v| v.as_str()) {
367 match subtype {
368 "can_use_tool" => {
369 // Handle permission request
370 if let Ok(request) = serde_json::from_value::<SDKControlPermissionRequest>(request_data.clone()) {
371 // Handle with can_use_tool callback
372 if let Some(ref can_use_tool) = can_use_tool_clone {
373 let context = ToolPermissionContext {
374 signal: None,
375 suggestions: request.permission_suggestions.unwrap_or_default(),
376 };
377
378 // Save original input for fallback (Python SDK always sends updatedInput)
379 let original_input = request.input.clone();
380
381 let result = can_use_tool
382 .can_use_tool(&request.tool_name, &request.input, &context)
383 .await;
384
385 // Match Python SDK response format:
386 // Allow: {"behavior": "allow", "updatedInput": ..., "updatedPermissions": ...}
387 // Deny: {"behavior": "deny", "message": "...", "interrupt": false}
388 // NOTE: updatedInput is ALWAYS required for allow (CLI Zod schema expects it)
389 let permission_response = match result {
390 PermissionResult::Allow(allow) => {
391 let mut resp = serde_json::json!({
392 "behavior": "allow",
393 "updatedInput": allow.updated_input.unwrap_or(original_input),
394 });
395 if let Some(perms) = allow.updated_permissions {
396 resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default();
397 }
398 resp
399 }
400 PermissionResult::Deny(deny) => {
401 let mut resp = serde_json::json!({
402 "behavior": "deny",
403 });
404 if !deny.message.is_empty() {
405 resp["message"] = serde_json::json!(deny.message);
406 }
407 if deny.interrupt {
408 resp["interrupt"] = serde_json::json!(true);
409 }
410 resp
411 }
412 };
413
414 // Wrap response with proper structure
415 // CLI expects "subtype": "success" for all successful responses
416 let response = serde_json::json!({
417 "subtype": "success",
418 "request_id": Self::extract_request_id(&control_message),
419 "response": permission_response
420 });
421
422 // Send response
423 let mut transport = transport_for_control.lock().await;
424 if let Err(e) = transport.send_sdk_control_response(response).await {
425 error!("Failed to send permission response: {}", e);
426 }
427 }
428 } else {
429 // Fallback for snake_case fields (tool_name, permission_suggestions)
430 if let Some(tool_name) = request_data.get("tool_name").and_then(|v| v.as_str())
431 && let Some(input_val) = request_data.get("input").cloned()
432 && let Some(ref can_use_tool) = can_use_tool_clone {
433 // Try to parse permission suggestions (snake_case)
434 let suggestions: Vec<PermissionUpdate> = request_data
435 .get("permission_suggestions")
436 .cloned()
437 .and_then(|v| serde_json::from_value::<Vec<PermissionUpdate>>(v).ok())
438 .unwrap_or_default();
439
440 let context = ToolPermissionContext { signal: None, suggestions };
441 let original_input = input_val.clone();
442 let result = can_use_tool
443 .can_use_tool(tool_name, &input_val, &context)
444 .await;
445
446 let permission_response = match result {
447 PermissionResult::Allow(allow) => {
448 let mut resp = serde_json::json!({
449 "behavior": "allow",
450 "updatedInput": allow.updated_input.unwrap_or(original_input),
451 });
452 if let Some(perms) = allow.updated_permissions { resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default(); }
453 resp
454 }
455 PermissionResult::Deny(deny) => {
456 let mut resp = serde_json::json!({ "behavior": "deny" });
457 if !deny.message.is_empty() { resp["message"] = serde_json::json!(deny.message); }
458 if deny.interrupt { resp["interrupt"] = serde_json::json!(true); }
459 resp
460 }
461 };
462
463 let response = serde_json::json!({
464 "subtype": "success",
465 "request_id": Self::extract_request_id(&control_message),
466 "response": permission_response
467 });
468 let mut transport = transport_for_control.lock().await;
469 if let Err(e) = transport.send_sdk_control_response(response).await {
470 error!("Failed to send permission response (fallback): {}", e);
471 }
472 }
473 }
474 }
475 "hook_callback" => {
476 // Handle hook callback with strongly-typed inputs/outputs
477 if let Ok(request) = serde_json::from_value::<SDKHookCallbackRequest>(request_data.clone()) {
478 let callbacks = hook_callbacks_clone.read().await;
479
480 if let Some(callback) = callbacks.get(&request.callback_id) {
481 let context = HookContext { signal: None };
482
483 // Try to deserialize input as HookInput
484 let hook_result = match serde_json::from_value::<crate::types::HookInput>(request.input.clone()) {
485 Ok(hook_input) => {
486 // Call the hook with strongly-typed input
487 callback
488 .execute(&hook_input, request.tool_use_id.as_deref(), &context)
489 .await
490 }
491 Err(parse_err) => {
492 error!("Failed to parse hook input: {}", parse_err);
493 // Return error using MessageParseError
494 Err(crate::errors::SdkError::MessageParseError {
495 error: format!("Invalid hook input: {parse_err}"),
496 raw: request.input.to_string(),
497 })
498 }
499 };
500
501 // Handle hook result
502 let response_json = match hook_result {
503 Ok(hook_output) => {
504 // Serialize HookJSONOutput to JSON
505 let output_value = serde_json::to_value(&hook_output)
506 .unwrap_or_else(|e| {
507 error!("Failed to serialize hook output: {}", e);
508 serde_json::json!({})
509 });
510
511 serde_json::json!({
512 "subtype": "success",
513 "request_id": Self::extract_request_id(&control_message),
514 "response": output_value
515 })
516 }
517 Err(e) => {
518 error!("Hook callback failed: {}", e);
519 serde_json::json!({
520 "subtype": "error",
521 "request_id": Self::extract_request_id(&control_message),
522 "error": e.to_string()
523 })
524 }
525 };
526
527 let mut transport = transport_for_control.lock().await;
528 if let Err(e) = transport.send_sdk_control_response(response_json).await {
529 error!("Failed to send hook callback response: {}", e);
530 }
531 } else {
532 warn!("No hook callback found for ID: {}", request.callback_id);
533 // Send error response
534 let error_response = serde_json::json!({
535 "subtype": "error",
536 "request_id": Self::extract_request_id(&control_message),
537 "error": format!("No hook callback found for ID: {}", request.callback_id)
538 });
539 let mut transport = transport_for_control.lock().await;
540 if let Err(e) = transport.send_sdk_control_response(error_response).await {
541 error!("Failed to send error response: {}", e);
542 }
543 }
544 } else {
545 // Fallback for snake_case fields (callback_id, tool_use_id)
546 let callback_id = request_data.get("callback_id").and_then(|v| v.as_str());
547 let tool_use_id = request_data.get("tool_use_id").and_then(|v| v.as_str()).map(|s| s.to_string());
548 let input = request_data.get("input").cloned().unwrap_or(serde_json::json!({}));
549
550 if let Some(callback_id) = callback_id {
551 let callbacks = hook_callbacks_clone.read().await;
552 if let Some(callback) = callbacks.get(callback_id) {
553 let context = HookContext { signal: None };
554
555 // Try to parse as HookInput
556 let hook_result = match serde_json::from_value::<crate::types::HookInput>(input.clone()) {
557 Ok(hook_input) => {
558 callback
559 .execute(&hook_input, tool_use_id.as_deref(), &context)
560 .await
561 }
562 Err(parse_err) => {
563 error!("Failed to parse hook input (fallback): {}", parse_err);
564 Err(crate::errors::SdkError::MessageParseError {
565 error: format!("Invalid hook input: {parse_err}"),
566 raw: input.to_string(),
567 })
568 }
569 };
570
571 let response_json = match hook_result {
572 Ok(hook_output) => {
573 let output_value = serde_json::to_value(&hook_output)
574 .unwrap_or_else(|e| {
575 error!("Failed to serialize hook output (fallback): {}", e);
576 serde_json::json!({})
577 });
578
579 serde_json::json!({
580 "subtype": "success",
581 "request_id": Self::extract_request_id(&control_message),
582 "response": output_value
583 })
584 }
585 Err(e) => {
586 error!("Hook callback failed (fallback): {}", e);
587 serde_json::json!({
588 "subtype": "error",
589 "request_id": Self::extract_request_id(&control_message),
590 "error": e.to_string()
591 })
592 }
593 };
594
595 let mut transport = transport_for_control.lock().await;
596 if let Err(e) = transport.send_sdk_control_response(response_json).await {
597 error!("Failed to send hook callback response (fallback): {}", e);
598 }
599 } else {
600 warn!("No hook callback found for ID: {}", callback_id);
601 }
602 } else {
603 warn!("Invalid hook_callback control message: missing callback_id");
604 }
605 }
606 }
607 "mcp_message" => {
608 // Handle MCP message
609 if let Some(server_name) = request_data.get("server_name").and_then(|v| v.as_str())
610 && let Some(message) = request_data.get("message") {
611 debug!("Processing MCP message for SDK server: {}", server_name);
612
613 // Check if we have an SDK server with this name
614 if let Some(server_arc) = sdk_mcp_servers_clone.get(server_name) {
615 // Try to downcast to SdkMcpServer
616 if let Some(sdk_server) = server_arc.downcast_ref::<crate::sdk_mcp::SdkMcpServer>() {
617 // Call the SDK MCP server
618 match sdk_server.handle_message(message.clone()).await {
619 Ok(mcp_result) => {
620 // Wrap response with proper structure
621 let response = serde_json::json!({
622 "subtype": "success",
623 "request_id": Self::extract_request_id(&control_message),
624 "response": {
625 "mcp_response": mcp_result
626 }
627 });
628
629 let mut transport = transport_for_control.lock().await;
630 if let Err(e) = transport.send_sdk_control_response(response).await {
631 error!("Failed to send MCP response: {}", e);
632 }
633 }
634 Err(e) => {
635 error!("SDK MCP server error: {}", e);
636 let error_response = serde_json::json!({
637 "subtype": "error",
638 "request_id": Self::extract_request_id(&control_message),
639 "error": format!("MCP server error: {}", e)
640 });
641
642 let mut transport = transport_for_control.lock().await;
643 if let Err(e) = transport.send_sdk_control_response(error_response).await {
644 error!("Failed to send MCP error response: {}", e);
645 }
646 }
647 }
648 } else {
649 warn!("SDK server '{}' is not of type SdkMcpServer", server_name);
650 }
651 } else {
652 warn!("No SDK MCP server found with name: {}", server_name);
653 let error_response = serde_json::json!({
654 "subtype": "error",
655 "request_id": Self::extract_request_id(&control_message),
656 "error": format!("Server '{}' not found", server_name)
657 });
658
659 let mut transport = transport_for_control.lock().await;
660 if let Err(e) = transport.send_sdk_control_response(error_response).await {
661 error!("Failed to send MCP error response: {}", e);
662 }
663 }
664 }
665 }
666 _ => {
667 debug!("Unknown SDK control subtype: {}", subtype);
668 }
669 }
670 }
671 }
672 }
673 });
674 }
675 }
676
677 /// Stream input messages to the CLI stdin by converting JSON values to InputMessage
678 #[allow(dead_code)]
679 pub async fn stream_input<S>(&mut self, input_stream: S) -> Result<()>
680 where
681 S: Stream<Item = JsonValue> + Send + 'static,
682 {
683 let transport = self.transport.clone();
684
685 tokio::spawn(async move {
686 use futures::StreamExt;
687 let mut stream = Box::pin(input_stream);
688
689 while let Some(value) = stream.next().await {
690 // Best-effort conversion from arbitrary JSON to InputMessage
691 let input_msg_opt = Self::json_to_input_message(value);
692 if let Some(input_msg) = input_msg_opt {
693 let mut guard = transport.lock().await;
694 if let Err(e) = guard.send_message(input_msg).await {
695 warn!("Failed to send streaming input message: {}", e);
696 }
697 } else {
698 warn!("Invalid streaming input JSON; expected user message shape");
699 }
700 }
701
702 // After streaming all inputs, signal end of input
703 let mut guard = transport.lock().await;
704 if let Err(e) = guard.end_input().await {
705 warn!("Failed to signal end_input: {}", e);
706 }
707 });
708 Ok(())
709 }
710
711 /// Receive messages
712 #[allow(dead_code)]
713 pub async fn receive_messages(&mut self) -> mpsc::Receiver<Result<Message>> {
714 self.message_rx.take().expect("Receiver already taken")
715 }
716
717 /// Send interrupt request
718 pub async fn interrupt(&mut self) -> Result<()> {
719 let interrupt_request = SDKControlRequest::Interrupt(SDKControlInterruptRequest {
720 subtype: "interrupt".to_string(),
721 });
722
723 self.send_control_request(interrupt_request).await?;
724 Ok(())
725 }
726
727 /// Set permission mode via control protocol
728 #[allow(dead_code)]
729 pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
730 let req = SDKControlRequest::SetPermissionMode(SDKControlSetPermissionModeRequest {
731 subtype: "set_permission_mode".to_string(),
732 mode: mode.to_string(),
733 });
734 // Ignore response payload; errors propagate
735 let _ = self.send_control_request(req).await?;
736 Ok(())
737 }
738
739 /// Set the active model via control protocol
740 #[allow(dead_code)]
741 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
742 let req = SDKControlRequest::SetModel(crate::types::SDKControlSetModelRequest {
743 subtype: "set_model".to_string(),
744 model,
745 });
746 let _ = self.send_control_request(req).await?;
747 Ok(())
748 }
749
750 /// Rewind tracked files to their state at a specific user message
751 ///
752 /// Requires `enable_file_checkpointing` to be enabled in `ClaudeCodeOptions`.
753 ///
754 /// # Arguments
755 ///
756 /// * `user_message_id` - UUID of the user message to rewind to
757 ///
758 /// # Example
759 ///
760 /// ```rust,no_run
761 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
762 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
763 /// let options = ClaudeCodeOptions::builder()
764 /// .enable_file_checkpointing(true)
765 /// .build();
766 /// let mut client = ClaudeSDKClient::new(options);
767 /// client.connect(None).await?;
768 ///
769 /// // Later, rewind to a checkpoint
770 /// // client.rewind_files("user-message-uuid-here").await?;
771 /// # Ok(())
772 /// # }
773 /// ```
774 pub async fn rewind_files(&mut self, user_message_id: &str) -> Result<()> {
775 let req = SDKControlRequest::RewindFiles(crate::types::SDKControlRewindFilesRequest::new(user_message_id));
776 let _ = self.send_control_request(req).await?;
777 Ok(())
778 }
779
780 /// Get context usage information
781 pub async fn get_context_usage(&mut self) -> Result<serde_json::Value> {
782 let req = SDKControlRequest::GetContextUsage(crate::types::SDKControlGetContextUsageRequest::new());
783 self.send_control_request(req).await
784 }
785
786 /// Stop a background task
787 pub async fn stop_task(&mut self, task_id: &str) -> Result<()> {
788 let req = SDKControlRequest::StopTask(crate::types::SDKControlStopTaskRequest::new(task_id));
789 let _ = self.send_control_request(req).await?;
790 Ok(())
791 }
792
793 /// Get MCP server status
794 pub async fn get_mcp_status(&mut self) -> Result<serde_json::Value> {
795 let req = SDKControlRequest::McpStatus(crate::types::SDKControlMcpStatusRequest::new());
796 self.send_control_request(req).await
797 }
798
799 /// Reconnect an MCP server
800 pub async fn reconnect_mcp_server(&mut self, server_name: &str) -> Result<()> {
801 let req = SDKControlRequest::McpReconnect(crate::types::SDKControlMcpReconnectRequest::new(server_name));
802 let _ = self.send_control_request(req).await?;
803 Ok(())
804 }
805
806 /// Toggle an MCP server on/off
807 pub async fn toggle_mcp_server(&mut self, server_name: &str, enabled: bool) -> Result<()> {
808 let req = SDKControlRequest::McpToggle(crate::types::SDKControlMcpToggleRequest::new(server_name, enabled));
809 let _ = self.send_control_request(req).await?;
810 Ok(())
811 }
812
813 /// Handle MCP message for SDK servers
814 #[allow(dead_code)]
815 async fn handle_mcp_message(&mut self, server_name: &str, message: &JsonValue) -> Result<JsonValue> {
816 // Check if we have an SDK server with this name
817 if let Some(_server) = self.sdk_mcp_servers.get(server_name) {
818 // TODO: Implement actual MCP server invocation
819 // For now, return a placeholder response
820 debug!("Handling MCP message for SDK server {}: {:?}", server_name, message);
821 Ok(serde_json::json!({
822 "jsonrpc": "2.0",
823 "id": message.get("id"),
824 "result": {
825 "content": "MCP server response placeholder"
826 }
827 }))
828 } else {
829 Err(SdkError::InvalidState {
830 message: format!("No SDK MCP server found with name: {server_name}"),
831 })
832 }
833 }
834
835 /// Close the query handler
836 #[allow(dead_code)]
837 pub async fn close(&mut self) -> Result<()> {
838 // Clean up resources
839 let mut transport = self.transport.lock().await;
840 transport.disconnect().await?;
841 Ok(())
842 }
843
844 /// Get initialization result
845 pub fn get_initialization_result(&self) -> Option<&JsonValue> {
846 self.initialization_result.as_ref()
847 }
848
849 /// Convert arbitrary JSON value to InputMessage understood by CLI
850 #[allow(dead_code)]
851 fn json_to_input_message(v: JsonValue) -> Option<InputMessage> {
852 // 1) Already in SDK message shape
853 if let Some(obj) = v.as_object() {
854 if let (Some(t), Some(message)) = (obj.get("type"), obj.get("message"))
855 && t.as_str() == Some("user") {
856 let parent = obj
857 .get("parent_tool_use_id")
858 .and_then(|p| p.as_str().map(|s| s.to_string()));
859 let session_id = obj
860 .get("session_id")
861 .and_then(|s| s.as_str())
862 .unwrap_or("default")
863 .to_string();
864
865 let im = InputMessage {
866 r#type: "user".to_string(),
867 message: message.clone(),
868 parent_tool_use_id: parent,
869 session_id,
870 };
871 return Some(im);
872 }
873
874 // 2) Simple wrapper: {"content":"...", "session_id":"..."}
875 if let Some(content) = obj.get("content").and_then(|c| c.as_str()) {
876 let session_id = obj
877 .get("session_id")
878 .and_then(|s| s.as_str())
879 .unwrap_or("default")
880 .to_string();
881 return Some(InputMessage::user(content.to_string(), session_id));
882 }
883 }
884
885 // 3) Bare string
886 if let Some(s) = v.as_str() {
887 return Some(InputMessage::user(s.to_string(), "default".to_string()));
888 }
889
890 None
891 }
892}
893
894#[cfg(test)]
895mod tests {
896 use super::*;
897
898 #[test]
899 fn test_extract_request_id_supports_both_cases() {
900 let snake = serde_json::json!({"request_id": "req_1"});
901 let camel = serde_json::json!({"requestId": "req_2"});
902 assert_eq!(Query::extract_request_id(&snake), Some(serde_json::json!("req_1")));
903 assert_eq!(Query::extract_request_id(&camel), Some(serde_json::json!("req_2")));
904 }
905
906 #[test]
907 fn test_json_to_input_message_from_string() {
908 let v = serde_json::json!("Hello");
909 let im = Query::json_to_input_message(v).expect("should convert");
910 assert_eq!(im.r#type, "user");
911 assert_eq!(im.session_id, "default");
912 assert_eq!(im.message["content"].as_str().unwrap(), "Hello");
913 }
914
915 #[test]
916 fn test_json_to_input_message_from_object_content() {
917 let v = serde_json::json!({"content":"Ping","session_id":"s1"});
918 let im = Query::json_to_input_message(v).expect("should convert");
919 assert_eq!(im.session_id, "s1");
920 assert_eq!(im.message["content"].as_str().unwrap(), "Ping");
921 }
922
923 #[test]
924 fn test_json_to_input_message_full_user_shape() {
925 let v = serde_json::json!({
926 "type":"user",
927 "message": {"role":"user","content":"Hi"},
928 "session_id": "abc",
929 "parent_tool_use_id": null
930 });
931 let im = Query::json_to_input_message(v).expect("should convert");
932 assert_eq!(im.session_id, "abc");
933 assert_eq!(im.message["role"].as_str().unwrap(), "user");
934 assert_eq!(im.message["content"].as_str().unwrap(), "Hi");
935 }
936}