Skip to main content

claude_code_sdk_rust/internal/
control.rs

1use crate::error::{CLIConnectionError, ClaudeSDKError, Result};
2use crate::internal::sdk_mcp::answer_mcp_message;
3use crate::internal::transport::Transport;
4use crate::types::{
5    CanUseToolCallback, ClaudeAgentOptions, HookCallback, HookContext, PermissionResult,
6    PermissionUpdate, SkillsConfig, ToolPermissionContext,
7};
8use std::collections::HashMap;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Default)]
12pub struct ControlCallbacks {
13    pub can_use_tool: Option<CanUseToolCallback>,
14    pub sdk_mcp_servers: HashMap<String, crate::mcp::SimpleMCPServer>,
15    pub hook_callbacks: HashMap<String, HookCallback>,
16    pub hooks_config: Option<serde_json::Value>,
17    pub agents: Option<serde_json::Value>,
18    pub exclude_dynamic_sections: Option<bool>,
19    pub skills: Option<Vec<String>>,
20}
21
22impl ControlCallbacks {
23    pub fn from_options(options: &ClaudeAgentOptions) -> Self {
24        let (hooks_config, hook_callbacks) = build_hooks_config(options);
25        Self {
26            can_use_tool: options.can_use_tool.clone(),
27            sdk_mcp_servers: options.sdk_mcp_servers.clone(),
28            hook_callbacks,
29            hooks_config,
30            agents: agents_config(options),
31            exclude_dynamic_sections: options
32                .system_prompt_preset
33                .as_ref()
34                .and_then(|preset| preset.exclude_dynamic_sections),
35            skills: match &options.skills {
36                Some(SkillsConfig::Names(skills)) => Some(skills.clone()),
37                Some(SkillsConfig::All) | None => None,
38            },
39        }
40    }
41}
42
43pub fn initialize_request(callbacks: &ControlCallbacks) -> serde_json::Value {
44    let mut request = serde_json::Map::new();
45    request.insert(
46        "subtype".to_string(),
47        serde_json::Value::String("initialize".to_string()),
48    );
49    request.insert(
50        "hooks".to_string(),
51        callbacks
52            .hooks_config
53            .clone()
54            .unwrap_or(serde_json::Value::Null),
55    );
56    if let Some(agents) = &callbacks.agents {
57        request.insert("agents".to_string(), agents.clone());
58    }
59    if let Some(exclude_dynamic_sections) = callbacks.exclude_dynamic_sections {
60        request.insert(
61            "excludeDynamicSections".to_string(),
62            serde_json::Value::Bool(exclude_dynamic_sections),
63        );
64    }
65    if let Some(skills) = &callbacks.skills {
66        request.insert("skills".to_string(), serde_json::json!(skills));
67    }
68
69    serde_json::Value::Object(request)
70}
71
72fn agents_config(options: &ClaudeAgentOptions) -> Option<serde_json::Value> {
73    if options.agents.is_empty() {
74        return None;
75    }
76
77    let mut agents = serde_json::Map::new();
78    let mut names: Vec<_> = options.agents.keys().cloned().collect();
79    names.sort();
80    for name in names {
81        let Some(agent) = options.agents.get(&name) else {
82            continue;
83        };
84        agents.insert(name, serde_json::to_value(agent).ok()?);
85    }
86    Some(serde_json::Value::Object(agents))
87}
88
89pub fn control_request_payload(request_id: &str, request: serde_json::Value) -> serde_json::Value {
90    serde_json::json!({
91        "type": "control_request",
92        "request_id": request_id,
93        "request": request,
94    })
95}
96
97pub fn control_error_response_payload(request_id: &str, error: &str) -> serde_json::Value {
98    serde_json::json!({
99        "type": "control_response",
100        "response": {
101            "subtype": "error",
102            "request_id": request_id,
103            "error": error,
104        },
105    })
106}
107
108pub async fn send_control_request(
109    transport: &mut dyn Transport,
110    request: serde_json::Value,
111) -> Result<serde_json::Map<String, serde_json::Value>> {
112    send_control_request_with_callbacks(transport, request, &ControlCallbacks::default()).await
113}
114
115pub async fn send_control_request_with_callbacks(
116    transport: &mut dyn Transport,
117    request: serde_json::Value,
118    callbacks: &ControlCallbacks,
119) -> Result<serde_json::Map<String, serde_json::Value>> {
120    send_control_request_with_callbacks_and_timeout(
121        transport,
122        request,
123        callbacks,
124        Duration::from_secs(60),
125    )
126    .await
127}
128
129pub(crate) async fn send_control_request_with_callbacks_and_timeout(
130    transport: &mut dyn Transport,
131    request: serde_json::Value,
132    callbacks: &ControlCallbacks,
133    timeout_duration: Duration,
134) -> Result<serde_json::Map<String, serde_json::Value>> {
135    let request_id = format!("req_{}", uuid::Uuid::new_v4().simple());
136    let subtype = request
137        .get("subtype")
138        .and_then(|v| v.as_str())
139        .unwrap_or("unknown")
140        .to_string();
141    match tokio::time::timeout(
142        timeout_duration,
143        send_control_request_with_id(transport, &request_id, request, callbacks),
144    )
145    .await
146    {
147        Ok(result) => result,
148        Err(_) => Err(ClaudeSDKError::Other(format!(
149            "Control request timeout: {subtype}"
150        ))),
151    }
152}
153
154pub(crate) fn initialize_timeout_duration() -> Duration {
155    initialize_timeout_from_millis_env_value(
156        std::env::var("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT")
157            .ok()
158            .as_deref(),
159    )
160}
161
162fn initialize_timeout_from_millis_env_value(value: Option<&str>) -> Duration {
163    let millis = value
164        .and_then(|value| value.parse::<u64>().ok())
165        .unwrap_or(60_000)
166        .max(60_000);
167    Duration::from_millis(millis)
168}
169
170async fn send_control_request_with_id(
171    transport: &mut dyn Transport,
172    request_id: &str,
173    request: serde_json::Value,
174    callbacks: &ControlCallbacks,
175) -> Result<serde_json::Map<String, serde_json::Value>> {
176    let subtype = request
177        .get("subtype")
178        .and_then(|v| v.as_str())
179        .unwrap_or("unknown")
180        .to_string();
181    let payload = control_request_payload(request_id, request);
182    let mut encoded = serde_json::to_vec(&payload)?;
183    encoded.push(b'\n');
184    transport.write(&encoded).await?;
185
186    while let Some(data) = transport.read().await? {
187        let value: serde_json::Value = serde_json::from_slice(&data)?;
188        match value.get("type").and_then(|v| v.as_str()) {
189            Some("control_response") => {
190                if let Some(response) = matching_control_response(&value, request_id) {
191                    return parse_control_response(response, &subtype);
192                }
193            }
194            Some("control_request") => {
195                respond_to_control_request(transport, &value, callbacks).await?;
196            }
197            _ => {}
198        }
199    }
200
201    Err(CLIConnectionError::new(format!("control request ended before response: {subtype}")).into())
202}
203
204fn matching_control_response<'a>(
205    value: &'a serde_json::Value,
206    request_id: &str,
207) -> Option<&'a serde_json::Map<String, serde_json::Value>> {
208    let response = value.get("response")?.as_object()?;
209    let response_id = response.get("request_id")?.as_str()?;
210    (response_id == request_id).then_some(response)
211}
212
213fn parse_control_response(
214    response: &serde_json::Map<String, serde_json::Value>,
215    subtype: &str,
216) -> Result<serde_json::Map<String, serde_json::Value>> {
217    match response.get("subtype").and_then(|v| v.as_str()) {
218        Some("success") => Ok(response
219            .get("response")
220            .and_then(|v| v.as_object())
221            .cloned()
222            .unwrap_or_default()),
223        Some("error") => Err(ClaudeSDKError::ControlRequest {
224            subtype: subtype.to_string(),
225            message: response
226                .get("error")
227                .and_then(|v| v.as_str())
228                .unwrap_or("unknown control request error")
229                .to_string(),
230        }),
231        _ => Err(ClaudeSDKError::ControlRequest {
232            subtype: subtype.to_string(),
233            message: "malformed control response".to_string(),
234        }),
235    }
236}
237
238pub(crate) async fn respond_to_control_request(
239    transport: &mut dyn Transport,
240    value: &serde_json::Value,
241    callbacks: &ControlCallbacks,
242) -> Result<()> {
243    let Some(request_id) = value.get("request_id").and_then(|v| v.as_str()) else {
244        return Ok(());
245    };
246    let request = value
247        .get("request")
248        .and_then(|request| request.as_object())
249        .cloned()
250        .unwrap_or_default();
251    let subtype = value
252        .get("request")
253        .and_then(|request| request.get("subtype"))
254        .and_then(|v| v.as_str())
255        .unwrap_or("unknown");
256
257    if subtype == "can_use_tool" {
258        let response = match answer_can_use_tool(&request, callbacks).await {
259            Ok(response) => control_success_response_payload(request_id, response),
260            Err(error) => control_error_response_payload(request_id, &error.to_string()),
261        };
262        let mut encoded = serde_json::to_vec(&response)?;
263        encoded.push(b'\n');
264        return transport.write(&encoded).await;
265    }
266
267    if subtype == "mcp_message" {
268        let response = answer_mcp_control_request(&request, callbacks);
269        let mut encoded =
270            serde_json::to_vec(&control_success_response_payload(request_id, response))?;
271        encoded.push(b'\n');
272        return transport.write(&encoded).await;
273    }
274
275    if subtype == "hook_callback" {
276        let response = match answer_hook_callback(&request, callbacks).await {
277            Ok(response) => control_success_response_payload(request_id, response),
278            Err(error) => control_error_response_payload(request_id, &error.to_string()),
279        };
280        let mut encoded = serde_json::to_vec(&response)?;
281        encoded.push(b'\n');
282        return transport.write(&encoded).await;
283    }
284
285    let response = control_error_response_payload(
286        request_id,
287        &format!("Unsupported control request subtype: {subtype}"),
288    );
289    let mut encoded = serde_json::to_vec(&response)?;
290    encoded.push(b'\n');
291    transport.write(&encoded).await
292}
293
294fn build_hooks_config(
295    options: &ClaudeAgentOptions,
296) -> (Option<serde_json::Value>, HashMap<String, HookCallback>) {
297    if options.hooks.is_empty() {
298        return (None, HashMap::new());
299    }
300
301    let mut callback_index = 0usize;
302    let mut hook_callbacks = HashMap::new();
303    let mut config = serde_json::Map::new();
304    let mut events: Vec<_> = options.hooks.keys().cloned().collect();
305    events.sort();
306
307    for event in events {
308        let Some(matchers) = options.hooks.get(&event) else {
309            continue;
310        };
311        let mut matcher_values = Vec::new();
312        for matcher in matchers {
313            let mut callback_ids = Vec::new();
314            for callback in &matcher.hooks {
315                let callback_id = format!("hook_{callback_index}");
316                callback_index += 1;
317                hook_callbacks.insert(callback_id.clone(), callback.clone());
318                callback_ids.push(serde_json::Value::String(callback_id));
319            }
320            let mut matcher_value = serde_json::Map::new();
321            matcher_value.insert(
322                "matcher".to_string(),
323                matcher
324                    .matcher
325                    .clone()
326                    .map(serde_json::Value::String)
327                    .unwrap_or(serde_json::Value::Null),
328            );
329            matcher_value.insert(
330                "hookCallbackIds".to_string(),
331                serde_json::Value::Array(callback_ids),
332            );
333            if let Some(timeout) = matcher.timeout {
334                matcher_value.insert("timeout".to_string(), serde_json::json!(timeout));
335            }
336            matcher_values.push(serde_json::Value::Object(matcher_value));
337        }
338        config.insert(event, serde_json::Value::Array(matcher_values));
339    }
340
341    (Some(serde_json::Value::Object(config)), hook_callbacks)
342}
343
344async fn answer_hook_callback(
345    request: &serde_json::Map<String, serde_json::Value>,
346    callbacks: &ControlCallbacks,
347) -> Result<serde_json::Value> {
348    let callback_id =
349        string_field(request, "callback_id").ok_or_else(|| ClaudeSDKError::ControlRequest {
350            subtype: "hook_callback".to_string(),
351            message: "missing callback_id".to_string(),
352        })?;
353    let callback = callbacks.hook_callbacks.get(&callback_id).ok_or_else(|| {
354        ClaudeSDKError::ControlRequest {
355            subtype: "hook_callback".to_string(),
356            message: format!("No hook callback found for ID: {callback_id}"),
357        }
358    })?;
359    let input = request
360        .get("input")
361        .cloned()
362        .unwrap_or(serde_json::Value::Null);
363    let tool_use_id = string_field(request, "tool_use_id");
364    let output = callback
365        .call(input, tool_use_id, HookContext::default())
366        .await?;
367    Ok(convert_hook_output_for_cli(output))
368}
369
370fn convert_hook_output_for_cli(output: serde_json::Value) -> serde_json::Value {
371    let serde_json::Value::Object(map) = output else {
372        return output;
373    };
374    let mut converted = serde_json::Map::new();
375    for (key, value) in map {
376        let key = match key.as_str() {
377            "async_" => "async".to_string(),
378            "continue_" => "continue".to_string(),
379            _ => key,
380        };
381        converted.insert(key, value);
382    }
383    serde_json::Value::Object(converted)
384}
385
386fn answer_mcp_control_request(
387    request: &serde_json::Map<String, serde_json::Value>,
388    callbacks: &ControlCallbacks,
389) -> serde_json::Value {
390    let server_name = request
391        .get("server_name")
392        .and_then(|v| v.as_str())
393        .unwrap_or("");
394    let message = request.get("message").unwrap_or(&serde_json::Value::Null);
395    serde_json::json!({
396        "mcp_response": answer_mcp_message(&callbacks.sdk_mcp_servers, server_name, message)
397    })
398}
399
400pub fn control_success_response_payload(
401    request_id: &str,
402    response: serde_json::Value,
403) -> serde_json::Value {
404    serde_json::json!({
405        "type": "control_response",
406        "response": {
407            "subtype": "success",
408            "request_id": request_id,
409            "response": response,
410        },
411    })
412}
413
414async fn answer_can_use_tool(
415    request: &serde_json::Map<String, serde_json::Value>,
416    callbacks: &ControlCallbacks,
417) -> Result<serde_json::Value> {
418    let callback =
419        callbacks
420            .can_use_tool
421            .as_ref()
422            .ok_or_else(|| ClaudeSDKError::ControlRequest {
423                subtype: "can_use_tool".to_string(),
424                message: "can_use_tool callback is not provided".to_string(),
425            })?;
426    let tool_name = request
427        .get("tool_name")
428        .and_then(|v| v.as_str())
429        .ok_or_else(|| ClaudeSDKError::ControlRequest {
430            subtype: "can_use_tool".to_string(),
431            message: "missing tool_name".to_string(),
432        })?
433        .to_string();
434    let input = request
435        .get("input")
436        .and_then(|v| v.as_object())
437        .cloned()
438        .unwrap_or_default();
439    let context = ToolPermissionContext {
440        suggestions: permission_suggestions(request),
441        tool_use_id: string_field(request, "tool_use_id"),
442        agent_id: string_field(request, "agent_id"),
443        blocked_path: string_field(request, "blocked_path"),
444        decision_reason: string_field(request, "decision_reason"),
445        title: string_field(request, "title"),
446        display_name: string_field(request, "display_name"),
447        description: string_field(request, "description"),
448    };
449    let result = callback.call(tool_name, input.clone(), context).await?;
450    Ok(permission_result_response(result, input))
451}
452
453fn string_field(request: &serde_json::Map<String, serde_json::Value>, key: &str) -> Option<String> {
454    request.get(key).and_then(|v| v.as_str()).map(String::from)
455}
456
457fn permission_suggestions(
458    request: &serde_json::Map<String, serde_json::Value>,
459) -> Vec<PermissionUpdate> {
460    request
461        .get("permission_suggestions")
462        .and_then(|v| v.as_array())
463        .into_iter()
464        .flatten()
465        .filter_map(|value| serde_json::from_value(value.clone()).ok())
466        .collect()
467}
468
469fn permission_result_response(
470    result: PermissionResult,
471    original_input: serde_json::Map<String, serde_json::Value>,
472) -> serde_json::Value {
473    match result {
474        PermissionResult::Allow {
475            updated_input,
476            updated_permissions,
477        } => {
478            let mut response = serde_json::Map::new();
479            response.insert(
480                "behavior".to_string(),
481                serde_json::Value::String("allow".to_string()),
482            );
483            response.insert(
484                "updatedInput".to_string(),
485                serde_json::Value::Object(updated_input.unwrap_or(original_input)),
486            );
487            if let Some(updated_permissions) = updated_permissions {
488                response.insert(
489                    "updatedPermissions".to_string(),
490                    serde_json::to_value(updated_permissions).unwrap_or(serde_json::Value::Null),
491                );
492            }
493            serde_json::Value::Object(response)
494        }
495        PermissionResult::Deny { message, interrupt } => {
496            let mut response = serde_json::Map::new();
497            response.insert(
498                "behavior".to_string(),
499                serde_json::Value::String("deny".to_string()),
500            );
501            response.insert("message".to_string(), serde_json::Value::String(message));
502            if interrupt {
503                response.insert("interrupt".to_string(), serde_json::Value::Bool(true));
504            }
505            serde_json::Value::Object(response)
506        }
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::initialize_timeout_from_millis_env_value;
513    use std::time::Duration;
514
515    #[test]
516    fn initialize_timeout_defaults_to_sixty_seconds() {
517        assert_eq!(
518            initialize_timeout_from_millis_env_value(None),
519            Duration::from_secs(60)
520        );
521    }
522
523    #[test]
524    fn initialize_timeout_uses_env_millis_when_above_minimum() {
525        assert_eq!(
526            initialize_timeout_from_millis_env_value(Some("120000")),
527            Duration::from_secs(120)
528        );
529    }
530
531    #[test]
532    fn initialize_timeout_keeps_sixty_second_minimum() {
533        assert_eq!(
534            initialize_timeout_from_millis_env_value(Some("1000")),
535            Duration::from_secs(60)
536        );
537    }
538}