use crate::config::ExecutionConfig;
use crate::errors::{ExecutionError, Result};
use crate::events::EventHandler;
use crate::executor::Executor;
use crate::types::{
ExecutionRequest, ExecutionResult, ExecutionState, ExecutionStatus, ExecutionSummary,
};
use once_cell::sync::OnceCell;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, Semaphore};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
static INSTANCE: OnceCell<ExecutionEngine> = OnceCell::new();
#[derive(Clone)]
pub struct ExecutionEngine {
config: ExecutionConfig,
executions: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ExecutionState>>>>>,
cancellation_tokens: Arc<RwLock<HashMap<Uuid, CancellationToken>>>,
event_handler: Option<Arc<dyn EventHandler>>,
semaphore: Arc<Semaphore>,
executor: Arc<Executor>,
}
impl ExecutionEngine {
pub fn init_global_with_handler(
mut config: ExecutionConfig,
handler: Option<Arc<dyn EventHandler>>,
) -> Result<&'static ExecutionEngine> {
if config.max_concurrent_executions != 1 {
tracing::warn!(
"Overriding max_concurrent_executions from {} to 1 for global singleton",
config.max_concurrent_executions
);
config.max_concurrent_executions = 1;
}
let mut engine = ExecutionEngine::new(config)?;
if let Some(h) = handler {
engine = engine.with_event_handler(h);
}
INSTANCE.set(engine).map_err(|_| {
ExecutionError::Internal("ExecutionEngine already initialized".to_string())
})?;
Ok(INSTANCE.get().expect("ExecutionEngine just initialized"))
}
pub fn init_global(config: ExecutionConfig) -> Result<&'static ExecutionEngine> {
Self::init_global_with_handler(config, None)
}
pub fn global() -> &'static ExecutionEngine {
INSTANCE.get().expect("ExecutionEngine not initialized")
}
pub fn new(config: ExecutionConfig) -> Result<Self> {
config.validate().map_err(ExecutionError::InvalidConfig)?;
let executor = Executor::new(config.clone());
let semaphore = Arc::new(Semaphore::new(config.max_concurrent_executions));
Ok(Self {
config,
executions: Arc::new(RwLock::new(HashMap::new())),
cancellation_tokens: Arc::new(RwLock::new(HashMap::new())),
event_handler: None,
semaphore,
executor: Arc::new(executor),
})
}
pub fn with_event_handler(mut self, handler: Arc<dyn EventHandler>) -> Self {
self.event_handler = Some(handler.clone());
let executor = Executor::new(self.config.clone()).with_event_handler(handler);
self.executor = Arc::new(executor);
self
}
pub async fn execute(&self, request: ExecutionRequest) -> Result<Uuid> {
let execution_id = request.id;
let cancel_token = CancellationToken::new();
let state = Arc::new(RwLock::new(ExecutionState::new(request.clone())));
{
let mut executions = self.executions.write().await;
executions.insert(execution_id, state.clone());
}
{
let mut tokens = self.cancellation_tokens.write().await;
tokens.insert(execution_id, cancel_token.clone());
}
let semaphore = self.semaphore.clone();
let current_permits = semaphore.available_permits();
if current_permits == 0 {
return Err(ExecutionError::ConcurrencyLimitReached(
self.config.max_concurrent_executions,
));
}
let permit = semaphore
.clone()
.acquire_owned()
.await
.map_err(|_| ExecutionError::Internal("Semaphore closed".to_string()))?;
let executor = self.executor.clone();
tokio::spawn(async move {
let result = executor.execute(request, state.clone(), cancel_token).await;
if let Ok(ref exec_result) = result {
let _ = executor.write_logs(execution_id, exec_result).await;
}
drop(permit);
result
});
Ok(execution_id)
}
pub async fn get_status(&self, execution_id: Uuid) -> Result<ExecutionStatus> {
let executions = self.executions.read().await;
let state = executions
.get(&execution_id)
.ok_or(ExecutionError::NotFound(execution_id))?;
let state_lock = state.read().await;
Ok(state_lock.status)
}
pub async fn get_result(&self, execution_id: Uuid) -> Result<ExecutionResult> {
let executions = self.executions.read().await;
let state = executions
.get(&execution_id)
.ok_or(ExecutionError::NotFound(execution_id))?;
let state_lock = state.read().await;
if !state_lock.status.is_terminal() {
return Err(ExecutionError::Internal(format!(
"Execution {} is still running (status: {:?})",
execution_id, state_lock.status
)));
}
Ok(state_lock.to_result())
}
pub async fn wait_for_completion(&self, execution_id: Uuid) -> Result<ExecutionResult> {
loop {
let status = self.get_status(execution_id).await?;
if status.is_terminal() {
return self.get_result(execution_id).await;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
pub async fn cancel(&self, execution_id: Uuid) -> Result<()> {
let state = {
let executions = self.executions.read().await;
executions
.get(&execution_id)
.ok_or(ExecutionError::NotFound(execution_id))?
.clone()
};
{
let state_lock = state.read().await;
if state_lock.status.is_terminal() {
return Err(ExecutionError::Internal(format!(
"Cannot cancel execution {} - already in terminal state: {:?}",
execution_id, state_lock.status
)));
}
}
let cancel_token = {
let tokens = self.cancellation_tokens.read().await;
tokens
.get(&execution_id)
.ok_or(ExecutionError::Internal(format!(
"Cancellation token not found for execution {}",
execution_id
)))?
.clone()
};
cancel_token.cancel();
Ok(())
}
pub async fn list_executions(&self) -> Vec<ExecutionSummary> {
let executions = self.executions.read().await;
let mut summaries = Vec::new();
for (id, state) in executions.iter() {
let state_lock = state.read().await;
let duration = state_lock.completed_at.map(|completed| {
(completed - state_lock.started_at)
.to_std()
.unwrap_or(std::time::Duration::from_secs(0))
});
summaries.push(ExecutionSummary {
id: *id,
status: state_lock.status,
started_at: state_lock.started_at,
duration,
});
}
summaries.sort_by(|a, b| b.started_at.cmp(&a.started_at));
summaries
}
pub async fn running_count(&self) -> usize {
let executions = self.executions.read().await;
let mut count = 0;
for (_, state) in executions.iter() {
let state_lock = state.read().await;
if state_lock.status == ExecutionStatus::Running
|| state_lock.status == ExecutionStatus::Pending
{
count += 1;
}
}
count
}
pub async fn total_count(&self) -> usize {
let executions = self.executions.read().await;
executions.len()
}
pub async fn read_logs(&self, execution_id: Uuid) -> Result<String> {
self.executor.read_logs(execution_id).await
}
pub fn config(&self) -> &ExecutionConfig {
&self.config
}
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
pub async fn cleanup_old_executions(&self) -> usize {
crate::cleanup::cleanup_old_executions(
&self.executions,
&self.cancellation_tokens,
self.config.execution_retention_secs,
self.config.max_in_memory_executions,
)
.await
}
pub async fn remove_execution(&self, execution_id: Uuid) -> Result<()> {
let removed = crate::cleanup::remove_execution(&self.executions, execution_id).await;
if removed {
let mut tokens = self.cancellation_tokens.write().await;
tokens.remove(&execution_id);
Ok(())
} else {
Err(ExecutionError::NotFound(execution_id))
}
}
pub fn start_cleanup_task(self: Arc<Self>) {
if !self.config.enable_auto_cleanup {
return;
}
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
interval.tick().await;
let removed = self.cleanup_old_executions().await;
if removed > 0 {
tracing::info!("Cleanup task removed {} old executions", removed);
}
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Command;
use std::collections::HashMap;
fn create_test_request() -> ExecutionRequest {
ExecutionRequest {
id: Uuid::new_v4(),
command: Command::Shell {
command: "echo 'test'".to_string(),
shell: "bash".to_string(),
},
env: HashMap::new(),
working_dir: None,
timeout_ms: Some(5000),
output_log_path: None,
metadata: Default::default(),
}
}
#[tokio::test]
async fn test_engine_creation() {
let config = ExecutionConfig::default();
let engine = ExecutionEngine::new(config);
assert!(engine.is_ok());
}
#[tokio::test]
async fn test_engine_invalid_config() {
let mut config = ExecutionConfig::default();
config.max_concurrent_executions = 0;
let engine = ExecutionEngine::new(config);
assert!(engine.is_err());
}
#[tokio::test]
async fn test_engine_execute_simple() {
let config = ExecutionConfig::default();
let engine = ExecutionEngine::new(config).unwrap();
let request = create_test_request();
let execution_id = engine.execute(request).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
let status = engine.get_status(execution_id).await.unwrap();
assert_eq!(status, ExecutionStatus::Completed);
}
#[tokio::test]
async fn test_engine_wait_for_completion() {
let config = ExecutionConfig::default();
let engine = ExecutionEngine::new(config).unwrap();
let request = create_test_request();
let execution_id = engine.execute(request).await.unwrap();
let result = engine.wait_for_completion(execution_id).await.unwrap();
assert_eq!(result.status, ExecutionStatus::Completed);
assert_eq!(result.exit_code, 0);
}
#[tokio::test]
async fn test_engine_get_result_before_complete() {
let config = ExecutionConfig::default();
let engine = ExecutionEngine::new(config).unwrap();
let request = ExecutionRequest {
id: Uuid::new_v4(),
command: Command::Shell {
command: "sleep 1".to_string(),
shell: "bash".to_string(),
},
env: HashMap::new(),
working_dir: None,
timeout_ms: Some(5000),
output_log_path: None,
metadata: Default::default(),
};
let execution_id = engine.execute(request).await.unwrap();
let result = engine.get_result(execution_id).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_engine_list_executions() {
let config = ExecutionConfig::default();
let engine = ExecutionEngine::new(config).unwrap();
let request1 = create_test_request();
let request2 = create_test_request();
let _id1 = engine.execute(request1).await.unwrap();
let _id2 = engine.execute(request2).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
let list = engine.list_executions().await;
assert_eq!(list.len(), 2);
}
#[tokio::test]
async fn test_engine_running_count() {
let config = ExecutionConfig::default();
let engine = ExecutionEngine::new(config).unwrap();
assert_eq!(engine.running_count().await, 0);
let request = ExecutionRequest {
id: Uuid::new_v4(),
command: Command::Shell {
command: "sleep 2".to_string(),
shell: "bash".to_string(),
},
env: HashMap::new(),
working_dir: None,
timeout_ms: Some(10000),
output_log_path: None,
metadata: Default::default(),
};
let _id = engine.execute(request).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let count = engine.running_count().await;
assert!(count > 0);
}
#[tokio::test]
async fn test_engine_concurrency_limit() {
let config = ExecutionConfig {
max_concurrent_executions: 2,
..Default::default()
};
let engine = ExecutionEngine::new(config).unwrap();
let request1 = ExecutionRequest {
id: Uuid::new_v4(),
command: Command::Shell {
command: "sleep 2".to_string(),
shell: "bash".to_string(),
},
env: HashMap::new(),
working_dir: None,
timeout_ms: Some(10000),
output_log_path: None,
metadata: Default::default(),
};
let request2 = request1.clone();
let mut request2 = request2;
request2.id = Uuid::new_v4();
let _id1 = engine.execute(request1).await.unwrap();
let _id2 = engine.execute(request2).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let request3 = ExecutionRequest {
id: Uuid::new_v4(),
command: Command::Shell {
command: "echo 'test'".to_string(),
shell: "bash".to_string(),
},
env: HashMap::new(),
working_dir: None,
timeout_ms: Some(5000),
output_log_path: None,
metadata: Default::default(),
};
let result = engine.execute(request3).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ExecutionError::ConcurrencyLimitReached(_)
));
}
#[tokio::test]
async fn test_engine_available_permits() {
let config = ExecutionConfig {
max_concurrent_executions: 5,
..Default::default()
};
let engine = ExecutionEngine::new(config).unwrap();
assert_eq!(engine.available_permits(), 5);
let request = create_test_request();
let _id = engine.execute(request).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let permits = engine.available_permits();
assert!(permits <= 5);
}
#[tokio::test]
async fn test_engine_not_found() {
let config = ExecutionConfig::default();
let engine = ExecutionEngine::new(config).unwrap();
let fake_id = Uuid::new_v4();
let result = engine.get_status(fake_id).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ExecutionError::NotFound(_)));
}
}