use std::process::Stdio;
use std::sync::atomic::{AtomicI64, Ordering};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use crate::error::{Error, Result};
use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
pub struct ChildProcessTransport {
program: String,
args: Vec<String>,
envs: Vec<(String, String)>,
}
impl ChildProcessTransport {
pub fn new(program: impl Into<String>) -> Self {
Self {
program: program.into(),
args: Vec::new(),
envs: Vec::new(),
}
}
pub fn arg(mut self, arg: impl Into<String>) -> Self {
self.args.push(arg.into());
self
}
pub fn args<I, S>(mut self, args: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.args.extend(args.into_iter().map(|s| s.into()));
self
}
pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.envs.push((key.into(), value.into()));
self
}
pub async fn spawn(self) -> Result<ChildProcessConnection> {
let mut cmd = Command::new(&self.program);
cmd.args(&self.args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit());
for (key, value) in &self.envs {
cmd.env(key, value);
}
let child = cmd
.spawn()
.map_err(|e| Error::Transport(format!("Failed to spawn {}: {}", self.program, e)))?;
tracing::info!(program = %self.program, "Spawned child process");
ChildProcessConnection::new(child)
}
}
pub struct ChildProcessConnection {
child: Child,
stdin: tokio::process::ChildStdin,
stdout: BufReader<tokio::process::ChildStdout>,
request_id: AtomicI64,
}
impl ChildProcessConnection {
fn new(mut child: Child) -> Result<Self> {
let stdin = child
.stdin
.take()
.ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
Ok(Self {
child,
stdin,
stdout: BufReader::new(stdout),
request_id: AtomicI64::new(1),
})
}
pub async fn send_request(
&mut self,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value> {
let id = self.request_id.fetch_add(1, Ordering::Relaxed);
let request = JsonRpcRequest::new(id, method).with_params(params);
let request_json = serde_json::to_string(&request)
.map_err(|e| Error::Transport(format!("Failed to serialize request: {}", e)))?;
tracing::debug!(method = %method, id = %id, "Sending request to child");
self.stdin
.write_all(request_json.as_bytes())
.await
.map_err(|e| Error::Transport(format!("Failed to write to child stdin: {}", e)))?;
self.stdin
.write_all(b"\n")
.await
.map_err(|e| Error::Transport(format!("Failed to write newline: {}", e)))?;
self.stdin
.flush()
.await
.map_err(|e| Error::Transport(format!("Failed to flush stdin: {}", e)))?;
let mut line = String::new();
self.stdout
.read_line(&mut line)
.await
.map_err(|e| Error::Transport(format!("Failed to read from child stdout: {}", e)))?;
if line.is_empty() {
return Err(Error::Transport("Child process closed stdout".to_string()));
}
tracing::debug!(response = %line.trim(), "Received response from child");
let response: JsonRpcResponse = serde_json::from_str(line.trim())
.map_err(|e| Error::Transport(format!("Failed to parse response: {}", e)))?;
match response {
JsonRpcResponse::Result(r) => Ok(r.result),
JsonRpcResponse::Error(e) => Err(Error::JsonRpc(e.error)),
_ => Err(Error::Transport("unexpected response variant".to_string())),
}
}
pub async fn send_notification(
&mut self,
method: &str,
params: serde_json::Value,
) -> Result<()> {
let notification = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params
});
let json = serde_json::to_string(¬ification)
.map_err(|e| Error::Transport(format!("Failed to serialize notification: {}", e)))?;
tracing::debug!(method = %method, "Sending notification to child");
self.stdin
.write_all(json.as_bytes())
.await
.map_err(|e| Error::Transport(format!("Failed to write notification: {}", e)))?;
self.stdin
.write_all(b"\n")
.await
.map_err(|e| Error::Transport(format!("Failed to write newline: {}", e)))?;
self.stdin
.flush()
.await
.map_err(|e| Error::Transport(format!("Failed to flush stdin: {}", e)))?;
Ok(())
}
pub async fn initialize(
&mut self,
client_name: &str,
client_version: &str,
) -> Result<serde_json::Value> {
self.send_request(
"initialize",
serde_json::json!({
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": {
"name": client_name,
"version": client_version
}
}),
)
.await
}
pub async fn send_initialized(&mut self) -> Result<()> {
self.send_notification("notifications/initialized", serde_json::json!({}))
.await
}
pub async fn list_tools(&mut self) -> Result<serde_json::Value> {
self.send_request("tools/list", serde_json::json!({})).await
}
pub async fn call_tool(
&mut self,
name: &str,
arguments: serde_json::Value,
) -> Result<serde_json::Value> {
self.send_request(
"tools/call",
serde_json::json!({
"name": name,
"arguments": arguments
}),
)
.await
}
pub fn is_running(&mut self) -> bool {
matches!(self.child.try_wait(), Ok(None))
}
pub async fn shutdown(mut self) -> Result<()> {
drop(self.stdin);
let result =
tokio::time::timeout(std::time::Duration::from_secs(5), self.child.wait()).await;
match result {
Ok(Ok(status)) => {
tracing::info!(status = ?status, "Child process exited");
Ok(())
}
Ok(Err(e)) => {
tracing::error!(error = %e, "Error waiting for child process");
Err(Error::Transport(format!("Child process error: {}", e)))
}
Err(_) => {
tracing::warn!("Child process did not exit gracefully, killing");
self.child
.kill()
.await
.map_err(|e| Error::Transport(format!("Failed to kill child: {}", e)))?;
Ok(())
}
}
}
pub async fn kill(mut self) -> Result<()> {
self.child
.kill()
.await
.map_err(|e| Error::Transport(format!("Failed to kill child: {}", e)))?;
tracing::info!("Child process killed");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_transport_builder() {
let transport = ChildProcessTransport::new("echo")
.arg("hello")
.env("FOO", "bar");
assert_eq!(transport.program, "echo");
assert_eq!(transport.args, vec!["hello"]);
assert_eq!(transport.envs, vec![("FOO".to_string(), "bar".to_string())]);
}
#[tokio::test]
async fn test_transport_args() {
let transport = ChildProcessTransport::new("cmd").args(["--flag1", "--flag2"]);
assert_eq!(transport.args, vec!["--flag1", "--flag2"]);
}
#[tokio::test]
async fn test_transport_env() {
let transport = ChildProcessTransport::new("prog")
.env("KEY1", "val1")
.env("KEY2", "val2");
assert_eq!(transport.envs.len(), 2);
assert_eq!(transport.envs[0], ("KEY1".to_string(), "val1".to_string()));
}
#[tokio::test]
async fn test_spawn_nonexistent_fails() {
let result = ChildProcessTransport::new("nonexistent-program-xyz-123")
.spawn()
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_spawn_and_communicate() {
let mut conn = ChildProcessTransport::new("cat").spawn().await.unwrap();
assert!(conn.is_running());
let response = conn
.send_request("echo", serde_json::json!({"msg": "hello"}))
.await;
assert!(response.is_err());
}
#[tokio::test]
async fn test_shutdown_graceful() {
let conn = ChildProcessTransport::new("cat").spawn().await.unwrap();
conn.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_is_running_after_exit() {
let mut conn = ChildProcessTransport::new("true").spawn().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!conn.is_running());
}
#[tokio::test]
async fn test_send_notification() {
let mut conn = ChildProcessTransport::new("cat").spawn().await.unwrap();
conn.send_notification("test/notify", serde_json::json!({"data": 1}))
.await
.unwrap();
conn.shutdown().await.unwrap();
}
}