1use crate::config::ExecutionConfig;
6use crate::errors::{ExecutionError, Result};
7use crate::events::EventHandler;
8use crate::executor::Executor;
9use crate::types::{
10 ExecutionRequest, ExecutionResult, ExecutionState, ExecutionStatus, ExecutionSummary,
11};
12use once_cell::sync::OnceCell;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::{RwLock, Semaphore};
16use tokio_util::sync::CancellationToken;
17use uuid::Uuid;
18
19static INSTANCE: OnceCell<ExecutionEngine> = OnceCell::new();
20
21#[derive(Clone)]
29pub struct ExecutionEngine {
30 config: ExecutionConfig,
31 executions: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ExecutionState>>>>>,
32 cancellation_tokens: Arc<RwLock<HashMap<Uuid, CancellationToken>>>,
33 event_handler: Option<Arc<dyn EventHandler>>,
34 semaphore: Arc<Semaphore>,
35 executor: Arc<Executor>,
36}
37
38impl ExecutionEngine {
39 pub fn init_global_with_handler(
44 mut config: ExecutionConfig,
45 handler: Option<Arc<dyn EventHandler>>,
46 ) -> Result<&'static ExecutionEngine> {
47 if config.max_concurrent_executions != 1 {
49 tracing::warn!(
50 "Overriding max_concurrent_executions from {} to 1 for global singleton",
51 config.max_concurrent_executions
52 );
53 config.max_concurrent_executions = 1;
54 }
55
56 let mut engine = ExecutionEngine::new(config)?;
57
58 if let Some(h) = handler {
60 engine = engine.with_event_handler(h);
61 }
62
63 INSTANCE.set(engine).map_err(|_| {
64 ExecutionError::Internal("ExecutionEngine already initialized".to_string())
65 })?;
66
67 Ok(INSTANCE.get().expect("ExecutionEngine just initialized"))
68 }
69
70 pub fn init_global(config: ExecutionConfig) -> Result<&'static ExecutionEngine> {
75 Self::init_global_with_handler(config, None)
76 }
77
78 pub fn global() -> &'static ExecutionEngine {
83 INSTANCE.get().expect("ExecutionEngine not initialized")
84 }
85
86 pub fn new(config: ExecutionConfig) -> Result<Self> {
88 config.validate().map_err(ExecutionError::InvalidConfig)?;
90
91 let executor = Executor::new(config.clone());
92 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_executions));
93
94 Ok(Self {
95 config,
96 executions: Arc::new(RwLock::new(HashMap::new())),
97 cancellation_tokens: Arc::new(RwLock::new(HashMap::new())),
98 event_handler: None,
99 semaphore,
100 executor: Arc::new(executor),
101 })
102 }
103
104 pub fn with_event_handler(mut self, handler: Arc<dyn EventHandler>) -> Self {
106 self.event_handler = Some(handler.clone());
107
108 let executor = Executor::new(self.config.clone()).with_event_handler(handler);
110 self.executor = Arc::new(executor);
111
112 self
113 }
114
115 pub async fn execute(&self, request: ExecutionRequest) -> Result<Uuid> {
120 let execution_id = request.id;
121
122 let cancel_token = CancellationToken::new();
124 let state = Arc::new(RwLock::new(ExecutionState::new(request.clone())));
125
126 {
128 let mut executions = self.executions.write().await;
129 executions.insert(execution_id, state.clone());
130 }
131 {
132 let mut tokens = self.cancellation_tokens.write().await;
133 tokens.insert(execution_id, cancel_token.clone());
134 }
135
136 let semaphore = self.semaphore.clone();
138 let current_permits = semaphore.available_permits();
139
140 if current_permits == 0 {
141 return Err(ExecutionError::ConcurrencyLimitReached(
143 self.config.max_concurrent_executions,
144 ));
145 }
146
147 let permit = semaphore
149 .clone()
150 .acquire_owned()
151 .await
152 .map_err(|_| ExecutionError::Internal("Semaphore closed".to_string()))?;
153
154 let executor = self.executor.clone();
156
157 tokio::spawn(async move {
158 let result = executor.execute(request, state.clone(), cancel_token).await;
160
161 if let Ok(ref exec_result) = result {
163 let _ = executor.write_logs(execution_id, exec_result).await;
164 }
165
166 drop(permit);
168
169 result
173 });
174
175 Ok(execution_id)
176 }
177
178 pub async fn get_status(&self, execution_id: Uuid) -> Result<ExecutionStatus> {
180 let executions = self.executions.read().await;
181 let state = executions
182 .get(&execution_id)
183 .ok_or(ExecutionError::NotFound(execution_id))?;
184
185 let state_lock = state.read().await;
186 Ok(state_lock.status)
187 }
188
189 pub async fn get_result(&self, execution_id: Uuid) -> Result<ExecutionResult> {
191 let executions = self.executions.read().await;
192 let state = executions
193 .get(&execution_id)
194 .ok_or(ExecutionError::NotFound(execution_id))?;
195
196 let state_lock = state.read().await;
197
198 if !state_lock.status.is_terminal() {
199 return Err(ExecutionError::Internal(format!(
200 "Execution {} is still running (status: {:?})",
201 execution_id, state_lock.status
202 )));
203 }
204
205 Ok(state_lock.to_result())
206 }
207
208 pub async fn wait_for_completion(&self, execution_id: Uuid) -> Result<ExecutionResult> {
210 loop {
212 let status = self.get_status(execution_id).await?;
213
214 if status.is_terminal() {
215 return self.get_result(execution_id).await;
216 }
217
218 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
220 }
221 }
222
223 pub async fn cancel(&self, execution_id: Uuid) -> Result<()> {
225 let state = {
227 let executions = self.executions.read().await;
228 executions
229 .get(&execution_id)
230 .ok_or(ExecutionError::NotFound(execution_id))?
231 .clone()
232 };
233
234 {
236 let state_lock = state.read().await;
237 if state_lock.status.is_terminal() {
238 return Err(ExecutionError::Internal(format!(
239 "Cannot cancel execution {} - already in terminal state: {:?}",
240 execution_id, state_lock.status
241 )));
242 }
243 }
244
245 let cancel_token = {
247 let tokens = self.cancellation_tokens.read().await;
248 tokens
249 .get(&execution_id)
250 .ok_or(ExecutionError::Internal(format!(
251 "Cancellation token not found for execution {}",
252 execution_id
253 )))?
254 .clone()
255 };
256
257 cancel_token.cancel();
259
260 Ok(())
261 }
262
263 pub async fn list_executions(&self) -> Vec<ExecutionSummary> {
265 let executions = self.executions.read().await;
266 let mut summaries = Vec::new();
267
268 for (id, state) in executions.iter() {
269 let state_lock = state.read().await;
270 let duration = state_lock.completed_at.map(|completed| {
271 (completed - state_lock.started_at)
272 .to_std()
273 .unwrap_or(std::time::Duration::from_secs(0))
274 });
275
276 summaries.push(ExecutionSummary {
277 id: *id,
278 status: state_lock.status,
279 started_at: state_lock.started_at,
280 duration,
281 });
282 }
283
284 summaries.sort_by(|a, b| b.started_at.cmp(&a.started_at));
286
287 summaries
288 }
289
290 pub async fn running_count(&self) -> usize {
292 let executions = self.executions.read().await;
293 let mut count = 0;
294
295 for (_, state) in executions.iter() {
296 let state_lock = state.read().await;
297 if state_lock.status == ExecutionStatus::Running
298 || state_lock.status == ExecutionStatus::Pending
299 {
300 count += 1;
301 }
302 }
303
304 count
305 }
306
307 pub async fn total_count(&self) -> usize {
309 let executions = self.executions.read().await;
310 executions.len()
311 }
312
313 pub async fn read_logs(&self, execution_id: Uuid) -> Result<String> {
315 self.executor.read_logs(execution_id).await
316 }
317
318 pub fn config(&self) -> &ExecutionConfig {
320 &self.config
321 }
322
323 pub fn available_permits(&self) -> usize {
325 self.semaphore.available_permits()
326 }
327
328 pub async fn cleanup_old_executions(&self) -> usize {
336 crate::cleanup::cleanup_old_executions(
337 &self.executions,
338 &self.cancellation_tokens,
339 self.config.execution_retention_secs,
340 self.config.max_in_memory_executions,
341 )
342 .await
343 }
344
345 pub async fn remove_execution(&self, execution_id: Uuid) -> Result<()> {
349 let removed = crate::cleanup::remove_execution(&self.executions, execution_id).await;
350
351 if removed {
352 let mut tokens = self.cancellation_tokens.write().await;
354 tokens.remove(&execution_id);
355 Ok(())
356 } else {
357 Err(ExecutionError::NotFound(execution_id))
358 }
359 }
360
361 pub fn start_cleanup_task(self: Arc<Self>) {
368 if !self.config.enable_auto_cleanup {
369 return;
370 }
371
372 tokio::spawn(async move {
373 let mut interval = tokio::time::interval(std::time::Duration::from_secs(300)); loop {
376 interval.tick().await;
377
378 let removed = self.cleanup_old_executions().await;
379
380 if removed > 0 {
381 tracing::info!("Cleanup task removed {} old executions", removed);
382 }
383 }
384 });
385 }
386}
387
388#[cfg(test)]
393mod tests {
394 use super::*;
395 use crate::types::Command;
396 use std::collections::HashMap;
397
398 fn create_test_request() -> ExecutionRequest {
399 ExecutionRequest {
400 id: Uuid::new_v4(),
401 command: Command::Shell {
402 command: "echo 'test'".to_string(),
403 shell: "bash".to_string(),
404 },
405 env: HashMap::new(),
406 working_dir: None,
407 timeout_ms: Some(5000),
408 output_log_path: None,
409 metadata: Default::default(),
410 }
411 }
412
413 #[tokio::test]
414 async fn test_engine_creation() {
415 let config = ExecutionConfig::default();
416 let engine = ExecutionEngine::new(config);
417 assert!(engine.is_ok());
418 }
419
420 #[tokio::test]
421 async fn test_engine_invalid_config() {
422 let mut config = ExecutionConfig::default();
423 config.max_concurrent_executions = 0; let engine = ExecutionEngine::new(config);
426 assert!(engine.is_err());
427 }
428
429 #[tokio::test]
430 async fn test_engine_execute_simple() {
431 let config = ExecutionConfig::default();
432 let engine = ExecutionEngine::new(config).unwrap();
433
434 let request = create_test_request();
435 let execution_id = engine.execute(request).await.unwrap();
436
437 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
439
440 let status = engine.get_status(execution_id).await.unwrap();
441 assert_eq!(status, ExecutionStatus::Completed);
442 }
443
444 #[tokio::test]
445 async fn test_engine_wait_for_completion() {
446 let config = ExecutionConfig::default();
447 let engine = ExecutionEngine::new(config).unwrap();
448
449 let request = create_test_request();
450 let execution_id = engine.execute(request).await.unwrap();
451
452 let result = engine.wait_for_completion(execution_id).await.unwrap();
453 assert_eq!(result.status, ExecutionStatus::Completed);
454 assert_eq!(result.exit_code, 0);
455 }
456
457 #[tokio::test]
458 async fn test_engine_get_result_before_complete() {
459 let config = ExecutionConfig::default();
460 let engine = ExecutionEngine::new(config).unwrap();
461
462 let request = ExecutionRequest {
463 id: Uuid::new_v4(),
464 command: Command::Shell {
465 command: "sleep 1".to_string(),
466 shell: "bash".to_string(),
467 },
468 env: HashMap::new(),
469 working_dir: None,
470 timeout_ms: Some(5000),
471 output_log_path: None,
472 metadata: Default::default(),
473 };
474
475 let execution_id = engine.execute(request).await.unwrap();
476
477 let result = engine.get_result(execution_id).await;
479 assert!(result.is_err());
480 }
481
482 #[tokio::test]
483 async fn test_engine_list_executions() {
484 let config = ExecutionConfig::default();
485 let engine = ExecutionEngine::new(config).unwrap();
486
487 let request1 = create_test_request();
489 let request2 = create_test_request();
490
491 let _id1 = engine.execute(request1).await.unwrap();
492 let _id2 = engine.execute(request2).await.unwrap();
493
494 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
496
497 let list = engine.list_executions().await;
498 assert_eq!(list.len(), 2);
499 }
500
501 #[tokio::test]
502 async fn test_engine_running_count() {
503 let config = ExecutionConfig::default();
504 let engine = ExecutionEngine::new(config).unwrap();
505
506 assert_eq!(engine.running_count().await, 0);
507
508 let request = ExecutionRequest {
510 id: Uuid::new_v4(),
511 command: Command::Shell {
512 command: "sleep 2".to_string(),
513 shell: "bash".to_string(),
514 },
515 env: HashMap::new(),
516 working_dir: None,
517 timeout_ms: Some(10000),
518 output_log_path: None,
519 metadata: Default::default(),
520 };
521
522 let _id = engine.execute(request).await.unwrap();
523
524 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
526 let count = engine.running_count().await;
527 assert!(count > 0);
528 }
529
530 #[tokio::test]
531 async fn test_engine_concurrency_limit() {
532 let config = ExecutionConfig {
533 max_concurrent_executions: 2,
534 ..Default::default()
535 };
536 let engine = ExecutionEngine::new(config).unwrap();
537
538 let request1 = ExecutionRequest {
540 id: Uuid::new_v4(),
541 command: Command::Shell {
542 command: "sleep 2".to_string(),
543 shell: "bash".to_string(),
544 },
545 env: HashMap::new(),
546 working_dir: None,
547 timeout_ms: Some(10000),
548 output_log_path: None,
549 metadata: Default::default(),
550 };
551
552 let request2 = request1.clone();
553 let mut request2 = request2;
554 request2.id = Uuid::new_v4();
555
556 let _id1 = engine.execute(request1).await.unwrap();
557 let _id2 = engine.execute(request2).await.unwrap();
558
559 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
561
562 let request3 = ExecutionRequest {
564 id: Uuid::new_v4(),
565 command: Command::Shell {
566 command: "echo 'test'".to_string(),
567 shell: "bash".to_string(),
568 },
569 env: HashMap::new(),
570 working_dir: None,
571 timeout_ms: Some(5000),
572 output_log_path: None,
573 metadata: Default::default(),
574 };
575
576 let result = engine.execute(request3).await;
577 assert!(result.is_err());
578 assert!(matches!(
579 result.unwrap_err(),
580 ExecutionError::ConcurrencyLimitReached(_)
581 ));
582 }
583
584 #[tokio::test]
585 async fn test_engine_available_permits() {
586 let config = ExecutionConfig {
587 max_concurrent_executions: 5,
588 ..Default::default()
589 };
590 let engine = ExecutionEngine::new(config).unwrap();
591
592 assert_eq!(engine.available_permits(), 5);
593
594 let request = create_test_request();
596 let _id = engine.execute(request).await.unwrap();
597
598 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
600
601 let permits = engine.available_permits();
603 assert!(permits <= 5);
604 }
605
606 #[tokio::test]
607 async fn test_engine_not_found() {
608 let config = ExecutionConfig::default();
609 let engine = ExecutionEngine::new(config).unwrap();
610
611 let fake_id = Uuid::new_v4();
612 let result = engine.get_status(fake_id).await;
613
614 assert!(result.is_err());
615 assert!(matches!(result.unwrap_err(), ExecutionError::NotFound(_)));
616 }
617}