persistent_scheduler/nativedb/
meta.rs

1use crate::core::cron::next_run;
2use crate::core::model::TaskMeta;
3use crate::core::model::TaskStatus;
4use crate::core::store::TaskStore;
5use crate::nativedb::init_nativedb;
6use crate::nativedb::TaskMetaEntity;
7use crate::nativedb::TaskMetaEntityKey;
8use crate::nativedb::{get_database, TaskKindEntity};
9use crate::utc_now;
10use async_trait::async_trait;
11use itertools::Itertools;
12use native_db::Database;
13use std::sync::Arc;
14use std::time::Instant;
15use thiserror::Error;
16use tracing::debug;
17
18#[derive(Error, Debug)]
19pub enum NativeDbTaskStoreError {
20    #[error("Task not found")]
21    TaskNotFound,
22
23    #[error("Invalid task status")]
24    InvalidTaskStatus,
25
26    #[error("Task ID conflict: The task with ID '{0}' already exists.")]
27    TaskIdConflict(String),
28
29    #[error("NativeDb error: {0:#?}")]
30    NativeDb(#[from] native_db::db_type::Error),
31
32    #[error("{0:#?}")]
33    Tokio(#[from] tokio::task::JoinError),
34}
35
36#[derive(Clone)]
37pub struct NativeDbTaskStore {
38    pub store: Arc<&'static Database<'static>>,
39}
40
41impl Default for NativeDbTaskStore {
42    fn default() -> Self {
43        NativeDbTaskStore::new(None, None)
44    }
45}
46
47impl NativeDbTaskStore {
48    pub fn new(db_path: Option<String>, cache_size: Option<u64>) -> Self {
49        let store = if let Ok(database) = get_database() {
50            Arc::new(database)
51        } else {
52            let database = init_nativedb(db_path, cache_size)
53                .expect("Failed to initialize the native database.");
54            Arc::new(database)
55        };
56        Self { store }
57    }
58
59    pub fn init(database: &'static Database<'static>) -> Self {
60        Self {
61            store: Arc::new(database),
62        }
63    }
64
65    pub fn fetch_and_lock_task(
66        db: Arc<&'static Database<'static>>,
67        queue: String,
68        runner_id: String,
69    ) -> Result<Option<TaskMeta>, NativeDbTaskStoreError> {
70        // Start the read transaction
71        let r = db.r_transaction()?;
72        let scan = r
73            .scan()
74            .secondary::<TaskMetaEntity>(TaskMetaEntityKey::queue_name)?;
75
76        // Start scanning for tasks in the given queue
77        let mut iter = scan.start_with(queue)?;
78
79        // Find the first task that meets the candidate criteria and is due to run
80        if let Some(task) = iter
81            .find(|item| {
82                item.as_ref().is_ok_and(|e| {
83                    is_candidate_task(&e.kind, &e.status) && e.next_run <= utc_now!()
84                })
85            })
86            .transpose()?
87        {
88            // Start a read-write transaction to update the task's status
89            let rw = db.rw_transaction()?;
90            let current = rw.get().primary::<TaskMetaEntity>(task.id)?;
91
92            match current {
93                Some(mut current) => {
94                    // If the task is still a candidate and ready to run, update it
95                    if is_candidate_task(&current.kind, &current.status)
96                        && current.next_run <= utc_now!()
97                    {
98                        let old = current.clone();
99                        current.runner_id = Some(runner_id);
100                        current.status = TaskStatus::Running;
101                        current.updated_at = utc_now!();
102
103                        // Perform the update in the same transaction
104                        rw.update(old.clone(), current.clone())?;
105                        rw.commit()?;
106
107                        Ok(Some(old.into()))
108                    } else {
109                        // Task status is not valid, return None
110                        Ok(None)
111                    }
112                }
113                None => {
114                    // Task not found, return None
115                    Ok(None)
116                }
117            }
118        } else {
119            // No task found, return None
120            Ok(None)
121        }
122    }
123
124    pub fn fetch_pending_tasks(
125        db: Arc<&'static Database<'static>>,
126    ) -> Result<Vec<TaskMeta>, NativeDbTaskStoreError> {
127        let start = Instant::now();
128        let r = db.r_transaction()?;
129        let scan = r
130            .scan()
131            .secondary::<TaskMetaEntity>(TaskMetaEntityKey::candidate_task)?;
132
133        let iter = scan.start_with("true")?;
134        let tasks: Vec<TaskMetaEntity> = iter
135            .filter_map(|item| item.ok().filter(|e| e.next_run <= utc_now!()))
136            .take(200)
137            .collect();
138
139        let rw = db.rw_transaction()?;
140        let mut result = Vec::new();
141        for entity in tasks.into_iter() {
142            let mut updated = entity.clone();
143            updated.status = TaskStatus::Running;
144            updated.updated_at = utc_now!();
145            rw.update(entity.clone(), updated)?;
146            result.push(entity.into());
147        }
148        rw.commit()?;
149        debug!(
150            "Time taken to fetch task from native_db: {:#?}",
151            start.elapsed()
152        );
153
154        Ok(result)
155    }
156
157    fn update_status(
158        db: Arc<&'static Database<'static>>,
159        task_id: String,
160        is_success: bool,
161        last_error: Option<String>,
162        next_run: Option<i64>,
163    ) -> Result<(), NativeDbTaskStoreError> {
164        let rw = db.rw_transaction()?;
165        let task = rw.get().primary::<TaskMetaEntity>(task_id)?;
166
167        let task = match task {
168            Some(t) => t,
169            None => return Err(NativeDbTaskStoreError::TaskNotFound),
170        };
171
172        if task.status == TaskStatus::Stopped || task.status == TaskStatus::Removed {
173            return Ok(());
174        }
175
176        let mut updated_task = task.clone();
177        if is_success {
178            updated_task.success_count += 1;
179            updated_task.status = TaskStatus::Success;
180        } else {
181            updated_task.failure_count += 1;
182            updated_task.status = TaskStatus::Failed;
183            updated_task.last_error = last_error;
184        }
185
186        if let Some(next_run_time) = next_run {
187            updated_task.last_run = updated_task.next_run;
188            updated_task.next_run = next_run_time;
189        }
190
191        updated_task.updated_at = utc_now!();
192
193        rw.update(task, updated_task)?;
194        rw.commit()?;
195
196        Ok(())
197    }
198
199    pub fn clean_up(db: Arc<&'static Database<'static>>) -> Result<(), NativeDbTaskStoreError> {
200        let rw = db.rw_transaction()?;
201        let entities: Vec<TaskMetaEntity> = rw
202            .scan()
203            .secondary(TaskMetaEntityKey::clean_up)?
204            .start_with("true")?
205            .try_collect()?;
206        //Only tasks finished older than 30 minutes are actually cleaned.
207        for entity in entities {
208            if (utc_now!() - entity.updated_at) > 30 * 60 * 1000 {
209                rw.remove(entity)?;
210            }
211        }
212        rw.commit()?;
213        Ok(())
214    }
215
216    pub fn set_status(
217        db: Arc<&'static Database<'static>>,
218        task_id: String,
219        status: TaskStatus,
220        reason: Option<String>,
221    ) -> Result<(), NativeDbTaskStoreError> {
222        assert!(matches!(status, TaskStatus::Removed | TaskStatus::Stopped));
223
224        let rw = db.rw_transaction()?;
225        let task = rw.get().primary::<TaskMetaEntity>(task_id)?;
226
227        if let Some(mut task) = task {
228            let old = task.clone();
229            task.status = status;
230            task.stopped_reason = reason;
231            task.updated_at = utc_now!();
232            rw.update(old, task)?;
233            rw.commit()?;
234            Ok(())
235        } else {
236            Err(NativeDbTaskStoreError::TaskNotFound)
237        }
238    }
239
240    pub fn heartbeat(
241        db: Arc<&'static Database<'static>>,
242        task_id: String,
243        runner_id: String,
244    ) -> Result<(), NativeDbTaskStoreError> {
245        let rw = db.rw_transaction()?;
246        let task = rw.get().primary::<TaskMetaEntity>(task_id)?;
247
248        if let Some(mut task) = task {
249            let old = task.clone();
250            task.heartbeat_at = utc_now!();
251            task.runner_id = Some(runner_id.to_string());
252            rw.update(old, task)?;
253            rw.commit()?;
254            Ok(())
255        } else {
256            Err(NativeDbTaskStoreError::TaskNotFound)
257        }
258    }
259
260    pub fn restore(db: Arc<&'static Database<'static>>) -> Result<(), NativeDbTaskStoreError> {
261        tracing::info!("starting task restore...");
262        let rw = db.rw_transaction()?;
263        let entities: Vec<TaskMetaEntity> = rw
264            .scan()
265            .primary::<TaskMetaEntity>()?
266            .all()?
267            .try_collect()?;
268
269        // Exclude stopped and Removed tasks
270        let targets: Vec<TaskMetaEntity> = entities
271            .into_iter()
272            .filter(|e| !matches!(e.status, TaskStatus::Removed | TaskStatus::Stopped))
273            .collect();
274        for entity in targets
275            .iter()
276            .filter(|e| matches!(e.status, TaskStatus::Running))
277        {
278            let mut updated_entity = entity.clone(); // Clone to modify
279            match updated_entity.kind {
280                TaskKindEntity::Cron | TaskKindEntity::Repeat => {
281                    updated_entity.status = TaskStatus::Scheduled; // Change status to Scheduled for Cron and Repeat
282                }
283                TaskKindEntity::Once => {
284                    updated_entity.status = TaskStatus::Removed; // Remove Once tasks if they didn't complete
285                }
286            }
287
288            // Handle potential error without using `?` in a map
289            rw.update(entity.clone(), updated_entity)?;
290        }
291
292        // Handle next run time for repeatable tasks
293        for entity in targets
294            .iter()
295            .filter(|e| matches!(e.kind, TaskKindEntity::Cron | TaskKindEntity::Repeat))
296        {
297            let mut updated = entity.clone();
298            match entity.kind {
299                TaskKindEntity::Cron => {
300                    if let (Some(cron_schedule), Some(cron_timezone)) =
301                        (entity.cron_schedule.clone(), entity.cron_timezone.clone())
302                    {
303                        updated.next_run = next_run(
304                            cron_schedule.as_str(),
305                            cron_timezone.as_str(),
306                            utc_now!(),
307                        )
308                        .unwrap_or_else(|| {
309                            updated.status = TaskStatus::Stopped; // Invalid configuration leads to Stopped
310                            updated.stopped_reason = Some("Invalid cron configuration (automatically stopped during task restoration)".to_string());
311                            updated.next_run // Keep current next_run
312                        });
313                    } else {
314                        updated.status = TaskStatus::Stopped; // Configuration error leads to Stopped
315                        updated.stopped_reason = Some("Missing cron schedule or timezone (automatically stopped during task restoration)".to_string());
316                    }
317                }
318                TaskKindEntity::Repeat => {
319                    updated.last_run = updated.next_run;
320                    let calculated_next_run =
321                        updated.last_run + (updated.repeat_interval * 1000) as i64;
322                    updated.next_run = if calculated_next_run <= utc_now!() {
323                        utc_now!()
324                    } else {
325                        calculated_next_run
326                    };
327                }
328                _ => {}
329            }
330
331            rw.update(entity.clone(), updated)?;
332        }
333
334        rw.commit()?;
335        tracing::info!("finished task restore.");
336        Ok(())
337    }
338
339    pub fn get(
340        db: Arc<&'static Database<'static>>,
341        task_id: String,
342    ) -> Result<Option<TaskMeta>, NativeDbTaskStoreError> {
343        let r = db.r_transaction()?;
344        Ok(r.get().primary(task_id)?.map(|e: TaskMetaEntity| e.into()))
345    }
346
347    pub fn list(
348        db: Arc<&'static Database<'static>>,
349    ) -> Result<Vec<TaskMeta>, NativeDbTaskStoreError> {
350        let r = db.r_transaction()?;
351        let list: Vec<TaskMetaEntity> = r.scan().primary()?.all()?.try_collect()?;
352        Ok(list.into_iter().map(|e| e.into()).collect())
353    }
354
355    pub fn store_one(
356        db: Arc<&'static Database<'static>>,
357        task: TaskMeta,
358    ) -> Result<(), NativeDbTaskStoreError> {
359        let rw = db.rw_transaction()?;
360        let entity: TaskMetaEntity = task.into();
361        rw.insert(entity)?;
362        rw.commit()?;
363        Ok(())
364    }
365
366    pub fn store_many(
367        db: Arc<&'static Database<'static>>,
368        tasks: Vec<TaskMeta>,
369    ) -> Result<(), NativeDbTaskStoreError> {
370        let rw = db.rw_transaction()?;
371        for task in tasks {
372            let entity: TaskMetaEntity = task.into();
373            rw.insert(entity)?;
374        }
375        rw.commit()?;
376        Ok(())
377    }
378}
379
380/// Determines if a task can be executed based on its kind and status.
381pub fn is_candidate_task(kind: &TaskKindEntity, status: &TaskStatus) -> bool {
382    match kind {
383        TaskKindEntity::Cron | TaskKindEntity::Repeat => matches!(
384            status,
385            TaskStatus::Scheduled | TaskStatus::Success | TaskStatus::Failed
386        ),
387        TaskKindEntity::Once => *status == TaskStatus::Scheduled,
388    }
389}
390
391#[async_trait]
392impl TaskStore for NativeDbTaskStore {
393    type Error = NativeDbTaskStoreError;
394
395    async fn restore_tasks(&self) -> Result<(), Self::Error> {
396        let db = self.store.clone();
397        tokio::task::spawn_blocking(move || Self::restore(db)).await?
398    }
399
400    async fn get(&self, task_id: &str) -> Result<Option<TaskMeta>, Self::Error> {
401        let db = self.store.clone();
402        let task_id = task_id.to_string();
403        tokio::task::spawn_blocking(move || Self::get(db, task_id)).await?
404    }
405
406    async fn list(&self) -> Result<Vec<TaskMeta>, Self::Error> {
407        let db = self.store.clone();
408        tokio::task::spawn_blocking(move || Self::list(db)).await?
409    }
410
411    async fn store_task(&self, task: TaskMeta) -> Result<(), Self::Error> {
412        let db = self.store.clone();
413        tokio::task::spawn_blocking(move || Self::store_one(db, task)).await?
414    }
415
416    async fn store_tasks(&self, tasks: Vec<TaskMeta>) -> Result<(), Self::Error> {
417        let db = self.store.clone();
418        tokio::task::spawn_blocking(move || Self::store_many(db, tasks)).await?
419    }
420
421    async fn fetch_pending_tasks(&self) -> Result<Vec<TaskMeta>, Self::Error> {
422        let db = self.store.clone();
423        tokio::task::spawn_blocking(move || Self::fetch_pending_tasks(db)).await?
424    }
425
426    async fn update_task_execution_status(
427        &self,
428        task_id: &str,
429        is_success: bool,
430        last_error: Option<String>,
431        next_run: Option<i64>,
432    ) -> Result<(), Self::Error> {
433        let db = self.store.clone();
434        let task_id = task_id.to_string();
435        tokio::task::spawn_blocking(move || {
436            Self::update_status(db, task_id, is_success, last_error, next_run)
437        })
438        .await?
439    }
440
441    async fn heartbeat(&self, task_id: &str, runner_id: &str) -> Result<(), Self::Error> {
442        let db = self.store.clone();
443        let task_id = task_id.to_string();
444        let runner_id = runner_id.to_string();
445        tokio::task::spawn_blocking(move || Self::heartbeat(db, task_id, runner_id)).await?
446    }
447
448    async fn set_task_stopped(
449        &self,
450        task_id: &str,
451        reason: Option<String>,
452    ) -> Result<(), Self::Error> {
453        let db = self.store.clone();
454        let task_id = task_id.to_string();
455
456        tokio::task::spawn_blocking(move || {
457            Self::set_status(db, task_id, TaskStatus::Stopped, reason)
458        })
459        .await?
460    }
461
462    async fn set_task_removed(&self, task_id: &str) -> Result<(), Self::Error> {
463        let db = self.store.clone();
464        let task_id = task_id.to_string();
465
466        tokio::task::spawn_blocking(move || {
467            Self::set_status(db, task_id, TaskStatus::Removed, None)
468        })
469        .await?
470    }
471
472    async fn cleanup(&self) -> Result<(), Self::Error> {
473        let db = self.store.clone();
474        tokio::task::spawn_blocking(move || Self::clean_up(db)).await?
475    }
476}