persistent_scheduler/core/
store.rs

1use crate::core::task_kind::TaskKind;
2use crate::{
3    core::model::{TaskMeta, TaskStatus},
4    utc_now,
5};
6use ahash::AHashMap;
7use async_trait::async_trait;
8use std::{error::Error, sync::Arc};
9use thiserror::Error;
10use tokio::sync::RwLock;
11
12#[async_trait::async_trait]
13pub trait TaskStore: Clone + Send {
14    type Error: Error + Send + Sync;
15
16    /// Restores task states by cleaning up all tasks in a running state and handling their next run times.
17    ///
18    /// This method performs the following actions:
19    /// - Cleans up all tasks that are currently in the `Running` state and may handle their `next_run` fields.
20    /// - Additional restoration logic can be added within this method.
21    ///
22    /// # Returns
23    /// Returns a `Result`, which is `Ok(())` if the operation succeeds; otherwise, it returns the appropriate error.
24    ///
25    /// # Examples
26    /// ```
27    /// # async fn example() {
28    /// #     let store = TaskStore::new();
29    /// #     store.restore_tasks().await.unwrap();
30    /// # }
31    /// ```
32    async fn restore_tasks(&self) -> Result<(), Self::Error>;
33
34    /// Retrieves task metadata based on the task ID.
35    ///
36    /// # Arguments
37    ///
38    /// * `task_id`: A unique identifier for the task.
39    ///
40    /// # Returns
41    ///
42    /// Returns an `Option<TaskMetaEntity>`. If the task is found, it returns `Some(TaskMetaEntity)`, otherwise it returns `None`.
43    async fn get(&self, task_id: &str) -> Result<Option<TaskMeta>, Self::Error>;
44
45    /// Lists all task metadata.
46    ///
47    /// # Returns
48    ///
49    /// Returns a vector containing all task metadata.
50    async fn list(&self) -> Result<Vec<TaskMeta>, Self::Error>;
51
52    /// Stores task metadata.
53    ///
54    /// # Arguments
55    ///
56    /// * `task`: The task metadata to be stored.
57    ///
58    /// # Returns
59    ///
60    /// Returns `Ok(())` if the task is successfully stored; returns an error if the task ID already exists.
61    async fn store_task(&self, task: TaskMeta) -> Result<(), Self::Error>;
62
63    /// Stores tasks metadata.
64    ///
65    /// # Arguments
66    ///
67    /// * `tasks`: The task metadata to be stored.
68    ///
69    /// # Returns
70    ///
71    /// Returns `Ok(())` if the tasks is successfully stored; returns an error if any task ID already exists.
72    async fn store_tasks(&self, tasks: Vec<TaskMeta>) -> Result<(), Self::Error>;
73
74    /// Fetches all pending tasks from the store.
75    ///
76    /// # Returns
77    ///
78    /// Returns a `Result` containing a `Vec<TaskMeta>` if successful, or an error of type `Self::Error` if fetching tasks fails.
79    ///
80    /// The returned `Vec<TaskMeta>` contains all tasks that are currently in a pending state, ready for processing.
81    ///
82    /// # Errors
83    ///
84    /// This function will return an error of type `Self::Error` if there is an issue querying the task store.
85    async fn fetch_pending_tasks(&self) -> Result<Vec<TaskMeta>, Self::Error>;
86
87    /// Updates the execution status of a task.
88    ///
89    /// # Arguments
90    ///
91    /// * `task_id`: The ID of the task to update.
92    /// * `is_success`: A boolean indicating whether the task succeeded.
93    /// * `last_error`: An optional string containing the last error message (if applicable).
94    /// * `next_run`: An optional timestamp for the next scheduled run of the task.
95    ///
96    /// # Returns
97    ///
98    /// Returns `Ok(())` if the update is successful; returns an error if the task is not found or if it is stopped or removed.
99    async fn update_task_execution_status(
100        &self,
101        task_id: &str,
102        is_success: bool,
103        last_error: Option<String>,
104        next_run: Option<i64>,
105    ) -> Result<(), Self::Error>;
106
107    /// Updates the heartbeat for a task.
108    ///
109    /// # Arguments
110    ///
111    /// * `task_id`: The ID of the task to update.
112    /// * `runner_id`: The ID of the runner that is currently executing the task.
113    ///
114    /// # Returns
115    ///
116    /// Returns `Ok(())` if the update is successful; returns an error if the task is not found.
117    async fn heartbeat(&self, task_id: &str, runner_id: &str) -> Result<(), Self::Error>;
118
119    /// Marks a task as stopped.
120    ///
121    /// # Arguments
122    ///
123    /// * `task_id`: The ID of the task to mark as stopped.
124    ///
125    /// # Returns
126    ///
127    /// Returns `Ok(())` if the task is successfully marked; returns an error if the task is not found.
128    async fn set_task_stopped(
129        &self,
130        task_id: &str,
131        reason: Option<String>,
132    ) -> Result<(), Self::Error>;
133
134    /// Marks a task as removed.
135    ///
136    /// # Arguments
137    ///
138    /// * `task_id`: The ID of the task to mark as removed.
139    ///
140    /// # Returns
141    ///
142    /// Returns `Ok(())` if the task is successfully marked; returns an error if the task is not found.
143    async fn set_task_removed(&self, task_id: &str) -> Result<(), Self::Error>;
144
145    /// Cleans up the task store by removing tasks marked as removed.
146    ///
147    /// # Returns
148    ///
149    /// Returns `Ok(())` if the cleanup is successful.
150    async fn cleanup(&self) -> Result<(), Self::Error>;
151}
152
153#[derive(Error, Debug)]
154pub enum InMemoryTaskStoreError {
155    #[error("Task not found")]
156    TaskNotFound,
157    #[error("Task ID conflict: The task with ID '{0}' already exists.")]
158    TaskIdConflict(String),
159}
160
161#[derive(Clone, Default)]
162pub struct InMemoryTaskStore {
163    tasks: Arc<RwLock<AHashMap<String, TaskMeta>>>,
164}
165
166impl InMemoryTaskStore {
167    /// Creates a new instance of `InMemoryTaskStore`.
168    pub fn new() -> Self {
169        Self {
170            tasks: Arc::new(RwLock::new(AHashMap::new())),
171        }
172    }
173}
174
175/// Determines if a task can be executed based on its kind and status.
176pub fn is_candidate_task(kind: &TaskKind, status: &TaskStatus) -> bool {
177    match kind {
178        TaskKind::Cron { .. } | TaskKind::Repeat { .. } => matches!(
179            status,
180            TaskStatus::Scheduled | TaskStatus::Success | TaskStatus::Failed
181        ),
182        TaskKind::Once => *status == TaskStatus::Scheduled,
183    }
184}
185
186#[async_trait]
187impl TaskStore for InMemoryTaskStore {
188    type Error = InMemoryTaskStoreError;
189
190    async fn restore_tasks(&self) -> Result<(), Self::Error> {
191        Ok(())
192    }
193
194    async fn get(&self, task_id: &str) -> Result<Option<TaskMeta>, Self::Error> {
195        let tasks = self.tasks.read().await;
196        Ok(tasks.get(task_id).cloned())
197    }
198
199    async fn list(&self) -> Result<Vec<TaskMeta>, Self::Error> {
200        let tasks = self.tasks.read().await;
201        Ok(tasks.values().cloned().collect())
202    }
203
204    async fn store_task(&self, task: TaskMeta) -> Result<(), Self::Error> {
205        let mut tasks = self.tasks.write().await;
206        if tasks.contains_key(&task.id) {
207            return Err(InMemoryTaskStoreError::TaskIdConflict(task.id.clone()));
208        }
209        tasks.insert(task.id.clone(), task);
210        Ok(())
211    }
212
213    async fn store_tasks(&self, tasks: Vec<TaskMeta>) -> Result<(), Self::Error> {
214        let mut w_tasks = self.tasks.write().await;
215        for task in tasks {
216            if w_tasks.contains_key(&task.id) {
217                return Err(InMemoryTaskStoreError::TaskIdConflict(task.id.clone()));
218            }
219            w_tasks.insert(task.id.clone(), task);
220        }
221        Ok(())
222    }
223
224    async fn fetch_pending_tasks(&self) -> Result<Vec<TaskMeta>, Self::Error> {
225        let mut tasks = self.tasks.write().await;
226        let mut result = Vec::new();
227        for task in tasks.values_mut() {
228            if is_candidate_task(&task.kind, &task.status) && task.next_run <= utc_now!() {
229                let t = task.clone();
230                task.status = TaskStatus::Running;
231                task.updated_at = utc_now!();
232                result.push(t);
233            }
234        }
235        Ok(result)
236    }
237
238    async fn update_task_execution_status(
239        &self,
240        task_id: &str,
241        is_success: bool,
242        last_error: Option<String>,
243        next_run: Option<i64>, // when is None?
244    ) -> Result<(), Self::Error> {
245        let mut tasks = self.tasks.write().await;
246
247        let task = tasks
248            .get_mut(task_id)
249            .ok_or(InMemoryTaskStoreError::TaskNotFound)?;
250
251        if task.status == TaskStatus::Stopped || task.status == TaskStatus::Removed {
252            return Ok(());
253        }
254
255        if is_success {
256            task.success_count += 1;
257            task.status = TaskStatus::Success;
258        } else {
259            task.failure_count += 1;
260            task.status = TaskStatus::Failed;
261            task.last_error = last_error;
262        }
263
264        if let Some(next_run_time) = next_run {
265            task.last_run = task.next_run;
266            task.next_run = next_run_time;
267        }
268
269        task.updated_at = utc_now!();
270
271        Ok(())
272    }
273
274    async fn heartbeat(&self, task_id: &str, runner_id: &str) -> Result<(), Self::Error> {
275        let mut tasks = self.tasks.write().await;
276        if let Some(task) = tasks.get_mut(task_id) {
277            task.heartbeat_at = utc_now!();
278            task.runner_id = Some(runner_id.to_string());
279            Ok(())
280        } else {
281            Err(InMemoryTaskStoreError::TaskNotFound)
282        }
283    }
284
285    async fn set_task_stopped(
286        &self,
287        task_id: &str,
288        reason: Option<String>,
289    ) -> Result<(), Self::Error> {
290        let mut tasks = self.tasks.write().await;
291        if let Some(task) = tasks.get_mut(task_id) {
292            task.updated_at = utc_now!();
293            task.stopped_reason = reason;
294            task.status = TaskStatus::Stopped;
295            Ok(())
296        } else {
297            Err(InMemoryTaskStoreError::TaskNotFound)
298        }
299    }
300
301    async fn set_task_removed(&self, task_id: &str) -> Result<(), Self::Error> {
302        let mut tasks = self.tasks.write().await;
303        if let Some(task) = tasks.get_mut(task_id) {
304            task.updated_at = utc_now!();
305            task.status = TaskStatus::Removed;
306            Ok(())
307        } else {
308            Err(InMemoryTaskStoreError::TaskNotFound)
309        }
310    }
311
312    async fn cleanup(&self) -> Result<(), Self::Error> {
313        let mut tasks = self.tasks.write().await;
314        tasks.retain(|_, task| task.status != TaskStatus::Removed);
315        Ok(())
316    }
317}