use anyhow::Result;
use chrono::{Duration, Utc};
use rand::{rngs::OsRng, RngCore};
use rusqlite::{params, Connection, OptionalExtension};
use sha2::{Digest, Sha256};
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct Worker {
pub id: i64,
pub name: String,
pub status: Option<String>,
pub current_pwd: Option<String>,
pub current_task_id: Option<String>,
pub last_poll_at: Option<String>,
}
#[derive(Debug, Error)]
pub enum InvalidRegisterCode {
#[error("unknown worker")]
UnknownWorker,
#[error("no pending registration")]
NoPending,
#[error("bad register code")]
BadCode,
#[error("register code expired")]
Expired,
}
fn iso_now() -> String {
Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
}
fn iso_at(dt: chrono::DateTime<Utc>) -> String {
dt.to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
}
fn token_urlsafe(byte_len: usize) -> String {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
let mut bytes = vec![0u8; byte_len];
OsRng.fill_bytes(&mut bytes);
URL_SAFE_NO_PAD.encode(&bytes)
}
fn sha256_hex(s: &str) -> String {
hex::encode(Sha256::digest(s.as_bytes()))
}
pub fn create_or_replace_register_code(
conn: &Connection,
name: &str,
ttl_sec: i64,
) -> Result<String> {
let code = token_urlsafe(16);
let h = sha256_hex(&code);
let expires = iso_at(Utc::now() + Duration::seconds(ttl_sec));
let exists: Option<i64> = conn
.query_row("SELECT id FROM workers WHERE name=?", [name], |r| r.get(0))
.optional()?;
if exists.is_none() {
conn.execute(
"INSERT INTO workers (name, register_code_hash, register_code_expires, created_at)
VALUES (?,?,?,?)",
params![name, h, expires, iso_now()],
)?;
} else {
conn.execute(
"UPDATE workers SET register_code_hash=?, register_code_expires=?,
token_hash=NULL, revoked_at=NULL WHERE name=?",
params![h, expires, name],
)?;
}
Ok(code)
}
pub fn finalize_register(
conn: &Connection,
name: &str,
register_code: &str,
client_pubkey: Option<&[u8]>,
) -> std::result::Result<(String, i64), InvalidRegisterCode> {
let row: Option<(i64, Option<String>, Option<String>, Option<String>)> = conn
.query_row(
"SELECT id, register_code_hash, register_code_expires, token_hash
FROM workers WHERE name=?",
[name],
|r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?)),
)
.optional()
.map_err(|_| InvalidRegisterCode::UnknownWorker)?;
let Some((wid, register_code_hash, register_code_expires, _token_hash)) = row else {
return Err(InvalidRegisterCode::UnknownWorker);
};
let Some(stored_hash) = register_code_hash else {
return Err(InvalidRegisterCode::NoPending);
};
if stored_hash != sha256_hex(register_code) {
return Err(InvalidRegisterCode::BadCode);
}
let expires = register_code_expires.unwrap_or_else(|| iso_now());
let Ok(parsed) = chrono::DateTime::parse_from_rfc3339(
&expires.replace('Z', "+00:00"),
) else {
return Err(InvalidRegisterCode::Expired);
};
if parsed.with_timezone(&Utc) <= Utc::now() {
return Err(InvalidRegisterCode::Expired);
}
let token = token_urlsafe(32);
let h = sha256_hex(&token);
conn.execute(
"UPDATE workers SET token_hash=?, register_code_hash=NULL,
register_code_expires=NULL, client_pubkey=? WHERE id=?",
params![h, client_pubkey, wid],
)
.map_err(|_| InvalidRegisterCode::UnknownWorker)?;
Ok((token, wid))
}
pub fn lookup_by_token(conn: &Connection, token: &str) -> Result<Option<Worker>> {
let h = sha256_hex(token);
let row = conn
.query_row(
"SELECT id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at
FROM workers WHERE token_hash=?",
[h],
|r| {
Ok((
r.get::<_, i64>(0)?,
r.get::<_, String>(1)?,
r.get::<_, Option<String>>(2)?,
r.get::<_, Option<String>>(3)?,
r.get::<_, Option<String>>(4)?,
r.get::<_, Option<String>>(5)?,
r.get::<_, Option<String>>(6)?,
))
},
)
.optional()?;
let Some((id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at)) = row
else {
return Ok(None);
};
if revoked_at.is_some() {
return Ok(None);
}
Ok(Some(Worker {
id,
name,
status,
current_pwd,
current_task_id,
last_poll_at,
}))
}
pub fn lookup_by_name(conn: &Connection, name: &str) -> Result<Option<Worker>> {
let row = conn
.query_row(
"SELECT id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at
FROM workers WHERE name=?",
[name],
|r| {
Ok((
r.get::<_, i64>(0)?,
r.get::<_, String>(1)?,
r.get::<_, Option<String>>(2)?,
r.get::<_, Option<String>>(3)?,
r.get::<_, Option<String>>(4)?,
r.get::<_, Option<String>>(5)?,
r.get::<_, Option<String>>(6)?,
))
},
)
.optional()?;
let Some((id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at)) = row
else {
return Ok(None);
};
if revoked_at.is_some() {
return Ok(None);
}
Ok(Some(Worker {
id,
name,
status,
current_pwd,
current_task_id,
last_poll_at,
}))
}
pub fn revoke(conn: &Connection, name: &str) -> Result<()> {
conn.execute(
"UPDATE workers SET revoked_at=?, token_hash=NULL WHERE name=?",
params![iso_now(), name],
)?;
Ok(())
}
pub fn update_status(
conn: &Connection,
name: &str,
status: &str,
current_pwd: Option<&str>,
current_task_id: Option<&str>,
) -> Result<()> {
conn.execute(
"UPDATE workers SET status=?, current_pwd=COALESCE(?, current_pwd),
current_task_id=?, last_poll_at=? WHERE name=?",
params![status, current_pwd, current_task_id, iso_now(), name],
)?;
Ok(())
}
pub fn list_all(conn: &Connection) -> Result<Vec<Worker>> {
let mut stmt = conn.prepare(
"SELECT id, name, status, current_pwd, current_task_id, last_poll_at
FROM workers WHERE revoked_at IS NULL ORDER BY name ASC",
)?;
let rows = stmt
.query_map([], |r| {
Ok(Worker {
id: r.get(0)?,
name: r.get(1)?,
status: r.get(2)?,
current_pwd: r.get(3)?,
current_task_id: r.get(4)?,
last_poll_at: r.get(5)?,
})
})?
.collect::<rusqlite::Result<Vec<_>>>()?;
Ok(rows)
}
pub fn get_client_pubkey(conn: &Connection, name: &str) -> Result<Option<Vec<u8>>> {
let row: Option<Option<Vec<u8>>> = conn
.query_row(
"SELECT client_pubkey FROM workers WHERE name=? AND revoked_at IS NULL",
[name],
|r| r.get(0),
)
.optional()?;
Ok(row.flatten())
}
#[cfg(test)]
mod tests {
use super::*;
fn fresh() -> Connection {
let c = crate::db::connect_in_memory().unwrap();
crate::db::bootstrap(&c).unwrap();
c
}
#[test]
fn issue_then_finalize() {
let c = fresh();
let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
let (token, wid) = finalize_register(&c, "w1", &code, None).unwrap();
assert!(wid > 0);
let w = lookup_by_token(&c, &token).unwrap().unwrap();
assert_eq!(w.name, "w1");
let err = finalize_register(&c, "w1", &code, None).unwrap_err();
assert!(matches!(err, InvalidRegisterCode::NoPending));
}
#[test]
fn finalize_wrong_code_rejected() {
let c = fresh();
let _ = create_or_replace_register_code(&c, "w1", 600).unwrap();
let err = finalize_register(&c, "w1", "xxxxxxxxxxxxx", None).unwrap_err();
assert!(matches!(err, InvalidRegisterCode::BadCode));
}
#[test]
fn finalize_expired_rejected() {
let c = fresh();
let code = create_or_replace_register_code(&c, "w1", -1).unwrap();
let err = finalize_register(&c, "w1", &code, None).unwrap_err();
assert!(matches!(err, InvalidRegisterCode::Expired));
}
#[test]
fn revoke_invalidates_token() {
let c = fresh();
let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
let (token, _) = finalize_register(&c, "w1", &code, None).unwrap();
revoke(&c, "w1").unwrap();
assert!(lookup_by_token(&c, &token).unwrap().is_none());
}
#[test]
fn update_status_and_list() {
let c = fresh();
let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
let _ = finalize_register(&c, "w1", &code, None).unwrap();
update_status(&c, "w1", "IDLE", Some("/tmp"), None).unwrap();
let listed = list_all(&c).unwrap();
assert_eq!(listed.len(), 1);
assert_eq!(listed[0].status.as_deref(), Some("IDLE"));
assert_eq!(listed[0].current_pwd.as_deref(), Some("/tmp"));
}
#[test]
fn finalize_records_client_pubkey() {
let c = fresh();
let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
let blob = b"hello-pubkey";
let _ = finalize_register(&c, "w1", &code, Some(blob)).unwrap();
assert_eq!(
get_client_pubkey(&c, "w1").unwrap().as_deref(),
Some(&blob[..])
);
}
}