docbox_database/models/
tasks.rs1use std::future::Future;
2
3use crate::{DbExecutor, DbPool, DbResult};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use sqlx::{Database, Decode, error::BoxDynError, prelude::FromRow};
7use tracing::Instrument;
8use utoipa::ToSchema;
9use uuid::Uuid;
10
11use super::document_box::DocumentBoxScopeRaw;
12
13pub type TaskId = Uuid;
14
15#[derive(Debug, FromRow, Serialize, ToSchema)]
17pub struct Task {
18 pub id: Uuid,
20
21 pub document_box: DocumentBoxScopeRaw,
23
24 pub status: TaskStatus,
26
27 pub output_data: Option<serde_json::Value>,
29
30 pub created_at: DateTime<Utc>,
32
33 pub completed_at: Option<DateTime<Utc>>,
35}
36
37#[derive(
38 Debug, Clone, Copy, strum::EnumString, strum::Display, Deserialize, Serialize, ToSchema,
39)]
40pub enum TaskStatus {
41 Pending,
42 Completed,
43 Failed,
44}
45
46impl<DB: Database> sqlx::Type<DB> for TaskStatus
47where
48 String: sqlx::Type<DB>,
49{
50 fn type_info() -> DB::TypeInfo {
51 String::type_info()
52 }
53}
54
55impl<'r, DB: Database> Decode<'r, DB> for TaskStatus
56where
57 String: Decode<'r, DB>,
58{
59 fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
60 let value = <String as Decode<DB>>::decode(value)?;
61 Ok(value.parse()?)
62 }
63}
64
65pub async fn background_task<Fut>(
66 db: DbPool,
67 scope: DocumentBoxScopeRaw,
68 future: Fut,
69) -> DbResult<(TaskId, DateTime<Utc>)>
70where
71 Fut: Future<Output = (TaskStatus, serde_json::Value)> + Send + 'static,
72{
73 let task = Task::create(&db, scope).await?;
75
76 let task_id = task.id;
77 let created_at = task.created_at;
78
79 let span = tracing::Span::current();
80
81 tokio::spawn(
83 async move {
84 let (status, output) = future.await;
85
86 if let Err(cause) = task.complete_task(&db, status, Some(output)).await {
88 tracing::error!(?cause, "failed to mark task as complete");
89 }
90 }
91 .instrument(span),
92 );
93
94 Ok((task_id, created_at))
95}
96
97impl Task {
98 pub async fn create(
100 db: impl DbExecutor<'_>,
101 document_box: DocumentBoxScopeRaw,
102 ) -> DbResult<Task> {
103 let task_id = Uuid::new_v4();
104 let status = TaskStatus::Pending;
105 let created_at = Utc::now();
106
107 sqlx::query(
108 r#"
109 INSERT INTO "docbox_tasks" ("id", "document_box", "status", "created_at")
110 VALUES ($1, $2, $3, $4)
111 "#,
112 )
113 .bind(task_id)
114 .bind(document_box.as_str())
115 .bind(status.to_string())
116 .bind(created_at)
117 .execute(db)
118 .await?;
119
120 Ok(Task {
121 id: task_id,
122 document_box,
123 status,
124 output_data: None,
125 created_at,
126 completed_at: None,
127 })
128 }
129
130 pub async fn find(
131 db: impl DbExecutor<'_>,
132 id: TaskId,
133 document_box: &DocumentBoxScopeRaw,
134 ) -> DbResult<Option<Task>> {
135 sqlx::query_as(r#"SELECT * FROM "docbox_tasks" WHERE "id" = $1 AND "document_box" = $2"#)
136 .bind(id)
137 .bind(document_box)
138 .fetch_optional(db)
139 .await
140 }
141
142 pub async fn complete_task(
144 mut self,
145 db: impl DbExecutor<'_>,
146 status: TaskStatus,
147 output_data: Option<serde_json::Value>,
148 ) -> DbResult<Task> {
149 let completed_at = Utc::now();
150
151 sqlx::query(
152 r#"UPDATE "docbox_tasks" SET
153 "status" = $1,
154 "output_data" = $2,
155 "completed_at" = $3
156 WHERE "id" = $4"#,
157 )
158 .bind(status.to_string())
159 .bind(output_data.clone())
160 .bind(completed_at)
161 .bind(self.id)
162 .execute(db)
163 .await?;
164
165 self.status = status;
166 self.output_data = output_data.clone();
167 self.completed_at = Some(completed_at);
168
169 Ok(self)
170 }
171}