use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Instant;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::Command;
use tokio::sync::{mpsc, oneshot};
use tracing::error;
use uuid::Uuid;
use crate::command::{
CommandContext, CommandEvent, CommandExecutor, CommandHandle, CommandRequest, CommandResult,
};
use crate::error::UbiquityError;
pub struct LocalCommandExecutor {
context: Arc<CommandContext>,
event_buffer_size: usize,
}
impl LocalCommandExecutor {
pub fn new() -> Self {
Self {
context: Arc::new(CommandContext::new()),
event_buffer_size: 1024,
}
}
pub fn with_event_buffer_size(mut self, size: usize) -> Self {
self.event_buffer_size = size;
self
}
async fn execute_process(
request: CommandRequest,
event_tx: mpsc::Sender<CommandEvent>,
cancel_rx: mpsc::Receiver<()>,
status_rx: mpsc::Receiver<oneshot::Sender<CommandResult>>,
) -> Result<(), UbiquityError> {
let start_time = Instant::now();
let command_id = request.id;
event_tx
.send(CommandEvent::Started {
command_id,
command: request.command.clone(),
args: request.args.clone(),
timestamp: chrono::Utc::now(),
})
.await
.map_err(|_| UbiquityError::Internal("Failed to send start event".to_string()))?;
let mut cmd = Command::new(&request.command);
cmd.args(&request.args);
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
cmd.stdin(Stdio::piped());
for (key, value) in &request.env {
cmd.env(key, value);
}
if let Some(dir) = &request.working_dir {
cmd.current_dir(dir);
}
let mut child = cmd
.spawn()
.map_err(|e| UbiquityError::CommandExecution(format!("Failed to spawn process: {}", e)))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| UbiquityError::Internal("Failed to capture stdout".to_string()))?;
let stderr = child
.stderr
.take()
.ok_or_else(|| UbiquityError::Internal("Failed to capture stderr".to_string()))?;
let mut stdin = child
.stdin
.take()
.ok_or_else(|| UbiquityError::Internal("Failed to capture stdin".to_string()))?;
if let Some(input) = &request.stdin {
stdin
.write_all(input.as_bytes())
.await
.map_err(|e| UbiquityError::CommandExecution(format!("Failed to write stdin: {}", e)))?;
stdin.shutdown().await.ok();
}
let stdout_reader = BufReader::new(stdout);
let stderr_reader = BufReader::new(stderr);
let collected_stdout;
let collected_stderr;
let event_tx_stdout = event_tx.clone();
let event_tx_stderr = event_tx.clone();
let stdout_task = tokio::spawn(async move {
let mut lines = stdout_reader.lines();
let mut output = Vec::new();
while let Ok(Some(line)) = lines.next_line().await {
output.push(line.clone());
let _ = event_tx_stdout
.send(CommandEvent::Stdout {
command_id,
data: line,
timestamp: chrono::Utc::now(),
})
.await;
}
output
});
let stderr_task = tokio::spawn(async move {
let mut lines = stderr_reader.lines();
let mut output = Vec::new();
while let Ok(Some(line)) = lines.next_line().await {
output.push(line.clone());
let _ = event_tx_stderr
.send(CommandEvent::Stderr {
command_id,
data: line,
timestamp: chrono::Utc::now(),
})
.await;
}
output
});
let process_task = tokio::spawn(async move {
child.wait().await
});
let (cancel_tx, mut cancel_rx_internal) = mpsc::channel::<()>(1);
let (status_tx_internal, mut status_rx_internal) = mpsc::channel::<oneshot::Sender<CommandResult>>(1);
let cancel_forward = tokio::spawn(async move {
let mut cancel_rx = cancel_rx;
while let Some(()) = cancel_rx.recv().await {
let _ = cancel_tx.send(()).await;
}
});
let status_forward = tokio::spawn(async move {
let mut status_rx = status_rx;
while let Some(tx) = status_rx.recv().await {
let _ = status_tx_internal.send(tx).await;
}
});
let process_task_handle = process_task.abort_handle();
let result = tokio::select! {
exit_status = process_task => {
match exit_status {
Ok(Ok(status)) => {
let exit_code = status.code().unwrap_or(-1);
let duration_ms = start_time.elapsed().as_millis() as u64;
collected_stdout = stdout_task.await.unwrap_or_default();
collected_stderr = stderr_task.await.unwrap_or_default();
event_tx
.send(CommandEvent::Completed {
command_id,
exit_code,
duration_ms,
timestamp: chrono::Utc::now(),
})
.await
.ok();
Ok(CommandResult {
id: command_id,
exit_code: Some(exit_code),
stdout: collected_stdout.join("\n"),
stderr: collected_stderr.join("\n"),
duration_ms,
cancelled: false,
})
}
Ok(Err(e)) => {
let duration_ms = start_time.elapsed().as_millis() as u64;
let error_msg = format!("Process error: {}", e);
event_tx
.send(CommandEvent::Failed {
command_id,
error: error_msg.clone(),
duration_ms,
timestamp: chrono::Utc::now(),
})
.await
.ok();
Err(UbiquityError::CommandExecution(error_msg))
}
Err(e) => {
let duration_ms = start_time.elapsed().as_millis() as u64;
let error_msg = format!("Task join error: {}", e);
event_tx
.send(CommandEvent::Failed {
command_id,
error: error_msg.clone(),
duration_ms,
timestamp: chrono::Utc::now(),
})
.await
.ok();
Err(UbiquityError::Internal(error_msg))
}
}
}
_ = cancel_rx_internal.recv() => {
let duration_ms = start_time.elapsed().as_millis() as u64;
if let Ok(mut child) = cmd.spawn() {
let _ = child.kill().await;
}
stdout_task.abort();
stderr_task.abort();
process_task_handle.abort();
event_tx
.send(CommandEvent::Cancelled {
command_id,
duration_ms,
timestamp: chrono::Utc::now(),
})
.await
.ok();
Ok(CommandResult {
id: command_id,
exit_code: None,
stdout: String::new(),
stderr: String::new(),
duration_ms,
cancelled: true,
})
}
_ = async {
if let Some(timeout_duration) = request.timeout {
tokio::time::sleep(timeout_duration).await
} else {
std::future::pending::<()>().await
}
} => {
let duration_ms = start_time.elapsed().as_millis() as u64;
if let Ok(mut child) = cmd.spawn() {
let _ = child.kill().await;
}
stdout_task.abort();
stderr_task.abort();
process_task_handle.abort();
let error_msg = format!("Command timed out after {:?}", request.timeout.unwrap());
event_tx
.send(CommandEvent::Failed {
command_id,
error: error_msg.clone(),
duration_ms,
timestamp: chrono::Utc::now(),
})
.await
.ok();
Err(UbiquityError::Timeout(error_msg))
}
};
let result_for_status = result.as_ref().ok().cloned();
tokio::spawn(async move {
while let Some(response_tx) = status_rx_internal.recv().await {
if let Some(ref cmd_result) = result_for_status {
let _ = response_tx.send(cmd_result.clone());
}
}
});
cancel_forward.abort();
status_forward.abort();
result.map(|_| ())
}
}
impl Default for LocalCommandExecutor {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl CommandExecutor for LocalCommandExecutor {
async fn execute(
&self,
request: CommandRequest,
) -> Result<Pin<Box<dyn Stream<Item = CommandEvent> + Send>>, UbiquityError> {
let (event_tx, event_rx) = mpsc::channel(self.event_buffer_size);
let (cancel_tx, cancel_rx) = mpsc::channel(1);
let (status_tx, status_rx) = mpsc::channel(1);
let command_id = request.id;
let handle = CommandHandle::new(command_id, cancel_tx, status_tx);
self.context.register(command_id, handle).await;
let context = self.context.clone();
let event_tx_clone = event_tx.clone();
tokio::spawn(async move {
let result = Self::execute_process(request, event_tx_clone, cancel_rx, status_rx).await;
context.unregister(&command_id).await;
if let Err(e) = result {
error!("Command execution error: {}", e);
}
});
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(event_rx)))
}
async fn cancel(&self, command_id: Uuid) -> Result<(), UbiquityError> {
self.context.cancel(&command_id).await
}
async fn status(&self, command_id: Uuid) -> Result<Option<CommandResult>, UbiquityError> {
if let Some(handle) = self.context.get(&command_id).await {
Ok(Some(handle.status().await?))
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[tokio::test]
async fn test_local_command_execution() {
let executor = LocalCommandExecutor::new();
let request = CommandRequest::new("echo").with_args(vec!["hello world".to_string()]);
let mut stream = executor.execute(request).await.unwrap();
let mut events = Vec::new();
while let Some(event) = stream.next().await {
events.push(event);
}
assert!(!events.is_empty());
let has_started = events.iter().any(|e| matches!(e, CommandEvent::Started { .. }));
let has_completed = events.iter().any(|e| matches!(e, CommandEvent::Completed { .. }));
assert!(has_started);
assert!(has_completed);
}
#[tokio::test]
async fn test_command_cancellation() {
let executor = LocalCommandExecutor::new();
let request = CommandRequest::new("sleep").with_args(vec!["10".to_string()]);
let command_id = request.id;
let mut stream = executor.execute(request).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
executor.cancel(command_id).await.unwrap();
let mut cancelled = false;
while let Some(event) = stream.next().await {
if matches!(event, CommandEvent::Cancelled { .. }) {
cancelled = true;
break;
}
}
assert!(cancelled);
}
#[tokio::test]
async fn test_command_timeout() {
let executor = LocalCommandExecutor::new();
let request = CommandRequest::new("sleep")
.with_args(vec!["10".to_string()])
.with_timeout(Duration::from_millis(100));
let mut stream = executor.execute(request).await.unwrap();
let mut timed_out = false;
while let Some(event) = stream.next().await {
if let CommandEvent::Failed { error, .. } = event {
if error.contains("timed out") {
timed_out = true;
break;
}
}
}
assert!(timed_out);
}
#[tokio::test]
async fn test_command_with_stdin() {
let executor = LocalCommandExecutor::new();
let request = CommandRequest::new("cat").with_stdin("test input data");
let mut stream = executor.execute(request).await.unwrap();
let mut stdout_data = String::new();
while let Some(event) = stream.next().await {
if let CommandEvent::Stdout { data, .. } = event {
stdout_data.push_str(&data);
}
}
assert_eq!(stdout_data.trim(), "test input data");
}
}