forge_core_db/models/
draft.rs

1use std::str::FromStr;
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use sqlx::{FromRow, QueryBuilder, Sqlite, SqlitePool};
6use ts_rs_forge::TS;
7use uuid::Uuid;
8
9#[derive(Debug, Clone, Copy, Serialize, Deserialize, TS, PartialEq, Eq)]
10#[serde(rename_all = "snake_case")]
11#[ts(rename_all = "snake_case")]
12pub enum DraftType {
13    FollowUp,
14    Retry,
15}
16
17impl DraftType {
18    pub fn as_str(&self) -> &'static str {
19        match self {
20            DraftType::FollowUp => "follow_up",
21            DraftType::Retry => "retry",
22        }
23    }
24}
25
26impl FromStr for DraftType {
27    type Err = ();
28
29    fn from_str(s: &str) -> Result<Self, Self::Err> {
30        match s {
31            "follow_up" => Ok(DraftType::FollowUp),
32            "retry" => Ok(DraftType::Retry),
33            _ => Err(()),
34        }
35    }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, TS)]
39pub struct Draft {
40    pub id: Uuid,
41    pub task_attempt_id: Uuid,
42    pub draft_type: DraftType,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub retry_process_id: Option<Uuid>,
45    pub prompt: String,
46    pub queued: bool,
47    pub sending: bool,
48    pub variant: Option<String>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub image_ids: Option<Vec<Uuid>>,
51    pub created_at: DateTime<Utc>,
52    pub updated_at: DateTime<Utc>,
53    pub version: i64,
54}
55
56#[derive(Debug, Clone, FromRow)]
57struct DraftRow {
58    pub id: Uuid,
59    pub task_attempt_id: Uuid,
60    pub draft_type: String,
61    pub retry_process_id: Option<Uuid>,
62    pub prompt: String,
63    pub queued: bool,
64    pub sending: bool,
65    pub variant: Option<String>,
66    pub image_ids: Option<String>,
67    pub created_at: DateTime<Utc>,
68    pub updated_at: DateTime<Utc>,
69    pub version: i64,
70}
71
72impl From<DraftRow> for Draft {
73    fn from(r: DraftRow) -> Self {
74        let image_ids = r
75            .image_ids
76            .as_deref()
77            .and_then(|s| serde_json::from_str::<Vec<Uuid>>(s).ok());
78        Draft {
79            id: r.id,
80            task_attempt_id: r.task_attempt_id,
81            draft_type: DraftType::from_str(&r.draft_type).unwrap_or(DraftType::FollowUp),
82            retry_process_id: r.retry_process_id,
83            prompt: r.prompt,
84            queued: r.queued,
85            sending: r.sending,
86            variant: r.variant,
87            image_ids,
88            created_at: r.created_at,
89            updated_at: r.updated_at,
90            version: r.version,
91        }
92    }
93}
94
95#[derive(Debug, Deserialize, TS)]
96pub struct UpsertDraft {
97    pub task_attempt_id: Uuid,
98    pub draft_type: DraftType,
99    pub retry_process_id: Option<Uuid>,
100    pub prompt: String,
101    pub queued: bool,
102    pub variant: Option<String>,
103    pub image_ids: Option<Vec<Uuid>>,
104}
105
106impl Draft {
107    pub async fn find_by_rowid(pool: &SqlitePool, rowid: i64) -> Result<Option<Self>, sqlx::Error> {
108        sqlx::query_as!(
109            DraftRow,
110            r#"SELECT
111                id                       as "id!: Uuid",
112                task_attempt_id          as "task_attempt_id!: Uuid",
113                draft_type,
114                retry_process_id         as "retry_process_id?: Uuid",
115                prompt,
116                queued                   as "queued!: bool",
117                sending                  as "sending!: bool",
118                variant,
119                image_ids,
120                created_at               as "created_at!: DateTime<Utc>",
121                updated_at               as "updated_at!: DateTime<Utc>",
122                version                  as "version!: i64"
123              FROM drafts
124             WHERE rowid = $1"#,
125            rowid
126        )
127        .fetch_optional(pool)
128        .await
129        .map(|opt| opt.map(Draft::from))
130    }
131
132    pub async fn find_by_task_attempt_and_type(
133        pool: &SqlitePool,
134        task_attempt_id: Uuid,
135        draft_type: DraftType,
136    ) -> Result<Option<Self>, sqlx::Error> {
137        let draft_type_str = draft_type.as_str();
138        sqlx::query_as!(
139            DraftRow,
140            r#"SELECT
141                id                       as "id!: Uuid",
142                task_attempt_id          as "task_attempt_id!: Uuid",
143                draft_type,
144                retry_process_id         as "retry_process_id?: Uuid",
145                prompt,
146                queued                   as "queued!: bool",
147                sending                  as "sending!: bool",
148                variant,
149                image_ids,
150                created_at               as "created_at!: DateTime<Utc>",
151                updated_at               as "updated_at!: DateTime<Utc>",
152                version                  as "version!: i64"
153              FROM drafts
154             WHERE task_attempt_id = $1 AND draft_type = $2"#,
155            task_attempt_id,
156            draft_type_str
157        )
158        .fetch_optional(pool)
159        .await
160        .map(|opt| opt.map(Draft::from))
161    }
162
163    pub async fn upsert(pool: &SqlitePool, data: &UpsertDraft) -> Result<Self, sqlx::Error> {
164        // Validate retry_process_id requirement
165        if data.draft_type == DraftType::Retry && data.retry_process_id.is_none() {
166            return Err(sqlx::Error::Protocol(
167                "retry_process_id is required for retry drafts".into(),
168            ));
169        }
170
171        let id = Uuid::new_v4();
172        let image_ids_json = data
173            .image_ids
174            .as_ref()
175            .map(|ids| serde_json::to_string(ids).unwrap_or_else(|_| "[]".to_string()));
176        let draft_type_str = data.draft_type.as_str();
177        let prompt = data.prompt.clone();
178        let variant = data.variant.clone();
179        sqlx::query_as!(
180            DraftRow,
181            r#"INSERT INTO drafts (id, task_attempt_id, draft_type, retry_process_id, prompt, queued, variant, image_ids)
182               VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
183               ON CONFLICT(task_attempt_id, draft_type) DO UPDATE SET
184                 retry_process_id = excluded.retry_process_id,
185                 prompt = excluded.prompt,
186                 queued = excluded.queued,
187                 variant = excluded.variant,
188                 image_ids = excluded.image_ids,
189                 version = drafts.version + 1
190               RETURNING
191                 id                       as "id!: Uuid",
192                 task_attempt_id          as "task_attempt_id!: Uuid",
193                 draft_type,
194                 retry_process_id         as "retry_process_id?: Uuid",
195                 prompt,
196                 queued                   as "queued!: bool",
197                 sending                  as "sending!: bool",
198                 variant,
199                 image_ids,
200                 created_at               as "created_at!: DateTime<Utc>",
201                 updated_at               as "updated_at!: DateTime<Utc>",
202                 version                  as "version!: i64""#,
203            id,
204            data.task_attempt_id,
205            draft_type_str,
206            data.retry_process_id,
207            prompt,
208            data.queued,
209            variant,
210            image_ids_json
211        )
212        .fetch_one(pool)
213        .await
214        .map(Draft::from)
215    }
216
217    pub async fn clear_after_send(
218        pool: &SqlitePool,
219        task_attempt_id: Uuid,
220        draft_type: DraftType,
221    ) -> Result<(), sqlx::Error> {
222        let draft_type_str = draft_type.as_str();
223
224        match draft_type {
225            DraftType::FollowUp => {
226                // Follow-up drafts: update to empty
227                sqlx::query(
228                    r#"UPDATE drafts
229                       SET prompt = '', queued = 0, sending = 0, image_ids = NULL, updated_at = CURRENT_TIMESTAMP, version = version + 1
230                     WHERE task_attempt_id = ? AND draft_type = ?"#,
231                )
232                .bind(task_attempt_id)
233                .bind(draft_type_str)
234                .execute(pool)
235                .await?;
236            }
237            DraftType::Retry => {
238                // Retry drafts: delete the record
239                Self::delete_by_task_attempt_and_type(pool, task_attempt_id, draft_type).await?;
240            }
241        }
242        Ok(())
243    }
244
245    pub async fn delete_by_task_attempt_and_type(
246        pool: &SqlitePool,
247        task_attempt_id: Uuid,
248        draft_type: DraftType,
249    ) -> Result<(), sqlx::Error> {
250        sqlx::query(r#"DELETE FROM drafts WHERE task_attempt_id = ? AND draft_type = ?"#)
251            .bind(task_attempt_id)
252            .bind(draft_type.as_str())
253            .execute(pool)
254            .await?;
255
256        Ok(())
257    }
258
259    /// Attempt to atomically mark this draft as "sending" if it's currently queued and non-empty.
260    /// Returns true if the row was updated (we acquired the send lock), false otherwise.
261    pub async fn try_mark_sending(
262        pool: &SqlitePool,
263        task_attempt_id: Uuid,
264        draft_type: DraftType,
265    ) -> Result<bool, sqlx::Error> {
266        let draft_type_str = draft_type.as_str();
267        let result = sqlx::query(
268            r#"UPDATE drafts
269               SET sending = 1, updated_at = CURRENT_TIMESTAMP, version = version + 1
270             WHERE task_attempt_id = ?
271               AND draft_type = ?
272               AND queued = 1
273               AND sending = 0
274               AND TRIM(prompt) != ''"#,
275        )
276        .bind(task_attempt_id)
277        .bind(draft_type_str)
278        .execute(pool)
279        .await?;
280
281        Ok(result.rows_affected() > 0)
282    }
283
284    /// Partial update on a draft by attempt and type. Updates only provided fields
285    /// and bumps `updated_at` and `version` when any change occurs.
286    pub async fn update_partial(
287        pool: &SqlitePool,
288        task_attempt_id: Uuid,
289        draft_type: DraftType,
290        prompt: Option<String>,
291        variant: Option<Option<String>>,
292        image_ids: Option<Vec<Uuid>>,
293        retry_process_id: Option<Uuid>,
294    ) -> Result<(), sqlx::Error> {
295        if retry_process_id.is_none()
296            && prompt.is_none()
297            && variant.is_none()
298            && image_ids.is_none()
299        {
300            return Ok(());
301        }
302        let mut query = QueryBuilder::<Sqlite>::new("UPDATE drafts SET ");
303
304        let mut separated = query.separated(", ");
305        if let Some(rpid) = retry_process_id {
306            separated.push("retry_process_id = ");
307            separated.push_bind_unseparated(rpid);
308        }
309        if let Some(p) = prompt {
310            separated.push("prompt = ");
311            separated.push_bind_unseparated(p);
312        }
313        if let Some(v_opt) = variant {
314            separated.push("variant = ");
315            match v_opt {
316                Some(v) => separated.push_bind_unseparated(v),
317                None => separated.push_bind_unseparated(Option::<String>::None),
318            };
319        }
320        if let Some(ids) = image_ids {
321            let image_ids_json = serde_json::to_string(&ids).unwrap_or_else(|_| "[]".to_string());
322            separated.push("image_ids = ");
323            separated.push_bind_unseparated(image_ids_json);
324        }
325        separated.push("updated_at = CURRENT_TIMESTAMP");
326        separated.push("version = version + 1");
327
328        query.push(" WHERE task_attempt_id = ");
329        query.push_bind(task_attempt_id);
330        query.push(" AND draft_type = ");
331        query.push_bind(draft_type.as_str());
332        query.build().execute(pool).await?;
333        Ok(())
334    }
335
336    /// Set queued flag (and bump metadata) for a draft by attempt and type.
337    pub async fn set_queued(
338        pool: &SqlitePool,
339        task_attempt_id: Uuid,
340        draft_type: DraftType,
341        queued: bool,
342        expected_queued: Option<bool>,
343        expected_version: Option<i64>,
344    ) -> Result<u64, sqlx::Error> {
345        let result = sqlx::query(
346            r#"UPDATE drafts
347                   SET queued = CASE
348                                   WHEN ?1 THEN (TRIM(prompt) <> '')
349                                   ELSE 0
350                                 END,
351                       updated_at = CURRENT_TIMESTAMP,
352                       version    = version + 1
353                 WHERE task_attempt_id = ?2
354                   AND draft_type      = ?3
355                   AND (?4 IS NULL OR queued  = ?4)
356                   AND (?5 IS NULL OR version = ?5)"#,
357        )
358        .bind(queued as i64)
359        .bind(task_attempt_id)
360        .bind(draft_type.as_str())
361        .bind(expected_queued.map(|value| value as i64))
362        .bind(expected_version)
363        .execute(pool)
364        .await?;
365
366        Ok(result.rows_affected())
367    }
368}