use std::collections::{HashMap, HashSet};
use anyhow::Result;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use chrono::{DateTime, Utc};
use rand::{rngs::OsRng, RngCore};
use rusqlite::{params, Connection, OptionalExtension, Row};
use dragoon_proto::models::{Artifact, Task, TaskKind, TaskLimits, TaskState};
fn iso(dt: DateTime<Utc>) -> String {
dt.to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
}
fn parse_iso(s: &str) -> anyhow::Result<DateTime<Utc>> {
let s = if let Some(stripped) = s.strip_suffix('Z') {
format!("{stripped}+00:00")
} else {
s.to_owned()
};
Ok(DateTime::parse_from_rfc3339(&s)?.with_timezone(&Utc))
}
pub fn new_task_id() -> String {
let mut bytes = [0u8; 16];
OsRng.fill_bytes(&mut bytes);
format!("tsk_{}", URL_SAFE_NO_PAD.encode(bytes))
}
fn allowed_transitions() -> HashMap<TaskState, HashSet<TaskState>> {
use TaskState::*;
let mut m: HashMap<TaskState, HashSet<TaskState>> = HashMap::new();
m.insert(Queued, [Running, Cancelling, Cancelled].into_iter().collect());
m.insert(
Running,
[Completed, Failed, Timeout, Cancelling].into_iter().collect(),
);
m.insert(
Cancelling,
[Cancelled, Failed, Completed].into_iter().collect(),
);
m.insert(Completed, HashSet::new());
m.insert(Failed, HashSet::new());
m.insert(Timeout, HashSet::new());
m.insert(Cancelled, HashSet::new());
m
}
pub fn is_terminal(s: TaskState) -> bool {
matches!(
s,
TaskState::Completed | TaskState::Failed | TaskState::Timeout | TaskState::Cancelled
)
}
pub fn can_transition(src: TaskState, dst: TaskState) -> bool {
allowed_transitions()
.get(&src)
.is_some_and(|set| set.contains(&dst))
}
fn task_state_from_str(s: &str) -> anyhow::Result<TaskState> {
Ok(match s {
"QUEUED" => TaskState::Queued,
"RUNNING" => TaskState::Running,
"COMPLETED" => TaskState::Completed,
"FAILED" => TaskState::Failed,
"TIMEOUT" => TaskState::Timeout,
"CANCELLING" => TaskState::Cancelling,
"CANCELLED" => TaskState::Cancelled,
other => anyhow::bail!("unknown task state {other}"),
})
}
fn task_state_str(s: TaskState) -> &'static str {
match s {
TaskState::Queued => "QUEUED",
TaskState::Running => "RUNNING",
TaskState::Completed => "COMPLETED",
TaskState::Failed => "FAILED",
TaskState::Timeout => "TIMEOUT",
TaskState::Cancelling => "CANCELLING",
TaskState::Cancelled => "CANCELLED",
}
}
fn task_kind_from_str(s: &str) -> anyhow::Result<TaskKind> {
Ok(match s {
"command" => TaskKind::Command,
"script" => TaskKind::Script,
"fetch" => TaskKind::Fetch,
other => anyhow::bail!("unknown task kind {other}"),
})
}
fn row_to_task(conn: &Connection, r: &Row<'_>) -> anyhow::Result<Task> {
let task_id: String = r.get("task_id")?;
let collect_json: String = r.get("collect_json")?;
let limits_json: String = r.get("limits_json")?;
let state_s: String = r.get("state")?;
let submitted_at: String = r.get("submitted_at")?;
let started_at: Option<String> = r.get("started_at")?;
let finished_at: Option<String> = r.get("finished_at")?;
let kind_s: String = r.get("kind")?;
let mut artifacts = Vec::new();
let mut stmt = conn.prepare(
"SELECT path, size, sha256 FROM artifacts WHERE task_id=? ORDER BY id ASC",
)?;
for art in stmt.query_map([&task_id], |ar| {
Ok(Artifact {
path: ar.get(0)?,
size: ar.get::<_, i64>(1)? as u64,
sha256: ar.get(2)?,
})
})? {
artifacts.push(art?);
}
Ok(Task {
task_id: task_id.clone(),
worker_name: r.get("worker_name")?,
submitter: r.get("submitter")?,
kind: task_kind_from_str(&kind_s)?,
payload: r.get("payload")?,
collect: serde_json::from_str(&collect_json)?,
limits: serde_json::from_str(&limits_json)?,
state: task_state_from_str(&state_s)?,
submitted_at: parse_iso(&submitted_at)?,
started_at: started_at.as_deref().map(parse_iso).transpose()?,
finished_at: finished_at.as_deref().map(parse_iso).transpose()?,
exit_code: r.get("exit_code")?,
final_pwd: r.get("final_pwd")?,
artifacts,
error: r.get("error")?,
fetch_path: r.get("fetch_path")?,
worker_seq: r.get("worker_seq")?,
})
}
fn next_worker_seq(conn: &Connection, worker_name: &str) -> Result<i64> {
let m: Option<i64> = conn
.query_row(
"SELECT COALESCE(MAX(worker_seq), 0) FROM tasks WHERE worker_name=?",
[worker_name],
|r| r.get(0),
)
.optional()?;
Ok(m.unwrap_or(0) + 1)
}
#[allow(clippy::too_many_arguments)]
pub fn insert_task(
conn: &Connection,
task_id: &str,
worker_name: &str,
submitter: &str,
kind: TaskKind,
payload: &str,
collect: &[String],
limits: &TaskLimits,
fetch_path: Option<&str>,
) -> Result<Task> {
let submitted = Utc::now();
let seq = next_worker_seq(conn, worker_name)?;
conn.execute(
"INSERT INTO tasks
(task_id, worker_name, submitter, kind, payload, collect_json, limits_json,
state, submitted_at, fetch_path, last_access_at, worker_seq)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)",
params![
task_id,
worker_name,
submitter,
match kind {
TaskKind::Command => "command",
TaskKind::Script => "script",
TaskKind::Fetch => "fetch",
},
payload,
serde_json::to_string(collect)?,
serde_json::to_string(limits)?,
"QUEUED",
iso(submitted),
fetch_path,
iso(submitted),
seq,
],
)?;
Ok(get_task(conn, task_id)?.expect("just inserted"))
}
pub fn get_task(conn: &Connection, task_id: &str) -> Result<Option<Task>> {
let row: Option<Task> = conn
.prepare("SELECT * FROM tasks WHERE task_id=?")?
.query_row([task_id], |r| {
row_to_task(conn, r).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(
0,
rusqlite::types::Type::Text,
Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
e.to_string(),
)),
)
})
})
.optional()?;
Ok(row)
}
pub fn next_queued_for_worker(conn: &Connection, worker_name: &str) -> Result<Option<Task>> {
conn.prepare(
"SELECT * FROM tasks WHERE worker_name=? AND state=? ORDER BY worker_seq ASC LIMIT 1",
)?
.query_row(params![worker_name, "QUEUED"], |r| {
row_to_task(conn, r).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(
0,
rusqlite::types::Type::Text,
Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())),
)
})
})
.optional()
.map_err(Into::into)
}
#[derive(Default, Debug, Clone)]
pub struct TransitionUpdate {
pub started_at: Option<DateTime<Utc>>,
pub finished_at: Option<DateTime<Utc>>,
pub exit_code: Option<i32>,
pub final_pwd: Option<String>,
pub error: Option<String>,
}
pub fn transition(
conn: &Connection,
task_id: &str,
new_state: TaskState,
update: TransitionUpdate,
) -> Result<Task> {
let cur = get_task(conn, task_id)?
.ok_or_else(|| anyhow::anyhow!("task {task_id} not found"))?;
if !can_transition(cur.state, new_state) {
anyhow::bail!(
"cannot transition {} -> {}",
task_state_str(cur.state),
task_state_str(new_state)
);
}
let mut sets: Vec<&str> = vec!["state=?"];
let mut vals: Vec<rusqlite::types::Value> =
vec![rusqlite::types::Value::Text(task_state_str(new_state).into())];
if let Some(ts) = update.started_at {
sets.push("started_at=?");
vals.push(rusqlite::types::Value::Text(iso(ts)));
}
if let Some(ts) = update.finished_at {
sets.push("finished_at=?");
vals.push(rusqlite::types::Value::Text(iso(ts)));
}
if let Some(code) = update.exit_code {
sets.push("exit_code=?");
vals.push(rusqlite::types::Value::Integer(code.into()));
}
if let Some(pwd) = update.final_pwd {
sets.push("final_pwd=?");
vals.push(rusqlite::types::Value::Text(pwd));
}
if let Some(err) = update.error {
sets.push("error=?");
vals.push(rusqlite::types::Value::Text(err));
}
sets.push("last_access_at=?");
vals.push(rusqlite::types::Value::Text(iso(Utc::now())));
let sql = format!(
"UPDATE tasks SET {} WHERE task_id=?",
sets.join(", ")
);
vals.push(rusqlite::types::Value::Text(task_id.into()));
conn.execute(&sql, rusqlite::params_from_iter(vals.iter()))?;
Ok(get_task(conn, task_id)?.expect("present after update"))
}
pub fn request_cancel(conn: &Connection, task_id: &str) -> Result<Task> {
let cur = get_task(conn, task_id)?
.ok_or_else(|| anyhow::anyhow!("task {task_id} not found"))?;
if is_terminal(cur.state) {
return Ok(cur);
}
conn.execute(
"UPDATE tasks SET cancel_requested=1 WHERE task_id=?",
[task_id],
)?;
if cur.state == TaskState::Queued {
return transition(
conn,
task_id,
TaskState::Cancelled,
TransitionUpdate {
finished_at: Some(Utc::now()),
error: Some("cancelled_before_start".into()),
..Default::default()
},
);
}
if cur.state == TaskState::Running {
return transition(
conn,
task_id,
TaskState::Cancelling,
TransitionUpdate::default(),
);
}
Ok(cur)
}
pub fn consume_cancel_signal(conn: &Connection, task_id: &str) -> Result<bool> {
let r: Option<i64> = conn
.query_row(
"SELECT cancel_requested FROM tasks WHERE task_id=?",
[task_id],
|r| r.get(0),
)
.optional()?;
Ok(r.unwrap_or(0) != 0)
}
pub fn add_artifact(
conn: &Connection,
task_id: &str,
artifact: &Artifact,
blob_path: &str,
) -> Result<()> {
conn.execute(
"INSERT INTO artifacts (task_id, path, size, sha256, blob_path) VALUES (?,?,?,?,?)",
params![task_id, artifact.path, artifact.size as i64, artifact.sha256, blob_path],
)?;
Ok(())
}
pub fn touch_access(conn: &Connection, task_id: &str) -> Result<()> {
conn.execute(
"UPDATE tasks SET last_access_at=? WHERE task_id=?",
params![iso(Utc::now()), task_id],
)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn fresh() -> Connection {
let c = crate::db::connect_in_memory().unwrap();
crate::db::bootstrap(&c).unwrap();
c
}
fn insert(conn: &Connection, name: &str, payload: &str, id: &str) -> Task {
insert_task(
conn,
id,
name,
"alice",
TaskKind::Command,
payload,
&[],
&TaskLimits::default(),
None,
)
.unwrap()
}
#[test]
fn legal_transitions() {
assert!(can_transition(TaskState::Queued, TaskState::Running));
assert!(can_transition(TaskState::Running, TaskState::Completed));
assert!(can_transition(TaskState::Running, TaskState::Cancelling));
assert!(can_transition(TaskState::Cancelling, TaskState::Cancelled));
assert!(!can_transition(TaskState::Queued, TaskState::Completed));
assert!(!can_transition(TaskState::Completed, TaskState::Running));
}
#[test]
fn worker_seq_strictly_increasing() {
let c = fresh();
let a = insert(&c, "w1", "x", "a");
let b = insert(&c, "w1", "x", "b");
let c2 = insert(&c, "w1", "x", "c");
assert_eq!((a.worker_seq, b.worker_seq, c2.worker_seq), (1, 2, 3));
let other = insert(&c, "w2", "x", "z");
assert_eq!(other.worker_seq, 1);
}
#[test]
fn next_queued_orders_by_seq() {
let c = fresh();
insert(&c, "w1", "first", "a");
insert(&c, "w1", "second", "b");
let nxt = next_queued_for_worker(&c, "w1").unwrap().unwrap();
assert_eq!(nxt.task_id, "a");
}
#[test]
fn transition_invalid_rejected() {
let c = fresh();
insert(&c, "w", "x", "t");
let r = transition(&c, "t", TaskState::Completed, Default::default());
assert!(r.is_err());
}
#[test]
fn request_cancel_queued_terminates_immediately() {
let c = fresh();
insert(&c, "w", "x", "t");
let t = request_cancel(&c, "t").unwrap();
assert_eq!(t.state, TaskState::Cancelled);
}
#[test]
fn request_cancel_running_goes_cancelling() {
let c = fresh();
insert(&c, "w", "x", "t");
let _ = transition(
&c,
"t",
TaskState::Running,
TransitionUpdate {
started_at: Some(Utc::now()),
..Default::default()
},
)
.unwrap();
let t = request_cancel(&c, "t").unwrap();
assert_eq!(t.state, TaskState::Cancelling);
assert!(consume_cancel_signal(&c, "t").unwrap());
}
#[test]
fn add_artifact_round_trip() {
let c = fresh();
insert(&c, "w", "x", "t");
let a = Artifact {
path: "outputs/a.log".into(),
size: 10,
sha256: "ab".repeat(32),
};
add_artifact(&c, "t", &a, "blobs/t/artifacts/outputs/a.log").unwrap();
let got = get_task(&c, "t").unwrap().unwrap();
assert_eq!(got.artifacts.len(), 1);
assert_eq!(got.artifacts[0], a);
}
#[test]
fn task_id_format() {
let id = new_task_id();
assert!(id.starts_with("tsk_"));
assert_eq!(id.len(), "tsk_".len() + 22);
}
}