1use crate::task::task::Task;
2use crate::utils::error::Result;
3use std::sync::Arc;
4use tokio::time::{interval, Duration};
5
6pub struct TaskScheduler {
7 task: Arc<Task>,
8}
9
10impl TaskScheduler {
11 pub fn new(task: Arc<Task>) -> Self {
12 TaskScheduler { task }
13 }
14
15 pub async fn start(&self) -> Result<()> {
16 if self.task.config.is_periodic {
17 self.run_periodic().await
18 } else {
19 self.run_once().await
20 }
21 }
22
23 async fn run_once(&self) -> Result<()> {
24 self.task.run().await
25 }
26
27 async fn run_periodic(&self) -> Result<()> {
28 let interval_duration = self.parse_interval()?;
29 let mut interval = interval(interval_duration);
30
31 loop {
32 interval.tick().await;
33
34 if !self.should_run() {
35 break;
36 }
37
38 self.task.run().await?;
39 }
40
41 Ok(())
42 }
43
44 fn should_run(&self) -> bool {
45 self.task
47 .execution_lock
48 .load(std::sync::atomic::Ordering::Relaxed)
49 }
50
51 fn parse_interval(&self) -> Result<Duration> {
52 self.task
54 .config
55 .interval
56 .parse::<u64>()
57 .map(Duration::from_millis)
58 .map_err(|_| {
59 crate::utils::error::NightError::Task("Invalid interval format".to_string())
60 })
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67 use crate::common::types::{TaskConfig, TaskStatus};
68 use std::collections::HashMap;
69 use uuid::Uuid;
70
71 fn create_test_task(is_periodic: bool, interval: &str) -> Arc<Task> {
72 let config = TaskConfig {
73 name: "Test Task".to_string(),
74 id: Uuid::new_v4(),
75 command: "echo Hello".to_string(),
76 is_periodic,
77 interval: interval.to_string(),
78 importance: 1,
79 dependencies: vec![],
80 };
81 let address_map = Arc::new(HashMap::new());
82 let depend = config.dependencies.clone();
83 Arc::new(Task::new(config, address_map, depend))
84 }
85
86 #[tokio::test]
87 async fn test_run_once() {
88 let task = create_test_task(false, "0");
89 let scheduler = TaskScheduler::new(task.clone());
90
91 scheduler.start().await.unwrap();
92
93 assert_eq!(*task.status.lock().unwrap(), TaskStatus::Completed);
94 }
95
96 #[tokio::test]
97 async fn test_run_periodic() {
98 let task = create_test_task(true, "100");
99 let scheduler = TaskScheduler::new(task.clone());
100
101 tokio::spawn(async move {
103 scheduler.start().await.unwrap();
104 });
105
106 tokio::time::sleep(Duration::from_millis(350)).await;
108
109 task.set_execution_lock(false);
110
111 tokio::time::sleep(Duration::from_millis(150)).await;
113
114 assert_eq!(*task.status.lock().unwrap(), TaskStatus::Completed);
115 }
116}