Skip to main content

dragoon_server/
tasks_repo.rs

1//! Tasks: ID generation, CRUD, state machine, per-worker monotonic seq.
2//! Mirrors `python/.../server/tasks_repo.py`.
3
4use std::collections::{HashMap, HashSet};
5
6use anyhow::Result;
7use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
8use chrono::{DateTime, Utc};
9use rand::{rngs::OsRng, RngCore};
10use rusqlite::{params, Connection, OptionalExtension, Row};
11
12use dragoon_proto::models::{Artifact, Task, TaskKind, TaskLimits, TaskState};
13
14fn iso(dt: DateTime<Utc>) -> String {
15    dt.to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
16}
17
18fn parse_iso(s: &str) -> anyhow::Result<DateTime<Utc>> {
19    let s = if let Some(stripped) = s.strip_suffix('Z') {
20        format!("{stripped}+00:00")
21    } else {
22        s.to_owned()
23    };
24    Ok(DateTime::parse_from_rfc3339(&s)?.with_timezone(&Utc))
25}
26
27/// Generate a fresh task id: `tsk_` + 16 random url-safe base64 bytes.
28pub fn new_task_id() -> String {
29    let mut bytes = [0u8; 16];
30    OsRng.fill_bytes(&mut bytes);
31    format!("tsk_{}", URL_SAFE_NO_PAD.encode(bytes))
32}
33
34// --------------------------------------------------------------------------
35// State machine (design ยง4.2)
36// --------------------------------------------------------------------------
37
38fn allowed_transitions() -> HashMap<TaskState, HashSet<TaskState>> {
39    use TaskState::*;
40    let mut m: HashMap<TaskState, HashSet<TaskState>> = HashMap::new();
41    m.insert(Queued, [Running, Cancelling, Cancelled].into_iter().collect());
42    m.insert(
43        Running,
44        [Completed, Failed, Timeout, Cancelling].into_iter().collect(),
45    );
46    m.insert(
47        Cancelling,
48        [Cancelled, Failed, Completed].into_iter().collect(),
49    );
50    m.insert(Completed, HashSet::new());
51    m.insert(Failed, HashSet::new());
52    m.insert(Timeout, HashSet::new());
53    m.insert(Cancelled, HashSet::new());
54    m
55}
56
57pub fn is_terminal(s: TaskState) -> bool {
58    matches!(
59        s,
60        TaskState::Completed | TaskState::Failed | TaskState::Timeout | TaskState::Cancelled
61    )
62}
63
64pub fn can_transition(src: TaskState, dst: TaskState) -> bool {
65    allowed_transitions()
66        .get(&src)
67        .is_some_and(|set| set.contains(&dst))
68}
69
70// --------------------------------------------------------------------------
71// SQL row -> Task
72// --------------------------------------------------------------------------
73
74fn task_state_from_str(s: &str) -> anyhow::Result<TaskState> {
75    Ok(match s {
76        "QUEUED" => TaskState::Queued,
77        "RUNNING" => TaskState::Running,
78        "COMPLETED" => TaskState::Completed,
79        "FAILED" => TaskState::Failed,
80        "TIMEOUT" => TaskState::Timeout,
81        "CANCELLING" => TaskState::Cancelling,
82        "CANCELLED" => TaskState::Cancelled,
83        other => anyhow::bail!("unknown task state {other}"),
84    })
85}
86
87fn task_state_str(s: TaskState) -> &'static str {
88    match s {
89        TaskState::Queued => "QUEUED",
90        TaskState::Running => "RUNNING",
91        TaskState::Completed => "COMPLETED",
92        TaskState::Failed => "FAILED",
93        TaskState::Timeout => "TIMEOUT",
94        TaskState::Cancelling => "CANCELLING",
95        TaskState::Cancelled => "CANCELLED",
96    }
97}
98
99fn task_kind_from_str(s: &str) -> anyhow::Result<TaskKind> {
100    Ok(match s {
101        "command" => TaskKind::Command,
102        "script" => TaskKind::Script,
103        "fetch" => TaskKind::Fetch,
104        other => anyhow::bail!("unknown task kind {other}"),
105    })
106}
107
108fn row_to_task(conn: &Connection, r: &Row<'_>) -> anyhow::Result<Task> {
109    let task_id: String = r.get("task_id")?;
110    let collect_json: String = r.get("collect_json")?;
111    let limits_json: String = r.get("limits_json")?;
112    let state_s: String = r.get("state")?;
113    let submitted_at: String = r.get("submitted_at")?;
114    let started_at: Option<String> = r.get("started_at")?;
115    let finished_at: Option<String> = r.get("finished_at")?;
116    let kind_s: String = r.get("kind")?;
117
118    let mut artifacts = Vec::new();
119    let mut stmt = conn.prepare(
120        "SELECT path, size, sha256 FROM artifacts WHERE task_id=? ORDER BY id ASC",
121    )?;
122    for art in stmt.query_map([&task_id], |ar| {
123        Ok(Artifact {
124            path: ar.get(0)?,
125            size: ar.get::<_, i64>(1)? as u64,
126            sha256: ar.get(2)?,
127        })
128    })? {
129        artifacts.push(art?);
130    }
131
132    Ok(Task {
133        task_id: task_id.clone(),
134        worker_name: r.get("worker_name")?,
135        submitter: r.get("submitter")?,
136        kind: task_kind_from_str(&kind_s)?,
137        payload: r.get("payload")?,
138        collect: serde_json::from_str(&collect_json)?,
139        limits: serde_json::from_str(&limits_json)?,
140        state: task_state_from_str(&state_s)?,
141        submitted_at: parse_iso(&submitted_at)?,
142        started_at: started_at.as_deref().map(parse_iso).transpose()?,
143        finished_at: finished_at.as_deref().map(parse_iso).transpose()?,
144        exit_code: r.get("exit_code")?,
145        final_pwd: r.get("final_pwd")?,
146        artifacts,
147        error: r.get("error")?,
148        fetch_path: r.get("fetch_path")?,
149        worker_seq: r.get("worker_seq")?,
150    })
151}
152
153fn next_worker_seq(conn: &Connection, worker_name: &str) -> Result<i64> {
154    let m: Option<i64> = conn
155        .query_row(
156            "SELECT COALESCE(MAX(worker_seq), 0) FROM tasks WHERE worker_name=?",
157            [worker_name],
158            |r| r.get(0),
159        )
160        .optional()?;
161    Ok(m.unwrap_or(0) + 1)
162}
163
164#[allow(clippy::too_many_arguments)]
165pub fn insert_task(
166    conn: &Connection,
167    task_id: &str,
168    worker_name: &str,
169    submitter: &str,
170    kind: TaskKind,
171    payload: &str,
172    collect: &[String],
173    limits: &TaskLimits,
174    fetch_path: Option<&str>,
175) -> Result<Task> {
176    let submitted = Utc::now();
177    let seq = next_worker_seq(conn, worker_name)?;
178    conn.execute(
179        "INSERT INTO tasks
180            (task_id, worker_name, submitter, kind, payload, collect_json, limits_json,
181             state, submitted_at, fetch_path, last_access_at, worker_seq)
182         VALUES (?,?,?,?,?,?,?,?,?,?,?,?)",
183        params![
184            task_id,
185            worker_name,
186            submitter,
187            match kind {
188                TaskKind::Command => "command",
189                TaskKind::Script => "script",
190                TaskKind::Fetch => "fetch",
191            },
192            payload,
193            serde_json::to_string(collect)?,
194            serde_json::to_string(limits)?,
195            "QUEUED",
196            iso(submitted),
197            fetch_path,
198            iso(submitted),
199            seq,
200        ],
201    )?;
202    Ok(get_task(conn, task_id)?.expect("just inserted"))
203}
204
205pub fn get_task(conn: &Connection, task_id: &str) -> Result<Option<Task>> {
206    let row: Option<Task> = conn
207        .prepare("SELECT * FROM tasks WHERE task_id=?")?
208        .query_row([task_id], |r| {
209            row_to_task(conn, r).map_err(|e| {
210                rusqlite::Error::FromSqlConversionFailure(
211                    0,
212                    rusqlite::types::Type::Text,
213                    Box::new(std::io::Error::new(
214                        std::io::ErrorKind::InvalidData,
215                        e.to_string(),
216                    )),
217                )
218            })
219        })
220        .optional()?;
221    Ok(row)
222}
223
224/// FIFO peek: lowest worker_seq among QUEUED tasks for `worker_name`.
225pub fn next_queued_for_worker(conn: &Connection, worker_name: &str) -> Result<Option<Task>> {
226    conn.prepare(
227        "SELECT * FROM tasks WHERE worker_name=? AND state=? ORDER BY worker_seq ASC LIMIT 1",
228    )?
229    .query_row(params![worker_name, "QUEUED"], |r| {
230        row_to_task(conn, r).map_err(|e| {
231            rusqlite::Error::FromSqlConversionFailure(
232                0,
233                rusqlite::types::Type::Text,
234                Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())),
235            )
236        })
237    })
238    .optional()
239    .map_err(Into::into)
240}
241
242#[derive(Default, Debug, Clone)]
243pub struct TransitionUpdate {
244    pub started_at: Option<DateTime<Utc>>,
245    pub finished_at: Option<DateTime<Utc>>,
246    pub exit_code: Option<i32>,
247    pub final_pwd: Option<String>,
248    pub error: Option<String>,
249}
250
251pub fn transition(
252    conn: &Connection,
253    task_id: &str,
254    new_state: TaskState,
255    update: TransitionUpdate,
256) -> Result<Task> {
257    let cur = get_task(conn, task_id)?
258        .ok_or_else(|| anyhow::anyhow!("task {task_id} not found"))?;
259    if !can_transition(cur.state, new_state) {
260        anyhow::bail!(
261            "cannot transition {} -> {}",
262            task_state_str(cur.state),
263            task_state_str(new_state)
264        );
265    }
266
267    let mut sets: Vec<&str> = vec!["state=?"];
268    let mut vals: Vec<rusqlite::types::Value> =
269        vec![rusqlite::types::Value::Text(task_state_str(new_state).into())];
270
271    if let Some(ts) = update.started_at {
272        sets.push("started_at=?");
273        vals.push(rusqlite::types::Value::Text(iso(ts)));
274    }
275    if let Some(ts) = update.finished_at {
276        sets.push("finished_at=?");
277        vals.push(rusqlite::types::Value::Text(iso(ts)));
278    }
279    if let Some(code) = update.exit_code {
280        sets.push("exit_code=?");
281        vals.push(rusqlite::types::Value::Integer(code.into()));
282    }
283    if let Some(pwd) = update.final_pwd {
284        sets.push("final_pwd=?");
285        vals.push(rusqlite::types::Value::Text(pwd));
286    }
287    if let Some(err) = update.error {
288        sets.push("error=?");
289        vals.push(rusqlite::types::Value::Text(err));
290    }
291    sets.push("last_access_at=?");
292    vals.push(rusqlite::types::Value::Text(iso(Utc::now())));
293
294    let sql = format!(
295        "UPDATE tasks SET {} WHERE task_id=?",
296        sets.join(", ")
297    );
298    vals.push(rusqlite::types::Value::Text(task_id.into()));
299    conn.execute(&sql, rusqlite::params_from_iter(vals.iter()))?;
300    Ok(get_task(conn, task_id)?.expect("present after update"))
301}
302
303pub fn request_cancel(conn: &Connection, task_id: &str) -> Result<Task> {
304    let cur = get_task(conn, task_id)?
305        .ok_or_else(|| anyhow::anyhow!("task {task_id} not found"))?;
306    if is_terminal(cur.state) {
307        return Ok(cur);
308    }
309    conn.execute(
310        "UPDATE tasks SET cancel_requested=1 WHERE task_id=?",
311        [task_id],
312    )?;
313    if cur.state == TaskState::Queued {
314        return transition(
315            conn,
316            task_id,
317            TaskState::Cancelled,
318            TransitionUpdate {
319                finished_at: Some(Utc::now()),
320                error: Some("cancelled_before_start".into()),
321                ..Default::default()
322            },
323        );
324    }
325    if cur.state == TaskState::Running {
326        return transition(
327            conn,
328            task_id,
329            TaskState::Cancelling,
330            TransitionUpdate::default(),
331        );
332    }
333    Ok(cur)
334}
335
336pub fn consume_cancel_signal(conn: &Connection, task_id: &str) -> Result<bool> {
337    let r: Option<i64> = conn
338        .query_row(
339            "SELECT cancel_requested FROM tasks WHERE task_id=?",
340            [task_id],
341            |r| r.get(0),
342        )
343        .optional()?;
344    Ok(r.unwrap_or(0) != 0)
345}
346
347pub fn add_artifact(
348    conn: &Connection,
349    task_id: &str,
350    artifact: &Artifact,
351    blob_path: &str,
352) -> Result<()> {
353    conn.execute(
354        "INSERT INTO artifacts (task_id, path, size, sha256, blob_path) VALUES (?,?,?,?,?)",
355        params![task_id, artifact.path, artifact.size as i64, artifact.sha256, blob_path],
356    )?;
357    Ok(())
358}
359
360pub fn touch_access(conn: &Connection, task_id: &str) -> Result<()> {
361    conn.execute(
362        "UPDATE tasks SET last_access_at=? WHERE task_id=?",
363        params![iso(Utc::now()), task_id],
364    )?;
365    Ok(())
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    fn fresh() -> Connection {
373        let c = crate::db::connect_in_memory().unwrap();
374        crate::db::bootstrap(&c).unwrap();
375        c
376    }
377
378    fn insert(conn: &Connection, name: &str, payload: &str, id: &str) -> Task {
379        insert_task(
380            conn,
381            id,
382            name,
383            "alice",
384            TaskKind::Command,
385            payload,
386            &[],
387            &TaskLimits::default(),
388            None,
389        )
390        .unwrap()
391    }
392
393    #[test]
394    fn legal_transitions() {
395        assert!(can_transition(TaskState::Queued, TaskState::Running));
396        assert!(can_transition(TaskState::Running, TaskState::Completed));
397        assert!(can_transition(TaskState::Running, TaskState::Cancelling));
398        assert!(can_transition(TaskState::Cancelling, TaskState::Cancelled));
399        assert!(!can_transition(TaskState::Queued, TaskState::Completed));
400        assert!(!can_transition(TaskState::Completed, TaskState::Running));
401    }
402
403    #[test]
404    fn worker_seq_strictly_increasing() {
405        let c = fresh();
406        let a = insert(&c, "w1", "x", "a");
407        let b = insert(&c, "w1", "x", "b");
408        let c2 = insert(&c, "w1", "x", "c");
409        assert_eq!((a.worker_seq, b.worker_seq, c2.worker_seq), (1, 2, 3));
410        // different worker keeps its own counter
411        let other = insert(&c, "w2", "x", "z");
412        assert_eq!(other.worker_seq, 1);
413    }
414
415    #[test]
416    fn next_queued_orders_by_seq() {
417        let c = fresh();
418        insert(&c, "w1", "first", "a");
419        insert(&c, "w1", "second", "b");
420        let nxt = next_queued_for_worker(&c, "w1").unwrap().unwrap();
421        assert_eq!(nxt.task_id, "a");
422    }
423
424    #[test]
425    fn transition_invalid_rejected() {
426        let c = fresh();
427        insert(&c, "w", "x", "t");
428        let r = transition(&c, "t", TaskState::Completed, Default::default());
429        assert!(r.is_err());
430    }
431
432    #[test]
433    fn request_cancel_queued_terminates_immediately() {
434        let c = fresh();
435        insert(&c, "w", "x", "t");
436        let t = request_cancel(&c, "t").unwrap();
437        assert_eq!(t.state, TaskState::Cancelled);
438    }
439
440    #[test]
441    fn request_cancel_running_goes_cancelling() {
442        let c = fresh();
443        insert(&c, "w", "x", "t");
444        let _ = transition(
445            &c,
446            "t",
447            TaskState::Running,
448            TransitionUpdate {
449                started_at: Some(Utc::now()),
450                ..Default::default()
451            },
452        )
453        .unwrap();
454        let t = request_cancel(&c, "t").unwrap();
455        assert_eq!(t.state, TaskState::Cancelling);
456        assert!(consume_cancel_signal(&c, "t").unwrap());
457    }
458
459    #[test]
460    fn add_artifact_round_trip() {
461        let c = fresh();
462        insert(&c, "w", "x", "t");
463        let a = Artifact {
464            path: "outputs/a.log".into(),
465            size: 10,
466            sha256: "ab".repeat(32),
467        };
468        add_artifact(&c, "t", &a, "blobs/t/artifacts/outputs/a.log").unwrap();
469        let got = get_task(&c, "t").unwrap().unwrap();
470        assert_eq!(got.artifacts.len(), 1);
471        assert_eq!(got.artifacts[0], a);
472    }
473
474    #[test]
475    fn task_id_format() {
476        let id = new_task_id();
477        assert!(id.starts_with("tsk_"));
478        // 16 bytes -> 22 base64-no-pad chars
479        assert_eq!(id.len(), "tsk_".len() + 22);
480    }
481}