1use chrono::prelude::*;
2use serde::Serialize;
3use serde_json::Value;
4use sqlx::{Acquire, PgExecutor, Pool, Postgres};
5use std::collections::HashSet;
6use std::time::Instant;
7use uuid::Uuid;
8
9use crate::{Error, Result, TaskTableProvider, TaskType};
10
11#[derive(Default)]
12pub struct TaskBuilder {
13 task_type: String,
14 req: Option<Value>,
15 parent: Option<Uuid>,
16 id: Option<Uuid>,
17}
18
19impl TaskBuilder {
20 pub fn new(task_type: impl TaskType) -> Self {
21 Self {
22 task_type: task_type.to_string(),
23 ..Default::default()
24 }
25 }
26
27 pub fn with_request(mut self, req: impl Serialize) -> Result<Self> {
28 self.req = Some(serde_json::to_value(req)?);
29 Ok(self)
30 }
31
32 #[must_use]
33 pub fn with_parent(mut self, parent: Uuid) -> Self {
34 self.parent = Some(parent);
35 self
36 }
37
38 #[must_use]
39 pub fn with_id(mut self, id: Uuid) -> Self {
40 self.id = Some(id);
41 self
42 }
43
44 pub async fn build<'a, DB, P>(self, db: DB, tables: &P) -> Result<Task>
45 where
46 P: TaskTableProvider,
47 DB: Acquire<'a, Database = Postgres>,
48 {
49 let Self {
50 task_type,
51 req,
52 parent,
53 id,
54 } = self;
55 Task::create_task(db, tables, id, task_type, req, parent).await
56 }
57}
58
59#[derive(Debug, Clone, sqlx::FromRow)]
60pub struct Task {
61 pub id: Uuid,
62 pub parent: Option<Uuid>,
63 pub created_at: DateTime<Utc>,
64 pub updated_at: DateTime<Utc>,
65 pub task_type: String,
66 pub request: Option<Value>,
67 pub result: Option<Value>,
68 pub error: Option<Value>,
69 pub in_progress: bool,
70 pub done: bool,
71}
72
73impl Task {
74 #[instrument(level = "trace", skip(db, tables, req))]
75 pub async fn create_task<'a, DB, P>(
76 db: DB,
77 tables: &P,
78 id: Option<Uuid>,
79 task_type: String,
80 req: Option<Value>,
81 parent: Option<Uuid>,
82 ) -> Result<Self>
83 where
84 P: TaskTableProvider,
85 DB: Acquire<'a, Database = Postgres>,
86 {
87 let id = id.unwrap_or_else(Uuid::new_v4);
88 let req = Some(serde_json::to_value(req)?);
89
90 let mut tx = db.begin().await?;
91
92 let table = tables.tasks_table_full_name();
93 let notify_fn = tables.tasks_notify_fn_full_name();
94
95 let sql = format!(
96 "
97INSERT INTO {table} (id, task_type, request, parent, in_progress, done)
98VALUES ($1, $2, $3, $4, false, false)
99RETURNING *
100"
101 );
102 let task: Self = sqlx::query_as(&sql)
103 .bind(id)
104 .bind(task_type.to_string())
105 .bind(req)
106 .bind(parent)
107 .fetch_one(&mut *tx)
108 .await?;
109
110 let sql = format!("SELECT {notify_fn}($1)");
111 sqlx::query(&sql).bind(task.id).execute(&mut *tx).await?;
112
113 tx.commit().await?;
114
115 debug!("created task {task_type:?} {id}");
116
117 Ok(task)
118 }
119
120 pub async fn load(
121 db: impl PgExecutor<'_>,
122 tables: &dyn TaskTableProvider,
123 id: Uuid,
124 ) -> Result<Option<Self>> {
125 let table = tables.tasks_table_full_name();
126 let sql = format!("SELECT * FROM {table} WHERE id = $1");
127 Ok(sqlx::query_as(&sql).bind(id).fetch_optional(db).await?)
128 }
129
130 pub async fn with_children(
131 self,
132 db: impl PgExecutor<'_>,
133 tables: &dyn TaskTableProvider,
134 recursive: bool,
135 ) -> Result<Vec<Self>> {
136 Self::load_children(db, tables, self.id, true, recursive).await
137 }
138
139 pub async fn children(
140 &self,
141 db: impl PgExecutor<'_>,
142 tables: &dyn TaskTableProvider,
143 recursive: bool,
144 ) -> Result<Vec<Self>> {
145 Self::load_children(db, tables, self.id, false, recursive).await
146 }
147
148 #[instrument(level = "trace", skip(db, tables))]
149 pub async fn load_children(
150 db: impl PgExecutor<'_>,
151 tables: &dyn TaskTableProvider,
152 id: Uuid,
153 include_self: bool,
154 recursive: bool,
155 ) -> Result<Vec<Self>> {
156 let table = tables.tasks_table_full_name();
157 let sql = if recursive {
158 let where_clause = if include_self { "" } else { "WHERE t.id != $1" };
159 format!(
160 "
161WITH RECURSIVE tasks_and_subtasks(id, parent) AS (
162 SELECT t.* FROM {table} t
163 WHERE t.id = $1
164 UNION ALL
165 SELECT child_task.*
166 FROM tasks_and_subtasks t, {table} child_task
167 WHERE t.id = child_task.parent
168)
169SELECT * FROM tasks_and_subtasks {where_clause}
170"
171 )
172 } else {
173 let self_condition = if include_self { "OR t.id = $1" } else { "" };
174 format!("SELECT * FROM {table} WHERE parent = $1 {self_condition}")
175 };
176
177 Ok(sqlx::query_as(&sql).bind(id).fetch_all(db).await?)
178 }
179
180 #[instrument(level = "trace", skip(db, tables))]
181 pub async fn load_any_waiting(
182 db: impl PgExecutor<'_>,
183 tables: &dyn TaskTableProvider,
184 allowed_types: &[impl TaskType],
185 ) -> Result<Option<Self>> {
186 let allowed_types = allowed_types
187 .iter()
188 .map(|ea| ea.to_string())
189 .collect::<Vec<_>>();
190 let table = tables.tasks_table_full_name();
191 let table_ready = tables.tasks_ready_view_full_name();
192 let sql = format!(
193 "
194UPDATE {table}
195SET updated_at = NOW(),
196 in_progress = true
197WHERE id = (SELECT id
198 FROM {table}
199 WHERE task_type = ANY($1) AND id IN (SELECT id FROM {table_ready})
200 LIMIT 1
201 FOR UPDATE SKIP LOCKED)
202RETURNING *;
203"
204 );
205 let task: Option<Self> = sqlx::query_as(&sql)
206 .bind(allowed_types)
207 .fetch_optional(db)
208 .await?;
209
210 Ok(task)
211 }
212
213 #[instrument(level = "trace", skip(db, tables))]
214 pub async fn load_waiting(
215 id: Uuid,
216 db: impl PgExecutor<'_>,
217 tables: &dyn TaskTableProvider,
218 allowed_types: &[impl TaskType],
219 ) -> Result<Option<Self>> {
220 let allowed_types = allowed_types
221 .iter()
222 .map(|ea| ea.to_string())
223 .collect::<Vec<_>>();
224 let table = tables.tasks_table_full_name();
225 let table_ready = tables.tasks_ready_view_full_name();
226 let sql = format!(
227 "
228UPDATE {table}
229SET updated_at = NOW(),
230 in_progress = true
231WHERE id = (SELECT id
232 FROM {table}
233 WHERE id = $1 AND task_type = ANY($2) AND id IN (SELECT id FROM {table_ready})
234 LIMIT 1
235 FOR UPDATE SKIP LOCKED)
236RETURNING *;
237"
238 );
239 let task: Option<Self> = sqlx::query_as(&sql)
240 .bind(sqlx::types::Uuid::from_u128(id.as_u128()))
241 .bind(allowed_types)
242 .fetch_optional(db)
243 .await?;
244
245 Ok(task)
246 }
247
248 #[instrument(level = "trace", skip(db, tables))]
250 pub async fn set_error(
251 id: Uuid,
252 db: impl Acquire<'_, Database = Postgres>,
253 tables: &dyn TaskTableProvider,
254 error: Value,
255 ) -> Result<()> {
256 let mut tx = db.begin().await?;
257
258 let table = tables.tasks_table_full_name();
259 let sql = format!(
260 "
261UPDATE {table}
262SET updated_at = NOW(),
263 error = $2,
264 in_progress = false
265WHERE id = $1"
266 );
267
268 sqlx::query(&sql)
269 .bind(id)
270 .bind(error)
271 .execute(&mut *tx)
272 .await?;
273
274 tx.commit().await?;
287
288 Ok(())
289 }
290
291 #[instrument(level = "trace", skip(self, db,tables), fields(id=%self.id))]
293 pub async fn save(
294 &self,
295 db: impl Acquire<'_, Database = Postgres>,
296 tables: &dyn TaskTableProvider,
297 ) -> Result<()> {
298 let mut tx = db.begin().await?;
299
300 let table = tables.tasks_table_full_name();
301 let sql = format!(
302 "
303UPDATE {table}
304SET updated_at = NOW(),
305 request = $2,
306 result = $3,
307 error = $4,
308 in_progress = $5,
309 done = $6
310WHERE id = $1"
311 );
312
313 sqlx::query(&sql)
314 .bind(self.id)
315 .bind(&self.request)
316 .bind(&self.result)
317 .bind(&self.error)
318 .bind(self.in_progress)
319 .bind(self.done)
320 .execute(&mut *tx)
321 .await?;
322
323 if self.done {
324 debug!("notifying about task done {}", self.id);
325 let notify_fn = tables.tasks_notify_done_fn_full_name();
326 let sql = format!("SELECT {notify_fn}($1)");
327 sqlx::query(&sql).bind(self.id).execute(&mut *tx).await?;
328 } else {
329 debug!("notifying about task ready again {}", self.id);
330 let notify_fn = tables.tasks_notify_fn_full_name();
331 let sql = format!("SELECT {notify_fn}($1)");
332 sqlx::query(&sql).bind(self.id).execute(&mut *tx).await?;
333 }
334
335 tx.commit().await?;
336
337 Ok(())
338 }
339
340 #[instrument(level = "trace", skip(self, db,tables), fields(id=%self.id))]
342 pub async fn update(
343 &mut self,
344 db: impl PgExecutor<'_>,
345 tables: &dyn TaskTableProvider,
346 ) -> Result<()> {
347 info!("UPDATING {}", self.id);
348 let table = tables.tasks_table_full_name();
349 let sql = format!("SELECT * FROM {table} WHERE id = $1");
350 let me: Self = sqlx::query_as(&sql).bind(self.id).fetch_one(db).await?;
351
352 let _ = std::mem::replace(self, me);
353
354 Ok(())
355 }
356
357 #[instrument(level = "trace", skip(self, db,tables), fields(id=%self.id))]
360 pub async fn delete(
361 &self,
362 db: impl Acquire<'_, Database = Postgres>,
363 tables: &dyn TaskTableProvider,
364 ) -> Result<()> {
365 let mut tx = db.begin().await?;
366 let con = tx.acquire().await?;
367
368 let table = tables.tasks_table_full_name();
369 let sql = format!("DELETE FROM {table} WHERE id = $1");
370 sqlx::query(&sql).bind(self.id).execute(&mut *con).await?;
371
372 let notify_fn = tables.tasks_notify_done_fn_full_name();
373 let sql = format!("SELECT {notify_fn}($1)");
374 sqlx::query(&sql).bind(self.id).execute(&mut *con).await?;
375
376 tx.commit().await?;
377
378 Ok(())
379 }
380
381 #[instrument(level = "trace", skip(self, db, tables, result, error), fields(id=%self.id))]
383 pub async fn fullfill(
384 &mut self,
385 db: impl Acquire<'_, Database = Postgres>,
386 tables: &dyn TaskTableProvider,
387 result: Option<impl Into<Value>>,
388 error: Option<impl Into<Value>>,
389 ) -> Result<()> {
390 self.result = result.map(Into::into);
391 self.error = error.map(Into::into);
392 self.in_progress = false;
393 self.done = true;
394 self.save(db, tables).await
395 }
396
397 #[instrument(level = "trace", skip(pool, tables))]
402 async fn wait_for_tasks_to_be_done(
403 tasks: Vec<&mut Self>,
404 pool: &Pool<Postgres>,
405 tables: &dyn TaskTableProvider,
406 poll_interval: Option<std::time::Duration>,
407 ) -> Result<()> {
408 fn update<'a>(
409 tasks_pending: Vec<&'a mut Task>,
410 tasks_done: &mut Vec<&'a mut Task>,
411 ids: std::collections::HashSet<Uuid>,
412 ) -> Vec<&'a mut Task> {
413 let (done, rest): (Vec<_>, Vec<_>) = tasks_pending
414 .into_iter()
415 .partition(|task| ids.contains(&task.id));
416 tasks_done.extend(done);
417 trace!("still waiting for {} tasks", rest.len());
418 rest
419 }
420
421 let start_time = Instant::now();
422 let mut tasks_pending = tasks.into_iter().collect::<Vec<_>>();
423 let mut tasks_done = Vec::new();
424
425 let mut listener = sqlx::postgres::PgListener::connect_with(pool).await?;
426 let queue_name = tables.tasks_queue_done_name();
427 listener.listen(&queue_name).await?;
428
429 let tasks_table = tables.tasks_table();
430 let ready_sql = format!("SELECT id FROM {tasks_table} WHERE id = ANY($1) AND done = true");
431 let existing_sql = format!("SELECT id FROM {tasks_table} WHERE id = ANY($1)");
432
433 loop {
434 trace!("waiting for task {} to be done", tasks_pending.len());
435
436 let ready: Vec<(Uuid,)> = sqlx::query_as(&ready_sql)
437 .bind(tasks_pending.iter().map(|ea| ea.id).collect::<Vec<_>>())
438 .fetch_all(pool)
439 .await?;
440
441 let existing: Vec<(Uuid,)> = sqlx::query_as(&existing_sql)
442 .bind(tasks_pending.iter().map(|ea| ea.id).collect::<Vec<_>>())
443 .fetch_all(pool)
444 .await?;
445 let existing = existing.into_iter().map(|(id,)| id).collect::<HashSet<_>>();
446
447 tasks_pending = update(
448 tasks_pending,
449 &mut tasks_done,
450 HashSet::from_iter(ready.into_iter().map(|(id,)| id)),
451 );
452
453 if tasks_pending.is_empty() {
454 break;
455 }
456
457 for ea in &tasks_pending {
459 if !existing.contains(&ea.id) {
460 return Err(Error::TaskDeleted { task: ea.id });
461 }
462 }
463
464 let notification = if let Some(poll_interval) = poll_interval {
465 tokio::select! {
466 _ = tokio::time::sleep(poll_interval) => {
467 continue;
468 },
469 notification = listener.recv() => notification,
470 }
471 } else {
472 listener.recv().await
473 };
474
475 if let Ok(notification) = notification {
476 if let Ok(id) = Uuid::parse_str(notification.payload()) {
477 tasks_pending =
478 update(tasks_pending, &mut tasks_done, HashSet::from_iter([id]));
479 if tasks_pending.is_empty() {
480 break;
481 }
482 }
483 }
484 }
485
486 for task in &mut tasks_done {
487 task.update(pool, tables).await?;
488 }
489
490 debug!(
491 "{} tasks done, wait time: {}ms",
492 tasks_done.len(),
493 (Instant::now() - start_time).as_millis()
494 );
495
496 Ok(())
497 }
498
499 #[instrument(level = "trace", skip(self, db,tables), fields(id=%self.id))]
500 pub async fn wait_until_done(
501 &mut self,
502 db: &Pool<Postgres>,
503 tables: &dyn TaskTableProvider,
504 poll_interval: Option<std::time::Duration>,
505 ) -> Result<()> {
506 Self::wait_for_tasks_to_be_done(vec![self], db, tables, poll_interval).await?;
507 Ok(())
508 }
509
510 pub async fn wait_until_done_and_delete(
511 &mut self,
512 pool: &Pool<Postgres>,
513 tables: &dyn TaskTableProvider,
514 poll_interval: Option<std::time::Duration>,
515 ) -> Result<()> {
516 self.wait_until_done(pool, tables, poll_interval).await?;
517 self.delete(pool, tables).await?;
518 Ok(())
519 }
520
521 pub fn request<R: serde::de::DeserializeOwned>(&mut self) -> Result<R> {
523 let request = match self.request.take() {
524 None => {
525 return Err(Error::TaskError {
526 task: self.id,
527 message: "Task has no request JSON data".to_string(),
528 })
529 }
530 Some(request) => request,
531 };
532
533 serde_json::from_value(request).map_err(|err| Error::TaskError {
534 task: self.id,
535 message: format!("Error deserializing request JSON from task: {err}"),
536 })
537 }
538
539 fn error(&mut self) -> Option<Error> {
541 self.error.take().map(|error| Error::TaskError {
542 task: self.id,
543 message: match error.get("error").and_then(|msg| msg.as_str()) {
544 Some(msg) => msg.to_string(),
545 None => error.to_string(),
546 },
547 })
548 }
549
550 fn result<R: serde::de::DeserializeOwned>(&mut self) -> Result<Option<R>> {
552 match self.result.take() {
553 None => Ok(None),
554 Some(request) => Ok(Some(serde_json::from_value(request)?)),
555 }
556 }
557
558 #[allow(dead_code)]
560 pub fn result_cloned<R: serde::de::DeserializeOwned>(&self) -> Result<Option<R>> {
561 if let Some(result) = &self.result {
562 Ok(Some(serde_json::from_value(result.clone())?))
563 } else {
564 Ok(None)
565 }
566 }
567
568 pub fn as_result<R: serde::de::DeserializeOwned>(&mut self) -> Result<Option<R>> {
570 self.error().map(Err).unwrap_or_else(|| self.result())
571 }
572}