Skip to main content

execution_engine/
engine.rs

1//! ExecutionEngine - Main entry point for command execution
2//!
3//! Provides high-level API with state management and concurrency control.
4
5use crate::config::ExecutionConfig;
6use crate::errors::{ExecutionError, Result};
7use crate::events::EventHandler;
8use crate::executor::Executor;
9use crate::types::{
10    ExecutionRequest, ExecutionResult, ExecutionState, ExecutionStatus, ExecutionSummary,
11};
12use once_cell::sync::OnceCell;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::{RwLock, Semaphore};
16use tokio_util::sync::CancellationToken;
17use uuid::Uuid;
18
19static INSTANCE: OnceCell<ExecutionEngine> = OnceCell::new();
20
21/// Main execution engine
22///
23/// Thread-safe, async execution engine with:
24/// - Semaphore-based concurrency limiting
25/// - In-memory state management
26/// - Event emission support
27/// - Automatic cleanup (optional)
28#[derive(Clone)]
29pub struct ExecutionEngine {
30    config: ExecutionConfig,
31    executions: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ExecutionState>>>>>,
32    cancellation_tokens: Arc<RwLock<HashMap<Uuid, CancellationToken>>>,
33    event_handler: Option<Arc<dyn EventHandler>>,
34    semaphore: Arc<Semaphore>,
35    executor: Arc<Executor>,
36}
37
38impl ExecutionEngine {
39    /// Initialize the global singleton instance with an optional event handler
40    ///
41    /// This method ensures that the engine is initialized only once.
42    /// It enforces `max_concurrent_executions = 1` for serial execution safety.
43    pub fn init_global_with_handler(
44        mut config: ExecutionConfig,
45        handler: Option<Arc<dyn EventHandler>>,
46    ) -> Result<&'static ExecutionEngine> {
47        // Enforce serial execution as requested
48        if config.max_concurrent_executions != 1 {
49            tracing::warn!(
50                "Overriding max_concurrent_executions from {} to 1 for global singleton",
51                config.max_concurrent_executions
52            );
53            config.max_concurrent_executions = 1;
54        }
55
56        let mut engine = ExecutionEngine::new(config)?;
57
58        // Attach handler if provided
59        if let Some(h) = handler {
60            engine = engine.with_event_handler(h);
61        }
62
63        INSTANCE.set(engine).map_err(|_| {
64            ExecutionError::Internal("ExecutionEngine already initialized".to_string())
65        })?;
66
67        Ok(INSTANCE.get().expect("ExecutionEngine just initialized"))
68    }
69
70    /// Initialize the global singleton instance
71    ///
72    /// This method ensures that the engine is initialized only once.
73    /// It enforces `max_concurrent_executions = 1` for serial execution safety.
74    pub fn init_global(config: ExecutionConfig) -> Result<&'static ExecutionEngine> {
75        Self::init_global_with_handler(config, None)
76    }
77
78    /// Get reference to the global singleton instance
79    ///
80    /// # Panics
81    /// Panics if `init_global` has not been called.
82    pub fn global() -> &'static ExecutionEngine {
83        INSTANCE.get().expect("ExecutionEngine not initialized")
84    }
85
86    /// Create new ExecutionEngine
87    pub fn new(config: ExecutionConfig) -> Result<Self> {
88        // Validate config
89        config.validate().map_err(ExecutionError::InvalidConfig)?;
90
91        let executor = Executor::new(config.clone());
92        let semaphore = Arc::new(Semaphore::new(config.max_concurrent_executions));
93
94        Ok(Self {
95            config,
96            executions: Arc::new(RwLock::new(HashMap::new())),
97            cancellation_tokens: Arc::new(RwLock::new(HashMap::new())),
98            event_handler: None,
99            semaphore,
100            executor: Arc::new(executor),
101        })
102    }
103
104    /// Set event handler (builder pattern)
105    pub fn with_event_handler(mut self, handler: Arc<dyn EventHandler>) -> Self {
106        self.event_handler = Some(handler.clone());
107
108        // Update executor with handler
109        let executor = Executor::new(self.config.clone()).with_event_handler(handler);
110        self.executor = Arc::new(executor);
111
112        self
113    }
114
115    /// Execute a command asynchronously
116    ///
117    /// Returns execution ID immediately and spawns background task.
118    /// Use get_status() or get_result() to check progress.
119    pub async fn execute(&self, request: ExecutionRequest) -> Result<Uuid> {
120        let execution_id = request.id;
121
122        // Create execution state with cancellation token
123        let cancel_token = CancellationToken::new();
124        let state = Arc::new(RwLock::new(ExecutionState::new(request.clone())));
125
126        // Store state and cancellation token
127        {
128            let mut executions = self.executions.write().await;
129            executions.insert(execution_id, state.clone());
130        }
131        {
132            let mut tokens = self.cancellation_tokens.write().await;
133            tokens.insert(execution_id, cancel_token.clone());
134        }
135
136        // Try to acquire semaphore permit (non-blocking check)
137        let semaphore = self.semaphore.clone();
138        let current_permits = semaphore.available_permits();
139
140        if current_permits == 0 {
141            // No permits available - at concurrency limit
142            return Err(ExecutionError::ConcurrencyLimitReached(
143                self.config.max_concurrent_executions,
144            ));
145        }
146
147        // Acquire permit (will block if at limit, but we checked above)
148        let permit = semaphore
149            .clone()
150            .acquire_owned()
151            .await
152            .map_err(|_| ExecutionError::Internal("Semaphore closed".to_string()))?;
153
154        // Spawn background execution task
155        let executor = self.executor.clone();
156
157        tokio::spawn(async move {
158            // Execute command
159            let result = executor.execute(request, state.clone(), cancel_token).await;
160
161            // Write logs if successful
162            if let Ok(ref exec_result) = result {
163                let _ = executor.write_logs(execution_id, exec_result).await;
164            }
165
166            // Release semaphore permit (via drop)
167            drop(permit);
168
169            // Note: We keep the state in memory for later retrieval
170            // Cleanup task will remove old executions based on retention policy
171
172            result
173        });
174
175        Ok(execution_id)
176    }
177
178    /// Get current status of an execution
179    pub async fn get_status(&self, execution_id: Uuid) -> Result<ExecutionStatus> {
180        let executions = self.executions.read().await;
181        let state = executions
182            .get(&execution_id)
183            .ok_or(ExecutionError::NotFound(execution_id))?;
184
185        let state_lock = state.read().await;
186        Ok(state_lock.status)
187    }
188
189    /// Get execution result (returns error if not complete)
190    pub async fn get_result(&self, execution_id: Uuid) -> Result<ExecutionResult> {
191        let executions = self.executions.read().await;
192        let state = executions
193            .get(&execution_id)
194            .ok_or(ExecutionError::NotFound(execution_id))?;
195
196        let state_lock = state.read().await;
197
198        if !state_lock.status.is_terminal() {
199            return Err(ExecutionError::Internal(format!(
200                "Execution {} is still running (status: {:?})",
201                execution_id, state_lock.status
202            )));
203        }
204
205        Ok(state_lock.to_result())
206    }
207
208    /// Wait for execution to complete and return result
209    pub async fn wait_for_completion(&self, execution_id: Uuid) -> Result<ExecutionResult> {
210        // Poll status until complete
211        loop {
212            let status = self.get_status(execution_id).await?;
213
214            if status.is_terminal() {
215                return self.get_result(execution_id).await;
216            }
217
218            // Sleep briefly before checking again
219            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
220        }
221    }
222
223    /// Cancel a running execution
224    pub async fn cancel(&self, execution_id: Uuid) -> Result<()> {
225        // Check if execution exists and get its state
226        let state = {
227            let executions = self.executions.read().await;
228            executions
229                .get(&execution_id)
230                .ok_or(ExecutionError::NotFound(execution_id))?
231                .clone()
232        };
233
234        // Check if already terminal
235        {
236            let state_lock = state.read().await;
237            if state_lock.status.is_terminal() {
238                return Err(ExecutionError::Internal(format!(
239                    "Cannot cancel execution {} - already in terminal state: {:?}",
240                    execution_id, state_lock.status
241                )));
242            }
243        }
244
245        // Get and trigger the cancellation token
246        let cancel_token = {
247            let tokens = self.cancellation_tokens.read().await;
248            tokens
249                .get(&execution_id)
250                .ok_or(ExecutionError::Internal(format!(
251                    "Cancellation token not found for execution {}",
252                    execution_id
253                )))?
254                .clone()
255        };
256
257        // Trigger cancellation
258        cancel_token.cancel();
259
260        Ok(())
261    }
262
263    /// List all executions in memory
264    pub async fn list_executions(&self) -> Vec<ExecutionSummary> {
265        let executions = self.executions.read().await;
266        let mut summaries = Vec::new();
267
268        for (id, state) in executions.iter() {
269            let state_lock = state.read().await;
270            let duration = state_lock.completed_at.map(|completed| {
271                (completed - state_lock.started_at)
272                    .to_std()
273                    .unwrap_or(std::time::Duration::from_secs(0))
274            });
275
276            summaries.push(ExecutionSummary {
277                id: *id,
278                status: state_lock.status,
279                started_at: state_lock.started_at,
280                duration,
281            });
282        }
283
284        // Sort by started_at (newest first)
285        summaries.sort_by(|a, b| b.started_at.cmp(&a.started_at));
286
287        summaries
288    }
289
290    /// Get number of currently running executions
291    pub async fn running_count(&self) -> usize {
292        let executions = self.executions.read().await;
293        let mut count = 0;
294
295        for (_, state) in executions.iter() {
296            let state_lock = state.read().await;
297            if state_lock.status == ExecutionStatus::Running
298                || state_lock.status == ExecutionStatus::Pending
299            {
300                count += 1;
301            }
302        }
303
304        count
305    }
306
307    /// Get total number of executions in memory
308    pub async fn total_count(&self) -> usize {
309        let executions = self.executions.read().await;
310        executions.len()
311    }
312
313    /// Read logs for an execution
314    pub async fn read_logs(&self, execution_id: Uuid) -> Result<String> {
315        self.executor.read_logs(execution_id).await
316    }
317
318    /// Get configuration
319    pub fn config(&self) -> &ExecutionConfig {
320        &self.config
321    }
322
323    /// Get available semaphore permits (concurrency slots)
324    pub fn available_permits(&self) -> usize {
325        self.semaphore.available_permits()
326    }
327
328    /// Clean up old executions based on retention policy
329    ///
330    /// Removes executions based on:
331    /// 1. Age: Older than `execution_retention_secs`
332    /// 2. Count: Exceeds `max_in_memory_executions`
333    ///
334    /// Returns the number of executions removed.
335    pub async fn cleanup_old_executions(&self) -> usize {
336        crate::cleanup::cleanup_old_executions(
337            &self.executions,
338            &self.cancellation_tokens,
339            self.config.execution_retention_secs,
340            self.config.max_in_memory_executions,
341        )
342        .await
343    }
344
345    /// Remove a specific execution from memory
346    ///
347    /// Returns `Ok(())` if removed, or `NotFound` error if execution doesn't exist.
348    pub async fn remove_execution(&self, execution_id: Uuid) -> Result<()> {
349        let removed = crate::cleanup::remove_execution(&self.executions, execution_id).await;
350
351        if removed {
352            // Also remove the cancellation token
353            let mut tokens = self.cancellation_tokens.write().await;
354            tokens.remove(&execution_id);
355            Ok(())
356        } else {
357            Err(ExecutionError::NotFound(execution_id))
358        }
359    }
360
361    /// Start automatic cleanup task
362    ///
363    /// Spawns a background task that runs every 5 minutes to clean up old executions.
364    /// Only starts if `enable_auto_cleanup` is true in config.
365    ///
366    /// This method consumes self by value and requires Arc wrapper.
367    pub fn start_cleanup_task(self: Arc<Self>) {
368        if !self.config.enable_auto_cleanup {
369            return;
370        }
371
372        tokio::spawn(async move {
373            let mut interval = tokio::time::interval(std::time::Duration::from_secs(300)); // 5 minutes
374
375            loop {
376                interval.tick().await;
377
378                let removed = self.cleanup_old_executions().await;
379
380                if removed > 0 {
381                    tracing::info!("Cleanup task removed {} old executions", removed);
382                }
383            }
384        });
385    }
386}
387
388// ============================================================================
389// Tests
390// ============================================================================
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use crate::types::Command;
396    use std::collections::HashMap;
397
398    fn create_test_request() -> ExecutionRequest {
399        ExecutionRequest {
400            id: Uuid::new_v4(),
401            command: Command::Shell {
402                command: "echo 'test'".to_string(),
403                shell: "bash".to_string(),
404            },
405            env: HashMap::new(),
406            working_dir: None,
407            timeout_ms: Some(5000),
408            output_log_path: None,
409            metadata: Default::default(),
410        }
411    }
412
413    #[tokio::test]
414    async fn test_engine_creation() {
415        let config = ExecutionConfig::default();
416        let engine = ExecutionEngine::new(config);
417        assert!(engine.is_ok());
418    }
419
420    #[tokio::test]
421    async fn test_engine_invalid_config() {
422        let mut config = ExecutionConfig::default();
423        config.max_concurrent_executions = 0; // Invalid
424
425        let engine = ExecutionEngine::new(config);
426        assert!(engine.is_err());
427    }
428
429    #[tokio::test]
430    async fn test_engine_execute_simple() {
431        let config = ExecutionConfig::default();
432        let engine = ExecutionEngine::new(config).unwrap();
433
434        let request = create_test_request();
435        let execution_id = engine.execute(request).await.unwrap();
436
437        // Wait a bit for execution to complete
438        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
439
440        let status = engine.get_status(execution_id).await.unwrap();
441        assert_eq!(status, ExecutionStatus::Completed);
442    }
443
444    #[tokio::test]
445    async fn test_engine_wait_for_completion() {
446        let config = ExecutionConfig::default();
447        let engine = ExecutionEngine::new(config).unwrap();
448
449        let request = create_test_request();
450        let execution_id = engine.execute(request).await.unwrap();
451
452        let result = engine.wait_for_completion(execution_id).await.unwrap();
453        assert_eq!(result.status, ExecutionStatus::Completed);
454        assert_eq!(result.exit_code, 0);
455    }
456
457    #[tokio::test]
458    async fn test_engine_get_result_before_complete() {
459        let config = ExecutionConfig::default();
460        let engine = ExecutionEngine::new(config).unwrap();
461
462        let request = ExecutionRequest {
463            id: Uuid::new_v4(),
464            command: Command::Shell {
465                command: "sleep 1".to_string(),
466                shell: "bash".to_string(),
467            },
468            env: HashMap::new(),
469            working_dir: None,
470            timeout_ms: Some(5000),
471            output_log_path: None,
472            metadata: Default::default(),
473        };
474
475        let execution_id = engine.execute(request).await.unwrap();
476
477        // Try to get result immediately (should fail)
478        let result = engine.get_result(execution_id).await;
479        assert!(result.is_err());
480    }
481
482    #[tokio::test]
483    async fn test_engine_list_executions() {
484        let config = ExecutionConfig::default();
485        let engine = ExecutionEngine::new(config).unwrap();
486
487        // Execute multiple commands
488        let request1 = create_test_request();
489        let request2 = create_test_request();
490
491        let _id1 = engine.execute(request1).await.unwrap();
492        let _id2 = engine.execute(request2).await.unwrap();
493
494        // Wait a bit
495        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
496
497        let list = engine.list_executions().await;
498        assert_eq!(list.len(), 2);
499    }
500
501    #[tokio::test]
502    async fn test_engine_running_count() {
503        let config = ExecutionConfig::default();
504        let engine = ExecutionEngine::new(config).unwrap();
505
506        assert_eq!(engine.running_count().await, 0);
507
508        // Execute a long-running command
509        let request = ExecutionRequest {
510            id: Uuid::new_v4(),
511            command: Command::Shell {
512                command: "sleep 2".to_string(),
513                shell: "bash".to_string(),
514            },
515            env: HashMap::new(),
516            working_dir: None,
517            timeout_ms: Some(10000),
518            output_log_path: None,
519            metadata: Default::default(),
520        };
521
522        let _id = engine.execute(request).await.unwrap();
523
524        // Check running count (should be 1)
525        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
526        let count = engine.running_count().await;
527        assert!(count > 0);
528    }
529
530    #[tokio::test]
531    async fn test_engine_concurrency_limit() {
532        let config = ExecutionConfig {
533            max_concurrent_executions: 2,
534            ..Default::default()
535        };
536        let engine = ExecutionEngine::new(config).unwrap();
537
538        // Start 2 long-running commands (at limit)
539        let request1 = ExecutionRequest {
540            id: Uuid::new_v4(),
541            command: Command::Shell {
542                command: "sleep 2".to_string(),
543                shell: "bash".to_string(),
544            },
545            env: HashMap::new(),
546            working_dir: None,
547            timeout_ms: Some(10000),
548            output_log_path: None,
549            metadata: Default::default(),
550        };
551
552        let request2 = request1.clone();
553        let mut request2 = request2;
554        request2.id = Uuid::new_v4();
555
556        let _id1 = engine.execute(request1).await.unwrap();
557        let _id2 = engine.execute(request2).await.unwrap();
558
559        // Wait for them to start
560        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
561
562        // Try to execute a 3rd command (should fail)
563        let request3 = ExecutionRequest {
564            id: Uuid::new_v4(),
565            command: Command::Shell {
566                command: "echo 'test'".to_string(),
567                shell: "bash".to_string(),
568            },
569            env: HashMap::new(),
570            working_dir: None,
571            timeout_ms: Some(5000),
572            output_log_path: None,
573            metadata: Default::default(),
574        };
575
576        let result = engine.execute(request3).await;
577        assert!(result.is_err());
578        assert!(matches!(
579            result.unwrap_err(),
580            ExecutionError::ConcurrencyLimitReached(_)
581        ));
582    }
583
584    #[tokio::test]
585    async fn test_engine_available_permits() {
586        let config = ExecutionConfig {
587            max_concurrent_executions: 5,
588            ..Default::default()
589        };
590        let engine = ExecutionEngine::new(config).unwrap();
591
592        assert_eq!(engine.available_permits(), 5);
593
594        // Execute a command
595        let request = create_test_request();
596        let _id = engine.execute(request).await.unwrap();
597
598        // Wait a bit
599        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
600
601        // Should have fewer permits available (might be back to 5 if execution completed)
602        let permits = engine.available_permits();
603        assert!(permits <= 5);
604    }
605
606    #[tokio::test]
607    async fn test_engine_not_found() {
608        let config = ExecutionConfig::default();
609        let engine = ExecutionEngine::new(config).unwrap();
610
611        let fake_id = Uuid::new_v4();
612        let result = engine.get_status(fake_id).await;
613
614        assert!(result.is_err());
615        assert!(matches!(result.unwrap_err(), ExecutionError::NotFound(_)));
616    }
617}