use std::collections::HashMap;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use async_trait::async_trait;
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, Command};
use tokio::sync::{Mutex, mpsc};
use tracing::{debug, error, trace, warn};
use crate::error::ClawError;
use crate::transport::Transport;
type MessageReceiver = mpsc::UnboundedReceiver<Result<Value, ClawError>>;
pub struct SubprocessCLITransport {
child_pid: std::sync::Mutex<Option<u32>>,
stdin: Arc<Mutex<Option<ChildStdin>>>,
messages_rx: Arc<std::sync::Mutex<Option<MessageReceiver>>>,
connected: Arc<AtomicBool>,
cli_path_arg: Option<PathBuf>,
cli_path: Arc<Mutex<Option<PathBuf>>>,
args: Vec<String>,
cwd: Option<PathBuf>,
env: HashMap<String, String>,
stderr_buffer: Arc<Mutex<String>>,
stderr_callback: Option<Arc<dyn Fn(String) + Send + Sync>>,
max_buffer_size: Option<usize>,
}
impl SubprocessCLITransport {
pub fn new(cli_path: Option<PathBuf>, args: Vec<String>) -> Self {
Self {
child_pid: std::sync::Mutex::new(None),
stdin: Arc::new(Mutex::new(None)),
messages_rx: Arc::new(std::sync::Mutex::new(None)),
connected: Arc::new(AtomicBool::new(false)),
cli_path_arg: cli_path,
cli_path: Arc::new(Mutex::new(None)),
args,
cwd: None,
env: HashMap::new(),
stderr_buffer: Arc::new(Mutex::new(String::new())),
stderr_callback: None,
max_buffer_size: None,
}
}
pub fn set_cwd(&mut self, cwd: PathBuf) {
self.cwd = Some(cwd);
}
pub fn set_env(&mut self, env: HashMap<String, String>) {
self.env = env;
}
pub fn set_stderr_callback(&mut self, callback: impl Fn(String) + Send + Sync + 'static) {
self.stderr_callback = Some(Arc::new(callback));
}
pub fn set_max_buffer_size(&mut self, size: usize) {
self.max_buffer_size = Some(size);
}
fn spawn_reader_task(
stdout: tokio::process::ChildStdout,
tx: mpsc::UnboundedSender<Result<Value, ClawError>>,
connected: Arc<AtomicBool>,
) {
tokio::spawn(async move {
let reader = BufReader::new(stdout);
let mut lines = reader.lines();
debug!("Started stdout reader task");
while let Ok(Some(line)) = lines.next_line().await {
trace!("Received line: {}", line);
if line.trim().is_empty() {
continue;
}
match serde_json::from_str::<Value>(&line) {
Ok(value) => {
if tx.send(Ok(value)).is_err() {
debug!("Message receiver dropped, stopping reader task");
break;
}
}
Err(e) => {
error!("Failed to parse JSON line '{}': {}", line, e);
if tx.send(Err(ClawError::JsonDecode(e))).is_err() {
debug!("Message receiver dropped, stopping reader task");
break;
}
}
}
}
debug!("Stdout reader task finished");
connected.store(false, Ordering::SeqCst);
});
}
fn spawn_stderr_task(
stderr: tokio::process::ChildStderr,
buffer: Arc<Mutex<String>>,
callback: Option<Arc<dyn Fn(String) + Send + Sync>>,
) {
tokio::spawn(async move {
let reader = BufReader::new(stderr);
let mut lines = reader.lines();
debug!("Started stderr reader task");
while let Ok(Some(line)) = lines.next_line().await {
warn!("CLI stderr: {}", line);
if let Some(cb) = &callback {
cb(line.clone());
}
let mut buf = buffer.lock().await;
buf.push_str(&line);
buf.push('\n');
}
debug!("Stderr reader task finished");
});
}
fn spawn_monitor_task(
mut child: Child,
connected: Arc<AtomicBool>,
stderr_buffer: Arc<Mutex<String>>,
) -> tokio::task::JoinHandle<Result<(), ClawError>> {
tokio::spawn(async move {
let status = child.wait().await.map_err(ClawError::Io)?;
debug!("Process exited with status: {:?}", status);
connected.store(false, Ordering::SeqCst);
if !status.success() {
let stderr = stderr_buffer.lock().await.clone();
return Err(ClawError::Process {
code: status.code().unwrap_or(-1),
stderr,
});
}
Ok(())
})
}
async fn graceful_shutdown(&self) -> Result<(), ClawError> {
debug!("Starting graceful shutdown");
self.end_input().await?;
tokio::time::sleep(Duration::from_millis(500)).await;
if self.connected.load(Ordering::SeqCst) {
let pid = *self.child_pid.lock().unwrap();
if let Some(pid) = pid {
debug!(
"Process still running after stdin close, sending signals to pid {}",
pid
);
self.force_shutdown_by_pid(pid).await?;
}
}
Ok(())
}
async fn force_shutdown_by_pid(&self, pid: u32) -> Result<(), ClawError> {
#[cfg(unix)]
{
use nix::sys::signal::{Signal, kill};
use nix::unistd::Pid;
let nix_pid = Pid::from_raw(pid as i32);
debug!("Sending SIGTERM to pid {}", pid);
let _ = kill(nix_pid, Signal::SIGTERM);
for _ in 0..50 {
tokio::time::sleep(Duration::from_millis(100)).await;
if !self.connected.load(Ordering::SeqCst) {
debug!("Process exited after SIGTERM");
return Ok(());
}
}
warn!("SIGTERM timed out, sending SIGKILL to pid {}", pid);
let _ = kill(nix_pid, Signal::SIGKILL);
}
#[cfg(not(unix))]
{
warn!("Signal-based shutdown not available on non-Unix; relying on kill_on_drop");
let _ = pid;
}
Ok(())
}
}
#[async_trait]
impl Transport for SubprocessCLITransport {
async fn connect(&mut self) -> Result<(), ClawError> {
if self.connected.load(Ordering::SeqCst) {
return Err(ClawError::Connection("already connected".to_string()));
}
let cli_path = {
let mut guard = self.cli_path.lock().await;
if guard.is_none() {
use crate::transport::CliDiscovery;
let discovered = CliDiscovery::find(self.cli_path_arg.as_deref()).await?;
let version = CliDiscovery::validate_version(&discovered).await?;
debug!(
"Using CLI at {} (version {})",
discovered.display(),
version
);
*guard = Some(discovered.clone());
discovered
} else {
guard.clone().unwrap()
}
};
debug!("Spawning CLI: {} {:?}", cli_path.display(), self.args);
let mut cmd = Command::new(&cli_path);
cmd.args(&self.args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
if let Some(cwd) = &self.cwd {
cmd.current_dir(cwd);
}
cmd.env("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
cmd.env("CLAUDE_AGENT_SDK_VERSION", env!("CARGO_PKG_VERSION"));
cmd.env_remove("CLAUDECODE");
if !self.env.is_empty() {
cmd.envs(&self.env);
}
let mut child = cmd.spawn().map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
ClawError::CliNotFound
} else {
ClawError::Io(e)
}
})?;
let child_pid = child.id();
debug!("Process spawned with pid: {:?}", child_pid);
*self.child_pid.lock().unwrap() = child_pid;
let stdin = child
.stdin
.take()
.ok_or_else(|| ClawError::Connection("failed to capture stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| ClawError::Connection("failed to capture stdout".to_string()))?;
let stderr = child
.stderr
.take()
.ok_or_else(|| ClawError::Connection("failed to capture stderr".to_string()))?;
let (tx, rx) = mpsc::unbounded_channel();
*self.messages_rx.lock().unwrap() = Some(rx);
*self.stdin.lock().await = Some(stdin);
Self::spawn_reader_task(stdout, tx, self.connected.clone());
Self::spawn_stderr_task(
stderr,
self.stderr_buffer.clone(),
self.stderr_callback.clone(),
);
let _monitor =
Self::spawn_monitor_task(child, self.connected.clone(), self.stderr_buffer.clone());
self.connected.store(true, Ordering::SeqCst);
debug!("Connection established");
Ok(())
}
async fn write(&self, message: &[u8]) -> Result<(), ClawError> {
if !self.is_ready() {
return Err(ClawError::Connection("not connected".to_string()));
}
let mut stdin_guard = self.stdin.lock().await;
let stdin = stdin_guard
.as_mut()
.ok_or_else(|| ClawError::Connection("stdin already closed".to_string()))?;
trace!("Writing {} bytes to stdin", message.len());
stdin.write_all(message).await.map_err(ClawError::Io)?;
stdin.flush().await.map_err(ClawError::Io)?;
Ok(())
}
fn messages(&self) -> MessageReceiver {
self.messages_rx
.lock()
.unwrap()
.take()
.expect("messages() can only be called once per connection")
}
async fn end_input(&self) -> Result<(), ClawError> {
debug!("Closing stdin");
let mut stdin_guard = self.stdin.lock().await;
if let Some(mut stdin) = stdin_guard.take() {
stdin.shutdown().await.map_err(ClawError::Io)?;
}
Ok(())
}
async fn close(&self) -> Result<(), ClawError> {
if !self.connected.load(Ordering::SeqCst) {
debug!("Already closed");
return Ok(());
}
let result = self.graceful_shutdown().await;
self.connected.store(false, Ordering::SeqCst);
result
}
fn is_ready(&self) -> bool {
self.connected.load(Ordering::SeqCst)
}
}
impl Drop for SubprocessCLITransport {
fn drop(&mut self) {
self.connected.store(false, Ordering::SeqCst);
debug!("SubprocessCLITransport dropped");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_transport() {
let transport = SubprocessCLITransport::new(
Some(PathBuf::from("claude")),
vec!["--output-format".to_string(), "stream-json".to_string()],
);
assert!(!transport.is_ready());
assert_eq!(transport.cli_path_arg, Some(PathBuf::from("claude")));
assert_eq!(transport.args.len(), 2);
}
#[test]
fn test_not_ready_before_connect() {
let transport = SubprocessCLITransport::new(None, vec![]);
assert!(!transport.is_ready());
}
#[tokio::test]
async fn test_write_when_not_connected() {
let transport = SubprocessCLITransport::new(None, vec![]);
let result = transport.write(b"test").await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ClawError::Connection(_)));
}
#[tokio::test]
async fn test_end_input_when_not_connected() {
let transport = SubprocessCLITransport::new(None, vec![]);
let result = transport.end_input().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_close_when_not_connected() {
let transport = SubprocessCLITransport::new(None, vec![]);
let result = transport.close().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_connect_with_invalid_cli() {
let temp_dir = std::env::temp_dir().join("rusty_claw_test_invalid");
std::fs::create_dir_all(&temp_dir).ok();
let invalid_path = temp_dir.join("nonexistent_claude_binary");
let mut transport = SubprocessCLITransport::new(Some(invalid_path), vec![]);
let result = transport.connect().await;
if let Err(err) = result {
assert!(
matches!(
err,
ClawError::CliNotFound | ClawError::InvalidCliVersion { .. }
),
"Expected CliNotFound or InvalidCliVersion, got: {:?}",
err
);
}
}
#[tokio::test]
async fn test_double_connect_fails() {
let mut transport = SubprocessCLITransport::new(None, vec![]);
if transport.connect().await.is_ok() {
let result2 = transport.connect().await;
assert!(result2.is_err());
assert!(matches!(result2.unwrap_err(), ClawError::Connection(_)));
}
}
}