use crate::services::process_hidden::HideWindow;
use crate::services::process_limits::PostSpawnAction;
use crate::services::remote::channel::{AgentChannel, ChannelError};
use crate::services::remote::protocol::{decode_base64, exec_params};
use crate::types::ProcessLimits;
use std::path::Path;
use std::process::ExitStatus;
use std::sync::Arc;
use tokio::process::{ChildStderr, ChildStdin, ChildStdout};
#[derive(Debug, Clone)]
pub struct SpawnResult {
pub stdout: String,
pub stderr: String,
pub exit_code: i32,
}
#[derive(Debug, thiserror::Error)]
pub enum SpawnError {
#[error("Channel error: {0}")]
Channel(#[from] ChannelError),
#[error("Process error: {0}")]
Process(String),
#[error("Decode error: {0}")]
Decode(String),
}
#[async_trait::async_trait]
pub trait ProcessSpawner: Send + Sync {
async fn spawn(
&self,
command: String,
args: Vec<String>,
cwd: Option<String>,
) -> Result<SpawnResult, SpawnError>;
}
pub struct LocalProcessSpawner;
#[async_trait::async_trait]
impl ProcessSpawner for LocalProcessSpawner {
async fn spawn(
&self,
command: String,
args: Vec<String>,
cwd: Option<String>,
) -> Result<SpawnResult, SpawnError> {
let mut cmd = tokio::process::Command::new(&command);
cmd.args(&args);
cmd.hide_window();
if let Some(ref dir) = cwd {
cmd.current_dir(dir);
}
let output = cmd
.output()
.await
.map_err(|e| SpawnError::Process(e.to_string()))?;
Ok(SpawnResult {
stdout: String::from_utf8_lossy(&output.stdout).to_string(),
stderr: String::from_utf8_lossy(&output.stderr).to_string(),
exit_code: output.status.code().unwrap_or(-1),
})
}
}
pub struct RemoteProcessSpawner {
channel: Arc<AgentChannel>,
}
impl RemoteProcessSpawner {
pub fn new(channel: Arc<AgentChannel>) -> Self {
Self { channel }
}
}
#[async_trait::async_trait]
impl ProcessSpawner for RemoteProcessSpawner {
async fn spawn(
&self,
command: String,
args: Vec<String>,
cwd: Option<String>,
) -> Result<SpawnResult, SpawnError> {
let params = exec_params(&command, &args, cwd.as_deref());
let (mut data_rx, result_rx) = self.channel.request_streaming("exec", params).await?;
let mut stdout = Vec::new();
let mut stderr = Vec::new();
while let Some(data) = data_rx.recv().await {
if let Some(out) = data.get("out").and_then(|v| v.as_str()) {
if let Ok(decoded) = decode_base64(out) {
stdout.extend_from_slice(&decoded);
}
}
if let Some(err) = data.get("err").and_then(|v| v.as_str()) {
if let Ok(decoded) = decode_base64(err) {
stderr.extend_from_slice(&decoded);
}
}
}
let result = result_rx
.await
.map_err(|_| SpawnError::Channel(ChannelError::ChannelClosed))?
.map_err(SpawnError::Process)?;
let exit_code = result
.get("code")
.and_then(|v| v.as_i64())
.map(|c| c as i32)
.unwrap_or(-1);
Ok(SpawnResult {
stdout: String::from_utf8_lossy(&stdout).to_string(),
stderr: String::from_utf8_lossy(&stderr).to_string(),
exit_code,
})
}
}
pub struct StdioChild {
inner: tokio::process::Child,
stdin: Option<ChildStdin>,
stdout: Option<ChildStdout>,
stderr: Option<ChildStderr>,
spawned_locally: bool,
}
impl StdioChild {
pub fn from_tokio_child(mut child: tokio::process::Child, spawned_locally: bool) -> Self {
let stdin = child.stdin.take();
let stdout = child.stdout.take();
let stderr = child.stderr.take();
Self {
inner: child,
stdin,
stdout,
stderr,
spawned_locally,
}
}
pub fn from_local_tokio_child(
child: tokio::process::Child,
post_spawn: PostSpawnAction,
) -> Self {
let out = Self::from_tokio_child(child, true);
if let Some(pid) = out.inner.id() {
post_spawn.apply_to_child(pid);
}
out
}
pub fn take_stdin(&mut self) -> Option<ChildStdin> {
self.stdin.take()
}
pub fn take_stdout(&mut self) -> Option<ChildStdout> {
self.stdout.take()
}
pub fn take_stderr(&mut self) -> Option<ChildStderr> {
self.stderr.take()
}
pub fn id(&self) -> Option<u32> {
self.inner.id()
}
pub fn spawned_locally(&self) -> bool {
self.spawned_locally
}
pub async fn kill(&mut self) -> std::io::Result<()> {
self.inner.kill().await
}
pub async fn wait(&mut self) -> std::io::Result<ExitStatus> {
self.inner.wait().await
}
}
#[async_trait::async_trait]
pub trait LongRunningSpawner: Send + Sync {
async fn spawn_stdio(
&self,
command: &str,
args: &[String],
env: Vec<(String, String)>,
cwd: Option<&Path>,
limits: Option<&ProcessLimits>,
) -> Result<StdioChild, SpawnError>;
async fn command_exists(&self, command: &str) -> bool;
}
pub struct LocalLongRunningSpawner;
#[async_trait::async_trait]
impl LongRunningSpawner for LocalLongRunningSpawner {
async fn spawn_stdio(
&self,
command: &str,
args: &[String],
env: Vec<(String, String)>,
cwd: Option<&Path>,
limits: Option<&ProcessLimits>,
) -> Result<StdioChild, SpawnError> {
let mut cmd = tokio::process::Command::new(command);
cmd.args(args)
.envs(env)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.hide_window()
.kill_on_drop(true);
if let Some(dir) = cwd {
cmd.current_dir(dir);
}
let post_spawn = match limits {
Some(lim) => lim
.apply_to_command(&mut cmd)
.map_err(|e| SpawnError::Process(format!("Failed to apply process limits: {e}")))?,
None => PostSpawnAction::default(),
};
let child = cmd
.spawn()
.map_err(|e| SpawnError::Process(e.to_string()))?;
Ok(StdioChild::from_local_tokio_child(child, post_spawn))
}
async fn command_exists(&self, command: &str) -> bool {
which::which(command).is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncReadExt;
#[tokio::test]
async fn test_local_spawner() {
let spawner = LocalProcessSpawner;
let result = spawner
.spawn("echo".to_string(), vec!["hello".to_string()], None)
.await
.unwrap();
assert_eq!(result.exit_code, 0);
assert!(result.stdout.trim() == "hello");
}
#[tokio::test]
async fn local_long_running_spawn_stdio_pipes_output() {
let spawner = LocalLongRunningSpawner;
let mut child = spawner
.spawn_stdio(
"sh",
&["-c".into(), "echo hi".into()],
Vec::new(),
None,
None,
)
.await
.expect("spawn succeeds");
let mut stdout = child.take_stdout().expect("stdout piped");
let mut buf = String::new();
stdout.read_to_string(&mut buf).await.unwrap();
assert_eq!(buf.trim(), "hi");
let status = child.wait().await.unwrap();
assert!(status.success());
assert!(child.spawned_locally());
}
#[tokio::test]
async fn local_long_running_command_exists_for_sh() {
let spawner = LocalLongRunningSpawner;
assert!(spawner.command_exists("sh").await);
assert!(
!spawner
.command_exists("fresh-unlikely-binary-name-ygzu9")
.await
);
}
}