Skip to main content

rover/storage/
tasks.rs

1//! `tasks` table async API.
2//!
3//! Sibling of `storage::pages`/`storage::robots`: a thin async wrapper that
4//! hops into the `tokio-rusqlite` actor for SQLite work. Timestamps are
5//! epoch milliseconds (sub-second ordering matters for the event stream).
6//!
7//! Helpers covering `task_events` live in `storage::events`.
8
9use rusqlite::{OptionalExtension, params};
10
11use super::{Db, StorageError, StringErr};
12
13/// Build a `StorageError` for an unknown enum text decoded from SQLite.
14///
15/// `tokio_rusqlite` 0.7's `Error` enum has no `Other` variant (only
16/// `ConnectionClosed`, `Close`, `Error(rusqlite::Error)`), so we wrap a
17/// synthetic `rusqlite::Error::FromSqlConversionFailure` — semantically
18/// correct here since we're failing to map a SQL text value back to a Rust
19/// enum. Matches the pattern used by `storage::robots::RobotsState`.
20fn unknown_enum_err(column_index: usize, message: String) -> StorageError {
21    StorageError::Backend(tokio_rusqlite::Error::Error(
22        rusqlite::Error::FromSqlConversionFailure(
23            column_index,
24            rusqlite::types::Type::Text,
25            Box::new(StringErr(message)),
26        ),
27    ))
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum TaskKind {
32    BatchFetch,
33    Retry,
34    Revalidate,
35    Summarize,
36}
37
38impl TaskKind {
39    pub fn as_str(self) -> &'static str {
40        match self {
41            Self::BatchFetch => "batch_fetch",
42            Self::Retry => "retry",
43            Self::Revalidate => "revalidate",
44            Self::Summarize => "summarize",
45        }
46    }
47
48    pub fn from_db(s: &str) -> Result<Self, StorageError> {
49        Ok(match s {
50            "batch_fetch" => Self::BatchFetch,
51            "retry" => Self::Retry,
52            "revalidate" => Self::Revalidate,
53            "summarize" => Self::Summarize,
54            other => {
55                // Column index 1 matches the `kind` position in the SELECT
56                // projections used by `get` / `list_orphans`.
57                return Err(unknown_enum_err(1, format!("unknown tasks.kind = {other}")));
58            }
59        })
60    }
61
62    /// Whether the worker can resume from persisted progress after an
63    /// owner-PID handoff. `summarize` is not resumable per design §2.3.
64    pub fn is_resumable(self) -> bool {
65        matches!(self, Self::BatchFetch | Self::Retry | Self::Revalidate)
66    }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum TaskStatus {
71    Pending,
72    Running,
73    Completed,
74    Failed,
75    Cancelled,
76}
77
78impl TaskStatus {
79    pub fn as_str(self) -> &'static str {
80        match self {
81            Self::Pending => "pending",
82            Self::Running => "running",
83            Self::Completed => "completed",
84            Self::Failed => "failed",
85            Self::Cancelled => "cancelled",
86        }
87    }
88
89    pub fn from_db(s: &str) -> Result<Self, StorageError> {
90        Ok(match s {
91            "pending" => Self::Pending,
92            "running" => Self::Running,
93            "completed" => Self::Completed,
94            "failed" => Self::Failed,
95            "cancelled" => Self::Cancelled,
96            other => {
97                // Column index 2 matches the `status` position in the SELECT
98                // projections used by `get` / `list_orphans`.
99                return Err(unknown_enum_err(
100                    2,
101                    format!("unknown tasks.status = {other}"),
102                ));
103            }
104        })
105    }
106
107    pub fn is_terminal(self) -> bool {
108        matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
109    }
110}
111
112/// Row shape returned by query helpers.
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub struct TaskRow {
115    pub id: String,
116    pub kind: TaskKind,
117    pub status: TaskStatus,
118    pub created_at: i64,
119    pub updated_at: i64,
120    pub params_json: String,
121    pub result_json: Option<String>,
122    pub error: Option<String>,
123    pub cancellation_requested: bool,
124    pub owner_pid: Option<i64>,
125}
126
127/// Input for inserting a new task (status & timestamps set internally).
128#[derive(Debug, Clone)]
129pub struct TaskInsert {
130    pub id: String,
131    pub kind: TaskKind,
132    pub params_json: String,
133    pub owner_pid: Option<i64>,
134}
135
136pub async fn insert(db: &Db, input: TaskInsert) -> Result<(), StorageError> {
137    let TaskInsert {
138        id,
139        kind,
140        params_json,
141        owner_pid,
142    } = input;
143    let kind_s = kind.as_str().to_string();
144    let now = now_epoch_ms();
145    let id_for_notify = id.clone();
146    db.conn
147        .call(move |c| {
148            c.execute(
149                "INSERT INTO tasks
150                   (id, kind, status, created_at, updated_at, params_json,
151                    result_json, error, cancellation_requested, owner_pid)
152                 VALUES (?1, ?2, 'running', ?3, ?3, ?4, NULL, NULL, 0, ?5)",
153                params![id, kind_s, now, params_json, owner_pid],
154            )?;
155            Ok::<_, rusqlite::Error>(())
156        })
157        .await?;
158    // Notify the scheduler if a listener has been installed. This is the
159    // single source of truth for new-task dispatch — every insert site
160    // (MCP tool, fetcher SWR, deferred retry, retry chain) benefits.
161    // A poisoned mutex or a closed channel is non-fatal: the orphan scan
162    // will pick the row up eventually if the inserting process dies.
163    if let Ok(guard) = db.new_task_tx.lock()
164        && let Some(tx) = guard.as_ref()
165        && let Err(e) = tx.send(id_for_notify)
166    {
167        tracing::debug!(
168            target: "rover::storage",
169            error = ?e,
170            "new-task notify channel closed",
171        );
172    }
173    Ok(())
174}
175
176pub async fn get(db: &Db, id: &str) -> Result<Option<TaskRow>, StorageError> {
177    let id = id.to_string();
178    let row = db
179        .conn
180        .call(move |c| {
181            c.query_row(
182                "SELECT id, kind, status, created_at, updated_at, params_json,
183                        result_json, error, cancellation_requested, owner_pid
184                 FROM tasks WHERE id = ?1",
185                [&id],
186                |r| {
187                    Ok((
188                        r.get::<_, String>(0)?,
189                        r.get::<_, String>(1)?,
190                        r.get::<_, String>(2)?,
191                        r.get::<_, i64>(3)?,
192                        r.get::<_, i64>(4)?,
193                        r.get::<_, String>(5)?,
194                        r.get::<_, Option<String>>(6)?,
195                        r.get::<_, Option<String>>(7)?,
196                        r.get::<_, i64>(8)?,
197                        r.get::<_, Option<i64>>(9)?,
198                    ))
199                },
200            )
201            .optional()
202        })
203        .await?;
204    let Some((
205        id,
206        kind_s,
207        status_s,
208        created_at,
209        updated_at,
210        params_json,
211        result_json,
212        error,
213        canc,
214        owner_pid,
215    )) = row
216    else {
217        return Ok(None);
218    };
219    Ok(Some(TaskRow {
220        id,
221        kind: TaskKind::from_db(&kind_s)?,
222        status: TaskStatus::from_db(&status_s)?,
223        created_at,
224        updated_at,
225        params_json,
226        result_json,
227        error,
228        cancellation_requested: canc != 0,
229        owner_pid,
230    }))
231}
232
233pub async fn set_status(
234    db: &Db,
235    id: &str,
236    status: TaskStatus,
237    result_json: Option<String>,
238    error: Option<String>,
239) -> Result<(), StorageError> {
240    let id = id.to_string();
241    let status_s = status.as_str().to_string();
242    let now = now_epoch_ms();
243    db.conn
244        .call(move |c| {
245            c.execute(
246                "UPDATE tasks
247                    SET status = ?1, updated_at = ?2,
248                        result_json = COALESCE(?3, result_json),
249                        error = COALESCE(?4, error)
250                  WHERE id = ?5",
251                params![status_s, now, result_json, error, id],
252            )?;
253            Ok::<_, rusqlite::Error>(())
254        })
255        .await?;
256    Ok(())
257}
258
259pub async fn set_cancellation_requested(db: &Db, id: &str) -> Result<bool, StorageError> {
260    let id = id.to_string();
261    let now = now_epoch_ms();
262    let changed = db
263        .conn
264        .call(move |c| {
265            let n = c.execute(
266                "UPDATE tasks
267                    SET cancellation_requested = 1, updated_at = ?1
268                  WHERE id = ?2 AND cancellation_requested = 0",
269                params![now, id],
270            )?;
271            Ok::<_, rusqlite::Error>(n)
272        })
273        .await?;
274    Ok(changed == 1)
275}
276
277pub async fn is_cancelled(db: &Db, id: &str) -> Result<bool, StorageError> {
278    let id = id.to_string();
279    let flag = db
280        .conn
281        .call(move |c| {
282            c.query_row(
283                "SELECT cancellation_requested FROM tasks WHERE id = ?1",
284                [&id],
285                |r| r.get::<_, i64>(0),
286            )
287            .optional()
288        })
289        .await?;
290    Ok(flag.unwrap_or(0) != 0)
291}
292
293pub async fn list_orphans(db: &Db) -> Result<Vec<TaskRow>, StorageError> {
294    let rows = db
295        .conn
296        .call(|c| {
297            let mut stmt = c.prepare(
298                "SELECT id, kind, status, created_at, updated_at, params_json,
299                        result_json, error, cancellation_requested, owner_pid
300                 FROM tasks
301                 WHERE status = 'running'
302                   AND owner_pid IS NOT NULL
303                   AND owner_pid NOT IN (SELECT pid FROM servers)",
304            )?;
305            let iter = stmt.query_map([], |r| {
306                Ok((
307                    r.get::<_, String>(0)?,
308                    r.get::<_, String>(1)?,
309                    r.get::<_, String>(2)?,
310                    r.get::<_, i64>(3)?,
311                    r.get::<_, i64>(4)?,
312                    r.get::<_, String>(5)?,
313                    r.get::<_, Option<String>>(6)?,
314                    r.get::<_, Option<String>>(7)?,
315                    r.get::<_, i64>(8)?,
316                    r.get::<_, Option<i64>>(9)?,
317                ))
318            })?;
319            let mut out = Vec::new();
320            for r in iter {
321                out.push(r?);
322            }
323            Ok::<_, rusqlite::Error>(out)
324        })
325        .await?;
326    let mut tasks = Vec::with_capacity(rows.len());
327    for (
328        id,
329        kind_s,
330        status_s,
331        created_at,
332        updated_at,
333        params_json,
334        result_json,
335        error,
336        canc,
337        owner_pid,
338    ) in rows
339    {
340        tasks.push(TaskRow {
341            id,
342            kind: TaskKind::from_db(&kind_s)?,
343            status: TaskStatus::from_db(&status_s)?,
344            created_at,
345            updated_at,
346            params_json,
347            result_json,
348            error,
349            cancellation_requested: canc != 0,
350            owner_pid,
351        });
352    }
353    Ok(tasks)
354}
355
356pub async fn claim_orphan(
357    db: &Db,
358    id: &str,
359    orphan_pid: i64,
360    own_pid: i64,
361) -> Result<bool, StorageError> {
362    let id = id.to_string();
363    let now = now_epoch_ms();
364    let changed = db
365        .conn
366        .call(move |c| {
367            let n = c.execute(
368                "UPDATE tasks
369                    SET owner_pid = ?1, updated_at = ?2
370                  WHERE id = ?3 AND owner_pid = ?4 AND status = 'running'",
371                params![own_pid, now, id, orphan_pid],
372            )?;
373            Ok::<_, rusqlite::Error>(n)
374        })
375        .await?;
376    Ok(changed == 1)
377}
378
379fn now_epoch_ms() -> i64 {
380    use std::time::{SystemTime, UNIX_EPOCH};
381    SystemTime::now()
382        .duration_since(UNIX_EPOCH)
383        .map(|d| d.as_millis() as i64)
384        .unwrap_or(0)
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use tempfile::tempdir;
391
392    async fn fresh_db() -> Db {
393        let tmp = tempdir().unwrap();
394        let db = Db::open(tmp.path().join("rover.db")).await.unwrap();
395        std::mem::forget(tmp);
396        db
397    }
398
399    fn sample_insert(id: &str, pid: Option<i64>) -> TaskInsert {
400        TaskInsert {
401            id: id.into(),
402            kind: TaskKind::BatchFetch,
403            params_json: r#"{"urls":["https://a.example/"]}"#.into(),
404            owner_pid: pid,
405        }
406    }
407
408    #[tokio::test]
409    async fn insert_and_get_round_trip() {
410        let db = fresh_db().await;
411        insert(&db, sample_insert("t1", Some(7))).await.unwrap();
412        let got = get(&db, "t1").await.unwrap().expect("row missing");
413        assert_eq!(got.id, "t1");
414        assert_eq!(got.kind, TaskKind::BatchFetch);
415        assert_eq!(got.status, TaskStatus::Running);
416        assert_eq!(got.owner_pid, Some(7));
417        assert!(!got.cancellation_requested);
418    }
419
420    #[tokio::test]
421    async fn get_unknown_returns_none() {
422        let db = fresh_db().await;
423        assert!(get(&db, "nope").await.unwrap().is_none());
424    }
425
426    #[tokio::test]
427    async fn set_status_terminal_writes_result_and_error() {
428        let db = fresh_db().await;
429        insert(&db, sample_insert("t1", Some(7))).await.unwrap();
430        set_status(
431            &db,
432            "t1",
433            TaskStatus::Failed,
434            None,
435            Some("owner_died".into()),
436        )
437        .await
438        .unwrap();
439        let got = get(&db, "t1").await.unwrap().unwrap();
440        assert_eq!(got.status, TaskStatus::Failed);
441        assert_eq!(got.error.as_deref(), Some("owner_died"));
442    }
443
444    #[tokio::test]
445    async fn set_cancellation_requested_is_idempotent() {
446        let db = fresh_db().await;
447        insert(&db, sample_insert("t1", Some(7))).await.unwrap();
448        let first = set_cancellation_requested(&db, "t1").await.unwrap();
449        let second = set_cancellation_requested(&db, "t1").await.unwrap();
450        assert!(first);
451        assert!(!second, "second call should be a no-op");
452        assert!(is_cancelled(&db, "t1").await.unwrap());
453    }
454
455    #[tokio::test]
456    async fn set_cancellation_requested_on_missing_id_returns_false() {
457        let db = fresh_db().await;
458        assert!(!set_cancellation_requested(&db, "ghost").await.unwrap());
459    }
460
461    #[tokio::test]
462    async fn list_orphans_excludes_live_pids() {
463        let db = fresh_db().await;
464        db.upsert_server_self(100, "v".into()).await.unwrap();
465        insert(&db, sample_insert("live", Some(100))).await.unwrap();
466        insert(&db, sample_insert("dead", Some(999))).await.unwrap();
467        let orphans = list_orphans(&db).await.unwrap();
468        let ids: Vec<&str> = orphans.iter().map(|t| t.id.as_str()).collect();
469        assert_eq!(ids, vec!["dead"]);
470    }
471
472    #[tokio::test]
473    async fn list_orphans_excludes_terminal_tasks() {
474        let db = fresh_db().await;
475        insert(&db, sample_insert("dead_done", Some(999)))
476            .await
477            .unwrap();
478        set_status(&db, "dead_done", TaskStatus::Completed, None, None)
479            .await
480            .unwrap();
481        let orphans = list_orphans(&db).await.unwrap();
482        assert!(
483            orphans.is_empty(),
484            "completed orphan should not appear: {orphans:?}",
485        );
486    }
487
488    #[tokio::test]
489    async fn claim_orphan_cas_wins_then_loses() {
490        let db = fresh_db().await;
491        insert(&db, sample_insert("orphan", Some(999)))
492            .await
493            .unwrap();
494        let first = claim_orphan(&db, "orphan", 999, 1).await.unwrap();
495        let second = claim_orphan(&db, "orphan", 999, 2).await.unwrap();
496        assert!(first, "first claimer should win");
497        assert!(!second, "second claimer should lose");
498        let got = get(&db, "orphan").await.unwrap().unwrap();
499        assert_eq!(got.owner_pid, Some(1));
500    }
501}