1use std::str::FromStr;
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use sqlx::{FromRow, QueryBuilder, Sqlite, SqlitePool};
6use ts_rs_forge::TS;
7use uuid::Uuid;
8
9#[derive(Debug, Clone, Copy, Serialize, Deserialize, TS, PartialEq, Eq)]
10#[serde(rename_all = "snake_case")]
11#[ts(rename_all = "snake_case")]
12pub enum DraftType {
13 FollowUp,
14 Retry,
15}
16
17impl DraftType {
18 pub fn as_str(&self) -> &'static str {
19 match self {
20 DraftType::FollowUp => "follow_up",
21 DraftType::Retry => "retry",
22 }
23 }
24}
25
26impl FromStr for DraftType {
27 type Err = ();
28
29 fn from_str(s: &str) -> Result<Self, Self::Err> {
30 match s {
31 "follow_up" => Ok(DraftType::FollowUp),
32 "retry" => Ok(DraftType::Retry),
33 _ => Err(()),
34 }
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, TS)]
39pub struct Draft {
40 pub id: Uuid,
41 pub task_attempt_id: Uuid,
42 pub draft_type: DraftType,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub retry_process_id: Option<Uuid>,
45 pub prompt: String,
46 pub queued: bool,
47 pub sending: bool,
48 pub variant: Option<String>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub image_ids: Option<Vec<Uuid>>,
51 pub created_at: DateTime<Utc>,
52 pub updated_at: DateTime<Utc>,
53 pub version: i64,
54}
55
56#[derive(Debug, Clone, FromRow)]
57struct DraftRow {
58 pub id: Uuid,
59 pub task_attempt_id: Uuid,
60 pub draft_type: String,
61 pub retry_process_id: Option<Uuid>,
62 pub prompt: String,
63 pub queued: bool,
64 pub sending: bool,
65 pub variant: Option<String>,
66 pub image_ids: Option<String>,
67 pub created_at: DateTime<Utc>,
68 pub updated_at: DateTime<Utc>,
69 pub version: i64,
70}
71
72impl From<DraftRow> for Draft {
73 fn from(r: DraftRow) -> Self {
74 let image_ids = r
75 .image_ids
76 .as_deref()
77 .and_then(|s| serde_json::from_str::<Vec<Uuid>>(s).ok());
78 Draft {
79 id: r.id,
80 task_attempt_id: r.task_attempt_id,
81 draft_type: DraftType::from_str(&r.draft_type).unwrap_or(DraftType::FollowUp),
82 retry_process_id: r.retry_process_id,
83 prompt: r.prompt,
84 queued: r.queued,
85 sending: r.sending,
86 variant: r.variant,
87 image_ids,
88 created_at: r.created_at,
89 updated_at: r.updated_at,
90 version: r.version,
91 }
92 }
93}
94
95#[derive(Debug, Deserialize, TS)]
96pub struct UpsertDraft {
97 pub task_attempt_id: Uuid,
98 pub draft_type: DraftType,
99 pub retry_process_id: Option<Uuid>,
100 pub prompt: String,
101 pub queued: bool,
102 pub variant: Option<String>,
103 pub image_ids: Option<Vec<Uuid>>,
104}
105
106impl Draft {
107 pub async fn find_by_rowid(pool: &SqlitePool, rowid: i64) -> Result<Option<Self>, sqlx::Error> {
108 sqlx::query_as!(
109 DraftRow,
110 r#"SELECT
111 id as "id!: Uuid",
112 task_attempt_id as "task_attempt_id!: Uuid",
113 draft_type,
114 retry_process_id as "retry_process_id?: Uuid",
115 prompt,
116 queued as "queued!: bool",
117 sending as "sending!: bool",
118 variant,
119 image_ids,
120 created_at as "created_at!: DateTime<Utc>",
121 updated_at as "updated_at!: DateTime<Utc>",
122 version as "version!: i64"
123 FROM drafts
124 WHERE rowid = $1"#,
125 rowid
126 )
127 .fetch_optional(pool)
128 .await
129 .map(|opt| opt.map(Draft::from))
130 }
131
132 pub async fn find_by_task_attempt_and_type(
133 pool: &SqlitePool,
134 task_attempt_id: Uuid,
135 draft_type: DraftType,
136 ) -> Result<Option<Self>, sqlx::Error> {
137 let draft_type_str = draft_type.as_str();
138 sqlx::query_as!(
139 DraftRow,
140 r#"SELECT
141 id as "id!: Uuid",
142 task_attempt_id as "task_attempt_id!: Uuid",
143 draft_type,
144 retry_process_id as "retry_process_id?: Uuid",
145 prompt,
146 queued as "queued!: bool",
147 sending as "sending!: bool",
148 variant,
149 image_ids,
150 created_at as "created_at!: DateTime<Utc>",
151 updated_at as "updated_at!: DateTime<Utc>",
152 version as "version!: i64"
153 FROM drafts
154 WHERE task_attempt_id = $1 AND draft_type = $2"#,
155 task_attempt_id,
156 draft_type_str
157 )
158 .fetch_optional(pool)
159 .await
160 .map(|opt| opt.map(Draft::from))
161 }
162
163 pub async fn upsert(pool: &SqlitePool, data: &UpsertDraft) -> Result<Self, sqlx::Error> {
164 if data.draft_type == DraftType::Retry && data.retry_process_id.is_none() {
166 return Err(sqlx::Error::Protocol(
167 "retry_process_id is required for retry drafts".into(),
168 ));
169 }
170
171 let id = Uuid::new_v4();
172 let image_ids_json = data
173 .image_ids
174 .as_ref()
175 .map(|ids| serde_json::to_string(ids).unwrap_or_else(|_| "[]".to_string()));
176 let draft_type_str = data.draft_type.as_str();
177 let prompt = data.prompt.clone();
178 let variant = data.variant.clone();
179 sqlx::query_as!(
180 DraftRow,
181 r#"INSERT INTO drafts (id, task_attempt_id, draft_type, retry_process_id, prompt, queued, variant, image_ids)
182 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
183 ON CONFLICT(task_attempt_id, draft_type) DO UPDATE SET
184 retry_process_id = excluded.retry_process_id,
185 prompt = excluded.prompt,
186 queued = excluded.queued,
187 variant = excluded.variant,
188 image_ids = excluded.image_ids,
189 version = drafts.version + 1
190 RETURNING
191 id as "id!: Uuid",
192 task_attempt_id as "task_attempt_id!: Uuid",
193 draft_type,
194 retry_process_id as "retry_process_id?: Uuid",
195 prompt,
196 queued as "queued!: bool",
197 sending as "sending!: bool",
198 variant,
199 image_ids,
200 created_at as "created_at!: DateTime<Utc>",
201 updated_at as "updated_at!: DateTime<Utc>",
202 version as "version!: i64""#,
203 id,
204 data.task_attempt_id,
205 draft_type_str,
206 data.retry_process_id,
207 prompt,
208 data.queued,
209 variant,
210 image_ids_json
211 )
212 .fetch_one(pool)
213 .await
214 .map(Draft::from)
215 }
216
217 pub async fn clear_after_send(
218 pool: &SqlitePool,
219 task_attempt_id: Uuid,
220 draft_type: DraftType,
221 ) -> Result<(), sqlx::Error> {
222 let draft_type_str = draft_type.as_str();
223
224 match draft_type {
225 DraftType::FollowUp => {
226 sqlx::query(
228 r#"UPDATE drafts
229 SET prompt = '', queued = 0, sending = 0, image_ids = NULL, updated_at = CURRENT_TIMESTAMP, version = version + 1
230 WHERE task_attempt_id = ? AND draft_type = ?"#,
231 )
232 .bind(task_attempt_id)
233 .bind(draft_type_str)
234 .execute(pool)
235 .await?;
236 }
237 DraftType::Retry => {
238 Self::delete_by_task_attempt_and_type(pool, task_attempt_id, draft_type).await?;
240 }
241 }
242 Ok(())
243 }
244
245 pub async fn delete_by_task_attempt_and_type(
246 pool: &SqlitePool,
247 task_attempt_id: Uuid,
248 draft_type: DraftType,
249 ) -> Result<(), sqlx::Error> {
250 sqlx::query(r#"DELETE FROM drafts WHERE task_attempt_id = ? AND draft_type = ?"#)
251 .bind(task_attempt_id)
252 .bind(draft_type.as_str())
253 .execute(pool)
254 .await?;
255
256 Ok(())
257 }
258
259 pub async fn try_mark_sending(
262 pool: &SqlitePool,
263 task_attempt_id: Uuid,
264 draft_type: DraftType,
265 ) -> Result<bool, sqlx::Error> {
266 let draft_type_str = draft_type.as_str();
267 let result = sqlx::query(
268 r#"UPDATE drafts
269 SET sending = 1, updated_at = CURRENT_TIMESTAMP, version = version + 1
270 WHERE task_attempt_id = ?
271 AND draft_type = ?
272 AND queued = 1
273 AND sending = 0
274 AND TRIM(prompt) != ''"#,
275 )
276 .bind(task_attempt_id)
277 .bind(draft_type_str)
278 .execute(pool)
279 .await?;
280
281 Ok(result.rows_affected() > 0)
282 }
283
284 pub async fn update_partial(
287 pool: &SqlitePool,
288 task_attempt_id: Uuid,
289 draft_type: DraftType,
290 prompt: Option<String>,
291 variant: Option<Option<String>>,
292 image_ids: Option<Vec<Uuid>>,
293 retry_process_id: Option<Uuid>,
294 ) -> Result<(), sqlx::Error> {
295 if retry_process_id.is_none()
296 && prompt.is_none()
297 && variant.is_none()
298 && image_ids.is_none()
299 {
300 return Ok(());
301 }
302 let mut query = QueryBuilder::<Sqlite>::new("UPDATE drafts SET ");
303
304 let mut separated = query.separated(", ");
305 if let Some(rpid) = retry_process_id {
306 separated.push("retry_process_id = ");
307 separated.push_bind_unseparated(rpid);
308 }
309 if let Some(p) = prompt {
310 separated.push("prompt = ");
311 separated.push_bind_unseparated(p);
312 }
313 if let Some(v_opt) = variant {
314 separated.push("variant = ");
315 match v_opt {
316 Some(v) => separated.push_bind_unseparated(v),
317 None => separated.push_bind_unseparated(Option::<String>::None),
318 };
319 }
320 if let Some(ids) = image_ids {
321 let image_ids_json = serde_json::to_string(&ids).unwrap_or_else(|_| "[]".to_string());
322 separated.push("image_ids = ");
323 separated.push_bind_unseparated(image_ids_json);
324 }
325 separated.push("updated_at = CURRENT_TIMESTAMP");
326 separated.push("version = version + 1");
327
328 query.push(" WHERE task_attempt_id = ");
329 query.push_bind(task_attempt_id);
330 query.push(" AND draft_type = ");
331 query.push_bind(draft_type.as_str());
332 query.build().execute(pool).await?;
333 Ok(())
334 }
335
336 pub async fn set_queued(
338 pool: &SqlitePool,
339 task_attempt_id: Uuid,
340 draft_type: DraftType,
341 queued: bool,
342 expected_queued: Option<bool>,
343 expected_version: Option<i64>,
344 ) -> Result<u64, sqlx::Error> {
345 let result = sqlx::query(
346 r#"UPDATE drafts
347 SET queued = CASE
348 WHEN ?1 THEN (TRIM(prompt) <> '')
349 ELSE 0
350 END,
351 updated_at = CURRENT_TIMESTAMP,
352 version = version + 1
353 WHERE task_attempt_id = ?2
354 AND draft_type = ?3
355 AND (?4 IS NULL OR queued = ?4)
356 AND (?5 IS NULL OR version = ?5)"#,
357 )
358 .bind(queued as i64)
359 .bind(task_attempt_id)
360 .bind(draft_type.as_str())
361 .bind(expected_queued.map(|value| value as i64))
362 .bind(expected_version)
363 .execute(pool)
364 .await?;
365
366 Ok(result.rows_affected())
367 }
368}