1use rusqlite::{OptionalExtension, params};
10
11use super::{Db, StorageError, StringErr};
12
13fn unknown_enum_err(column_index: usize, message: String) -> StorageError {
21 StorageError::Backend(tokio_rusqlite::Error::Error(
22 rusqlite::Error::FromSqlConversionFailure(
23 column_index,
24 rusqlite::types::Type::Text,
25 Box::new(StringErr(message)),
26 ),
27 ))
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum TaskKind {
32 BatchFetch,
33 Retry,
34 Revalidate,
35 Summarize,
36}
37
38impl TaskKind {
39 pub fn as_str(self) -> &'static str {
40 match self {
41 Self::BatchFetch => "batch_fetch",
42 Self::Retry => "retry",
43 Self::Revalidate => "revalidate",
44 Self::Summarize => "summarize",
45 }
46 }
47
48 pub fn from_db(s: &str) -> Result<Self, StorageError> {
49 Ok(match s {
50 "batch_fetch" => Self::BatchFetch,
51 "retry" => Self::Retry,
52 "revalidate" => Self::Revalidate,
53 "summarize" => Self::Summarize,
54 other => {
55 return Err(unknown_enum_err(1, format!("unknown tasks.kind = {other}")));
58 }
59 })
60 }
61
62 pub fn is_resumable(self) -> bool {
65 matches!(self, Self::BatchFetch | Self::Retry | Self::Revalidate)
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum TaskStatus {
71 Pending,
72 Running,
73 Completed,
74 Failed,
75 Cancelled,
76}
77
78impl TaskStatus {
79 pub fn as_str(self) -> &'static str {
80 match self {
81 Self::Pending => "pending",
82 Self::Running => "running",
83 Self::Completed => "completed",
84 Self::Failed => "failed",
85 Self::Cancelled => "cancelled",
86 }
87 }
88
89 pub fn from_db(s: &str) -> Result<Self, StorageError> {
90 Ok(match s {
91 "pending" => Self::Pending,
92 "running" => Self::Running,
93 "completed" => Self::Completed,
94 "failed" => Self::Failed,
95 "cancelled" => Self::Cancelled,
96 other => {
97 return Err(unknown_enum_err(
100 2,
101 format!("unknown tasks.status = {other}"),
102 ));
103 }
104 })
105 }
106
107 pub fn is_terminal(self) -> bool {
108 matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq)]
114pub struct TaskRow {
115 pub id: String,
116 pub kind: TaskKind,
117 pub status: TaskStatus,
118 pub created_at: i64,
119 pub updated_at: i64,
120 pub params_json: String,
121 pub result_json: Option<String>,
122 pub error: Option<String>,
123 pub cancellation_requested: bool,
124 pub owner_pid: Option<i64>,
125}
126
127#[derive(Debug, Clone)]
129pub struct TaskInsert {
130 pub id: String,
131 pub kind: TaskKind,
132 pub params_json: String,
133 pub owner_pid: Option<i64>,
134}
135
136pub async fn insert(db: &Db, input: TaskInsert) -> Result<(), StorageError> {
137 let TaskInsert {
138 id,
139 kind,
140 params_json,
141 owner_pid,
142 } = input;
143 let kind_s = kind.as_str().to_string();
144 let now = now_epoch_ms();
145 let id_for_notify = id.clone();
146 db.conn
147 .call(move |c| {
148 c.execute(
149 "INSERT INTO tasks
150 (id, kind, status, created_at, updated_at, params_json,
151 result_json, error, cancellation_requested, owner_pid)
152 VALUES (?1, ?2, 'running', ?3, ?3, ?4, NULL, NULL, 0, ?5)",
153 params![id, kind_s, now, params_json, owner_pid],
154 )?;
155 Ok::<_, rusqlite::Error>(())
156 })
157 .await?;
158 if let Ok(guard) = db.new_task_tx.lock()
164 && let Some(tx) = guard.as_ref()
165 && let Err(e) = tx.send(id_for_notify)
166 {
167 tracing::debug!(
168 target: "rover::storage",
169 error = ?e,
170 "new-task notify channel closed",
171 );
172 }
173 Ok(())
174}
175
176pub async fn get(db: &Db, id: &str) -> Result<Option<TaskRow>, StorageError> {
177 let id = id.to_string();
178 let row = db
179 .conn
180 .call(move |c| {
181 c.query_row(
182 "SELECT id, kind, status, created_at, updated_at, params_json,
183 result_json, error, cancellation_requested, owner_pid
184 FROM tasks WHERE id = ?1",
185 [&id],
186 |r| {
187 Ok((
188 r.get::<_, String>(0)?,
189 r.get::<_, String>(1)?,
190 r.get::<_, String>(2)?,
191 r.get::<_, i64>(3)?,
192 r.get::<_, i64>(4)?,
193 r.get::<_, String>(5)?,
194 r.get::<_, Option<String>>(6)?,
195 r.get::<_, Option<String>>(7)?,
196 r.get::<_, i64>(8)?,
197 r.get::<_, Option<i64>>(9)?,
198 ))
199 },
200 )
201 .optional()
202 })
203 .await?;
204 let Some((
205 id,
206 kind_s,
207 status_s,
208 created_at,
209 updated_at,
210 params_json,
211 result_json,
212 error,
213 canc,
214 owner_pid,
215 )) = row
216 else {
217 return Ok(None);
218 };
219 Ok(Some(TaskRow {
220 id,
221 kind: TaskKind::from_db(&kind_s)?,
222 status: TaskStatus::from_db(&status_s)?,
223 created_at,
224 updated_at,
225 params_json,
226 result_json,
227 error,
228 cancellation_requested: canc != 0,
229 owner_pid,
230 }))
231}
232
233pub async fn set_status(
234 db: &Db,
235 id: &str,
236 status: TaskStatus,
237 result_json: Option<String>,
238 error: Option<String>,
239) -> Result<(), StorageError> {
240 let id = id.to_string();
241 let status_s = status.as_str().to_string();
242 let now = now_epoch_ms();
243 db.conn
244 .call(move |c| {
245 c.execute(
246 "UPDATE tasks
247 SET status = ?1, updated_at = ?2,
248 result_json = COALESCE(?3, result_json),
249 error = COALESCE(?4, error)
250 WHERE id = ?5",
251 params![status_s, now, result_json, error, id],
252 )?;
253 Ok::<_, rusqlite::Error>(())
254 })
255 .await?;
256 Ok(())
257}
258
259pub async fn set_cancellation_requested(db: &Db, id: &str) -> Result<bool, StorageError> {
260 let id = id.to_string();
261 let now = now_epoch_ms();
262 let changed = db
263 .conn
264 .call(move |c| {
265 let n = c.execute(
266 "UPDATE tasks
267 SET cancellation_requested = 1, updated_at = ?1
268 WHERE id = ?2 AND cancellation_requested = 0",
269 params![now, id],
270 )?;
271 Ok::<_, rusqlite::Error>(n)
272 })
273 .await?;
274 Ok(changed == 1)
275}
276
277pub async fn is_cancelled(db: &Db, id: &str) -> Result<bool, StorageError> {
278 let id = id.to_string();
279 let flag = db
280 .conn
281 .call(move |c| {
282 c.query_row(
283 "SELECT cancellation_requested FROM tasks WHERE id = ?1",
284 [&id],
285 |r| r.get::<_, i64>(0),
286 )
287 .optional()
288 })
289 .await?;
290 Ok(flag.unwrap_or(0) != 0)
291}
292
293pub async fn list_orphans(db: &Db) -> Result<Vec<TaskRow>, StorageError> {
294 let rows = db
295 .conn
296 .call(|c| {
297 let mut stmt = c.prepare(
298 "SELECT id, kind, status, created_at, updated_at, params_json,
299 result_json, error, cancellation_requested, owner_pid
300 FROM tasks
301 WHERE status = 'running'
302 AND owner_pid IS NOT NULL
303 AND owner_pid NOT IN (SELECT pid FROM servers)",
304 )?;
305 let iter = stmt.query_map([], |r| {
306 Ok((
307 r.get::<_, String>(0)?,
308 r.get::<_, String>(1)?,
309 r.get::<_, String>(2)?,
310 r.get::<_, i64>(3)?,
311 r.get::<_, i64>(4)?,
312 r.get::<_, String>(5)?,
313 r.get::<_, Option<String>>(6)?,
314 r.get::<_, Option<String>>(7)?,
315 r.get::<_, i64>(8)?,
316 r.get::<_, Option<i64>>(9)?,
317 ))
318 })?;
319 let mut out = Vec::new();
320 for r in iter {
321 out.push(r?);
322 }
323 Ok::<_, rusqlite::Error>(out)
324 })
325 .await?;
326 let mut tasks = Vec::with_capacity(rows.len());
327 for (
328 id,
329 kind_s,
330 status_s,
331 created_at,
332 updated_at,
333 params_json,
334 result_json,
335 error,
336 canc,
337 owner_pid,
338 ) in rows
339 {
340 tasks.push(TaskRow {
341 id,
342 kind: TaskKind::from_db(&kind_s)?,
343 status: TaskStatus::from_db(&status_s)?,
344 created_at,
345 updated_at,
346 params_json,
347 result_json,
348 error,
349 cancellation_requested: canc != 0,
350 owner_pid,
351 });
352 }
353 Ok(tasks)
354}
355
356pub async fn claim_orphan(
357 db: &Db,
358 id: &str,
359 orphan_pid: i64,
360 own_pid: i64,
361) -> Result<bool, StorageError> {
362 let id = id.to_string();
363 let now = now_epoch_ms();
364 let changed = db
365 .conn
366 .call(move |c| {
367 let n = c.execute(
368 "UPDATE tasks
369 SET owner_pid = ?1, updated_at = ?2
370 WHERE id = ?3 AND owner_pid = ?4 AND status = 'running'",
371 params![own_pid, now, id, orphan_pid],
372 )?;
373 Ok::<_, rusqlite::Error>(n)
374 })
375 .await?;
376 Ok(changed == 1)
377}
378
379fn now_epoch_ms() -> i64 {
380 use std::time::{SystemTime, UNIX_EPOCH};
381 SystemTime::now()
382 .duration_since(UNIX_EPOCH)
383 .map(|d| d.as_millis() as i64)
384 .unwrap_or(0)
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use tempfile::tempdir;
391
392 async fn fresh_db() -> Db {
393 let tmp = tempdir().unwrap();
394 let db = Db::open(tmp.path().join("rover.db")).await.unwrap();
395 std::mem::forget(tmp);
396 db
397 }
398
399 fn sample_insert(id: &str, pid: Option<i64>) -> TaskInsert {
400 TaskInsert {
401 id: id.into(),
402 kind: TaskKind::BatchFetch,
403 params_json: r#"{"urls":["https://a.example/"]}"#.into(),
404 owner_pid: pid,
405 }
406 }
407
408 #[tokio::test]
409 async fn insert_and_get_round_trip() {
410 let db = fresh_db().await;
411 insert(&db, sample_insert("t1", Some(7))).await.unwrap();
412 let got = get(&db, "t1").await.unwrap().expect("row missing");
413 assert_eq!(got.id, "t1");
414 assert_eq!(got.kind, TaskKind::BatchFetch);
415 assert_eq!(got.status, TaskStatus::Running);
416 assert_eq!(got.owner_pid, Some(7));
417 assert!(!got.cancellation_requested);
418 }
419
420 #[tokio::test]
421 async fn get_unknown_returns_none() {
422 let db = fresh_db().await;
423 assert!(get(&db, "nope").await.unwrap().is_none());
424 }
425
426 #[tokio::test]
427 async fn set_status_terminal_writes_result_and_error() {
428 let db = fresh_db().await;
429 insert(&db, sample_insert("t1", Some(7))).await.unwrap();
430 set_status(
431 &db,
432 "t1",
433 TaskStatus::Failed,
434 None,
435 Some("owner_died".into()),
436 )
437 .await
438 .unwrap();
439 let got = get(&db, "t1").await.unwrap().unwrap();
440 assert_eq!(got.status, TaskStatus::Failed);
441 assert_eq!(got.error.as_deref(), Some("owner_died"));
442 }
443
444 #[tokio::test]
445 async fn set_cancellation_requested_is_idempotent() {
446 let db = fresh_db().await;
447 insert(&db, sample_insert("t1", Some(7))).await.unwrap();
448 let first = set_cancellation_requested(&db, "t1").await.unwrap();
449 let second = set_cancellation_requested(&db, "t1").await.unwrap();
450 assert!(first);
451 assert!(!second, "second call should be a no-op");
452 assert!(is_cancelled(&db, "t1").await.unwrap());
453 }
454
455 #[tokio::test]
456 async fn set_cancellation_requested_on_missing_id_returns_false() {
457 let db = fresh_db().await;
458 assert!(!set_cancellation_requested(&db, "ghost").await.unwrap());
459 }
460
461 #[tokio::test]
462 async fn list_orphans_excludes_live_pids() {
463 let db = fresh_db().await;
464 db.upsert_server_self(100, "v".into()).await.unwrap();
465 insert(&db, sample_insert("live", Some(100))).await.unwrap();
466 insert(&db, sample_insert("dead", Some(999))).await.unwrap();
467 let orphans = list_orphans(&db).await.unwrap();
468 let ids: Vec<&str> = orphans.iter().map(|t| t.id.as_str()).collect();
469 assert_eq!(ids, vec!["dead"]);
470 }
471
472 #[tokio::test]
473 async fn list_orphans_excludes_terminal_tasks() {
474 let db = fresh_db().await;
475 insert(&db, sample_insert("dead_done", Some(999)))
476 .await
477 .unwrap();
478 set_status(&db, "dead_done", TaskStatus::Completed, None, None)
479 .await
480 .unwrap();
481 let orphans = list_orphans(&db).await.unwrap();
482 assert!(
483 orphans.is_empty(),
484 "completed orphan should not appear: {orphans:?}",
485 );
486 }
487
488 #[tokio::test]
489 async fn claim_orphan_cas_wins_then_loses() {
490 let db = fresh_db().await;
491 insert(&db, sample_insert("orphan", Some(999)))
492 .await
493 .unwrap();
494 let first = claim_orphan(&db, "orphan", 999, 1).await.unwrap();
495 let second = claim_orphan(&db, "orphan", 999, 2).await.unwrap();
496 assert!(first, "first claimer should win");
497 assert!(!second, "second claimer should lose");
498 let got = get(&db, "orphan").await.unwrap().unwrap();
499 assert_eq!(got.owner_pid, Some(1));
500 }
501}