1use crate::error::{CLIConnectionError, ClaudeSDKError, Result};
2use crate::internal::sdk_mcp::answer_mcp_message;
3use crate::internal::transport::Transport;
4use crate::types::{
5 CanUseToolCallback, ClaudeAgentOptions, HookCallback, HookContext, PermissionResult,
6 PermissionUpdate, SkillsConfig, ToolPermissionContext,
7};
8use std::collections::HashMap;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Default)]
12pub struct ControlCallbacks {
13 pub can_use_tool: Option<CanUseToolCallback>,
14 pub sdk_mcp_servers: HashMap<String, crate::mcp::SimpleMCPServer>,
15 pub hook_callbacks: HashMap<String, HookCallback>,
16 pub hooks_config: Option<serde_json::Value>,
17 pub agents: Option<serde_json::Value>,
18 pub exclude_dynamic_sections: Option<bool>,
19 pub skills: Option<Vec<String>>,
20}
21
22impl ControlCallbacks {
23 pub fn from_options(options: &ClaudeAgentOptions) -> Self {
24 let (hooks_config, hook_callbacks) = build_hooks_config(options);
25 Self {
26 can_use_tool: options.can_use_tool.clone(),
27 sdk_mcp_servers: options.sdk_mcp_servers.clone(),
28 hook_callbacks,
29 hooks_config,
30 agents: agents_config(options),
31 exclude_dynamic_sections: options
32 .system_prompt_preset
33 .as_ref()
34 .and_then(|preset| preset.exclude_dynamic_sections),
35 skills: match &options.skills {
36 Some(SkillsConfig::Names(skills)) => Some(skills.clone()),
37 Some(SkillsConfig::All) | None => None,
38 },
39 }
40 }
41}
42
43pub fn initialize_request(callbacks: &ControlCallbacks) -> serde_json::Value {
44 let mut request = serde_json::Map::new();
45 request.insert(
46 "subtype".to_string(),
47 serde_json::Value::String("initialize".to_string()),
48 );
49 request.insert(
50 "hooks".to_string(),
51 callbacks
52 .hooks_config
53 .clone()
54 .unwrap_or(serde_json::Value::Null),
55 );
56 if let Some(agents) = &callbacks.agents {
57 request.insert("agents".to_string(), agents.clone());
58 }
59 if let Some(exclude_dynamic_sections) = callbacks.exclude_dynamic_sections {
60 request.insert(
61 "excludeDynamicSections".to_string(),
62 serde_json::Value::Bool(exclude_dynamic_sections),
63 );
64 }
65 if let Some(skills) = &callbacks.skills {
66 request.insert("skills".to_string(), serde_json::json!(skills));
67 }
68
69 serde_json::Value::Object(request)
70}
71
72fn agents_config(options: &ClaudeAgentOptions) -> Option<serde_json::Value> {
73 if options.agents.is_empty() {
74 return None;
75 }
76
77 let mut agents = serde_json::Map::new();
78 let mut names: Vec<_> = options.agents.keys().cloned().collect();
79 names.sort();
80 for name in names {
81 let Some(agent) = options.agents.get(&name) else {
82 continue;
83 };
84 agents.insert(name, serde_json::to_value(agent).ok()?);
85 }
86 Some(serde_json::Value::Object(agents))
87}
88
89pub fn control_request_payload(request_id: &str, request: serde_json::Value) -> serde_json::Value {
90 serde_json::json!({
91 "type": "control_request",
92 "request_id": request_id,
93 "request": request,
94 })
95}
96
97pub fn control_error_response_payload(request_id: &str, error: &str) -> serde_json::Value {
98 serde_json::json!({
99 "type": "control_response",
100 "response": {
101 "subtype": "error",
102 "request_id": request_id,
103 "error": error,
104 },
105 })
106}
107
108pub async fn send_control_request(
109 transport: &mut dyn Transport,
110 request: serde_json::Value,
111) -> Result<serde_json::Map<String, serde_json::Value>> {
112 send_control_request_with_callbacks(transport, request, &ControlCallbacks::default()).await
113}
114
115pub async fn send_control_request_with_callbacks(
116 transport: &mut dyn Transport,
117 request: serde_json::Value,
118 callbacks: &ControlCallbacks,
119) -> Result<serde_json::Map<String, serde_json::Value>> {
120 send_control_request_with_callbacks_and_timeout(
121 transport,
122 request,
123 callbacks,
124 Duration::from_secs(60),
125 )
126 .await
127}
128
129pub(crate) async fn send_control_request_with_callbacks_and_timeout(
130 transport: &mut dyn Transport,
131 request: serde_json::Value,
132 callbacks: &ControlCallbacks,
133 timeout_duration: Duration,
134) -> Result<serde_json::Map<String, serde_json::Value>> {
135 let request_id = format!("req_{}", uuid::Uuid::new_v4().simple());
136 let subtype = request
137 .get("subtype")
138 .and_then(|v| v.as_str())
139 .unwrap_or("unknown")
140 .to_string();
141 match tokio::time::timeout(
142 timeout_duration,
143 send_control_request_with_id(transport, &request_id, request, callbacks),
144 )
145 .await
146 {
147 Ok(result) => result,
148 Err(_) => Err(ClaudeSDKError::Other(format!(
149 "Control request timeout: {subtype}"
150 ))),
151 }
152}
153
154pub(crate) fn initialize_timeout_duration() -> Duration {
155 initialize_timeout_from_millis_env_value(
156 std::env::var("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT")
157 .ok()
158 .as_deref(),
159 )
160}
161
162fn initialize_timeout_from_millis_env_value(value: Option<&str>) -> Duration {
163 let millis = value
164 .and_then(|value| value.parse::<u64>().ok())
165 .unwrap_or(60_000)
166 .max(60_000);
167 Duration::from_millis(millis)
168}
169
170async fn send_control_request_with_id(
171 transport: &mut dyn Transport,
172 request_id: &str,
173 request: serde_json::Value,
174 callbacks: &ControlCallbacks,
175) -> Result<serde_json::Map<String, serde_json::Value>> {
176 let subtype = request
177 .get("subtype")
178 .and_then(|v| v.as_str())
179 .unwrap_or("unknown")
180 .to_string();
181 let payload = control_request_payload(request_id, request);
182 let mut encoded = serde_json::to_vec(&payload)?;
183 encoded.push(b'\n');
184 transport.write(&encoded).await?;
185
186 while let Some(data) = transport.read().await? {
187 let value: serde_json::Value = serde_json::from_slice(&data)?;
188 match value.get("type").and_then(|v| v.as_str()) {
189 Some("control_response") => {
190 if let Some(response) = matching_control_response(&value, request_id) {
191 return parse_control_response(response, &subtype);
192 }
193 }
194 Some("control_request") => {
195 respond_to_control_request(transport, &value, callbacks).await?;
196 }
197 _ => {}
198 }
199 }
200
201 Err(CLIConnectionError::new(format!("control request ended before response: {subtype}")).into())
202}
203
204fn matching_control_response<'a>(
205 value: &'a serde_json::Value,
206 request_id: &str,
207) -> Option<&'a serde_json::Map<String, serde_json::Value>> {
208 let response = value.get("response")?.as_object()?;
209 let response_id = response.get("request_id")?.as_str()?;
210 (response_id == request_id).then_some(response)
211}
212
213fn parse_control_response(
214 response: &serde_json::Map<String, serde_json::Value>,
215 subtype: &str,
216) -> Result<serde_json::Map<String, serde_json::Value>> {
217 match response.get("subtype").and_then(|v| v.as_str()) {
218 Some("success") => Ok(response
219 .get("response")
220 .and_then(|v| v.as_object())
221 .cloned()
222 .unwrap_or_default()),
223 Some("error") => Err(ClaudeSDKError::ControlRequest {
224 subtype: subtype.to_string(),
225 message: response
226 .get("error")
227 .and_then(|v| v.as_str())
228 .unwrap_or("unknown control request error")
229 .to_string(),
230 }),
231 _ => Err(ClaudeSDKError::ControlRequest {
232 subtype: subtype.to_string(),
233 message: "malformed control response".to_string(),
234 }),
235 }
236}
237
238pub(crate) async fn respond_to_control_request(
239 transport: &mut dyn Transport,
240 value: &serde_json::Value,
241 callbacks: &ControlCallbacks,
242) -> Result<()> {
243 let Some(request_id) = value.get("request_id").and_then(|v| v.as_str()) else {
244 return Ok(());
245 };
246 let request = value
247 .get("request")
248 .and_then(|request| request.as_object())
249 .cloned()
250 .unwrap_or_default();
251 let subtype = value
252 .get("request")
253 .and_then(|request| request.get("subtype"))
254 .and_then(|v| v.as_str())
255 .unwrap_or("unknown");
256
257 if subtype == "can_use_tool" {
258 let response = match answer_can_use_tool(&request, callbacks).await {
259 Ok(response) => control_success_response_payload(request_id, response),
260 Err(error) => control_error_response_payload(request_id, &error.to_string()),
261 };
262 let mut encoded = serde_json::to_vec(&response)?;
263 encoded.push(b'\n');
264 return transport.write(&encoded).await;
265 }
266
267 if subtype == "mcp_message" {
268 let response = answer_mcp_control_request(&request, callbacks);
269 let mut encoded =
270 serde_json::to_vec(&control_success_response_payload(request_id, response))?;
271 encoded.push(b'\n');
272 return transport.write(&encoded).await;
273 }
274
275 if subtype == "hook_callback" {
276 let response = match answer_hook_callback(&request, callbacks).await {
277 Ok(response) => control_success_response_payload(request_id, response),
278 Err(error) => control_error_response_payload(request_id, &error.to_string()),
279 };
280 let mut encoded = serde_json::to_vec(&response)?;
281 encoded.push(b'\n');
282 return transport.write(&encoded).await;
283 }
284
285 let response = control_error_response_payload(
286 request_id,
287 &format!("Unsupported control request subtype: {subtype}"),
288 );
289 let mut encoded = serde_json::to_vec(&response)?;
290 encoded.push(b'\n');
291 transport.write(&encoded).await
292}
293
294fn build_hooks_config(
295 options: &ClaudeAgentOptions,
296) -> (Option<serde_json::Value>, HashMap<String, HookCallback>) {
297 if options.hooks.is_empty() {
298 return (None, HashMap::new());
299 }
300
301 let mut callback_index = 0usize;
302 let mut hook_callbacks = HashMap::new();
303 let mut config = serde_json::Map::new();
304 let mut events: Vec<_> = options.hooks.keys().cloned().collect();
305 events.sort();
306
307 for event in events {
308 let Some(matchers) = options.hooks.get(&event) else {
309 continue;
310 };
311 let mut matcher_values = Vec::new();
312 for matcher in matchers {
313 let mut callback_ids = Vec::new();
314 for callback in &matcher.hooks {
315 let callback_id = format!("hook_{callback_index}");
316 callback_index += 1;
317 hook_callbacks.insert(callback_id.clone(), callback.clone());
318 callback_ids.push(serde_json::Value::String(callback_id));
319 }
320 let mut matcher_value = serde_json::Map::new();
321 matcher_value.insert(
322 "matcher".to_string(),
323 matcher
324 .matcher
325 .clone()
326 .map(serde_json::Value::String)
327 .unwrap_or(serde_json::Value::Null),
328 );
329 matcher_value.insert(
330 "hookCallbackIds".to_string(),
331 serde_json::Value::Array(callback_ids),
332 );
333 if let Some(timeout) = matcher.timeout {
334 matcher_value.insert("timeout".to_string(), serde_json::json!(timeout));
335 }
336 matcher_values.push(serde_json::Value::Object(matcher_value));
337 }
338 config.insert(event, serde_json::Value::Array(matcher_values));
339 }
340
341 (Some(serde_json::Value::Object(config)), hook_callbacks)
342}
343
344async fn answer_hook_callback(
345 request: &serde_json::Map<String, serde_json::Value>,
346 callbacks: &ControlCallbacks,
347) -> Result<serde_json::Value> {
348 let callback_id =
349 string_field(request, "callback_id").ok_or_else(|| ClaudeSDKError::ControlRequest {
350 subtype: "hook_callback".to_string(),
351 message: "missing callback_id".to_string(),
352 })?;
353 let callback = callbacks.hook_callbacks.get(&callback_id).ok_or_else(|| {
354 ClaudeSDKError::ControlRequest {
355 subtype: "hook_callback".to_string(),
356 message: format!("No hook callback found for ID: {callback_id}"),
357 }
358 })?;
359 let input = request
360 .get("input")
361 .cloned()
362 .unwrap_or(serde_json::Value::Null);
363 let tool_use_id = string_field(request, "tool_use_id");
364 let output = callback
365 .call(input, tool_use_id, HookContext::default())
366 .await?;
367 Ok(convert_hook_output_for_cli(output))
368}
369
370fn convert_hook_output_for_cli(output: serde_json::Value) -> serde_json::Value {
371 let serde_json::Value::Object(map) = output else {
372 return output;
373 };
374 let mut converted = serde_json::Map::new();
375 for (key, value) in map {
376 let key = match key.as_str() {
377 "async_" => "async".to_string(),
378 "continue_" => "continue".to_string(),
379 _ => key,
380 };
381 converted.insert(key, value);
382 }
383 serde_json::Value::Object(converted)
384}
385
386fn answer_mcp_control_request(
387 request: &serde_json::Map<String, serde_json::Value>,
388 callbacks: &ControlCallbacks,
389) -> serde_json::Value {
390 let server_name = request
391 .get("server_name")
392 .and_then(|v| v.as_str())
393 .unwrap_or("");
394 let message = request.get("message").unwrap_or(&serde_json::Value::Null);
395 serde_json::json!({
396 "mcp_response": answer_mcp_message(&callbacks.sdk_mcp_servers, server_name, message)
397 })
398}
399
400pub fn control_success_response_payload(
401 request_id: &str,
402 response: serde_json::Value,
403) -> serde_json::Value {
404 serde_json::json!({
405 "type": "control_response",
406 "response": {
407 "subtype": "success",
408 "request_id": request_id,
409 "response": response,
410 },
411 })
412}
413
414async fn answer_can_use_tool(
415 request: &serde_json::Map<String, serde_json::Value>,
416 callbacks: &ControlCallbacks,
417) -> Result<serde_json::Value> {
418 let callback =
419 callbacks
420 .can_use_tool
421 .as_ref()
422 .ok_or_else(|| ClaudeSDKError::ControlRequest {
423 subtype: "can_use_tool".to_string(),
424 message: "can_use_tool callback is not provided".to_string(),
425 })?;
426 let tool_name = request
427 .get("tool_name")
428 .and_then(|v| v.as_str())
429 .ok_or_else(|| ClaudeSDKError::ControlRequest {
430 subtype: "can_use_tool".to_string(),
431 message: "missing tool_name".to_string(),
432 })?
433 .to_string();
434 let input = request
435 .get("input")
436 .and_then(|v| v.as_object())
437 .cloned()
438 .unwrap_or_default();
439 let context = ToolPermissionContext {
440 suggestions: permission_suggestions(request),
441 tool_use_id: string_field(request, "tool_use_id"),
442 agent_id: string_field(request, "agent_id"),
443 blocked_path: string_field(request, "blocked_path"),
444 decision_reason: string_field(request, "decision_reason"),
445 title: string_field(request, "title"),
446 display_name: string_field(request, "display_name"),
447 description: string_field(request, "description"),
448 };
449 let result = callback.call(tool_name, input.clone(), context).await?;
450 Ok(permission_result_response(result, input))
451}
452
453fn string_field(request: &serde_json::Map<String, serde_json::Value>, key: &str) -> Option<String> {
454 request.get(key).and_then(|v| v.as_str()).map(String::from)
455}
456
457fn permission_suggestions(
458 request: &serde_json::Map<String, serde_json::Value>,
459) -> Vec<PermissionUpdate> {
460 request
461 .get("permission_suggestions")
462 .and_then(|v| v.as_array())
463 .into_iter()
464 .flatten()
465 .filter_map(|value| serde_json::from_value(value.clone()).ok())
466 .collect()
467}
468
469fn permission_result_response(
470 result: PermissionResult,
471 original_input: serde_json::Map<String, serde_json::Value>,
472) -> serde_json::Value {
473 match result {
474 PermissionResult::Allow {
475 updated_input,
476 updated_permissions,
477 } => {
478 let mut response = serde_json::Map::new();
479 response.insert(
480 "behavior".to_string(),
481 serde_json::Value::String("allow".to_string()),
482 );
483 response.insert(
484 "updatedInput".to_string(),
485 serde_json::Value::Object(updated_input.unwrap_or(original_input)),
486 );
487 if let Some(updated_permissions) = updated_permissions {
488 response.insert(
489 "updatedPermissions".to_string(),
490 serde_json::to_value(updated_permissions).unwrap_or(serde_json::Value::Null),
491 );
492 }
493 serde_json::Value::Object(response)
494 }
495 PermissionResult::Deny { message, interrupt } => {
496 let mut response = serde_json::Map::new();
497 response.insert(
498 "behavior".to_string(),
499 serde_json::Value::String("deny".to_string()),
500 );
501 response.insert("message".to_string(), serde_json::Value::String(message));
502 if interrupt {
503 response.insert("interrupt".to_string(), serde_json::Value::Bool(true));
504 }
505 serde_json::Value::Object(response)
506 }
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::initialize_timeout_from_millis_env_value;
513 use std::time::Duration;
514
515 #[test]
516 fn initialize_timeout_defaults_to_sixty_seconds() {
517 assert_eq!(
518 initialize_timeout_from_millis_env_value(None),
519 Duration::from_secs(60)
520 );
521 }
522
523 #[test]
524 fn initialize_timeout_uses_env_millis_when_above_minimum() {
525 assert_eq!(
526 initialize_timeout_from_millis_env_value(Some("120000")),
527 Duration::from_secs(120)
528 );
529 }
530
531 #[test]
532 fn initialize_timeout_keeps_sixty_second_minimum() {
533 assert_eq!(
534 initialize_timeout_from_millis_env_value(Some("1000")),
535 Duration::from_secs(60)
536 );
537 }
538}