night/task/
scheduler.rs

1use crate::task::task::Task;
2use crate::utils::error::Result;
3use std::sync::Arc;
4use tokio::time::{interval, Duration};
5
6pub struct TaskScheduler {
7    task: Arc<Task>,
8}
9
10impl TaskScheduler {
11    pub fn new(task: Arc<Task>) -> Self {
12        TaskScheduler { task }
13    }
14
15    pub async fn start(&self) -> Result<()> {
16        if self.task.config.is_periodic {
17            self.run_periodic().await
18        } else {
19            self.run_once().await
20        }
21    }
22
23    async fn run_once(&self) -> Result<()> {
24        self.task.run().await
25    }
26
27    async fn run_periodic(&self) -> Result<()> {
28        let interval_duration = self.parse_interval()?;
29        let mut interval = interval(interval_duration);
30
31        loop {
32            interval.tick().await;
33
34            if !self.should_run() {
35                break;
36            }
37
38            self.task.run().await?;
39        }
40
41        Ok(())
42    }
43
44    fn should_run(&self) -> bool {
45        // Check if the task's execution lock is set
46        self.task
47            .execution_lock
48            .load(std::sync::atomic::Ordering::Relaxed)
49    }
50
51    fn parse_interval(&self) -> Result<Duration> {
52        // Parse the interval string into a Duration
53        self.task
54            .config
55            .interval
56            .parse::<u64>()
57            .map(Duration::from_millis)
58            .map_err(|_| {
59                crate::utils::error::NightError::Task("Invalid interval format".to_string())
60            })
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67    use crate::common::types::{TaskConfig, TaskStatus};
68    use std::collections::HashMap;
69    use uuid::Uuid;
70
71    fn create_test_task(is_periodic: bool, interval: &str) -> Arc<Task> {
72        let config = TaskConfig {
73            name: "Test Task".to_string(),
74            id: Uuid::new_v4(),
75            command: "echo Hello".to_string(),
76            is_periodic,
77            interval: interval.to_string(),
78            importance: 1,
79            dependencies: vec![],
80        };
81        let address_map = Arc::new(HashMap::new());
82        let depend = config.dependencies.clone();
83        Arc::new(Task::new(config, address_map, depend))
84    }
85
86    #[tokio::test]
87    async fn test_run_once() {
88        let task = create_test_task(false, "0");
89        let scheduler = TaskScheduler::new(task.clone());
90
91        scheduler.start().await.unwrap();
92
93        assert_eq!(*task.status.lock().unwrap(), TaskStatus::Completed);
94    }
95
96    #[tokio::test]
97    async fn test_run_periodic() {
98        let task = create_test_task(true, "100");
99        let scheduler = TaskScheduler::new(task.clone());
100
101        // Run the scheduler for a short time
102        tokio::spawn(async move {
103            scheduler.start().await.unwrap();
104        });
105
106        // Wait for a bit to allow multiple executions
107        tokio::time::sleep(Duration::from_millis(350)).await;
108
109        task.set_execution_lock(false);
110
111        // Wait for the scheduler to stop
112        tokio::time::sleep(Duration::from_millis(150)).await;
113
114        assert_eq!(*task.status.lock().unwrap(), TaskStatus::Completed);
115    }
116}