persistent_scheduler/nativedb/
meta.rs

1use crate::core::cron::next_run;
2use crate::core::model::TaskMeta;
3use crate::core::model::TaskStatus;
4use crate::core::store::TaskStore;
5use crate::nativedb::init_nativedb;
6use crate::nativedb::TaskMetaEntity;
7use crate::nativedb::TaskMetaEntityKey;
8use crate::nativedb::{get_database, TaskKindEntity};
9use crate::utc_now;
10use async_trait::async_trait;
11use itertools::Itertools;
12use native_db::Database;
13use std::sync::Arc;
14use std::time::Instant;
15use thiserror::Error;
16use tracing::debug;
17
18#[derive(Error, Debug)]
19pub enum NativeDbTaskStoreError {
20    #[error("Task not found")]
21    TaskNotFound,
22
23    #[error("Invalid task status")]
24    InvalidTaskStatus,
25
26    #[error("Task ID conflict: The task with ID '{0}' already exists.")]
27    TaskIdConflict(String),
28
29    #[error("NativeDb error: {0:#?}")]
30    NativeDb(#[from] native_db::db_type::Error),
31
32    #[error("{0:#?}")]
33    Tokio(#[from] tokio::task::JoinError),
34}
35
36#[derive(Clone)]
37pub struct NativeDbTaskStore {
38    pub store: Arc<&'static Database<'static>>,
39}
40
41impl Default for NativeDbTaskStore {
42    fn default() -> Self {
43        NativeDbTaskStore::new(None, None)
44    }
45}
46
47impl NativeDbTaskStore {
48    pub fn new(db_path: Option<String>, cache_size: Option<u64>) -> Self {
49        let store = if let Ok(database) = get_database() {
50            Arc::new(database)
51        } else {
52            let database = init_nativedb(db_path, cache_size)
53                .expect("Failed to initialize the native database.");
54            Arc::new(database)
55        };
56        Self { store }
57    }
58
59    pub fn init(database: &'static Database<'static>) -> Self {
60        Self {
61            store: Arc::new(database),
62        }
63    }
64
65    pub fn fetch_and_lock_task(
66        db: Arc<&'static Database<'static>>,
67        queue: String,
68        runner_id: String,
69    ) -> Result<Option<TaskMeta>, NativeDbTaskStoreError> {
70        // Start the read transaction
71        let r = db.r_transaction()?;
72        let scan = r
73            .scan()
74            .secondary::<TaskMetaEntity>(TaskMetaEntityKey::queue_name)?;
75
76        // Start scanning for tasks in the given queue
77        let mut iter = scan.start_with(queue)?;
78
79        // Find the first task that meets the candidate criteria and is due to run
80        if let Some(task) = iter
81            .find(|item| {
82                item.as_ref().is_ok_and(|e| {
83                    is_candidate_task(&e.kind, &e.status) && e.next_run <= utc_now!()
84                })
85            })
86            .transpose()?
87        {
88            // Start a read-write transaction to update the task's status
89            let rw = db.rw_transaction()?;
90            let current = rw.get().primary::<TaskMetaEntity>(task.id)?;
91
92            match current {
93                Some(mut current) => {
94                    // If the task is still a candidate and ready to run, update it
95                    if is_candidate_task(&current.kind, &current.status)
96                        && current.next_run <= utc_now!()
97                    {
98                        let old = current.clone();
99                        current.runner_id = Some(runner_id);
100                        current.status = TaskStatus::Running;
101                        current.updated_at = utc_now!();
102
103                        // Perform the update in the same transaction
104                        rw.update(old.clone(), current.clone())?;
105                        rw.commit()?;
106
107                        Ok(Some(old.into()))
108                    } else {
109                        // Task status is not valid, return None
110                        Ok(None)
111                    }
112                }
113                None => {
114                    // Task not found, return None
115                    Ok(None)
116                }
117            }
118        } else {
119            // No task found, return None
120            Ok(None)
121        }
122    }
123
124    pub fn fetch_pending_tasks(
125        db: Arc<&'static Database<'static>>,
126    ) -> Result<Vec<TaskMeta>, NativeDbTaskStoreError> {
127        let start = Instant::now();
128        let r = db.r_transaction()?;
129        let scan = r
130            .scan()
131            .secondary::<TaskMetaEntity>(TaskMetaEntityKey::candidate_task)?;
132
133        let iter = scan.start_with("true")?;
134        let tasks: Vec<TaskMetaEntity> = iter
135            .filter_map(|item| item.ok().filter(|e| e.next_run <= utc_now!()))
136            .take(200)
137            .collect();
138
139        let rw = db.rw_transaction()?;
140        let mut result = Vec::new();
141        for entity in tasks.into_iter() {
142            let mut updated = entity.clone();
143            updated.status = TaskStatus::Running;
144            updated.updated_at = utc_now!();
145            rw.update(entity.clone(), updated)?;
146            result.push(entity.into());
147        }
148        rw.commit()?;
149        debug!(
150            "Time taken to fetch task from native_db: {:#?}",
151            start.elapsed()
152        );
153
154        Ok(result)
155    }
156
157    fn update_status(
158        db: Arc<&'static Database<'static>>,
159        task_id: String,
160        is_success: bool,
161        last_error: Option<String>,
162        last_duration_ms: Option<usize>,
163        last_retry_count: Option<usize>,
164        next_run: Option<i64>,
165    ) -> Result<(), NativeDbTaskStoreError> {
166        let rw = db.rw_transaction()?;
167        let task = rw.get().primary::<TaskMetaEntity>(task_id)?;
168
169        let task = match task {
170            Some(t) => t,
171            None => return Err(NativeDbTaskStoreError::TaskNotFound),
172        };
173
174        if task.status == TaskStatus::Stopped || task.status == TaskStatus::Removed {
175            return Ok(());
176        }
177
178        let mut updated_task = task.clone();
179        updated_task.last_duration_ms = last_duration_ms;
180        updated_task.last_retry_count = last_retry_count;
181
182        if is_success {
183            updated_task.success_count += 1;
184            updated_task.status = TaskStatus::Success;
185        } else {
186            updated_task.failure_count += 1;
187            updated_task.status = TaskStatus::Failed;
188            updated_task.last_error = last_error;
189        }
190
191        if let Some(next_run_time) = next_run {
192            updated_task.last_run = updated_task.next_run;
193            updated_task.next_run = next_run_time;
194        }
195
196        updated_task.updated_at = utc_now!();
197
198        rw.update(task, updated_task)?;
199        rw.commit()?;
200
201        Ok(())
202    }
203
204    pub fn clean_up(db: Arc<&'static Database<'static>>) -> Result<(), NativeDbTaskStoreError> {
205        let rw = db.rw_transaction()?;
206        let entities: Vec<TaskMetaEntity> = rw
207            .scan()
208            .secondary(TaskMetaEntityKey::clean_up)?
209            .start_with("true")?
210            .try_collect()?;
211        //Only tasks finished older than 30 minutes are actually cleaned.
212        for entity in entities {
213            if (utc_now!() - entity.updated_at) > 30 * 60 * 1000 {
214                rw.remove(entity)?;
215            }
216        }
217        rw.commit()?;
218        Ok(())
219    }
220
221    pub fn set_status(
222        db: Arc<&'static Database<'static>>,
223        task_id: String,
224        status: TaskStatus,
225        reason: Option<String>,
226    ) -> Result<(), NativeDbTaskStoreError> {
227        assert!(matches!(status, TaskStatus::Removed | TaskStatus::Stopped));
228
229        let rw = db.rw_transaction()?;
230        let task = rw.get().primary::<TaskMetaEntity>(task_id)?;
231
232        if let Some(mut task) = task {
233            let old = task.clone();
234            task.status = status;
235            task.stopped_reason = reason;
236            task.updated_at = utc_now!();
237            rw.update(old, task)?;
238            rw.commit()?;
239            Ok(())
240        } else {
241            Err(NativeDbTaskStoreError::TaskNotFound)
242        }
243    }
244
245    pub fn heartbeat(
246        db: Arc<&'static Database<'static>>,
247        task_id: String,
248        runner_id: String,
249    ) -> Result<(), NativeDbTaskStoreError> {
250        let rw = db.rw_transaction()?;
251        let task = rw.get().primary::<TaskMetaEntity>(task_id)?;
252
253        if let Some(mut task) = task {
254            let old = task.clone();
255            task.heartbeat_at = utc_now!();
256            task.runner_id = Some(runner_id.to_string());
257            rw.update(old, task)?;
258            rw.commit()?;
259            Ok(())
260        } else {
261            Err(NativeDbTaskStoreError::TaskNotFound)
262        }
263    }
264
265    pub fn restore(db: Arc<&'static Database<'static>>) -> Result<(), NativeDbTaskStoreError> {
266        tracing::info!("starting task restore...");
267        let rw = db.rw_transaction()?;
268        let entities: Vec<TaskMetaEntity> = rw
269            .scan()
270            .primary::<TaskMetaEntity>()?
271            .all()?
272            .try_collect()?;
273
274        // Exclude stopped and Removed tasks
275        let targets: Vec<TaskMetaEntity> = entities
276            .into_iter()
277            .filter(|e| !matches!(e.status, TaskStatus::Removed | TaskStatus::Stopped))
278            .collect();
279        for entity in targets
280            .iter()
281            .filter(|e| matches!(e.status, TaskStatus::Running))
282        {
283            let mut updated_entity = entity.clone(); // Clone to modify
284            match updated_entity.kind {
285                TaskKindEntity::Cron | TaskKindEntity::Repeat => {
286                    updated_entity.status = TaskStatus::Scheduled; // Change status to Scheduled for Cron and Repeat
287                }
288                TaskKindEntity::Once => {
289                    updated_entity.status = TaskStatus::Removed; // Remove Once tasks if they didn't complete
290                }
291            }
292
293            // Handle potential error without using `?` in a map
294            rw.update(entity.clone(), updated_entity)?;
295        }
296
297        // Handle next run time for repeatable tasks
298        for entity in targets
299            .iter()
300            .filter(|e| matches!(e.kind, TaskKindEntity::Cron | TaskKindEntity::Repeat))
301        {
302            let mut updated = entity.clone();
303            match entity.kind {
304                TaskKindEntity::Cron => {
305                    if let (Some(cron_schedule), Some(cron_timezone)) =
306                        (entity.cron_schedule.clone(), entity.cron_timezone.clone())
307                    {
308                        updated.next_run = next_run(
309                            cron_schedule.as_str(),
310                            cron_timezone.as_str(),
311                            utc_now!(),
312                        )
313                        .unwrap_or_else(|| {
314                            updated.status = TaskStatus::Stopped; // Invalid configuration leads to Stopped
315                            updated.stopped_reason = Some("Invalid cron configuration (automatically stopped during task restoration)".to_string());
316                            updated.next_run // Keep current next_run
317                        });
318                    } else {
319                        updated.status = TaskStatus::Stopped; // Configuration error leads to Stopped
320                        updated.stopped_reason = Some("Missing cron schedule or timezone (automatically stopped during task restoration)".to_string());
321                    }
322                }
323                TaskKindEntity::Repeat => {
324                    updated.last_run = updated.next_run;
325                    let calculated_next_run =
326                        updated.last_run + (updated.repeat_interval * 1000) as i64;
327                    updated.next_run = if calculated_next_run <= utc_now!() {
328                        utc_now!()
329                    } else {
330                        calculated_next_run
331                    };
332                }
333                _ => {}
334            }
335
336            rw.update(entity.clone(), updated)?;
337        }
338
339        rw.commit()?;
340        tracing::info!("finished task restore.");
341        Ok(())
342    }
343
344    pub fn get(
345        db: Arc<&'static Database<'static>>,
346        task_id: String,
347    ) -> Result<Option<TaskMeta>, NativeDbTaskStoreError> {
348        let r = db.r_transaction()?;
349        Ok(r.get().primary(task_id)?.map(|e: TaskMetaEntity| e.into()))
350    }
351
352    pub fn list(
353        db: Arc<&'static Database<'static>>,
354    ) -> Result<Vec<TaskMeta>, NativeDbTaskStoreError> {
355        let r = db.r_transaction()?;
356        let list: Vec<TaskMetaEntity> = r.scan().primary()?.all()?.try_collect()?;
357        Ok(list.into_iter().map(|e| e.into()).collect())
358    }
359
360    pub fn store_one(
361        db: Arc<&'static Database<'static>>,
362        task: TaskMeta,
363    ) -> Result<(), NativeDbTaskStoreError> {
364        let rw = db.rw_transaction()?;
365        let entity: TaskMetaEntity = task.into();
366        rw.insert(entity)?;
367        rw.commit()?;
368        Ok(())
369    }
370
371    pub fn store_many(
372        db: Arc<&'static Database<'static>>,
373        tasks: Vec<TaskMeta>,
374    ) -> Result<(), NativeDbTaskStoreError> {
375        let rw = db.rw_transaction()?;
376        for task in tasks {
377            let entity: TaskMetaEntity = task.into();
378            rw.insert(entity)?;
379        }
380        rw.commit()?;
381        Ok(())
382    }
383}
384
385/// Determines if a task can be executed based on its kind and status.
386pub fn is_candidate_task(kind: &TaskKindEntity, status: &TaskStatus) -> bool {
387    match kind {
388        TaskKindEntity::Cron | TaskKindEntity::Repeat => matches!(
389            status,
390            TaskStatus::Scheduled | TaskStatus::Success | TaskStatus::Failed
391        ),
392        TaskKindEntity::Once => *status == TaskStatus::Scheduled,
393    }
394}
395
396#[async_trait]
397impl TaskStore for NativeDbTaskStore {
398    type Error = NativeDbTaskStoreError;
399
400    async fn restore_tasks(&self) -> Result<(), Self::Error> {
401        let db = self.store.clone();
402        tokio::task::spawn_blocking(move || Self::restore(db)).await?
403    }
404
405    async fn get(&self, task_id: &str) -> Result<Option<TaskMeta>, Self::Error> {
406        let db = self.store.clone();
407        let task_id = task_id.to_string();
408        tokio::task::spawn_blocking(move || Self::get(db, task_id)).await?
409    }
410
411    async fn list(&self) -> Result<Vec<TaskMeta>, Self::Error> {
412        let db = self.store.clone();
413        tokio::task::spawn_blocking(move || Self::list(db)).await?
414    }
415
416    async fn store_task(&self, task: TaskMeta) -> Result<(), Self::Error> {
417        let db = self.store.clone();
418        tokio::task::spawn_blocking(move || Self::store_one(db, task)).await?
419    }
420
421    async fn store_tasks(&self, tasks: Vec<TaskMeta>) -> Result<(), Self::Error> {
422        let db = self.store.clone();
423        tokio::task::spawn_blocking(move || Self::store_many(db, tasks)).await?
424    }
425
426    async fn fetch_pending_tasks(&self) -> Result<Vec<TaskMeta>, Self::Error> {
427        let db = self.store.clone();
428        tokio::task::spawn_blocking(move || Self::fetch_pending_tasks(db)).await?
429    }
430
431    async fn update_task_execution_status(
432        &self,
433        task_id: &str,
434        is_success: bool,
435        last_error: Option<String>,
436        last_duration_ms: Option<usize>,
437        last_retry_count: Option<usize>,
438        next_run: Option<i64>,
439    ) -> Result<(), Self::Error> {
440        let db = self.store.clone();
441        let task_id = task_id.to_string();
442        tokio::task::spawn_blocking(move || {
443            Self::update_status(
444                db,
445                task_id,
446                is_success,
447                last_error,
448                last_duration_ms,
449                last_retry_count,
450                next_run,
451            )
452        })
453        .await?
454    }
455
456    async fn heartbeat(&self, task_id: &str, runner_id: &str) -> Result<(), Self::Error> {
457        let db = self.store.clone();
458        let task_id = task_id.to_string();
459        let runner_id = runner_id.to_string();
460        tokio::task::spawn_blocking(move || Self::heartbeat(db, task_id, runner_id)).await?
461    }
462
463    async fn set_task_stopped(
464        &self,
465        task_id: &str,
466        reason: Option<String>,
467    ) -> Result<(), Self::Error> {
468        let db = self.store.clone();
469        let task_id = task_id.to_string();
470
471        tokio::task::spawn_blocking(move || {
472            Self::set_status(db, task_id, TaskStatus::Stopped, reason)
473        })
474        .await?
475    }
476
477    async fn set_task_removed(&self, task_id: &str) -> Result<(), Self::Error> {
478        let db = self.store.clone();
479        let task_id = task_id.to_string();
480
481        tokio::task::spawn_blocking(move || {
482            Self::set_status(db, task_id, TaskStatus::Removed, None)
483        })
484        .await?
485    }
486
487    async fn cleanup(&self) -> Result<(), Self::Error> {
488        let db = self.store.clone();
489        tokio::task::spawn_blocking(move || Self::clean_up(db)).await?
490    }
491}