cc_sdk/internal_query.rs
1//! Internal query implementation with control protocol support
2//!
3//! This module provides the internal Query struct that handles control protocol,
4//! permissions, hooks, and MCP server integration.
5
6use crate::{
7 errors::{Result, SdkError},
8 transport::{InputMessage, Transport},
9 types::{
10 CanUseTool, HookCallback, HookContext, HookMatcher, Message, PermissionResult, PermissionUpdate,
11 SDKControlInitializeRequest, SDKControlInterruptRequest, SDKControlPermissionRequest,
12 SDKControlRequest, SDKHookCallbackRequest, SDKControlSetPermissionModeRequest,
13 ToolPermissionContext,
14 },
15};
16use futures::stream::Stream;
17use futures::StreamExt;
18use serde_json::Value as JsonValue;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex, RwLock};
22use tokio::time::{timeout, Duration};
23use tracing::{debug, error, warn};
24
25/// Internal query handler with control protocol support
26pub struct Query {
27 /// Transport layer (shared with client)
28 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
29 /// Whether in streaming mode
30 #[allow(dead_code)]
31 is_streaming_mode: bool,
32 /// Tool permission callback
33 can_use_tool: Option<Arc<dyn CanUseTool>>,
34 /// Hook configurations
35 hooks: Option<HashMap<String, Vec<HookMatcher>>>,
36 /// SDK MCP servers
37 sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
38 /// Message channel sender (reserved for future streaming receive support)
39 #[allow(dead_code)]
40 message_tx: mpsc::Sender<Result<Message>>,
41 /// Message channel receiver (reserved for future streaming receive support)
42 #[allow(dead_code)]
43 message_rx: Option<mpsc::Receiver<Result<Message>>>,
44 /// Initialization result
45 initialization_result: Option<JsonValue>,
46 /// Active hook callbacks
47 hook_callbacks: Arc<RwLock<HashMap<String, Arc<dyn HookCallback>>>>,
48 /// Hook callback counter
49 callback_counter: Arc<Mutex<u64>>,
50 /// Request counter for generating unique IDs
51 request_counter: Arc<Mutex<u64>>,
52 /// Pending control request responses
53 pending_responses: Arc<RwLock<HashMap<String, tokio::sync::oneshot::Sender<JsonValue>>>>,
54}
55
56impl Query {
57 /// Create a new Query handler
58 pub fn new(
59 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
60 is_streaming_mode: bool,
61 can_use_tool: Option<Arc<dyn CanUseTool>>,
62 hooks: Option<HashMap<String, Vec<HookMatcher>>>,
63 sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
64 ) -> Self {
65 let (tx, rx) = mpsc::channel(100);
66
67 Self {
68 transport,
69 is_streaming_mode,
70 can_use_tool,
71 hooks,
72 sdk_mcp_servers,
73 message_tx: tx,
74 message_rx: Some(rx),
75 initialization_result: None,
76 hook_callbacks: Arc::new(RwLock::new(HashMap::new())),
77 callback_counter: Arc::new(Mutex::new(0)),
78 request_counter: Arc::new(Mutex::new(0)),
79 pending_responses: Arc::new(RwLock::new(HashMap::new())),
80 }
81 }
82
83 /// Test helper to register a hook callback with a known ID
84 ///
85 /// This is intended for E2E tests to inject a callback ID that can be
86 /// referenced by inbound `hook_callback` control messages.
87 pub async fn register_hook_callback_for_test(
88 &self,
89 callback_id: String,
90 callback: Arc<dyn HookCallback>,
91 ) {
92 let mut map = self.hook_callbacks.write().await;
93 map.insert(callback_id, callback);
94 }
95
96 /// Start the query handler
97 pub async fn start(&mut self) -> Result<()> {
98 // Start control request handler task
99 self.start_control_handler().await;
100
101 // Start SDK message forwarder task (route non-control messages to message_tx)
102 let transport = self.transport.clone();
103 let tx = self.message_tx.clone();
104 tokio::spawn(async move {
105 // Get message stream once and consume it continuously
106 let mut stream = {
107 let mut guard = transport.lock().await;
108 guard.receive_messages()
109 }; // Lock released immediately after getting stream
110
111 // Continuously consume from the same stream
112 while let Some(result) = stream.next().await {
113 match result {
114 Ok(msg) => {
115 if tx.send(Ok(msg)).await.is_err() {
116 break;
117 }
118 }
119 Err(e) => {
120 let _ = tx.send(Err(e)).await;
121 break;
122 }
123 }
124 }
125 });
126 Ok(())
127 }
128
129 /// Initialize the control protocol
130 pub async fn initialize(&mut self) -> Result<()> {
131 // Build hooks with callback IDs (Python SDK style)
132 let hooks_with_ids = if let Some(ref hooks) = self.hooks {
133 let mut counter = self.callback_counter.lock().await;
134 let mut callbacks_map = self.hook_callbacks.write().await;
135
136 let hooks_json: HashMap<String, JsonValue> = hooks
137 .iter()
138 .map(|(event_name, matchers)| {
139 let matchers_with_ids: Vec<JsonValue> = matchers
140 .iter()
141 .map(|matcher| {
142 // Generate callback IDs for each hook in this matcher
143 let callback_ids: Vec<String> = matcher
144 .hooks
145 .iter()
146 .map(|hook_callback| {
147 *counter += 1;
148 let callback_id = format!("hook_{}_{}", *counter, uuid::Uuid::new_v4().simple());
149
150 // Store the callback for later use
151 callbacks_map.insert(callback_id.clone(), hook_callback.clone());
152
153 callback_id
154 })
155 .collect();
156
157 serde_json::json!({
158 "matcher": matcher.matcher.clone(),
159 "hookCallbackIds": callback_ids
160 })
161 })
162 .collect();
163
164 (event_name.clone(), serde_json::json!(matchers_with_ids))
165 })
166 .collect();
167
168 Some(hooks_json)
169 } else {
170 None
171 };
172
173 // Send initialize request
174 let init_request = SDKControlRequest::Initialize(SDKControlInitializeRequest {
175 subtype: "initialize".to_string(),
176 hooks: hooks_with_ids,
177 });
178
179 // Send control request and save result
180 let result = self.send_control_request(init_request).await?;
181 self.initialization_result = Some(result);
182
183 debug!("Initialization request sent with hook callback IDs");
184 Ok(())
185 }
186
187 /// Send a control request and wait for response
188 async fn send_control_request(&mut self, request: SDKControlRequest) -> Result<JsonValue> {
189 // Generate unique request ID
190 let request_id = {
191 let mut counter = self.request_counter.lock().await;
192 *counter += 1;
193 format!("req_{}_{}", *counter, uuid::Uuid::new_v4().simple())
194 };
195
196 // Create oneshot channel for response
197 let (tx, rx) = tokio::sync::oneshot::channel();
198
199 // Register pending response
200 {
201 let mut pending = self.pending_responses.write().await;
202 pending.insert(request_id.clone(), tx);
203 }
204
205 // Build control request with request_id (snake_case for CLI compatibility)
206 let control_request = serde_json::json!({
207 "type": "control_request",
208 "request_id": request_id,
209 "request": request
210 });
211
212 debug!("Sending control request: {:?}", control_request);
213
214 // Send via transport
215 {
216 let mut transport = self.transport.lock().await;
217 transport.send_sdk_control_request(control_request).await?;
218 }
219
220 // Wait for response with timeout
221 match timeout(Duration::from_secs(60), rx).await {
222 Ok(Ok(response)) => {
223 debug!("Received control response for {}", request_id);
224 Ok(response)
225 }
226 Ok(Err(_)) => Err(SdkError::ControlRequestError(
227 "Response channel closed".to_string(),
228 )),
229 Err(_) => {
230 // Clean up pending response
231 let mut pending = self.pending_responses.write().await;
232 pending.remove(&request_id);
233 Err(SdkError::Timeout { seconds: 60 })
234 }
235 }
236 }
237
238 /// Handle permission request
239 #[allow(dead_code)]
240 async fn handle_permission_request(&mut self, request: SDKControlPermissionRequest) -> Result<()> {
241 if let Some(ref can_use_tool) = self.can_use_tool {
242 let context = ToolPermissionContext {
243 signal: None,
244 suggestions: request.permission_suggestions.unwrap_or_default(),
245 };
246
247 let result = can_use_tool
248 .can_use_tool(&request.tool_name, &request.input, &context)
249 .await;
250
251 // Send response back (CLI expects: { allow: bool, input?, reason? })
252 let response = match result {
253 PermissionResult::Allow(allow) => {
254 let mut obj = serde_json::json!({ "allow": true });
255 if let Some(updated) = allow.updated_input {
256 obj["input"] = updated;
257 }
258 obj
259 }
260 PermissionResult::Deny(deny) => {
261 let mut obj = serde_json::json!({ "allow": false });
262 if !deny.message.is_empty() {
263 obj["reason"] = serde_json::json!(deny.message);
264 }
265 obj
266 }
267 };
268
269 // Send response back through transport
270 let mut transport = self.transport.lock().await;
271 transport.send_sdk_control_response(response).await?;
272 debug!("Permission response sent");
273 }
274
275 Ok(())
276 }
277
278 /// Extract requestId from CLI message (supports both camelCase and snake_case)
279 fn extract_request_id(msg: &JsonValue) -> Option<JsonValue> {
280 msg.get("requestId")
281 .or_else(|| msg.get("request_id"))
282 .cloned()
283 }
284
285 /// Start control request handler task
286 async fn start_control_handler(&mut self) {
287 let transport = self.transport.clone();
288 let can_use_tool = self.can_use_tool.clone();
289 let hook_callbacks = self.hook_callbacks.clone();
290 let sdk_mcp_servers = self.sdk_mcp_servers.clone();
291 let pending_responses = self.pending_responses.clone();
292
293 // Take ownership of the SDK control receiver to avoid holding locks
294 let sdk_control_rx = {
295 let mut transport_lock = transport.lock().await;
296 transport_lock.take_sdk_control_receiver()
297 }; // Lock released here
298
299 if let Some(mut control_rx) = sdk_control_rx {
300 tokio::spawn(async move {
301 // Now we can receive control requests without holding any locks
302 let transport_for_control = transport;
303 let can_use_tool_clone = can_use_tool;
304 let hook_callbacks_clone = hook_callbacks;
305 let sdk_mcp_servers_clone = sdk_mcp_servers;
306 let pending_responses_clone = pending_responses;
307
308 loop {
309 // Receive control request without holding lock
310 let control_message = control_rx.recv().await;
311
312 if let Some(control_message) = control_message {
313 debug!("Received control message: {:?}", control_message);
314
315 // Check if this is a control response (from CLI to SDK)
316 if control_message.get("type").and_then(|v| v.as_str()) == Some("control_response") {
317 // Expected shape: {"type":"control_response", "response": {"request_id": "...", ...}}
318 if let Some(resp_obj) = control_message.get("response") {
319 let request_id = resp_obj
320 .get("request_id")
321 .or_else(|| resp_obj.get("requestId"))
322 .and_then(|v| v.as_str());
323
324 if let Some(request_id) = request_id {
325 let mut pending = pending_responses_clone.write().await;
326 if let Some(tx) = pending.remove(request_id) {
327 // Deliver only the nested "response" object (matches Python SDK semantics)
328 let _ = tx.send(resp_obj.clone());
329 debug!("Control response delivered for request_id: {}", request_id);
330 } else {
331 warn!("No pending request found for request_id: {}", request_id);
332 }
333 } else {
334 warn!("Control response missing request_id: {:?}", control_message);
335 }
336 } else {
337 warn!("Control response missing 'response' payload: {:?}", control_message);
338 }
339 continue;
340 }
341
342 // Parse and handle control requests (from CLI to SDK)
343 // Check if this is a control_request with a nested request field
344 let request_data = if control_message.get("type").and_then(|v| v.as_str()) == Some("control_request") {
345 control_message.get("request").cloned().unwrap_or(control_message.clone())
346 } else {
347 control_message.clone()
348 };
349
350 if let Some(subtype) = request_data.get("subtype").and_then(|v| v.as_str()) {
351 match subtype {
352 "can_use_tool" => {
353 // Handle permission request
354 if let Ok(request) = serde_json::from_value::<SDKControlPermissionRequest>(request_data.clone()) {
355 // Handle with can_use_tool callback
356 if let Some(ref can_use_tool) = can_use_tool_clone {
357 let context = ToolPermissionContext {
358 signal: None,
359 suggestions: request.permission_suggestions.unwrap_or_default(),
360 };
361
362 let result = can_use_tool
363 .can_use_tool(&request.tool_name, &request.input, &context)
364 .await;
365
366 // CLI expects: {"allow": true, "input": ...} or {"allow": false, "reason": ...}
367 let permission_response = match result {
368 PermissionResult::Allow(allow) => {
369 let mut resp = serde_json::json!({
370 "allow": true,
371 });
372 if let Some(input) = allow.updated_input {
373 resp["input"] = input;
374 }
375 if let Some(perms) = allow.updated_permissions {
376 resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default();
377 }
378 resp
379 }
380 PermissionResult::Deny(deny) => {
381 let mut resp = serde_json::json!({
382 "allow": false,
383 });
384 if !deny.message.is_empty() {
385 resp["reason"] = serde_json::json!(deny.message);
386 }
387 if deny.interrupt {
388 resp["interrupt"] = serde_json::json!(true);
389 }
390 resp
391 }
392 };
393
394 // Wrap response with proper structure
395 // CLI expects "subtype": "success" for all successful responses
396 let response = serde_json::json!({
397 "subtype": "success",
398 "request_id": Self::extract_request_id(&control_message),
399 "response": permission_response
400 });
401
402 // Send response
403 let mut transport = transport_for_control.lock().await;
404 if let Err(e) = transport.send_sdk_control_response(response).await {
405 error!("Failed to send permission response: {}", e);
406 }
407 }
408 } else {
409 // Fallback for snake_case fields (tool_name, permission_suggestions)
410 if let Some(tool_name) = request_data.get("tool_name").and_then(|v| v.as_str())
411 && let Some(input_val) = request_data.get("input").cloned()
412 && let Some(ref can_use_tool) = can_use_tool_clone {
413 // Try to parse permission suggestions (snake_case)
414 let suggestions: Vec<PermissionUpdate> = request_data
415 .get("permission_suggestions")
416 .cloned()
417 .and_then(|v| serde_json::from_value::<Vec<PermissionUpdate>>(v).ok())
418 .unwrap_or_default();
419
420 let context = ToolPermissionContext { signal: None, suggestions };
421 let result = can_use_tool
422 .can_use_tool(tool_name, &input_val, &context)
423 .await;
424
425 let permission_response = match result {
426 PermissionResult::Allow(allow) => {
427 let mut resp = serde_json::json!({ "allow": true });
428 if let Some(input) = allow.updated_input { resp["input"] = input; }
429 if let Some(perms) = allow.updated_permissions { resp["updatedPermissions"] = serde_json::to_value(perms).unwrap_or_default(); }
430 resp
431 }
432 PermissionResult::Deny(deny) => {
433 let mut resp = serde_json::json!({ "allow": false });
434 if !deny.message.is_empty() { resp["reason"] = serde_json::json!(deny.message); }
435 if deny.interrupt { resp["interrupt"] = serde_json::json!(true); }
436 resp
437 }
438 };
439
440 let response = serde_json::json!({
441 "subtype": "success",
442 "request_id": Self::extract_request_id(&control_message),
443 "response": permission_response
444 });
445 let mut transport = transport_for_control.lock().await;
446 if let Err(e) = transport.send_sdk_control_response(response).await {
447 error!("Failed to send permission response (fallback): {}", e);
448 }
449 }
450 }
451 }
452 "hook_callback" => {
453 // Handle hook callback with strongly-typed inputs/outputs
454 if let Ok(request) = serde_json::from_value::<SDKHookCallbackRequest>(request_data.clone()) {
455 let callbacks = hook_callbacks_clone.read().await;
456
457 if let Some(callback) = callbacks.get(&request.callback_id) {
458 let context = HookContext { signal: None };
459
460 // Try to deserialize input as HookInput
461 let hook_result = match serde_json::from_value::<crate::types::HookInput>(request.input.clone()) {
462 Ok(hook_input) => {
463 // Call the hook with strongly-typed input
464 callback
465 .execute(&hook_input, request.tool_use_id.as_deref(), &context)
466 .await
467 }
468 Err(parse_err) => {
469 error!("Failed to parse hook input: {}", parse_err);
470 // Return error using MessageParseError
471 Err(crate::errors::SdkError::MessageParseError {
472 error: format!("Invalid hook input: {parse_err}"),
473 raw: request.input.to_string(),
474 })
475 }
476 };
477
478 // Handle hook result
479 let response_json = match hook_result {
480 Ok(hook_output) => {
481 // Serialize HookJSONOutput to JSON
482 let output_value = serde_json::to_value(&hook_output)
483 .unwrap_or_else(|e| {
484 error!("Failed to serialize hook output: {}", e);
485 serde_json::json!({})
486 });
487
488 serde_json::json!({
489 "subtype": "success",
490 "request_id": Self::extract_request_id(&control_message),
491 "response": output_value
492 })
493 }
494 Err(e) => {
495 error!("Hook callback failed: {}", e);
496 serde_json::json!({
497 "subtype": "error",
498 "request_id": Self::extract_request_id(&control_message),
499 "error": e.to_string()
500 })
501 }
502 };
503
504 let mut transport = transport_for_control.lock().await;
505 if let Err(e) = transport.send_sdk_control_response(response_json).await {
506 error!("Failed to send hook callback response: {}", e);
507 }
508 } else {
509 warn!("No hook callback found for ID: {}", request.callback_id);
510 // Send error response
511 let error_response = serde_json::json!({
512 "subtype": "error",
513 "request_id": Self::extract_request_id(&control_message),
514 "error": format!("No hook callback found for ID: {}", request.callback_id)
515 });
516 let mut transport = transport_for_control.lock().await;
517 if let Err(e) = transport.send_sdk_control_response(error_response).await {
518 error!("Failed to send error response: {}", e);
519 }
520 }
521 } else {
522 // Fallback for snake_case fields (callback_id, tool_use_id)
523 let callback_id = request_data.get("callback_id").and_then(|v| v.as_str());
524 let tool_use_id = request_data.get("tool_use_id").and_then(|v| v.as_str()).map(|s| s.to_string());
525 let input = request_data.get("input").cloned().unwrap_or(serde_json::json!({}));
526
527 if let Some(callback_id) = callback_id {
528 let callbacks = hook_callbacks_clone.read().await;
529 if let Some(callback) = callbacks.get(callback_id) {
530 let context = HookContext { signal: None };
531
532 // Try to parse as HookInput
533 let hook_result = match serde_json::from_value::<crate::types::HookInput>(input.clone()) {
534 Ok(hook_input) => {
535 callback
536 .execute(&hook_input, tool_use_id.as_deref(), &context)
537 .await
538 }
539 Err(parse_err) => {
540 error!("Failed to parse hook input (fallback): {}", parse_err);
541 Err(crate::errors::SdkError::MessageParseError {
542 error: format!("Invalid hook input: {parse_err}"),
543 raw: input.to_string(),
544 })
545 }
546 };
547
548 let response_json = match hook_result {
549 Ok(hook_output) => {
550 let output_value = serde_json::to_value(&hook_output)
551 .unwrap_or_else(|e| {
552 error!("Failed to serialize hook output (fallback): {}", e);
553 serde_json::json!({})
554 });
555
556 serde_json::json!({
557 "subtype": "success",
558 "request_id": Self::extract_request_id(&control_message),
559 "response": output_value
560 })
561 }
562 Err(e) => {
563 error!("Hook callback failed (fallback): {}", e);
564 serde_json::json!({
565 "subtype": "error",
566 "request_id": Self::extract_request_id(&control_message),
567 "error": e.to_string()
568 })
569 }
570 };
571
572 let mut transport = transport_for_control.lock().await;
573 if let Err(e) = transport.send_sdk_control_response(response_json).await {
574 error!("Failed to send hook callback response (fallback): {}", e);
575 }
576 } else {
577 warn!("No hook callback found for ID: {}", callback_id);
578 }
579 } else {
580 warn!("Invalid hook_callback control message: missing callback_id");
581 }
582 }
583 }
584 "mcp_message" => {
585 // Handle MCP message
586 if let Some(server_name) = request_data.get("server_name").and_then(|v| v.as_str())
587 && let Some(message) = request_data.get("message") {
588 debug!("Processing MCP message for SDK server: {}", server_name);
589
590 // Check if we have an SDK server with this name
591 if let Some(server_arc) = sdk_mcp_servers_clone.get(server_name) {
592 // Try to downcast to SdkMcpServer
593 if let Some(sdk_server) = server_arc.downcast_ref::<crate::sdk_mcp::SdkMcpServer>() {
594 // Call the SDK MCP server
595 match sdk_server.handle_message(message.clone()).await {
596 Ok(mcp_result) => {
597 // Wrap response with proper structure
598 let response = serde_json::json!({
599 "subtype": "success",
600 "request_id": Self::extract_request_id(&control_message),
601 "response": {
602 "mcp_response": mcp_result
603 }
604 });
605
606 let mut transport = transport_for_control.lock().await;
607 if let Err(e) = transport.send_sdk_control_response(response).await {
608 error!("Failed to send MCP response: {}", e);
609 }
610 }
611 Err(e) => {
612 error!("SDK MCP server error: {}", e);
613 let error_response = serde_json::json!({
614 "subtype": "error",
615 "request_id": Self::extract_request_id(&control_message),
616 "error": format!("MCP server error: {}", e)
617 });
618
619 let mut transport = transport_for_control.lock().await;
620 if let Err(e) = transport.send_sdk_control_response(error_response).await {
621 error!("Failed to send MCP error response: {}", e);
622 }
623 }
624 }
625 } else {
626 warn!("SDK server '{}' is not of type SdkMcpServer", server_name);
627 }
628 } else {
629 warn!("No SDK MCP server found with name: {}", server_name);
630 let error_response = serde_json::json!({
631 "subtype": "error",
632 "request_id": Self::extract_request_id(&control_message),
633 "error": format!("Server '{}' not found", server_name)
634 });
635
636 let mut transport = transport_for_control.lock().await;
637 if let Err(e) = transport.send_sdk_control_response(error_response).await {
638 error!("Failed to send MCP error response: {}", e);
639 }
640 }
641 }
642 }
643 _ => {
644 debug!("Unknown SDK control subtype: {}", subtype);
645 }
646 }
647 }
648 }
649 }
650 });
651 }
652 }
653
654 /// Stream input messages to the CLI stdin by converting JSON values to InputMessage
655 #[allow(dead_code)]
656 pub async fn stream_input<S>(&mut self, input_stream: S) -> Result<()>
657 where
658 S: Stream<Item = JsonValue> + Send + 'static,
659 {
660 let transport = self.transport.clone();
661
662 tokio::spawn(async move {
663 use futures::StreamExt;
664 let mut stream = Box::pin(input_stream);
665
666 while let Some(value) = stream.next().await {
667 // Best-effort conversion from arbitrary JSON to InputMessage
668 let input_msg_opt = Self::json_to_input_message(value);
669 if let Some(input_msg) = input_msg_opt {
670 let mut guard = transport.lock().await;
671 if let Err(e) = guard.send_message(input_msg).await {
672 warn!("Failed to send streaming input message: {}", e);
673 }
674 } else {
675 warn!("Invalid streaming input JSON; expected user message shape");
676 }
677 }
678
679 // After streaming all inputs, signal end of input
680 let mut guard = transport.lock().await;
681 if let Err(e) = guard.end_input().await {
682 warn!("Failed to signal end_input: {}", e);
683 }
684 });
685 Ok(())
686 }
687
688 /// Receive messages
689 #[allow(dead_code)]
690 pub async fn receive_messages(&mut self) -> mpsc::Receiver<Result<Message>> {
691 self.message_rx.take().expect("Receiver already taken")
692 }
693
694 /// Send interrupt request
695 pub async fn interrupt(&mut self) -> Result<()> {
696 let interrupt_request = SDKControlRequest::Interrupt(SDKControlInterruptRequest {
697 subtype: "interrupt".to_string(),
698 });
699
700 self.send_control_request(interrupt_request).await?;
701 Ok(())
702 }
703
704 /// Set permission mode via control protocol
705 #[allow(dead_code)]
706 pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
707 let req = SDKControlRequest::SetPermissionMode(SDKControlSetPermissionModeRequest {
708 subtype: "set_permission_mode".to_string(),
709 mode: mode.to_string(),
710 });
711 // Ignore response payload; errors propagate
712 let _ = self.send_control_request(req).await?;
713 Ok(())
714 }
715
716 /// Set the active model via control protocol
717 #[allow(dead_code)]
718 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
719 let req = SDKControlRequest::SetModel(crate::types::SDKControlSetModelRequest {
720 subtype: "set_model".to_string(),
721 model,
722 });
723 let _ = self.send_control_request(req).await?;
724 Ok(())
725 }
726
727 /// Handle MCP message for SDK servers
728 #[allow(dead_code)]
729 async fn handle_mcp_message(&mut self, server_name: &str, message: &JsonValue) -> Result<JsonValue> {
730 // Check if we have an SDK server with this name
731 if let Some(_server) = self.sdk_mcp_servers.get(server_name) {
732 // TODO: Implement actual MCP server invocation
733 // For now, return a placeholder response
734 debug!("Handling MCP message for SDK server {}: {:?}", server_name, message);
735 Ok(serde_json::json!({
736 "jsonrpc": "2.0",
737 "id": message.get("id"),
738 "result": {
739 "content": "MCP server response placeholder"
740 }
741 }))
742 } else {
743 Err(SdkError::InvalidState {
744 message: format!("No SDK MCP server found with name: {server_name}"),
745 })
746 }
747 }
748
749 /// Close the query handler
750 #[allow(dead_code)]
751 pub async fn close(&mut self) -> Result<()> {
752 // Clean up resources
753 let mut transport = self.transport.lock().await;
754 transport.disconnect().await?;
755 Ok(())
756 }
757
758 /// Get initialization result
759 pub fn get_initialization_result(&self) -> Option<&JsonValue> {
760 self.initialization_result.as_ref()
761 }
762
763 /// Convert arbitrary JSON value to InputMessage understood by CLI
764 #[allow(dead_code)]
765 fn json_to_input_message(v: JsonValue) -> Option<InputMessage> {
766 // 1) Already in SDK message shape
767 if let Some(obj) = v.as_object() {
768 if let (Some(t), Some(message)) = (obj.get("type"), obj.get("message"))
769 && t.as_str() == Some("user") {
770 let parent = obj
771 .get("parent_tool_use_id")
772 .and_then(|p| p.as_str().map(|s| s.to_string()));
773 let session_id = obj
774 .get("session_id")
775 .and_then(|s| s.as_str())
776 .unwrap_or("default")
777 .to_string();
778
779 let im = InputMessage {
780 r#type: "user".to_string(),
781 message: message.clone(),
782 parent_tool_use_id: parent,
783 session_id,
784 };
785 return Some(im);
786 }
787
788 // 2) Simple wrapper: {"content":"...", "session_id":"..."}
789 if let Some(content) = obj.get("content").and_then(|c| c.as_str()) {
790 let session_id = obj
791 .get("session_id")
792 .and_then(|s| s.as_str())
793 .unwrap_or("default")
794 .to_string();
795 return Some(InputMessage::user(content.to_string(), session_id));
796 }
797 }
798
799 // 3) Bare string
800 if let Some(s) = v.as_str() {
801 return Some(InputMessage::user(s.to_string(), "default".to_string()));
802 }
803
804 None
805 }
806}
807
808#[cfg(test)]
809mod tests {
810 use super::*;
811
812 #[test]
813 fn test_extract_request_id_supports_both_cases() {
814 let snake = serde_json::json!({"request_id": "req_1"});
815 let camel = serde_json::json!({"requestId": "req_2"});
816 assert_eq!(Query::extract_request_id(&snake), Some(serde_json::json!("req_1")));
817 assert_eq!(Query::extract_request_id(&camel), Some(serde_json::json!("req_2")));
818 }
819
820 #[test]
821 fn test_json_to_input_message_from_string() {
822 let v = serde_json::json!("Hello");
823 let im = Query::json_to_input_message(v).expect("should convert");
824 assert_eq!(im.r#type, "user");
825 assert_eq!(im.session_id, "default");
826 assert_eq!(im.message["content"].as_str().unwrap(), "Hello");
827 }
828
829 #[test]
830 fn test_json_to_input_message_from_object_content() {
831 let v = serde_json::json!({"content":"Ping","session_id":"s1"});
832 let im = Query::json_to_input_message(v).expect("should convert");
833 assert_eq!(im.session_id, "s1");
834 assert_eq!(im.message["content"].as_str().unwrap(), "Ping");
835 }
836
837 #[test]
838 fn test_json_to_input_message_full_user_shape() {
839 let v = serde_json::json!({
840 "type":"user",
841 "message": {"role":"user","content":"Hi"},
842 "session_id": "abc",
843 "parent_tool_use_id": null
844 });
845 let im = Query::json_to_input_message(v).expect("should convert");
846 assert_eq!(im.session_id, "abc");
847 assert_eq!(im.message["role"].as_str().unwrap(), "user");
848 assert_eq!(im.message["content"].as_str().unwrap(), "Hi");
849 }
850}