1use std::collections::{HashMap, HashSet};
5
6use anyhow::Result;
7use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
8use chrono::{DateTime, Utc};
9use rand::{rngs::OsRng, RngCore};
10use rusqlite::{params, Connection, OptionalExtension, Row};
11
12use dragoon_proto::models::{Artifact, Task, TaskKind, TaskLimits, TaskState};
13
14fn iso(dt: DateTime<Utc>) -> String {
15 dt.to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
16}
17
18fn parse_iso(s: &str) -> anyhow::Result<DateTime<Utc>> {
19 let s = if let Some(stripped) = s.strip_suffix('Z') {
20 format!("{stripped}+00:00")
21 } else {
22 s.to_owned()
23 };
24 Ok(DateTime::parse_from_rfc3339(&s)?.with_timezone(&Utc))
25}
26
27pub fn new_task_id() -> String {
29 let mut bytes = [0u8; 16];
30 OsRng.fill_bytes(&mut bytes);
31 format!("tsk_{}", URL_SAFE_NO_PAD.encode(bytes))
32}
33
34fn allowed_transitions() -> HashMap<TaskState, HashSet<TaskState>> {
39 use TaskState::*;
40 let mut m: HashMap<TaskState, HashSet<TaskState>> = HashMap::new();
41 m.insert(Queued, [Running, Cancelling, Cancelled].into_iter().collect());
42 m.insert(
43 Running,
44 [Completed, Failed, Timeout, Cancelling].into_iter().collect(),
45 );
46 m.insert(
47 Cancelling,
48 [Cancelled, Failed, Completed].into_iter().collect(),
49 );
50 m.insert(Completed, HashSet::new());
51 m.insert(Failed, HashSet::new());
52 m.insert(Timeout, HashSet::new());
53 m.insert(Cancelled, HashSet::new());
54 m
55}
56
57pub fn is_terminal(s: TaskState) -> bool {
58 matches!(
59 s,
60 TaskState::Completed | TaskState::Failed | TaskState::Timeout | TaskState::Cancelled
61 )
62}
63
64pub fn can_transition(src: TaskState, dst: TaskState) -> bool {
65 allowed_transitions()
66 .get(&src)
67 .is_some_and(|set| set.contains(&dst))
68}
69
70fn task_state_from_str(s: &str) -> anyhow::Result<TaskState> {
75 Ok(match s {
76 "QUEUED" => TaskState::Queued,
77 "RUNNING" => TaskState::Running,
78 "COMPLETED" => TaskState::Completed,
79 "FAILED" => TaskState::Failed,
80 "TIMEOUT" => TaskState::Timeout,
81 "CANCELLING" => TaskState::Cancelling,
82 "CANCELLED" => TaskState::Cancelled,
83 other => anyhow::bail!("unknown task state {other}"),
84 })
85}
86
87fn task_state_str(s: TaskState) -> &'static str {
88 match s {
89 TaskState::Queued => "QUEUED",
90 TaskState::Running => "RUNNING",
91 TaskState::Completed => "COMPLETED",
92 TaskState::Failed => "FAILED",
93 TaskState::Timeout => "TIMEOUT",
94 TaskState::Cancelling => "CANCELLING",
95 TaskState::Cancelled => "CANCELLED",
96 }
97}
98
99fn task_kind_from_str(s: &str) -> anyhow::Result<TaskKind> {
100 Ok(match s {
101 "command" => TaskKind::Command,
102 "script" => TaskKind::Script,
103 "fetch" => TaskKind::Fetch,
104 other => anyhow::bail!("unknown task kind {other}"),
105 })
106}
107
108fn row_to_task(conn: &Connection, r: &Row<'_>) -> anyhow::Result<Task> {
109 let task_id: String = r.get("task_id")?;
110 let collect_json: String = r.get("collect_json")?;
111 let limits_json: String = r.get("limits_json")?;
112 let state_s: String = r.get("state")?;
113 let submitted_at: String = r.get("submitted_at")?;
114 let started_at: Option<String> = r.get("started_at")?;
115 let finished_at: Option<String> = r.get("finished_at")?;
116 let kind_s: String = r.get("kind")?;
117
118 let mut artifacts = Vec::new();
119 let mut stmt = conn.prepare(
120 "SELECT path, size, sha256 FROM artifacts WHERE task_id=? ORDER BY id ASC",
121 )?;
122 for art in stmt.query_map([&task_id], |ar| {
123 Ok(Artifact {
124 path: ar.get(0)?,
125 size: ar.get::<_, i64>(1)? as u64,
126 sha256: ar.get(2)?,
127 })
128 })? {
129 artifacts.push(art?);
130 }
131
132 Ok(Task {
133 task_id: task_id.clone(),
134 worker_name: r.get("worker_name")?,
135 submitter: r.get("submitter")?,
136 kind: task_kind_from_str(&kind_s)?,
137 payload: r.get("payload")?,
138 collect: serde_json::from_str(&collect_json)?,
139 limits: serde_json::from_str(&limits_json)?,
140 state: task_state_from_str(&state_s)?,
141 submitted_at: parse_iso(&submitted_at)?,
142 started_at: started_at.as_deref().map(parse_iso).transpose()?,
143 finished_at: finished_at.as_deref().map(parse_iso).transpose()?,
144 exit_code: r.get("exit_code")?,
145 final_pwd: r.get("final_pwd")?,
146 artifacts,
147 error: r.get("error")?,
148 fetch_path: r.get("fetch_path")?,
149 worker_seq: r.get("worker_seq")?,
150 })
151}
152
153fn next_worker_seq(conn: &Connection, worker_name: &str) -> Result<i64> {
154 let m: Option<i64> = conn
155 .query_row(
156 "SELECT COALESCE(MAX(worker_seq), 0) FROM tasks WHERE worker_name=?",
157 [worker_name],
158 |r| r.get(0),
159 )
160 .optional()?;
161 Ok(m.unwrap_or(0) + 1)
162}
163
164#[allow(clippy::too_many_arguments)]
165pub fn insert_task(
166 conn: &Connection,
167 task_id: &str,
168 worker_name: &str,
169 submitter: &str,
170 kind: TaskKind,
171 payload: &str,
172 collect: &[String],
173 limits: &TaskLimits,
174 fetch_path: Option<&str>,
175) -> Result<Task> {
176 let submitted = Utc::now();
177 let seq = next_worker_seq(conn, worker_name)?;
178 conn.execute(
179 "INSERT INTO tasks
180 (task_id, worker_name, submitter, kind, payload, collect_json, limits_json,
181 state, submitted_at, fetch_path, last_access_at, worker_seq)
182 VALUES (?,?,?,?,?,?,?,?,?,?,?,?)",
183 params![
184 task_id,
185 worker_name,
186 submitter,
187 match kind {
188 TaskKind::Command => "command",
189 TaskKind::Script => "script",
190 TaskKind::Fetch => "fetch",
191 },
192 payload,
193 serde_json::to_string(collect)?,
194 serde_json::to_string(limits)?,
195 "QUEUED",
196 iso(submitted),
197 fetch_path,
198 iso(submitted),
199 seq,
200 ],
201 )?;
202 Ok(get_task(conn, task_id)?.expect("just inserted"))
203}
204
205pub fn get_task(conn: &Connection, task_id: &str) -> Result<Option<Task>> {
206 let row: Option<Task> = conn
207 .prepare("SELECT * FROM tasks WHERE task_id=?")?
208 .query_row([task_id], |r| {
209 row_to_task(conn, r).map_err(|e| {
210 rusqlite::Error::FromSqlConversionFailure(
211 0,
212 rusqlite::types::Type::Text,
213 Box::new(std::io::Error::new(
214 std::io::ErrorKind::InvalidData,
215 e.to_string(),
216 )),
217 )
218 })
219 })
220 .optional()?;
221 Ok(row)
222}
223
224pub fn next_queued_for_worker(conn: &Connection, worker_name: &str) -> Result<Option<Task>> {
226 conn.prepare(
227 "SELECT * FROM tasks WHERE worker_name=? AND state=? ORDER BY worker_seq ASC LIMIT 1",
228 )?
229 .query_row(params![worker_name, "QUEUED"], |r| {
230 row_to_task(conn, r).map_err(|e| {
231 rusqlite::Error::FromSqlConversionFailure(
232 0,
233 rusqlite::types::Type::Text,
234 Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())),
235 )
236 })
237 })
238 .optional()
239 .map_err(Into::into)
240}
241
242#[derive(Default, Debug, Clone)]
243pub struct TransitionUpdate {
244 pub started_at: Option<DateTime<Utc>>,
245 pub finished_at: Option<DateTime<Utc>>,
246 pub exit_code: Option<i32>,
247 pub final_pwd: Option<String>,
248 pub error: Option<String>,
249}
250
251pub fn transition(
252 conn: &Connection,
253 task_id: &str,
254 new_state: TaskState,
255 update: TransitionUpdate,
256) -> Result<Task> {
257 let cur = get_task(conn, task_id)?
258 .ok_or_else(|| anyhow::anyhow!("task {task_id} not found"))?;
259 if !can_transition(cur.state, new_state) {
260 anyhow::bail!(
261 "cannot transition {} -> {}",
262 task_state_str(cur.state),
263 task_state_str(new_state)
264 );
265 }
266
267 let mut sets: Vec<&str> = vec!["state=?"];
268 let mut vals: Vec<rusqlite::types::Value> =
269 vec![rusqlite::types::Value::Text(task_state_str(new_state).into())];
270
271 if let Some(ts) = update.started_at {
272 sets.push("started_at=?");
273 vals.push(rusqlite::types::Value::Text(iso(ts)));
274 }
275 if let Some(ts) = update.finished_at {
276 sets.push("finished_at=?");
277 vals.push(rusqlite::types::Value::Text(iso(ts)));
278 }
279 if let Some(code) = update.exit_code {
280 sets.push("exit_code=?");
281 vals.push(rusqlite::types::Value::Integer(code.into()));
282 }
283 if let Some(pwd) = update.final_pwd {
284 sets.push("final_pwd=?");
285 vals.push(rusqlite::types::Value::Text(pwd));
286 }
287 if let Some(err) = update.error {
288 sets.push("error=?");
289 vals.push(rusqlite::types::Value::Text(err));
290 }
291 sets.push("last_access_at=?");
292 vals.push(rusqlite::types::Value::Text(iso(Utc::now())));
293
294 let sql = format!(
295 "UPDATE tasks SET {} WHERE task_id=?",
296 sets.join(", ")
297 );
298 vals.push(rusqlite::types::Value::Text(task_id.into()));
299 conn.execute(&sql, rusqlite::params_from_iter(vals.iter()))?;
300 Ok(get_task(conn, task_id)?.expect("present after update"))
301}
302
303pub fn request_cancel(conn: &Connection, task_id: &str) -> Result<Task> {
304 let cur = get_task(conn, task_id)?
305 .ok_or_else(|| anyhow::anyhow!("task {task_id} not found"))?;
306 if is_terminal(cur.state) {
307 return Ok(cur);
308 }
309 conn.execute(
310 "UPDATE tasks SET cancel_requested=1 WHERE task_id=?",
311 [task_id],
312 )?;
313 if cur.state == TaskState::Queued {
314 return transition(
315 conn,
316 task_id,
317 TaskState::Cancelled,
318 TransitionUpdate {
319 finished_at: Some(Utc::now()),
320 error: Some("cancelled_before_start".into()),
321 ..Default::default()
322 },
323 );
324 }
325 if cur.state == TaskState::Running {
326 return transition(
327 conn,
328 task_id,
329 TaskState::Cancelling,
330 TransitionUpdate::default(),
331 );
332 }
333 Ok(cur)
334}
335
336pub fn consume_cancel_signal(conn: &Connection, task_id: &str) -> Result<bool> {
337 let r: Option<i64> = conn
338 .query_row(
339 "SELECT cancel_requested FROM tasks WHERE task_id=?",
340 [task_id],
341 |r| r.get(0),
342 )
343 .optional()?;
344 Ok(r.unwrap_or(0) != 0)
345}
346
347pub fn add_artifact(
348 conn: &Connection,
349 task_id: &str,
350 artifact: &Artifact,
351 blob_path: &str,
352) -> Result<()> {
353 conn.execute(
354 "INSERT INTO artifacts (task_id, path, size, sha256, blob_path) VALUES (?,?,?,?,?)",
355 params![task_id, artifact.path, artifact.size as i64, artifact.sha256, blob_path],
356 )?;
357 Ok(())
358}
359
360pub fn touch_access(conn: &Connection, task_id: &str) -> Result<()> {
361 conn.execute(
362 "UPDATE tasks SET last_access_at=? WHERE task_id=?",
363 params![iso(Utc::now()), task_id],
364 )?;
365 Ok(())
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 fn fresh() -> Connection {
373 let c = crate::db::connect_in_memory().unwrap();
374 crate::db::bootstrap(&c).unwrap();
375 c
376 }
377
378 fn insert(conn: &Connection, name: &str, payload: &str, id: &str) -> Task {
379 insert_task(
380 conn,
381 id,
382 name,
383 "alice",
384 TaskKind::Command,
385 payload,
386 &[],
387 &TaskLimits::default(),
388 None,
389 )
390 .unwrap()
391 }
392
393 #[test]
394 fn legal_transitions() {
395 assert!(can_transition(TaskState::Queued, TaskState::Running));
396 assert!(can_transition(TaskState::Running, TaskState::Completed));
397 assert!(can_transition(TaskState::Running, TaskState::Cancelling));
398 assert!(can_transition(TaskState::Cancelling, TaskState::Cancelled));
399 assert!(!can_transition(TaskState::Queued, TaskState::Completed));
400 assert!(!can_transition(TaskState::Completed, TaskState::Running));
401 }
402
403 #[test]
404 fn worker_seq_strictly_increasing() {
405 let c = fresh();
406 let a = insert(&c, "w1", "x", "a");
407 let b = insert(&c, "w1", "x", "b");
408 let c2 = insert(&c, "w1", "x", "c");
409 assert_eq!((a.worker_seq, b.worker_seq, c2.worker_seq), (1, 2, 3));
410 let other = insert(&c, "w2", "x", "z");
412 assert_eq!(other.worker_seq, 1);
413 }
414
415 #[test]
416 fn next_queued_orders_by_seq() {
417 let c = fresh();
418 insert(&c, "w1", "first", "a");
419 insert(&c, "w1", "second", "b");
420 let nxt = next_queued_for_worker(&c, "w1").unwrap().unwrap();
421 assert_eq!(nxt.task_id, "a");
422 }
423
424 #[test]
425 fn transition_invalid_rejected() {
426 let c = fresh();
427 insert(&c, "w", "x", "t");
428 let r = transition(&c, "t", TaskState::Completed, Default::default());
429 assert!(r.is_err());
430 }
431
432 #[test]
433 fn request_cancel_queued_terminates_immediately() {
434 let c = fresh();
435 insert(&c, "w", "x", "t");
436 let t = request_cancel(&c, "t").unwrap();
437 assert_eq!(t.state, TaskState::Cancelled);
438 }
439
440 #[test]
441 fn request_cancel_running_goes_cancelling() {
442 let c = fresh();
443 insert(&c, "w", "x", "t");
444 let _ = transition(
445 &c,
446 "t",
447 TaskState::Running,
448 TransitionUpdate {
449 started_at: Some(Utc::now()),
450 ..Default::default()
451 },
452 )
453 .unwrap();
454 let t = request_cancel(&c, "t").unwrap();
455 assert_eq!(t.state, TaskState::Cancelling);
456 assert!(consume_cancel_signal(&c, "t").unwrap());
457 }
458
459 #[test]
460 fn add_artifact_round_trip() {
461 let c = fresh();
462 insert(&c, "w", "x", "t");
463 let a = Artifact {
464 path: "outputs/a.log".into(),
465 size: 10,
466 sha256: "ab".repeat(32),
467 };
468 add_artifact(&c, "t", &a, "blobs/t/artifacts/outputs/a.log").unwrap();
469 let got = get_task(&c, "t").unwrap().unwrap();
470 assert_eq!(got.artifacts.len(), 1);
471 assert_eq!(got.artifacts[0], a);
472 }
473
474 #[test]
475 fn task_id_format() {
476 let id = new_task_id();
477 assert!(id.starts_with("tsk_"));
478 assert_eq!(id.len(), "tsk_".len() + 22);
480 }
481}