persistent_scheduler/nativedb/
meta.rs

1use crate::core::cron::next_run;
2use crate::core::model::TaskMeta;
3use crate::core::model::TaskStatus;
4use crate::core::store::TaskStore;
5use crate::nativedb::init_nativedb;
6use crate::nativedb::TaskMetaEntity;
7use crate::nativedb::TaskMetaEntityKey;
8use crate::nativedb::{get_database, TaskKindEntity};
9use crate::utc_now;
10use async_trait::async_trait;
11use itertools::Itertools;
12use native_db::Database;
13use std::sync::Arc;
14use std::time::Instant;
15use thiserror::Error;
16use tracing::debug;
17
18#[derive(Error, Debug)]
19pub enum NativeDbTaskStoreError {
20    #[error("Task not found")]
21    TaskNotFound,
22
23    #[error("Invalid task status")]
24    InvalidTaskStatus,
25
26    #[error("Task ID conflict: The task with ID '{0}' already exists.")]
27    TaskIdConflict(String),
28
29    #[error("NativeDb error: {0:#?}")]
30    NativeDb(#[from] native_db::db_type::Error),
31
32    #[error("{0:#?}")]
33    Tokio(#[from] tokio::task::JoinError),
34}
35
36#[derive(Clone)]
37pub struct NativeDbTaskStore {
38    pub store: Arc<&'static Database<'static>>,
39}
40
41impl Default for NativeDbTaskStore {
42    fn default() -> Self {
43        NativeDbTaskStore::new(None, None)
44    }
45}
46
47impl NativeDbTaskStore {
48    pub fn new(db_path: Option<String>, cache_size: Option<u64>) -> Self {
49        let store = if let Ok(database) = get_database() {
50            Arc::new(database)
51        } else {
52            let database = init_nativedb(db_path, cache_size)
53                .expect("Failed to initialize the native database.");
54            Arc::new(database)
55        };
56        Self { store }
57    }
58
59    pub fn init(database: &'static Database<'static>) -> Self {
60        Self {
61            store: Arc::new(database),
62        }
63    }
64
65    pub fn fetch_and_lock_task(
66        db: Arc<&'static Database<'static>>,
67        queue: String,
68        runner_id: String,
69    ) -> Result<Option<TaskMeta>, NativeDbTaskStoreError> {
70        // Start the read transaction
71        let r = db.r_transaction()?;
72        let scan = r
73            .scan()
74            .secondary::<TaskMetaEntity>(TaskMetaEntityKey::queue_name)?;
75
76        // Start scanning for tasks in the given queue
77        let mut iter = scan.start_with(queue)?;
78
79        // Find the first task that meets the candidate criteria and is due to run
80        if let Some(task) = iter
81            .find(|item| {
82                item.as_ref().is_ok_and(|e| {
83                    is_candidate_task(&e.kind, &e.status) && e.next_run <= utc_now!()
84                })
85            })
86            .transpose()?
87        {
88            // Start a read-write transaction to update the task's status
89            let rw = db.rw_transaction()?;
90            let current = rw.get().primary::<TaskMetaEntity>(task.id)?;
91
92            match current {
93                Some(mut current) => {
94                    // If the task is still a candidate and ready to run, update it
95                    if is_candidate_task(&current.kind, &current.status)
96                        && current.next_run <= utc_now!()
97                    {
98                        let old = current.clone();
99                        current.runner_id = Some(runner_id);
100                        current.status = TaskStatus::Running;
101                        current.updated_at = utc_now!();
102
103                        // Perform the update in the same transaction
104                        rw.update(old.clone(), current.clone())?;
105                        rw.commit()?;
106
107                        Ok(Some(old.into()))
108                    } else {
109                        // Task status is not valid, return None
110                        Ok(None)
111                    }
112                }
113                None => {
114                    // Task not found, return None
115                    Ok(None)
116                }
117            }
118        } else {
119            // No task found, return None
120            Ok(None)
121        }
122    }
123
124    pub fn fetch_pending_tasks(
125        db: Arc<&'static Database<'static>>,
126    ) -> Result<Vec<TaskMeta>, NativeDbTaskStoreError> {
127        let start = Instant::now();
128        let r = db.r_transaction()?;
129        let scan = r
130            .scan()
131            .secondary::<TaskMetaEntity>(TaskMetaEntityKey::candidate_task)?;
132
133        let iter = scan.start_with("true")?;
134        let tasks: Vec<TaskMetaEntity> = iter
135            .filter_map(|item| item.ok().filter(|e| e.next_run <= utc_now!()))
136            .take(200)
137            .collect();
138
139        let rw = db.rw_transaction()?;
140        let mut result = Vec::new();
141        for entity in tasks.into_iter() {
142            let mut updated = entity.clone();
143            updated.status = TaskStatus::Running;
144            updated.updated_at = utc_now!();
145            rw.update(entity.clone(), updated)?;
146            result.push(entity.into());
147        }
148        rw.commit()?;
149        debug!(
150            "Time taken to fetch task from native_db: {:#?}",
151            start.elapsed()
152        );
153
154        Ok(result)
155    }
156
157    fn update_status(
158        db: Arc<&'static Database<'static>>,
159        task_id: String,
160        is_success: bool,
161        last_error: Option<String>,
162        next_run: Option<i64>,
163    ) -> Result<(), NativeDbTaskStoreError> {
164        let rw = db.rw_transaction()?;
165        let task = rw.get().primary::<TaskMetaEntity>(task_id)?;
166
167        let task = match task {
168            Some(t) => t,
169            None => return Err(NativeDbTaskStoreError::TaskNotFound),
170        };
171
172        if task.status == TaskStatus::Stopped || task.status == TaskStatus::Removed {
173            return Ok(());
174        }
175
176        let mut updated_task = task.clone();
177        if is_success {
178            updated_task.success_count += 1;
179            updated_task.status = TaskStatus::Success;
180        } else {
181            updated_task.failure_count += 1;
182            updated_task.status = TaskStatus::Failed;
183            updated_task.last_error = last_error;
184        }
185
186        if let Some(next_run_time) = next_run {
187            updated_task.last_run = updated_task.next_run;
188            updated_task.next_run = next_run_time;
189        }
190
191        updated_task.updated_at = utc_now!();
192
193        rw.update(task, updated_task)?;
194        rw.commit()?;
195
196        Ok(())
197    }
198
199    pub fn clean_up(db: Arc<&'static Database<'static>>) -> Result<(), NativeDbTaskStoreError> {
200        let rw = db.rw_transaction()?;
201        let entities: Vec<TaskMetaEntity> = rw
202            .scan()
203            .secondary(TaskMetaEntityKey::clean_up)?
204            .start_with("true")?
205            .try_collect()?;
206        //Only tasks finished older than 30 minutes are actually cleaned.
207        for entity in entities {
208            if (utc_now!() - entity.updated_at) > 30 * 60 * 1000 {
209                rw.remove(entity)?;
210            }
211        }
212        rw.commit()?;
213        Ok(())
214    }
215
216    pub fn set_status(
217        db: Arc<&'static Database<'static>>,
218        task_id: String,
219        status: TaskStatus,
220    ) -> Result<(), NativeDbTaskStoreError> {
221        assert!(matches!(status, TaskStatus::Removed | TaskStatus::Stopped));
222
223        let rw = db.rw_transaction()?;
224        let task = rw.get().primary::<TaskMetaEntity>(task_id)?;
225
226        if let Some(mut task) = task {
227            let old = task.clone();
228            task.status = TaskStatus::Removed;
229            task.updated_at = utc_now!();
230            rw.update(old, task)?;
231            rw.commit()?;
232            Ok(())
233        } else {
234            Err(NativeDbTaskStoreError::TaskNotFound)
235        }
236    }
237
238    pub fn heartbeat(
239        db: Arc<&'static Database<'static>>,
240        task_id: String,
241        runner_id: String,
242    ) -> Result<(), NativeDbTaskStoreError> {
243        let rw = db.rw_transaction()?;
244        let task = rw.get().primary::<TaskMetaEntity>(task_id)?;
245
246        if let Some(mut task) = task {
247            let old = task.clone();
248            task.heartbeat_at = utc_now!();
249            task.runner_id = Some(runner_id.to_string());
250            rw.update(old, task)?;
251            rw.commit()?;
252            Ok(())
253        } else {
254            Err(NativeDbTaskStoreError::TaskNotFound)
255        }
256    }
257
258    pub fn restore(db: Arc<&'static Database<'static>>) -> Result<(), NativeDbTaskStoreError> {
259        tracing::info!("starting task restore...");
260        let rw = db.rw_transaction()?;
261        let entities: Vec<TaskMetaEntity> = rw
262            .scan()
263            .primary::<TaskMetaEntity>()?
264            .all()?
265            .try_collect()?;
266
267        // Exclude stopped and Removed tasks
268        let targets: Vec<TaskMetaEntity> = entities
269            .into_iter()
270            .filter(|e| !matches!(e.status, TaskStatus::Removed | TaskStatus::Stopped))
271            .collect();
272        for entity in targets
273            .iter()
274            .filter(|e| matches!(e.status, TaskStatus::Running))
275        {
276            let mut updated_entity = entity.clone(); // Clone to modify
277            match updated_entity.kind {
278                TaskKindEntity::Cron | TaskKindEntity::Repeat => {
279                    updated_entity.status = TaskStatus::Scheduled; // Change status to Scheduled for Cron and Repeat
280                }
281                TaskKindEntity::Once => {
282                    updated_entity.status = TaskStatus::Removed; // Remove Once tasks if they didn't complete
283                }
284            }
285
286            // Handle potential error without using `?` in a map
287            rw.update(entity.clone(), updated_entity)?;
288        }
289
290        // Handle next run time for repeatable tasks
291        for entity in targets
292            .iter()
293            .filter(|e| matches!(e.kind, TaskKindEntity::Cron | TaskKindEntity::Repeat))
294        {
295            let mut updated = entity.clone();
296            match entity.kind {
297                TaskKindEntity::Cron => {
298                    if let (Some(cron_schedule), Some(cron_timezone)) =
299                        (entity.cron_schedule.clone(), entity.cron_timezone.clone())
300                    {
301                        updated.next_run = next_run(
302                            cron_schedule.as_str(),
303                            cron_timezone.as_str(),
304                            utc_now!(),
305                        )
306                        .unwrap_or_else(|| {
307                            updated.status = TaskStatus::Stopped; // Invalid configuration leads to Stopped
308                            updated.stopped_reason = Some("Invalid cron configuration (automatically stopped during task restoration)".to_string());
309                            updated.next_run // Keep current next_run
310                        });
311                    } else {
312                        updated.status = TaskStatus::Stopped; // Configuration error leads to Stopped
313                        updated.stopped_reason = Some("Missing cron schedule or timezone (automatically stopped during task restoration)".to_string());
314                    }
315                }
316                TaskKindEntity::Repeat => {
317                    updated.last_run = updated.next_run;
318                    let calculated_next_run =
319                        updated.last_run + (updated.repeat_interval * 1000) as i64;
320                    updated.next_run = if calculated_next_run <= utc_now!() {
321                        utc_now!()
322                    } else {
323                        calculated_next_run
324                    };
325                }
326                _ => {}
327            }
328
329            rw.update(entity.clone(), updated)?;
330        }
331
332        rw.commit()?;
333        tracing::info!("finished task restore.");
334        Ok(())
335    }
336
337    pub fn get(
338        db: Arc<&'static Database<'static>>,
339        task_id: String,
340    ) -> Result<Option<TaskMeta>, NativeDbTaskStoreError> {
341        let r = db.r_transaction()?;
342        Ok(r.get().primary(task_id)?.map(|e: TaskMetaEntity| e.into()))
343    }
344
345    pub fn list(
346        db: Arc<&'static Database<'static>>,
347    ) -> Result<Vec<TaskMeta>, NativeDbTaskStoreError> {
348        let r = db.r_transaction()?;
349        let list: Vec<TaskMetaEntity> = r.scan().primary()?.all()?.try_collect()?;
350        Ok(list.into_iter().map(|e| e.into()).collect())
351    }
352
353    pub fn store_one(
354        db: Arc<&'static Database<'static>>,
355        task: TaskMeta,
356    ) -> Result<(), NativeDbTaskStoreError> {
357        let rw = db.rw_transaction()?;
358        let entity: TaskMetaEntity = task.into();
359        rw.insert(entity)?;
360        rw.commit()?;
361        Ok(())
362    }
363
364    pub fn store_many(
365        db: Arc<&'static Database<'static>>,
366        tasks: Vec<TaskMeta>,
367    ) -> Result<(), NativeDbTaskStoreError> {
368        let rw = db.rw_transaction()?;
369        for task in tasks {
370            let entity: TaskMetaEntity = task.into();
371            rw.insert(entity)?;
372        }
373        rw.commit()?;
374        Ok(())
375    }
376}
377
378/// Determines if a task can be executed based on its kind and status.
379pub fn is_candidate_task(kind: &TaskKindEntity, status: &TaskStatus) -> bool {
380    match kind {
381        TaskKindEntity::Cron | TaskKindEntity::Repeat => matches!(
382            status,
383            TaskStatus::Scheduled | TaskStatus::Success | TaskStatus::Failed
384        ),
385        TaskKindEntity::Once => *status == TaskStatus::Scheduled,
386    }
387}
388
389#[async_trait]
390impl TaskStore for NativeDbTaskStore {
391    type Error = NativeDbTaskStoreError;
392
393    async fn restore_tasks(&self) -> Result<(), Self::Error> {
394        let db = self.store.clone();
395        tokio::task::spawn_blocking(move || Self::restore(db)).await?
396    }
397
398    async fn get(&self, task_id: &str) -> Result<Option<TaskMeta>, Self::Error> {
399        let db = self.store.clone();
400        let task_id = task_id.to_string();
401        tokio::task::spawn_blocking(move || Self::get(db, task_id)).await?
402    }
403
404    async fn list(&self) -> Result<Vec<TaskMeta>, Self::Error> {
405        let db = self.store.clone();
406        tokio::task::spawn_blocking(move || Self::list(db)).await?
407    }
408
409    async fn store_task(&self, task: TaskMeta) -> Result<(), Self::Error> {
410        let db = self.store.clone();
411        tokio::task::spawn_blocking(move || Self::store_one(db, task)).await?
412    }
413
414    async fn store_tasks(&self, tasks: Vec<TaskMeta>) -> Result<(), Self::Error> {
415        let db = self.store.clone();
416        tokio::task::spawn_blocking(move || Self::store_many(db, tasks)).await?
417    }
418
419    async fn fetch_pending_tasks(&self) -> Result<Vec<TaskMeta>, Self::Error> {
420        let db = self.store.clone();
421        tokio::task::spawn_blocking(move || Self::fetch_pending_tasks(db)).await?
422    }
423
424    async fn update_task_execution_status(
425        &self,
426        task_id: &str,
427        is_success: bool,
428        last_error: Option<String>,
429        next_run: Option<i64>,
430    ) -> Result<(), Self::Error> {
431        let db = self.store.clone();
432        let task_id = task_id.to_string();
433        tokio::task::spawn_blocking(move || {
434            Self::update_status(db, task_id, is_success, last_error, next_run)
435        })
436        .await?
437    }
438
439    async fn heartbeat(&self, task_id: &str, runner_id: &str) -> Result<(), Self::Error> {
440        let db = self.store.clone();
441        let task_id = task_id.to_string();
442        let runner_id = runner_id.to_string();
443        tokio::task::spawn_blocking(move || Self::heartbeat(db, task_id, runner_id)).await?
444    }
445
446    async fn set_task_stopped(&self, task_id: &str) -> Result<(), Self::Error> {
447        let db = self.store.clone();
448        let task_id = task_id.to_string();
449
450        tokio::task::spawn_blocking(move || Self::set_status(db, task_id, TaskStatus::Stopped))
451            .await?
452    }
453
454    async fn set_task_removed(&self, task_id: &str) -> Result<(), Self::Error> {
455        let db = self.store.clone();
456        let task_id = task_id.to_string();
457
458        tokio::task::spawn_blocking(move || Self::set_status(db, task_id, TaskStatus::Removed))
459            .await?
460    }
461
462    async fn cleanup(&self) -> Result<(), Self::Error> {
463        let db = self.store.clone();
464        tokio::task::spawn_blocking(move || Self::clean_up(db)).await?
465    }
466}