persistent_scheduler/core/
store.rs

1use crate::core::task_kind::TaskKind;
2use crate::{
3    core::model::{TaskMeta, TaskStatus},
4    utc_now,
5};
6use ahash::AHashMap;
7use async_trait::async_trait;
8use std::{error::Error, sync::Arc};
9use thiserror::Error;
10use tokio::sync::RwLock;
11
12#[async_trait::async_trait]
13pub trait TaskStore: Clone + Send {
14    type Error: Error + Send + Sync;
15
16    /// Restores task states by cleaning up all tasks in a running state and handling their next run times.
17    ///
18    /// This method performs the following actions:
19    /// - Cleans up all tasks that are currently in the `Running` state and may handle their `next_run` fields.
20    /// - Additional restoration logic can be added within this method.
21    ///
22    /// # Returns
23    /// Returns a `Result`, which is `Ok(())` if the operation succeeds; otherwise, it returns the appropriate error.
24    ///
25    /// # Examples
26    /// ```
27    /// # async fn example() {
28    /// #     let store = TaskStore::new();
29    /// #     store.restore_tasks().await.unwrap();
30    /// # }
31    /// ```
32    async fn restore_tasks(&self) -> Result<(), Self::Error>;
33
34    /// Retrieves task metadata based on the task ID.
35    ///
36    /// # Arguments
37    ///
38    /// * `task_id`: A unique identifier for the task.
39    ///
40    /// # Returns
41    ///
42    /// Returns an `Option<TaskMetaEntity>`. If the task is found, it returns `Some(TaskMetaEntity)`, otherwise it returns `None`.
43    async fn get(&self, task_id: &str) -> Result<Option<TaskMeta>, Self::Error>;
44
45    /// Lists all task metadata.
46    ///
47    /// # Returns
48    ///
49    /// Returns a vector containing all task metadata.
50    async fn list(&self) -> Result<Vec<TaskMeta>, Self::Error>;
51
52    /// Stores task metadata.
53    ///
54    /// # Arguments
55    ///
56    /// * `task`: The task metadata to be stored.
57    ///
58    /// # Returns
59    ///
60    /// Returns `Ok(())` if the task is successfully stored; returns an error if the task ID already exists.
61    async fn store_task(&self, task: TaskMeta) -> Result<(), Self::Error>;
62
63    /// Stores tasks metadata.
64    ///
65    /// # Arguments
66    ///
67    /// * `tasks`: The task metadata to be stored.
68    ///
69    /// # Returns
70    ///
71    /// Returns `Ok(())` if the tasks is successfully stored; returns an error if any task ID already exists.
72    async fn store_tasks(&self, tasks: Vec<TaskMeta>) -> Result<(), Self::Error>;
73
74    /// Fetches all pending tasks from the store.
75    ///
76    /// # Returns
77    ///
78    /// Returns a `Result` containing a `Vec<TaskMeta>` if successful, or an error of type `Self::Error` if fetching tasks fails.
79    ///
80    /// The returned `Vec<TaskMeta>` contains all tasks that are currently in a pending state, ready for processing.
81    ///
82    /// # Errors
83    ///
84    /// This function will return an error of type `Self::Error` if there is an issue querying the task store.
85    async fn fetch_pending_tasks(&self) -> Result<Vec<TaskMeta>, Self::Error>;
86
87    /// Updates the execution status of a task.
88    ///
89    /// # Arguments
90    ///
91    /// * `task_id`: The ID of the task to update.
92    /// * `is_success`: A boolean indicating whether the task succeeded.
93    /// * `last_error`: An optional string containing the last error message (if applicable).
94    /// * `next_run`: An optional timestamp for the next scheduled run of the task.
95    ///
96    /// # Returns
97    ///
98    /// Returns `Ok(())` if the update is successful; returns an error if the task is not found or if it is stopped or removed.
99    async fn update_task_execution_status(
100        &self,
101        task_id: &str,
102        is_success: bool,
103        last_error: Option<String>,
104        last_duration_ms: Option<usize>,
105        last_retry_count: Option<usize>,
106        next_run: Option<i64>,
107    ) -> Result<(), Self::Error>;
108
109    /// Updates the heartbeat for a task.
110    ///
111    /// # Arguments
112    ///
113    /// * `task_id`: The ID of the task to update.
114    /// * `runner_id`: The ID of the runner that is currently executing the task.
115    ///
116    /// # Returns
117    ///
118    /// Returns `Ok(())` if the update is successful; returns an error if the task is not found.
119    async fn heartbeat(&self, task_id: &str, runner_id: &str) -> Result<(), Self::Error>;
120
121    /// Marks a task as stopped.
122    ///
123    /// # Arguments
124    ///
125    /// * `task_id`: The ID of the task to mark as stopped.
126    ///
127    /// # Returns
128    ///
129    /// Returns `Ok(())` if the task is successfully marked; returns an error if the task is not found.
130    async fn set_task_stopped(
131        &self,
132        task_id: &str,
133        reason: Option<String>,
134    ) -> Result<(), Self::Error>;
135
136    /// Marks a task as removed.
137    ///
138    /// # Arguments
139    ///
140    /// * `task_id`: The ID of the task to mark as removed.
141    ///
142    /// # Returns
143    ///
144    /// Returns `Ok(())` if the task is successfully marked; returns an error if the task is not found.
145    async fn set_task_removed(&self, task_id: &str) -> Result<(), Self::Error>;
146
147    /// Cleans up the task store by removing tasks marked as removed.
148    ///
149    /// # Returns
150    ///
151    /// Returns `Ok(())` if the cleanup is successful.
152    async fn cleanup(&self) -> Result<(), Self::Error>;
153}
154
155#[derive(Error, Debug)]
156pub enum InMemoryTaskStoreError {
157    #[error("Task not found")]
158    TaskNotFound,
159    #[error("Task ID conflict: The task with ID '{0}' already exists.")]
160    TaskIdConflict(String),
161}
162
163#[derive(Clone, Default)]
164pub struct InMemoryTaskStore {
165    tasks: Arc<RwLock<AHashMap<String, TaskMeta>>>,
166}
167
168impl InMemoryTaskStore {
169    /// Creates a new instance of `InMemoryTaskStore`.
170    pub fn new() -> Self {
171        Self {
172            tasks: Arc::new(RwLock::new(AHashMap::new())),
173        }
174    }
175}
176
177/// Determines if a task can be executed based on its kind and status.
178pub fn is_candidate_task(kind: &TaskKind, status: &TaskStatus) -> bool {
179    match kind {
180        TaskKind::Cron { .. } | TaskKind::Repeat { .. } => matches!(
181            status,
182            TaskStatus::Scheduled | TaskStatus::Success | TaskStatus::Failed
183        ),
184        TaskKind::Once => *status == TaskStatus::Scheduled,
185    }
186}
187
188#[async_trait]
189impl TaskStore for InMemoryTaskStore {
190    type Error = InMemoryTaskStoreError;
191
192    async fn restore_tasks(&self) -> Result<(), Self::Error> {
193        Ok(())
194    }
195
196    async fn get(&self, task_id: &str) -> Result<Option<TaskMeta>, Self::Error> {
197        let tasks = self.tasks.read().await;
198        Ok(tasks.get(task_id).cloned())
199    }
200
201    async fn list(&self) -> Result<Vec<TaskMeta>, Self::Error> {
202        let tasks = self.tasks.read().await;
203        Ok(tasks.values().cloned().collect())
204    }
205
206    async fn store_task(&self, task: TaskMeta) -> Result<(), Self::Error> {
207        let mut tasks = self.tasks.write().await;
208        if tasks.contains_key(&task.id) {
209            return Err(InMemoryTaskStoreError::TaskIdConflict(task.id.clone()));
210        }
211        tasks.insert(task.id.clone(), task);
212        Ok(())
213    }
214
215    async fn store_tasks(&self, tasks: Vec<TaskMeta>) -> Result<(), Self::Error> {
216        let mut w_tasks = self.tasks.write().await;
217        for task in tasks {
218            if w_tasks.contains_key(&task.id) {
219                return Err(InMemoryTaskStoreError::TaskIdConflict(task.id.clone()));
220            }
221            w_tasks.insert(task.id.clone(), task);
222        }
223        Ok(())
224    }
225
226    async fn fetch_pending_tasks(&self) -> Result<Vec<TaskMeta>, Self::Error> {
227        let mut tasks = self.tasks.write().await;
228        let mut result = Vec::new();
229        for task in tasks.values_mut() {
230            if is_candidate_task(&task.kind, &task.status) && task.next_run <= utc_now!() {
231                let t = task.clone();
232                task.status = TaskStatus::Running;
233                task.updated_at = utc_now!();
234                result.push(t);
235            }
236        }
237        Ok(result)
238    }
239
240    async fn update_task_execution_status(
241        &self,
242        task_id: &str,
243        is_success: bool,
244        last_error: Option<String>,
245        last_duration_ms: Option<usize>,
246        last_retry_count: Option<usize>,
247        next_run: Option<i64>, // when is None?
248    ) -> Result<(), Self::Error> {
249        let mut tasks = self.tasks.write().await;
250
251        let task = tasks
252            .get_mut(task_id)
253            .ok_or(InMemoryTaskStoreError::TaskNotFound)?;
254
255        if task.status == TaskStatus::Stopped || task.status == TaskStatus::Removed {
256            return Ok(());
257        }
258
259        task.last_retry_count = last_retry_count;
260        task.last_duration_ms = last_duration_ms;
261        if is_success {
262            task.success_count += 1;
263            task.status = TaskStatus::Success;
264        } else {
265            task.failure_count += 1;
266            task.status = TaskStatus::Failed;
267            task.last_error = last_error;
268        }
269
270        if let Some(next_run_time) = next_run {
271            task.last_run = task.next_run;
272            task.next_run = next_run_time;
273        }
274
275        task.updated_at = utc_now!();
276
277        Ok(())
278    }
279
280    async fn heartbeat(&self, task_id: &str, runner_id: &str) -> Result<(), Self::Error> {
281        let mut tasks = self.tasks.write().await;
282        if let Some(task) = tasks.get_mut(task_id) {
283            task.heartbeat_at = utc_now!();
284            task.runner_id = Some(runner_id.to_string());
285            Ok(())
286        } else {
287            Err(InMemoryTaskStoreError::TaskNotFound)
288        }
289    }
290
291    async fn set_task_stopped(
292        &self,
293        task_id: &str,
294        reason: Option<String>,
295    ) -> Result<(), Self::Error> {
296        let mut tasks = self.tasks.write().await;
297        if let Some(task) = tasks.get_mut(task_id) {
298            task.updated_at = utc_now!();
299            task.stopped_reason = reason;
300            task.status = TaskStatus::Stopped;
301            Ok(())
302        } else {
303            Err(InMemoryTaskStoreError::TaskNotFound)
304        }
305    }
306
307    async fn set_task_removed(&self, task_id: &str) -> Result<(), Self::Error> {
308        let mut tasks = self.tasks.write().await;
309        if let Some(task) = tasks.get_mut(task_id) {
310            task.updated_at = utc_now!();
311            task.status = TaskStatus::Removed;
312            Ok(())
313        } else {
314            Err(InMemoryTaskStoreError::TaskNotFound)
315        }
316    }
317
318    async fn cleanup(&self) -> Result<(), Self::Error> {
319        let mut tasks = self.tasks.write().await;
320        tasks.retain(|_, task| task.status != TaskStatus::Removed);
321        Ok(())
322    }
323}