use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot, RwLock};
use uuid::Uuid;
use crate::error::UbiquityError;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum CommandEvent {
Started {
command_id: Uuid,
command: String,
args: Vec<String>,
timestamp: chrono::DateTime<chrono::Utc>,
},
Stdout {
command_id: Uuid,
data: String,
timestamp: chrono::DateTime<chrono::Utc>,
},
Stderr {
command_id: Uuid,
data: String,
timestamp: chrono::DateTime<chrono::Utc>,
},
Progress {
command_id: Uuid,
percentage: f32,
message: String,
timestamp: chrono::DateTime<chrono::Utc>,
},
Completed {
command_id: Uuid,
exit_code: i32,
duration_ms: u64,
timestamp: chrono::DateTime<chrono::Utc>,
},
Failed {
command_id: Uuid,
error: String,
duration_ms: u64,
timestamp: chrono::DateTime<chrono::Utc>,
},
Cancelled {
command_id: Uuid,
duration_ms: u64,
timestamp: chrono::DateTime<chrono::Utc>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommandRequest {
pub id: Uuid,
pub command: String,
pub args: Vec<String>,
pub env: HashMap<String, String>,
pub working_dir: Option<String>,
pub timeout: Option<Duration>,
pub stdin: Option<String>,
}
impl CommandRequest {
pub fn new(command: impl Into<String>) -> Self {
Self {
id: Uuid::new_v4(),
command: command.into(),
args: Vec::new(),
env: HashMap::new(),
working_dir: None,
timeout: None,
stdin: None,
}
}
pub fn with_args(mut self, args: Vec<String>) -> Self {
self.args = args;
self
}
pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.insert(key.into(), value.into());
self
}
pub fn with_working_dir(mut self, dir: impl Into<String>) -> Self {
self.working_dir = Some(dir.into());
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_stdin(mut self, stdin: impl Into<String>) -> Self {
self.stdin = Some(stdin.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommandResult {
pub id: Uuid,
pub exit_code: Option<i32>,
pub stdout: String,
pub stderr: String,
pub duration_ms: u64,
pub cancelled: bool,
}
#[async_trait]
pub trait CommandExecutor: Send + Sync {
async fn execute(
&self,
request: CommandRequest,
) -> Result<Pin<Box<dyn Stream<Item = CommandEvent> + Send>>, UbiquityError>;
async fn cancel(&self, command_id: Uuid) -> Result<(), UbiquityError>;
async fn status(&self, command_id: Uuid) -> Result<Option<CommandResult>, UbiquityError>;
}
pub struct CommandContext {
active_commands: Arc<RwLock<HashMap<Uuid, CommandHandle>>>,
}
impl CommandContext {
pub fn new() -> Self {
Self {
active_commands: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register(&self, id: Uuid, handle: CommandHandle) {
self.active_commands.write().await.insert(id, handle);
}
pub async fn unregister(&self, id: &Uuid) {
self.active_commands.write().await.remove(id);
}
pub async fn get(&self, id: &Uuid) -> Option<CommandHandle> {
self.active_commands.read().await.get(id).cloned()
}
pub async fn cancel(&self, id: &Uuid) -> Result<(), UbiquityError> {
if let Some(handle) = self.get(id).await {
handle.cancel().await
} else {
Err(UbiquityError::NotFound(format!("Command {} not found", id)))
}
}
pub async fn cancel_all(&self) -> Vec<(Uuid, Result<(), UbiquityError>)> {
let commands: Vec<(Uuid, CommandHandle)> = {
let guard = self.active_commands.read().await;
guard.iter().map(|(id, h)| (*id, h.clone())).collect()
};
let mut results = Vec::new();
for (id, handle) in commands {
let result = handle.cancel().await;
results.push((id, result));
}
results
}
}
#[derive(Clone)]
pub struct CommandHandle {
pub id: Uuid,
cancel_tx: mpsc::Sender<()>,
status_tx: mpsc::Sender<oneshot::Sender<CommandResult>>,
}
impl CommandHandle {
pub fn new(
id: Uuid,
cancel_tx: mpsc::Sender<()>,
status_tx: mpsc::Sender<oneshot::Sender<CommandResult>>,
) -> Self {
Self {
id,
cancel_tx,
status_tx,
}
}
pub async fn cancel(&self) -> Result<(), UbiquityError> {
self.cancel_tx
.send(())
.await
.map_err(|_| UbiquityError::Internal("Failed to send cancel signal".to_string()))
}
pub async fn status(&self) -> Result<CommandResult, UbiquityError> {
let (tx, rx) = oneshot::channel();
self.status_tx
.send(tx)
.await
.map_err(|_| UbiquityError::Internal("Failed to request status".to_string()))?;
rx.await
.map_err(|_| UbiquityError::Internal("Failed to receive status".to_string()))
}
}
pub struct ProgressTracker {
command_id: Uuid,
tx: mpsc::Sender<CommandEvent>,
}
impl ProgressTracker {
pub fn new(command_id: Uuid, tx: mpsc::Sender<CommandEvent>) -> Self {
Self { command_id, tx }
}
pub async fn update(&self, percentage: f32, message: impl Into<String>) -> Result<(), UbiquityError> {
self.tx
.send(CommandEvent::Progress {
command_id: self.command_id,
percentage: percentage.clamp(0.0, 100.0),
message: message.into(),
timestamp: chrono::Utc::now(),
})
.await
.map_err(|_| UbiquityError::Internal("Failed to send progress update".to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_command_request_builder() {
let request = CommandRequest::new("echo")
.with_args(vec!["hello".to_string(), "world".to_string()])
.with_env("PATH", "/usr/bin")
.with_working_dir("/tmp")
.with_timeout(Duration::from_secs(30))
.with_stdin("test input");
assert_eq!(request.command, "echo");
assert_eq!(request.args, vec!["hello", "world"]);
assert_eq!(request.env.get("PATH"), Some(&"/usr/bin".to_string()));
assert_eq!(request.working_dir, Some("/tmp".to_string()));
assert_eq!(request.timeout, Some(Duration::from_secs(30)));
assert_eq!(request.stdin, Some("test input".to_string()));
}
#[tokio::test]
async fn test_command_context() {
let context = CommandContext::new();
let id = Uuid::new_v4();
let (cancel_tx, _cancel_rx) = mpsc::channel(1);
let (status_tx, _status_rx) = mpsc::channel(1);
let handle = CommandHandle::new(id, cancel_tx, status_tx);
context.register(id, handle.clone()).await;
let retrieved = context.get(&id).await;
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().id, id);
context.unregister(&id).await;
assert!(context.get(&id).await.is_none());
}
#[tokio::test]
async fn test_progress_tracker() {
let id = Uuid::new_v4();
let (tx, mut rx) = mpsc::channel(10);
let tracker = ProgressTracker::new(id, tx);
tracker.update(50.0, "Half way there").await.unwrap();
let event = rx.recv().await.unwrap();
match event {
CommandEvent::Progress { command_id, percentage, message, .. } => {
assert_eq!(command_id, id);
assert_eq!(percentage, 50.0);
assert_eq!(message, "Half way there");
}
_ => panic!("Expected Progress event"),
}
}
}