pg_taskq/
task.rs

1use chrono::prelude::*;
2use serde::Serialize;
3use serde_json::Value;
4use sqlx::{Acquire, PgExecutor, Pool, Postgres};
5use std::collections::HashSet;
6use std::time::Instant;
7use uuid::Uuid;
8
9use crate::{Error, Result, TaskTableProvider, TaskType};
10
11#[derive(Default)]
12pub struct TaskBuilder {
13    task_type: String,
14    req: Option<Value>,
15    parent: Option<Uuid>,
16    id: Option<Uuid>,
17}
18
19impl TaskBuilder {
20    pub fn new(task_type: impl TaskType) -> Self {
21        Self {
22            task_type: task_type.to_string(),
23            ..Default::default()
24        }
25    }
26
27    pub fn with_request(mut self, req: impl Serialize) -> Result<Self> {
28        self.req = Some(serde_json::to_value(req)?);
29        Ok(self)
30    }
31
32    #[must_use]
33    pub fn with_parent(mut self, parent: Uuid) -> Self {
34        self.parent = Some(parent);
35        self
36    }
37
38    #[must_use]
39    pub fn with_id(mut self, id: Uuid) -> Self {
40        self.id = Some(id);
41        self
42    }
43
44    pub async fn build<'a, DB, P>(self, db: DB, tables: &P) -> Result<Task>
45    where
46        P: TaskTableProvider,
47        DB: Acquire<'a, Database = Postgres>,
48    {
49        let Self {
50            task_type,
51            req,
52            parent,
53            id,
54        } = self;
55        Task::create_task(db, tables, id, task_type, req, parent).await
56    }
57}
58
59#[derive(Debug, Clone, sqlx::FromRow)]
60pub struct Task {
61    pub id: Uuid,
62    pub parent: Option<Uuid>,
63    pub created_at: DateTime<Utc>,
64    pub updated_at: DateTime<Utc>,
65    pub task_type: String,
66    pub request: Option<Value>,
67    pub result: Option<Value>,
68    pub error: Option<Value>,
69    pub in_progress: bool,
70    pub done: bool,
71}
72
73impl Task {
74    #[instrument(level = "trace", skip(db, tables, req))]
75    pub async fn create_task<'a, DB, P>(
76        db: DB,
77        tables: &P,
78        id: Option<Uuid>,
79        task_type: String,
80        req: Option<Value>,
81        parent: Option<Uuid>,
82    ) -> Result<Self>
83    where
84        P: TaskTableProvider,
85        DB: Acquire<'a, Database = Postgres>,
86    {
87        let id = id.unwrap_or_else(Uuid::new_v4);
88        let req = Some(serde_json::to_value(req)?);
89
90        let mut tx = db.begin().await?;
91
92        let table = tables.tasks_table_full_name();
93        let notify_fn = tables.tasks_notify_fn_full_name();
94
95        let sql = format!(
96            "
97INSERT INTO {table} (id, task_type, request, parent, in_progress, done)
98VALUES ($1, $2, $3, $4, false, false)
99RETURNING *
100"
101        );
102        let task: Self = sqlx::query_as(&sql)
103            .bind(id)
104            .bind(task_type.to_string())
105            .bind(req)
106            .bind(parent)
107            .fetch_one(&mut *tx)
108            .await?;
109
110        let sql = format!("SELECT {notify_fn}($1)");
111        sqlx::query(&sql).bind(task.id).execute(&mut *tx).await?;
112
113        tx.commit().await?;
114
115        debug!("created task {task_type:?} {id}");
116
117        Ok(task)
118    }
119
120    pub async fn load(
121        db: impl PgExecutor<'_>,
122        tables: &dyn TaskTableProvider,
123        id: Uuid,
124    ) -> Result<Option<Self>> {
125        let table = tables.tasks_table_full_name();
126        let sql = format!("SELECT * FROM {table} WHERE id = $1");
127        Ok(sqlx::query_as(&sql).bind(id).fetch_optional(db).await?)
128    }
129
130    pub async fn with_children(
131        self,
132        db: impl PgExecutor<'_>,
133        tables: &dyn TaskTableProvider,
134        recursive: bool,
135    ) -> Result<Vec<Self>> {
136        Self::load_children(db, tables, self.id, true, recursive).await
137    }
138
139    pub async fn children(
140        &self,
141        db: impl PgExecutor<'_>,
142        tables: &dyn TaskTableProvider,
143        recursive: bool,
144    ) -> Result<Vec<Self>> {
145        Self::load_children(db, tables, self.id, false, recursive).await
146    }
147
148    #[instrument(level = "trace", skip(db, tables))]
149    pub async fn load_children(
150        db: impl PgExecutor<'_>,
151        tables: &dyn TaskTableProvider,
152        id: Uuid,
153        include_self: bool,
154        recursive: bool,
155    ) -> Result<Vec<Self>> {
156        let table = tables.tasks_table_full_name();
157        let sql = if recursive {
158            let where_clause = if include_self { "" } else { "WHERE t.id != $1" };
159            format!(
160                "
161WITH RECURSIVE tasks_and_subtasks(id, parent) AS (
162     SELECT t.* FROM {table} t
163     WHERE t.id = $1
164     UNION ALL
165     SELECT child_task.*
166     FROM tasks_and_subtasks t, {table} child_task
167     WHERE t.id = child_task.parent
168)
169SELECT * FROM tasks_and_subtasks {where_clause}
170"
171            )
172        } else {
173            let self_condition = if include_self { "OR t.id = $1" } else { "" };
174            format!("SELECT * FROM {table} WHERE parent = $1 {self_condition}")
175        };
176
177        Ok(sqlx::query_as(&sql).bind(id).fetch_all(db).await?)
178    }
179
180    #[instrument(level = "trace", skip(db, tables))]
181    pub async fn load_any_waiting(
182        db: impl PgExecutor<'_>,
183        tables: &dyn TaskTableProvider,
184        allowed_types: &[impl TaskType],
185    ) -> Result<Option<Self>> {
186        let allowed_types = allowed_types
187            .iter()
188            .map(|ea| ea.to_string())
189            .collect::<Vec<_>>();
190        let table = tables.tasks_table_full_name();
191        let table_ready = tables.tasks_ready_view_full_name();
192        let sql = format!(
193            "
194UPDATE {table}
195SET updated_at = NOW(),
196    in_progress = true
197WHERE id = (SELECT id
198            FROM {table}
199            WHERE task_type = ANY($1) AND id IN (SELECT id FROM {table_ready})
200            LIMIT 1
201            FOR UPDATE SKIP LOCKED)
202RETURNING *;
203"
204        );
205        let task: Option<Self> = sqlx::query_as(&sql)
206            .bind(allowed_types)
207            .fetch_optional(db)
208            .await?;
209
210        Ok(task)
211    }
212
213    #[instrument(level = "trace", skip(db, tables))]
214    pub async fn load_waiting(
215        id: Uuid,
216        db: impl PgExecutor<'_>,
217        tables: &dyn TaskTableProvider,
218        allowed_types: &[impl TaskType],
219    ) -> Result<Option<Self>> {
220        let allowed_types = allowed_types
221            .iter()
222            .map(|ea| ea.to_string())
223            .collect::<Vec<_>>();
224        let table = tables.tasks_table_full_name();
225        let table_ready = tables.tasks_ready_view_full_name();
226        let sql = format!(
227            "
228UPDATE {table}
229SET updated_at = NOW(),
230    in_progress = true
231WHERE id = (SELECT id
232            FROM {table}
233            WHERE id = $1 AND task_type = ANY($2) AND id IN (SELECT id FROM {table_ready})
234            LIMIT 1
235            FOR UPDATE SKIP LOCKED)
236RETURNING *;
237"
238        );
239        let task: Option<Self> = sqlx::query_as(&sql)
240            .bind(sqlx::types::Uuid::from_u128(id.as_u128()))
241            .bind(allowed_types)
242            .fetch_optional(db)
243            .await?;
244
245        Ok(task)
246    }
247
248    /// Sets the error value for the task
249    #[instrument(level = "trace", skip(db, tables))]
250    pub async fn set_error(
251        id: Uuid,
252        db: impl Acquire<'_, Database = Postgres>,
253        tables: &dyn TaskTableProvider,
254        error: Value,
255    ) -> Result<()> {
256        let mut tx = db.begin().await?;
257
258        let table = tables.tasks_table_full_name();
259        let sql = format!(
260            "
261UPDATE {table}
262SET updated_at = NOW(),
263    error = $2,
264    in_progress = false
265WHERE id = $1"
266        );
267
268        sqlx::query(&sql)
269            .bind(id)
270            .bind(error)
271            .execute(&mut *tx)
272            .await?;
273
274        // if self.done {
275        //     debug!("notifying about task done {}", self.id);
276        //     let notify_fn = tables.tasks_notify_done_fn_full_name();
277        //     let sql = format!("SELECT {notify_fn}($1)");
278        //     sqlx::query(&sql).bind(self.id).execute(&mut *tx).await?;
279        // } else {
280        //     debug!("notifying about task ready again {}", self.id);
281        //     let notify_fn = tables.tasks_notify_done_fn_full_name();
282        //     let sql = format!("SELECT {notify_fn}($1)");
283        //     sqlx::query(&sql).bind(self.id).execute(&mut *tx).await?;
284        // }
285
286        tx.commit().await?;
287
288        Ok(())
289    }
290
291    /// Saves current state in DB
292    #[instrument(level = "trace", skip(self, db,tables), fields(id=%self.id))]
293    pub async fn save(
294        &self,
295        db: impl Acquire<'_, Database = Postgres>,
296        tables: &dyn TaskTableProvider,
297    ) -> Result<()> {
298        let mut tx = db.begin().await?;
299
300        let table = tables.tasks_table_full_name();
301        let sql = format!(
302            "
303UPDATE {table}
304SET updated_at = NOW(),
305    request = $2,
306    result = $3,
307    error = $4,
308    in_progress = $5,
309    done = $6
310WHERE id = $1"
311        );
312
313        sqlx::query(&sql)
314            .bind(self.id)
315            .bind(&self.request)
316            .bind(&self.result)
317            .bind(&self.error)
318            .bind(self.in_progress)
319            .bind(self.done)
320            .execute(&mut *tx)
321            .await?;
322
323        if self.done {
324            debug!("notifying about task done {}", self.id);
325            let notify_fn = tables.tasks_notify_done_fn_full_name();
326            let sql = format!("SELECT {notify_fn}($1)");
327            sqlx::query(&sql).bind(self.id).execute(&mut *tx).await?;
328        } else {
329            debug!("notifying about task ready again {}", self.id);
330            let notify_fn = tables.tasks_notify_fn_full_name();
331            let sql = format!("SELECT {notify_fn}($1)");
332            sqlx::query(&sql).bind(self.id).execute(&mut *tx).await?;
333        }
334
335        tx.commit().await?;
336
337        Ok(())
338    }
339
340    /// Queries state of this task from DB and updates self.
341    #[instrument(level = "trace", skip(self, db,tables), fields(id=%self.id))]
342    pub async fn update(
343        &mut self,
344        db: impl PgExecutor<'_>,
345        tables: &dyn TaskTableProvider,
346    ) -> Result<()> {
347        info!("UPDATING {}", self.id);
348        let table = tables.tasks_table_full_name();
349        let sql = format!("SELECT * FROM {table} WHERE id = $1");
350        let me: Self = sqlx::query_as(&sql).bind(self.id).fetch_one(db).await?;
351
352        let _ = std::mem::replace(self, me);
353
354        Ok(())
355    }
356
357    /// Deletes this and all child tasks (recursively) from the DB. Call this
358    /// when the task is done.
359    #[instrument(level = "trace", skip(self, db,tables), fields(id=%self.id))]
360    pub async fn delete(
361        &self,
362        db: impl Acquire<'_, Database = Postgres>,
363        tables: &dyn TaskTableProvider,
364    ) -> Result<()> {
365        let mut tx = db.begin().await?;
366        let con = tx.acquire().await?;
367
368        let table = tables.tasks_table_full_name();
369        let sql = format!("DELETE FROM {table} WHERE id = $1");
370        sqlx::query(&sql).bind(self.id).execute(&mut *con).await?;
371
372        let notify_fn = tables.tasks_notify_done_fn_full_name();
373        let sql = format!("SELECT {notify_fn}($1)");
374        sqlx::query(&sql).bind(self.id).execute(&mut *con).await?;
375
376        tx.commit().await?;
377
378        Ok(())
379    }
380
381    /// Fullfill task with result/error and mark as done
382    #[instrument(level = "trace", skip(self, db, tables, result, error), fields(id=%self.id))]
383    pub async fn fullfill(
384        &mut self,
385        db: impl Acquire<'_, Database = Postgres>,
386        tables: &dyn TaskTableProvider,
387        result: Option<impl Into<Value>>,
388        error: Option<impl Into<Value>>,
389    ) -> Result<()> {
390        self.result = result.map(Into::into);
391        self.error = error.map(Into::into);
392        self.in_progress = false;
393        self.done = true;
394        self.save(db, tables).await
395    }
396
397    /// Takes `tasks` and listens for notifications until all tasks are done.
398    /// Inbetween notifications, at `poll_interval`, will manually query for
399    /// updated tasks. Will ensure that those tasks are updated when this method
400    /// returns.
401    #[instrument(level = "trace", skip(pool, tables))]
402    async fn wait_for_tasks_to_be_done(
403        tasks: Vec<&mut Self>,
404        pool: &Pool<Postgres>,
405        tables: &dyn TaskTableProvider,
406        poll_interval: Option<std::time::Duration>,
407    ) -> Result<()> {
408        fn update<'a>(
409            tasks_pending: Vec<&'a mut Task>,
410            tasks_done: &mut Vec<&'a mut Task>,
411            ids: std::collections::HashSet<Uuid>,
412        ) -> Vec<&'a mut Task> {
413            let (done, rest): (Vec<_>, Vec<_>) = tasks_pending
414                .into_iter()
415                .partition(|task| ids.contains(&task.id));
416            tasks_done.extend(done);
417            trace!("still waiting for {} tasks", rest.len());
418            rest
419        }
420
421        let start_time = Instant::now();
422        let mut tasks_pending = tasks.into_iter().collect::<Vec<_>>();
423        let mut tasks_done = Vec::new();
424
425        let mut listener = sqlx::postgres::PgListener::connect_with(pool).await?;
426        let queue_name = tables.tasks_queue_done_name();
427        listener.listen(&queue_name).await?;
428
429        let tasks_table = tables.tasks_table();
430        let ready_sql = format!("SELECT id FROM {tasks_table} WHERE id = ANY($1) AND done = true");
431        let existing_sql = format!("SELECT id FROM {tasks_table} WHERE id = ANY($1)");
432
433        loop {
434            trace!("waiting for task {} to be done", tasks_pending.len());
435
436            let ready: Vec<(Uuid,)> = sqlx::query_as(&ready_sql)
437                .bind(tasks_pending.iter().map(|ea| ea.id).collect::<Vec<_>>())
438                .fetch_all(pool)
439                .await?;
440
441            let existing: Vec<(Uuid,)> = sqlx::query_as(&existing_sql)
442                .bind(tasks_pending.iter().map(|ea| ea.id).collect::<Vec<_>>())
443                .fetch_all(pool)
444                .await?;
445            let existing = existing.into_iter().map(|(id,)| id).collect::<HashSet<_>>();
446
447            tasks_pending = update(
448                tasks_pending,
449                &mut tasks_done,
450                HashSet::from_iter(ready.into_iter().map(|(id,)| id)),
451            );
452
453            if tasks_pending.is_empty() {
454                break;
455            }
456
457            // in case one of the tasks we are waiting for was deleted
458            for ea in &tasks_pending {
459                if !existing.contains(&ea.id) {
460                    return Err(Error::TaskDeleted { task: ea.id });
461                }
462            }
463
464            let notification = if let Some(poll_interval) = poll_interval {
465                tokio::select! {
466                    _ = tokio::time::sleep(poll_interval) => {
467                        continue;
468                    },
469                    notification = listener.recv() => notification,
470                }
471            } else {
472                listener.recv().await
473            };
474
475            if let Ok(notification) = notification {
476                if let Ok(id) = Uuid::parse_str(notification.payload()) {
477                    tasks_pending =
478                        update(tasks_pending, &mut tasks_done, HashSet::from_iter([id]));
479                    if tasks_pending.is_empty() {
480                        break;
481                    }
482                }
483            }
484        }
485
486        for task in &mut tasks_done {
487            task.update(pool, tables).await?;
488        }
489
490        debug!(
491            "{} tasks done, wait time: {}ms",
492            tasks_done.len(),
493            (Instant::now() - start_time).as_millis()
494        );
495
496        Ok(())
497    }
498
499    #[instrument(level = "trace", skip(self, db,tables), fields(id=%self.id))]
500    pub async fn wait_until_done(
501        &mut self,
502        db: &Pool<Postgres>,
503        tables: &dyn TaskTableProvider,
504        poll_interval: Option<std::time::Duration>,
505    ) -> Result<()> {
506        Self::wait_for_tasks_to_be_done(vec![self], db, tables, poll_interval).await?;
507        Ok(())
508    }
509
510    pub async fn wait_until_done_and_delete(
511        &mut self,
512        pool: &Pool<Postgres>,
513        tables: &dyn TaskTableProvider,
514        poll_interval: Option<std::time::Duration>,
515    ) -> Result<()> {
516        self.wait_until_done(pool, tables, poll_interval).await?;
517        self.delete(pool, tables).await?;
518        Ok(())
519    }
520
521    /// Takes request returns it deserialized.
522    pub fn request<R: serde::de::DeserializeOwned>(&mut self) -> Result<R> {
523        let request = match self.request.take() {
524            None => {
525                return Err(Error::TaskError {
526                    task: self.id,
527                    message: "Task has no request JSON data".to_string(),
528                })
529            }
530            Some(request) => request,
531        };
532
533        serde_json::from_value(request).map_err(|err| Error::TaskError {
534            task: self.id,
535            message: format!("Error deserializing request JSON from task: {err}"),
536        })
537    }
538
539    /// Converts self into the result payload (or error).
540    fn error(&mut self) -> Option<Error> {
541        self.error.take().map(|error| Error::TaskError {
542            task: self.id,
543            message: match error.get("error").and_then(|msg| msg.as_str()) {
544                Some(msg) => msg.to_string(),
545                None => error.to_string(),
546            },
547        })
548    }
549
550    /// Takes the result returns it deserialized.
551    fn result<R: serde::de::DeserializeOwned>(&mut self) -> Result<Option<R>> {
552        match self.result.take() {
553            None => Ok(None),
554            Some(request) => Ok(Some(serde_json::from_value(request)?)),
555        }
556    }
557
558    /// Takes the result returns it deserialized.
559    #[allow(dead_code)]
560    pub fn result_cloned<R: serde::de::DeserializeOwned>(&self) -> Result<Option<R>> {
561        if let Some(result) = &self.result {
562            Ok(Some(serde_json::from_value(result.clone())?))
563        } else {
564            Ok(None)
565        }
566    }
567
568    /// Converts self into the result payload (or error).
569    pub fn as_result<R: serde::de::DeserializeOwned>(&mut self) -> Result<Option<R>> {
570        self.error().map(Err).unwrap_or_else(|| self.result())
571    }
572}