docbox_database/models/
tasks.rs

1use std::future::Future;
2
3use crate::{DbExecutor, DbPool, DbResult};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use sqlx::{Database, Decode, error::BoxDynError, prelude::FromRow};
7use tracing::Instrument;
8use utoipa::ToSchema;
9use uuid::Uuid;
10
11use super::document_box::DocumentBoxScopeRaw;
12
13pub type TaskId = Uuid;
14
15/// Represents a stored asynchronous task progress
16#[derive(Debug, FromRow, Serialize, ToSchema)]
17pub struct Task {
18    /// Unique ID of the task
19    pub id: Uuid,
20
21    /// ID of the document box the task belongs to
22    pub document_box: DocumentBoxScopeRaw,
23
24    /// Status of the task
25    pub status: TaskStatus,
26
27    /// Output data from the task completion
28    pub output_data: Option<serde_json::Value>,
29
30    /// When the task was created
31    pub created_at: DateTime<Utc>,
32
33    // When execution of the task completed
34    pub completed_at: Option<DateTime<Utc>>,
35}
36
37#[derive(
38    Debug, Clone, Copy, strum::EnumString, strum::Display, Deserialize, Serialize, ToSchema,
39)]
40pub enum TaskStatus {
41    Pending,
42    Completed,
43    Failed,
44}
45
46impl<DB: Database> sqlx::Type<DB> for TaskStatus
47where
48    String: sqlx::Type<DB>,
49{
50    fn type_info() -> DB::TypeInfo {
51        String::type_info()
52    }
53}
54
55impl<'r, DB: Database> Decode<'r, DB> for TaskStatus
56where
57    String: Decode<'r, DB>,
58{
59    fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
60        let value = <String as Decode<DB>>::decode(value)?;
61        Ok(value.parse()?)
62    }
63}
64
65pub async fn background_task<Fut>(
66    db: DbPool,
67    scope: DocumentBoxScopeRaw,
68    future: Fut,
69) -> DbResult<(TaskId, DateTime<Utc>)>
70where
71    Fut: Future<Output = (TaskStatus, serde_json::Value)> + Send + 'static,
72{
73    // Create task for progression
74    let task = Task::create(&db, scope).await?;
75
76    let task_id = task.id;
77    let created_at = task.created_at;
78
79    let span = tracing::Span::current();
80
81    // Swap background task
82    tokio::spawn(
83        async move {
84            let (status, output) = future.await;
85
86            // Update task completion
87            if let Err(cause) = task.complete_task(&db, status, Some(output)).await {
88                tracing::error!(?cause, "failed to mark task as complete");
89            }
90        }
91        .instrument(span),
92    );
93
94    Ok((task_id, created_at))
95}
96
97impl Task {
98    /// Stores / updates the stored user data, returns back the user ID
99    pub async fn create(
100        db: impl DbExecutor<'_>,
101        document_box: DocumentBoxScopeRaw,
102    ) -> DbResult<Task> {
103        let task_id = Uuid::new_v4();
104        let status = TaskStatus::Pending;
105        let created_at = Utc::now();
106
107        sqlx::query(
108            r#"
109            INSERT INTO "docbox_tasks" ("id", "document_box", "status", "created_at") 
110            VALUES ($1, $2, $3, $4)
111        "#,
112        )
113        .bind(task_id)
114        .bind(document_box.as_str())
115        .bind(status.to_string())
116        .bind(created_at)
117        .execute(db)
118        .await?;
119
120        Ok(Task {
121            id: task_id,
122            document_box,
123            status,
124            output_data: None,
125            created_at,
126            completed_at: None,
127        })
128    }
129
130    pub async fn find(
131        db: impl DbExecutor<'_>,
132        id: TaskId,
133        document_box: &DocumentBoxScopeRaw,
134    ) -> DbResult<Option<Task>> {
135        sqlx::query_as(r#"SELECT * FROM "docbox_tasks" WHERE "id" = $1 AND "document_box" = $2"#)
136            .bind(id)
137            .bind(document_box)
138            .fetch_optional(db)
139            .await
140    }
141
142    /// Mark the task as completed and set its output data
143    pub async fn complete_task(
144        mut self,
145        db: impl DbExecutor<'_>,
146        status: TaskStatus,
147        output_data: Option<serde_json::Value>,
148    ) -> DbResult<Task> {
149        let completed_at = Utc::now();
150
151        sqlx::query(
152            r#"UPDATE "docbox_tasks" SET 
153            "status" = $1, 
154            "output_data" = $2, 
155            "completed_at" = $3
156            WHERE "id" = $4"#,
157        )
158        .bind(status.to_string())
159        .bind(output_data.clone())
160        .bind(completed_at)
161        .bind(self.id)
162        .execute(db)
163        .await?;
164
165        self.status = status;
166        self.output_data = output_data.clone();
167        self.completed_at = Some(completed_at);
168
169        Ok(self)
170    }
171}