Skip to main content

claude_code_sdk_rust/internal/
transport.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::process::Stdio;
4use std::sync::Arc;
5use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
6use tokio::process::{Child, ChildStdin, ChildStdout, Command};
7use tokio::sync::Mutex;
8
9use crate::error::{CLIConnectionError, CLINotFoundError, ClaudeSDKError, ProcessError, Result};
10use crate::internal::stdout_decoder::StdoutDecoder;
11use crate::types::ClaudeAgentOptions;
12
13const DEFAULT_ENTRY_POINT: &str = "sdk-rust";
14const DEFAULT_MAX_BUFFER_SIZE: usize = 1024 * 1024;
15
16#[async_trait]
17pub trait Transport: Send + Sync {
18    async fn connect(&mut self) -> Result<()>;
19    async fn write(&mut self, data: &[u8]) -> Result<()>;
20    async fn close_input(&mut self) -> Result<()>;
21    async fn read(&mut self) -> Result<Option<Vec<u8>>>;
22    async fn close(&mut self) -> Result<()>;
23}
24
25#[derive(Debug)]
26pub struct SubprocessCLITransport {
27    options: TransportOptions,
28    child: Option<Child>,
29    stdin: Option<ChildStdin>,
30    stdout_reader: Option<BufReader<ChildStdout>>,
31    stdout_decoder: StdoutDecoder,
32    stderr: Arc<Mutex<String>>,
33}
34
35#[derive(Debug, Clone)]
36pub struct TransportOptions {
37    pub tools: Vec<String>,
38    pub tools_set: bool,
39    pub tools_preset: Option<crate::types::ToolsPreset>,
40    pub allowed_tools: Vec<String>,
41    pub system_prompt: Option<String>,
42    pub system_prompt_preset: Option<crate::types::SystemPromptPreset>,
43    pub system_prompt_file: Option<crate::types::SystemPromptFile>,
44    pub mcp_servers: std::collections::HashMap<String, crate::types::MCPServerConfig>,
45    pub mcp_servers_config: Option<String>,
46    pub permission_mode: Option<crate::types::PermissionMode>,
47    pub continue_conversation: bool,
48    pub resume: Option<String>,
49    pub session_id: Option<String>,
50    pub fork_session: bool,
51    pub max_turns: Option<i32>,
52    pub max_budget_usd: Option<f64>,
53    pub task_budget: Option<crate::types::TaskBudget>,
54    pub disallowed_tools: Vec<String>,
55    pub model: Option<String>,
56    pub fallback_model: Option<String>,
57    pub betas: Vec<crate::types::SdkBeta>,
58    pub permission_prompt_tool_name: Option<String>,
59    pub cwd: Option<String>,
60    pub cli_path: Option<String>,
61    pub settings: Option<String>,
62    pub add_dirs: Vec<String>,
63    pub env: std::collections::HashMap<String, String>,
64    pub extra_args: std::collections::HashMap<String, Option<String>>,
65    pub max_buffer_size: Option<usize>,
66    pub user: Option<String>,
67    pub include_partial_messages: bool,
68    pub include_hook_events: bool,
69    pub strict_mcp_config: bool,
70    pub setting_sources: Option<Vec<crate::types::SettingSource>>,
71    pub skills: Option<crate::types::SkillsConfig>,
72    pub sandbox: Option<crate::types::SandboxSettings>,
73    pub plugins: Vec<crate::types::SDKPluginConfig>,
74    pub max_thinking_tokens: Option<i32>,
75    pub thinking: Option<crate::types::ThinkingConfig>,
76    pub effort: Option<String>,
77    pub output_format: Option<serde_json::Map<String, serde_json::Value>>,
78    pub enable_file_checkpointing: bool,
79    pub stderr: Option<crate::types::StderrCallback>,
80    pub can_use_tool: Option<crate::types::CanUseToolCallback>,
81    pub sdk_mcp_servers: std::collections::HashMap<String, crate::mcp::SimpleMCPServer>,
82    pub session_store_enabled: bool,
83}
84
85impl From<&ClaudeAgentOptions> for TransportOptions {
86    fn from(opts: &ClaudeAgentOptions) -> Self {
87        Self {
88            tools: opts.tools.clone(),
89            tools_set: opts.tools_set,
90            tools_preset: opts.tools_preset.clone(),
91            allowed_tools: opts.allowed_tools.clone(),
92            system_prompt: opts.system_prompt.clone(),
93            system_prompt_preset: opts.system_prompt_preset.clone(),
94            system_prompt_file: opts.system_prompt_file.clone(),
95            mcp_servers: opts.mcp_servers.clone(),
96            mcp_servers_config: opts.mcp_servers_config.clone(),
97            permission_mode: opts.permission_mode,
98            continue_conversation: opts.continue_conversation,
99            resume: opts.resume.clone(),
100            session_id: opts.session_id.clone(),
101            fork_session: opts.fork_session,
102            max_turns: opts.max_turns,
103            max_budget_usd: opts.max_budget_usd,
104            task_budget: opts.task_budget.clone(),
105            disallowed_tools: opts.disallowed_tools.clone(),
106            model: opts.model.clone(),
107            fallback_model: opts.fallback_model.clone(),
108            betas: opts.betas.clone(),
109            permission_prompt_tool_name: opts
110                .permission_prompt_tool_name
111                .clone()
112                .or_else(|| opts.can_use_tool.as_ref().map(|_| "stdio".to_string())),
113            cwd: opts.cwd.clone(),
114            cli_path: opts.cli_path.clone(),
115            settings: opts.settings.clone(),
116            add_dirs: opts.add_dirs.clone(),
117            env: opts.env.clone(),
118            extra_args: opts.extra_args.clone(),
119            max_buffer_size: opts.max_buffer_size,
120            user: opts.user.clone(),
121            include_partial_messages: opts.include_partial_messages,
122            include_hook_events: opts.include_hook_events,
123            strict_mcp_config: opts.strict_mcp_config,
124            setting_sources: opts.setting_sources.clone(),
125            skills: opts.skills.clone(),
126            sandbox: opts.sandbox.clone(),
127            plugins: opts.plugins.clone(),
128            max_thinking_tokens: opts.max_thinking_tokens,
129            thinking: opts.thinking.clone(),
130            effort: opts.effort.clone(),
131            output_format: opts.output_format.clone(),
132            enable_file_checkpointing: opts.enable_file_checkpointing,
133            stderr: opts.stderr.clone(),
134            can_use_tool: opts.can_use_tool.clone(),
135            sdk_mcp_servers: opts.sdk_mcp_servers.clone(),
136            session_store_enabled: opts.session_store.is_some(),
137        }
138    }
139}
140
141impl SubprocessCLITransport {
142    pub fn new(options: TransportOptions) -> Self {
143        let max_buffer_size = options.max_buffer_size.unwrap_or(DEFAULT_MAX_BUFFER_SIZE);
144        Self {
145            options,
146            child: None,
147            stdin: None,
148            stdout_reader: None,
149            stdout_decoder: StdoutDecoder::new(max_buffer_size),
150            stderr: Arc::new(Mutex::new(String::new())),
151        }
152    }
153
154    fn resolve_cli_path(&self) -> Result<String> {
155        crate::internal::cli_discovery::find_cli_path(self.options.cli_path.as_deref())
156    }
157
158    fn build_args(&self) -> Result<Vec<String>> {
159        crate::internal::cli_args::build_cli_args(&self.options)
160    }
161
162    fn build_env(&self) -> std::collections::HashMap<String, String> {
163        build_process_env(std::env::vars(), &self.options)
164    }
165    async fn finish_read(&mut self) -> Result<Option<Vec<u8>>> {
166        if let Some(ref mut child) = self.child {
167            match child.wait().await {
168                Ok(status) => {
169                    if !status.success() {
170                        let stderr = self.stderr.lock().await.clone();
171                        return Err(ProcessError::new(
172                            "Claude Code process exited with error",
173                            status.code(),
174                            stderr,
175                        )
176                        .into());
177                    }
178                }
179                Err(e) => {
180                    return Err(CLIConnectionError::new(format!(
181                        "failed to wait for process: {}",
182                        e
183                    ))
184                    .into());
185                }
186            }
187        }
188        Ok(None)
189    }
190}
191
192fn build_process_env<I>(inherited: I, options: &TransportOptions) -> HashMap<String, String>
193where
194    I: IntoIterator<Item = (String, String)>,
195{
196    let mut env = inherited
197        .into_iter()
198        .filter(|(key, _)| key != "CLAUDECODE")
199        .collect::<HashMap<_, _>>();
200
201    env.insert(
202        "CLAUDE_CODE_ENTRYPOINT".to_string(),
203        DEFAULT_ENTRY_POINT.to_string(),
204    );
205
206    for (key, value) in &options.env {
207        env.insert(key.clone(), value.clone());
208    }
209
210    env.insert(
211        "CLAUDE_AGENT_SDK_VERSION".to_string(),
212        env!("CARGO_PKG_VERSION").to_string(),
213    );
214
215    apply_otel_trace_context(&mut env, &options.env, active_otel_trace_context());
216
217    if options.enable_file_checkpointing {
218        env.insert(
219            "CLAUDE_CODE_ENABLE_SDK_FILE_CHECKPOINTING".to_string(),
220            "true".to_string(),
221        );
222    }
223
224    if let Some(ref cwd) = options.cwd {
225        env.insert("PWD".to_string(), cwd.clone());
226    }
227
228    env
229}
230
231fn apply_otel_trace_context(
232    env: &mut HashMap<String, String>,
233    explicit_env: &HashMap<String, String>,
234    carrier: HashMap<String, String>,
235) {
236    if !carrier.contains_key("traceparent") {
237        return;
238    }
239
240    for key in ["TRACEPARENT", "TRACESTATE"] {
241        if !explicit_env.contains_key(key) {
242            env.remove(key);
243        }
244    }
245
246    for (key, value) in carrier {
247        let env_key = key.to_ascii_uppercase();
248        if !explicit_env.contains_key(&env_key) {
249            env.insert(env_key, value);
250        }
251    }
252}
253
254#[cfg(feature = "otel")]
255fn active_otel_trace_context() -> HashMap<String, String> {
256    let mut carrier = HashMap::new();
257    opentelemetry::global::get_text_map_propagator(|propagator| {
258        propagator.inject(&mut carrier);
259    });
260    carrier
261}
262
263#[cfg(not(feature = "otel"))]
264fn active_otel_trace_context() -> HashMap<String, String> {
265    HashMap::new()
266}
267
268#[async_trait]
269impl Transport for SubprocessCLITransport {
270    async fn connect(&mut self) -> Result<()> {
271        if self.child.is_some() {
272            return Ok(());
273        }
274
275        let cli_path = self.resolve_cli_path()?;
276        if std::env::var_os("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK").is_none() {
277            let _ = crate::internal::cli_discovery::check_cli_version(&cli_path).await;
278        }
279
280        if let Some(ref cwd) = self.options.cwd {
281            if !tokio::fs::metadata(cwd)
282                .await
283                .map(|m| m.is_dir())
284                .unwrap_or(false)
285            {
286                return Err(CLIConnectionError::new(format!(
287                    "working directory does not exist: {}",
288                    cwd
289                ))
290                .into());
291            }
292        }
293
294        let args = self.build_args()?;
295        let env = self.build_env();
296
297        let mut cmd = Command::new(&cli_path);
298        cmd.args(&args)
299            .stdin(Stdio::piped())
300            .stdout(Stdio::piped())
301            .stderr(Stdio::piped());
302
303        if let Some(ref cwd) = self.options.cwd {
304            cmd.current_dir(cwd);
305        }
306
307        for (key, value) in &env {
308            cmd.env(key, value);
309        }
310
311        let mut child = cmd.spawn().map_err(|e| {
312            if e.kind() == std::io::ErrorKind::NotFound {
313                ClaudeSDKError::CLINotFound(CLINotFoundError::new(
314                    "Claude Code not found",
315                    cli_path,
316                ))
317            } else {
318                CLIConnectionError::new(format!("failed to start Claude Code: {}", e)).into()
319            }
320        })?;
321
322        let stdin = child
323            .stdin
324            .take()
325            .ok_or_else(|| CLIConnectionError::new("failed to open CLI stdin"))?;
326        let stdout = child
327            .stdout
328            .take()
329            .ok_or_else(|| CLIConnectionError::new("failed to open CLI stdout"))?;
330        let stderr = child
331            .stderr
332            .take()
333            .ok_or_else(|| CLIConnectionError::new("failed to open CLI stderr"))?;
334
335        let stderr_arc = self.stderr.clone();
336        let stderr_callback = self.options.stderr.clone();
337        tokio::spawn(async move {
338            let mut reader = BufReader::new(stderr);
339            let mut line = String::new();
340            while let Ok(n) = reader.read_line(&mut line).await {
341                if n == 0 {
342                    break;
343                }
344                let mut stderr_guard = stderr_arc.lock().await;
345                stderr_guard.push_str(&line);
346                if let Some(callback) = &stderr_callback {
347                    callback.call(line.clone());
348                }
349                line.clear();
350            }
351        });
352
353        self.child = Some(child);
354        self.stdin = Some(stdin);
355        self.stdout_reader = Some(BufReader::new(stdout));
356
357        Ok(())
358    }
359
360    async fn write(&mut self, data: &[u8]) -> Result<()> {
361        let stdin = self
362            .stdin
363            .as_mut()
364            .ok_or_else(|| CLIConnectionError::new("transport is not connected"))?;
365
366        stdin
367            .write_all(data)
368            .await
369            .map_err(|e| CLIConnectionError::new(format!("failed to write to stdin: {}", e)))?;
370        stdin
371            .flush()
372            .await
373            .map_err(|e| CLIConnectionError::new(format!("failed to flush stdin: {}", e)))?;
374
375        Ok(())
376    }
377
378    async fn close_input(&mut self) -> Result<()> {
379        if let Some(mut stdin) = self.stdin.take() {
380            stdin
381                .shutdown()
382                .await
383                .map_err(|e| CLIConnectionError::new(format!("failed to close stdin: {}", e)))?;
384        }
385        Ok(())
386    }
387
388    async fn read(&mut self) -> Result<Option<Vec<u8>>> {
389        loop {
390            if let Some(data) = self.stdout_decoder.next() {
391                return Ok(Some(data));
392            }
393
394            let mut chunk = [0u8; 8192];
395            let read_result = {
396                let reader = self
397                    .stdout_reader
398                    .as_mut()
399                    .ok_or_else(|| CLIConnectionError::new("transport is not connected"))?;
400                reader.read(&mut chunk).await
401            };
402
403            match read_result {
404                Ok(0) => {
405                    self.stdout_decoder.finish()?;
406                    if let Some(data) = self.stdout_decoder.next() {
407                        return Ok(Some(data));
408                    }
409                    return self.finish_read().await;
410                }
411                Ok(n) => self
412                    .stdout_decoder
413                    .push(std::str::from_utf8(&chunk[..n]).map_err(|e| {
414                        CLIConnectionError::new(format!("stdout was not valid UTF-8: {}", e))
415                    })?)?,
416                Err(e) => {
417                    return Err(
418                        CLIConnectionError::new(format!("failed reading stdout: {}", e)).into(),
419                    )
420                }
421            }
422        }
423    }
424
425    async fn close(&mut self) -> Result<()> {
426        let _ = self.close_input().await;
427
428        if let Some(mut child) = self.child.take() {
429            let _ = child.kill().await;
430            let _ = child.wait().await;
431        }
432
433        Ok(())
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn process_env_matches_python_sdk_subprocess_defaults() {
443        let options = crate::types::ClaudeAgentOptions::builder()
444            .env_var("CLAUDE_CODE_ENTRYPOINT", "custom-entrypoint")
445            .env_var("TRACEPARENT", "explicit-trace")
446            .cwd("/tmp/project")
447            .enable_file_checkpointing(true)
448            .build();
449        let transport_options = TransportOptions::from(&options);
450
451        let env = build_process_env(
452            [
453                ("CLAUDECODE".to_string(), "1".to_string()),
454                ("PATH".to_string(), "/bin".to_string()),
455                ("TRACEPARENT".to_string(), "ambient-trace".to_string()),
456            ],
457            &transport_options,
458        );
459
460        assert_eq!(env.get("CLAUDECODE"), None);
461        assert_eq!(
462            env.get("CLAUDE_CODE_ENTRYPOINT").map(String::as_str),
463            Some("custom-entrypoint")
464        );
465        assert_eq!(
466            env.get("CLAUDE_AGENT_SDK_VERSION").map(String::as_str),
467            Some(env!("CARGO_PKG_VERSION"))
468        );
469        assert_eq!(
470            env.get("CLAUDE_CODE_ENABLE_SDK_FILE_CHECKPOINTING")
471                .map(String::as_str),
472            Some("true")
473        );
474        assert_eq!(env.get("PWD").map(String::as_str), Some("/tmp/project"));
475        assert_eq!(
476            env.get("TRACEPARENT").map(String::as_str),
477            Some("explicit-trace")
478        );
479    }
480
481    #[test]
482    fn process_env_injects_active_otel_context_like_python_sdk() {
483        let options = crate::types::ClaudeAgentOptions::builder().build();
484        let transport_options = TransportOptions::from(&options);
485        let mut env = build_process_env(
486            [
487                (
488                    "TRACEPARENT".to_string(),
489                    "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01".to_string(),
490                ),
491                ("TRACESTATE".to_string(), "vendor=stale".to_string()),
492            ],
493            &transport_options,
494        );
495
496        apply_otel_trace_context(
497            &mut env,
498            &transport_options.env,
499            HashMap::from([
500                (
501                    "traceparent".to_string(),
502                    "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string(),
503                ),
504                ("tracestate".to_string(), "vendor=value".to_string()),
505            ]),
506        );
507
508        assert_eq!(
509            env.get("TRACEPARENT").map(String::as_str),
510            Some("00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01")
511        );
512        assert_eq!(
513            env.get("TRACESTATE").map(String::as_str),
514            Some("vendor=value")
515        );
516    }
517
518    #[test]
519    fn process_env_preserves_explicit_traceparent_over_otel_context() {
520        let options = crate::types::ClaudeAgentOptions::builder()
521            .env_var("TRACEPARENT", "custom")
522            .build();
523        let transport_options = TransportOptions::from(&options);
524        let mut env = build_process_env(
525            [("TRACEPARENT".to_string(), "ambient".to_string())],
526            &transport_options,
527        );
528
529        apply_otel_trace_context(
530            &mut env,
531            &transport_options.env,
532            HashMap::from([("traceparent".to_string(), "active".to_string())]),
533        );
534
535        assert_eq!(env.get("TRACEPARENT").map(String::as_str), Some("custom"));
536    }
537
538    #[test]
539    fn process_env_preserves_inherited_w3c_env_without_active_otel_span() {
540        let options = crate::types::ClaudeAgentOptions::builder().build();
541        let transport_options = TransportOptions::from(&options);
542        let mut env = build_process_env(
543            [
544                ("TRACEPARENT".to_string(), "ambient".to_string()),
545                ("TRACESTATE".to_string(), "vendor=abc".to_string()),
546            ],
547            &transport_options,
548        );
549
550        apply_otel_trace_context(
551            &mut env,
552            &transport_options.env,
553            HashMap::from([("baggage".to_string(), "user.id=123".to_string())]),
554        );
555
556        assert_eq!(env.get("TRACEPARENT").map(String::as_str), Some("ambient"));
557        assert_eq!(
558            env.get("TRACESTATE").map(String::as_str),
559            Some("vendor=abc")
560        );
561    }
562
563    #[tokio::test]
564    async fn subprocess_stderr_callback_receives_lines() {
565        use std::io::Write;
566        use std::sync::{Arc, Mutex};
567
568        let dir =
569            std::env::temp_dir().join(format!("claude-rust-stderr-test-{}", uuid::Uuid::new_v4()));
570        std::fs::create_dir_all(&dir).unwrap();
571        let script = dir.join("claude");
572        let mut file = std::fs::File::create(&script).unwrap();
573        writeln!(
574            file,
575            r#"#!/bin/sh
576if [ "$1" = "-v" ]; then
577  printf '2.0.0 (Claude Code)\n'
578  exit 0
579fi
580printf 'diagnostic line\n' >&2
581printf '{{"type":"result","subtype":"success","duration_ms":1,"duration_api_ms":1,"is_error":false,"num_turns":1,"session_id":"s"}}\n'
582"#
583        )
584        .unwrap();
585        #[cfg(unix)]
586        {
587            use std::os::unix::fs::PermissionsExt;
588            let mut permissions = std::fs::metadata(&script).unwrap().permissions();
589            permissions.set_mode(0o755);
590            std::fs::set_permissions(&script, permissions).unwrap();
591        }
592
593        let lines = Arc::new(Mutex::new(Vec::<String>::new()));
594        let captured = lines.clone();
595        let options = crate::types::ClaudeAgentOptions::builder()
596            .cli_path(script.to_string_lossy().to_string())
597            .stderr(move |line| captured.lock().unwrap().push(line))
598            .build();
599        let mut transport = SubprocessCLITransport::new(TransportOptions::from(&options));
600
601        transport.connect().await.unwrap();
602        let message = transport.read().await.unwrap().expect("result");
603        let value: serde_json::Value = serde_json::from_slice(&message).unwrap();
604        assert_eq!(value["type"], "result");
605
606        for _ in 0..20 {
607            if lines
608                .lock()
609                .unwrap()
610                .iter()
611                .any(|line| line == "diagnostic line\n")
612            {
613                let _ = transport.close().await;
614                let _ = std::fs::remove_dir_all(&dir);
615                return;
616            }
617            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
618        }
619        panic!("stderr callback did not receive diagnostic line");
620    }
621}