Skip to main content

claude_code_rs/transport/
subprocess.rs

1use std::path::PathBuf;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::process::{Child, Command};
8use tokio::sync::{mpsc, Mutex};
9use tokio_util::sync::CancellationToken;
10
11use crate::error::{Error, Result};
12use crate::types::options::{ClaudeAgentOptions, StderrCallback};
13use crate::types::permissions::PermissionMode;
14
15use super::{Transport, TransportWriter};
16
17/// Transport implementation that communicates with the Claude CLI via subprocess.
18pub struct SubprocessTransport {
19    cli_path: PathBuf,
20    options: BuildOptions,
21    child: Option<Child>,
22    cancel: CancellationToken,
23    ready: bool,
24}
25
26/// Subset of options needed for building the CLI command.
27struct BuildOptions {
28    model: Option<String>,
29    system_prompt: Option<String>,
30    append_system_prompt: Option<String>,
31    max_turns: Option<u32>,
32    max_tokens: Option<u32>,
33    session_id: Option<String>,
34    continue_session: bool,
35    cwd: Option<PathBuf>,
36    permission_mode: PermissionMode,
37    allowed_tools: Vec<String>,
38    no_cache: bool,
39    temperature: Option<f64>,
40    context_window: Option<f64>,
41    extra_cli_args: Vec<String>,
42    env: std::collections::HashMap<String, String>,
43    on_stderr: Option<StderrCallback>,
44}
45
46impl SubprocessTransport {
47    pub fn new(cli_path: PathBuf, options: &ClaudeAgentOptions) -> Self {
48        Self {
49            cli_path,
50            options: BuildOptions {
51                model: options.model.clone(),
52                system_prompt: options.system_prompt.clone(),
53                append_system_prompt: options.append_system_prompt.clone(),
54                max_turns: options.max_turns,
55                max_tokens: options.max_tokens,
56                session_id: options.session_id.clone(),
57                continue_session: options.continue_session,
58                cwd: options.cwd.clone(),
59                permission_mode: options.permission_mode.clone(),
60                allowed_tools: options.allowed_tools.clone(),
61                no_cache: options.no_cache,
62                temperature: options.temperature,
63                context_window: options.context_window,
64                extra_cli_args: options.extra_cli_args.clone(),
65                env: options.env.clone(),
66                on_stderr: options.on_stderr.clone(),
67            },
68            child: None,
69            cancel: CancellationToken::new(),
70            ready: false,
71        }
72    }
73
74    /// Build the CLI command with all flags.
75    fn build_command(&self) -> Command {
76        let mut cmd = Command::new(&self.cli_path);
77
78        cmd.args(["--output-format", "stream-json"]);
79        cmd.args(["--input-format", "stream-json"]);
80        cmd.arg("--verbose");
81
82        if let Some(ref model) = self.options.model {
83            cmd.args(["--model", model]);
84        }
85
86        if let Some(ref sp) = self.options.system_prompt {
87            cmd.args(["--system-prompt", sp]);
88        }
89
90        if let Some(ref asp) = self.options.append_system_prompt {
91            cmd.args(["--append-system-prompt", asp]);
92        }
93
94        if let Some(turns) = self.options.max_turns {
95            cmd.args(["--max-turns", &turns.to_string()]);
96        }
97
98        if let Some(tokens) = self.options.max_tokens {
99            cmd.args(["--max-tokens", &tokens.to_string()]);
100        }
101
102        if let Some(ref sid) = self.options.session_id {
103            cmd.args(["--session-id", sid]);
104        }
105
106        if self.options.continue_session {
107            cmd.arg("--continue");
108        }
109
110        match &self.options.permission_mode {
111            PermissionMode::Default => {}
112            PermissionMode::AcceptAll => {
113                cmd.args(["--permission-mode", "bypassPermissions"]);
114            }
115            PermissionMode::DenyAll => {
116                cmd.args(["--permission-mode", "plan"]);
117            }
118            PermissionMode::AllowedTools => {
119                for tool in &self.options.allowed_tools {
120                    cmd.args(["--allowedTools", tool]);
121                }
122            }
123        }
124
125        if self.options.no_cache {
126            cmd.arg("--no-cache");
127        }
128
129        if let Some(temp) = self.options.temperature {
130            cmd.args(["--temperature", &temp.to_string()]);
131        }
132
133        if let Some(cw) = self.options.context_window {
134            cmd.args(["--context-window", &cw.to_string()]);
135        }
136
137        for arg in &self.options.extra_cli_args {
138            cmd.arg(arg);
139        }
140
141        if let Some(ref cwd) = self.options.cwd {
142            cmd.current_dir(cwd);
143        }
144
145        for (key, val) in &self.options.env {
146            cmd.env(key, val);
147        }
148
149        cmd.stdin(std::process::Stdio::piped());
150        cmd.stdout(std::process::Stdio::piped());
151        cmd.stderr(std::process::Stdio::piped());
152
153        cmd
154    }
155}
156
157#[async_trait]
158impl Transport for SubprocessTransport {
159    async fn connect(&mut self) -> Result<(mpsc::Receiver<Result<Value>>, TransportWriter)> {
160        if self.ready {
161            return Err(Error::AlreadyConnected);
162        }
163
164        let mut cmd = self.build_command();
165        let mut child = cmd
166            .spawn()
167            .map_err(|e| Error::CliConnection(format!("failed to spawn CLI: {e}")))?;
168
169        let stdout = child
170            .stdout
171            .take()
172            .ok_or_else(|| Error::CliConnection("no stdout".into()))?;
173        let stderr = child
174            .stderr
175            .take()
176            .ok_or_else(|| Error::CliConnection("no stderr".into()))?;
177        let stdin = child
178            .stdin
179            .take()
180            .ok_or_else(|| Error::CliConnection("no stdin".into()))?;
181
182        let stdin = Arc::new(Mutex::new(stdin));
183        self.child = Some(child);
184        self.ready = true;
185
186        // Incoming message channel (stdout -> reader).
187        let (read_tx, read_rx) = mpsc::channel::<Result<Value>>(256);
188
189        // Outgoing message channel (writer -> stdin).
190        let (write_tx, mut write_rx) = mpsc::channel::<Value>(256);
191
192        let cancel = self.cancel.clone();
193
194        // Stdout reader task.
195        let stdout_tx = read_tx.clone();
196        let stdout_cancel = cancel.clone();
197        tokio::spawn(async move {
198            let reader = BufReader::new(stdout);
199            let mut lines = reader.lines();
200
201            loop {
202                tokio::select! {
203                    _ = stdout_cancel.cancelled() => break,
204                    line = lines.next_line() => {
205                        match line {
206                            Ok(Some(line)) => {
207                                let line = line.trim().to_string();
208                                if line.is_empty() {
209                                    continue;
210                                }
211                                match serde_json::from_str::<Value>(&line) {
212                                    Ok(value) => {
213                                        if stdout_tx.send(Ok(value)).await.is_err() {
214                                            break;
215                                        }
216                                    }
217                                    Err(e) => {
218                                        tracing::warn!(line = %line, "failed to parse JSON from CLI: {e}");
219                                    }
220                                }
221                            }
222                            Ok(None) => break,
223                            Err(e) => {
224                                let _ = stdout_tx.send(Err(Error::Io(e))).await;
225                                break;
226                            }
227                        }
228                    }
229                }
230            }
231        });
232
233        // Stdin writer task: reads from write channel, serializes to stdin.
234        let write_cancel = cancel.clone();
235        let write_stdin = stdin.clone();
236        tokio::spawn(async move {
237            loop {
238                tokio::select! {
239                    _ = write_cancel.cancelled() => break,
240                    msg = write_rx.recv() => {
241                        match msg {
242                            Some(value) => {
243                                let mut data = match serde_json::to_string(&value) {
244                                    Ok(s) => s,
245                                    Err(e) => {
246                                        tracing::error!("failed to serialize outgoing message: {e}");
247                                        continue;
248                                    }
249                                };
250                                data.push('\n');
251
252                                let mut guard = write_stdin.lock().await;
253                                if let Err(e) = guard.write_all(data.as_bytes()).await {
254                                    tracing::error!("failed to write to stdin: {e}");
255                                    break;
256                                }
257                                if let Err(e) = guard.flush().await {
258                                    tracing::error!("failed to flush stdin: {e}");
259                                    break;
260                                }
261                            }
262                            None => break,
263                        }
264                    }
265                }
266            }
267        });
268
269        // Stderr reader task.
270        let on_stderr = self.options.on_stderr.clone();
271        let stderr_cancel = cancel;
272        tokio::spawn(async move {
273            let reader = BufReader::new(stderr);
274            let mut lines = reader.lines();
275
276            loop {
277                tokio::select! {
278                    _ = stderr_cancel.cancelled() => break,
279                    line = lines.next_line() => {
280                        match line {
281                            Ok(Some(line)) => {
282                                if let Some(ref cb) = on_stderr {
283                                    cb(line);
284                                } else {
285                                    tracing::debug!(target: "claude_cli_stderr", "{}", line);
286                                }
287                            }
288                            Ok(None) | Err(_) => break,
289                        }
290                    }
291                }
292            }
293        });
294
295        let writer = TransportWriter::new(write_tx);
296        Ok((read_rx, writer))
297    }
298
299    async fn end_input(&self) -> Result<()> {
300        // Closing the writer channel will cause the writer task to exit,
301        // which effectively closes stdin. The caller drops the TransportWriter.
302        // For explicit shutdown, we cancel everything.
303        Ok(())
304    }
305
306    async fn close(&mut self) -> Result<()> {
307        self.ready = false;
308        self.cancel.cancel();
309
310        if let Some(ref mut child) = self.child {
311            let _ = child.kill().await;
312        }
313
314        self.child = None;
315        Ok(())
316    }
317
318    fn is_ready(&self) -> bool {
319        self.ready
320    }
321}
322
323impl Drop for SubprocessTransport {
324    fn drop(&mut self) {
325        self.cancel.cancel();
326    }
327}