dragoon-server 0.1.0

Public-relay server for the dragoon remote-executor: axum + rusqlite + ed25519 task signing + per-user message inbox.
Documentation
//! Worker registration & token bookkeeping.
//! Mirrors `python/.../server/workers_repo.py`.

use anyhow::Result;
use chrono::{Duration, Utc};
use rand::{rngs::OsRng, RngCore};
use rusqlite::{params, Connection, OptionalExtension};
use sha2::{Digest, Sha256};
use thiserror::Error;

#[derive(Debug, Clone)]
pub struct Worker {
    pub id: i64,
    pub name: String,
    pub status: Option<String>,
    pub current_pwd: Option<String>,
    pub current_task_id: Option<String>,
    pub last_poll_at: Option<String>,
}

#[derive(Debug, Error)]
pub enum InvalidRegisterCode {
    #[error("unknown worker")]
    UnknownWorker,
    #[error("no pending registration")]
    NoPending,
    #[error("bad register code")]
    BadCode,
    #[error("register code expired")]
    Expired,
}

fn iso_now() -> String {
    Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
}

fn iso_at(dt: chrono::DateTime<Utc>) -> String {
    dt.to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
}

fn token_urlsafe(byte_len: usize) -> String {
    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
    let mut bytes = vec![0u8; byte_len];
    OsRng.fill_bytes(&mut bytes);
    URL_SAFE_NO_PAD.encode(&bytes)
}

fn sha256_hex(s: &str) -> String {
    hex::encode(Sha256::digest(s.as_bytes()))
}

/// Issue (or rotate) a one-shot register code for `name`. Returns the
/// **plaintext** code; only the hash is persisted.
pub fn create_or_replace_register_code(
    conn: &Connection,
    name: &str,
    ttl_sec: i64,
) -> Result<String> {
    let code = token_urlsafe(16);
    let h = sha256_hex(&code);
    let expires = iso_at(Utc::now() + Duration::seconds(ttl_sec));
    let exists: Option<i64> = conn
        .query_row("SELECT id FROM workers WHERE name=?", [name], |r| r.get(0))
        .optional()?;
    if exists.is_none() {
        conn.execute(
            "INSERT INTO workers (name, register_code_hash, register_code_expires, created_at)
             VALUES (?,?,?,?)",
            params![name, h, expires, iso_now()],
        )?;
    } else {
        conn.execute(
            "UPDATE workers SET register_code_hash=?, register_code_expires=?,
             token_hash=NULL, revoked_at=NULL WHERE name=?",
            params![h, expires, name],
        )?;
    }
    Ok(code)
}

/// Exchange a register code for a persistent worker token. Optionally
/// records the worker-side public key (Phase 4 design addition).
pub fn finalize_register(
    conn: &Connection,
    name: &str,
    register_code: &str,
    client_pubkey: Option<&[u8]>,
) -> std::result::Result<(String, i64), InvalidRegisterCode> {
    let row: Option<(i64, Option<String>, Option<String>, Option<String>)> = conn
        .query_row(
            "SELECT id, register_code_hash, register_code_expires, token_hash
             FROM workers WHERE name=?",
            [name],
            |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?)),
        )
        .optional()
        .map_err(|_| InvalidRegisterCode::UnknownWorker)?;

    let Some((wid, register_code_hash, register_code_expires, _token_hash)) = row else {
        return Err(InvalidRegisterCode::UnknownWorker);
    };

    let Some(stored_hash) = register_code_hash else {
        return Err(InvalidRegisterCode::NoPending);
    };
    if stored_hash != sha256_hex(register_code) {
        return Err(InvalidRegisterCode::BadCode);
    }
    let expires = register_code_expires.unwrap_or_else(|| iso_now());
    let Ok(parsed) = chrono::DateTime::parse_from_rfc3339(
        &expires.replace('Z', "+00:00"),
    ) else {
        return Err(InvalidRegisterCode::Expired);
    };
    if parsed.with_timezone(&Utc) <= Utc::now() {
        return Err(InvalidRegisterCode::Expired);
    }

    let token = token_urlsafe(32);
    let h = sha256_hex(&token);
    conn.execute(
        "UPDATE workers SET token_hash=?, register_code_hash=NULL,
         register_code_expires=NULL, client_pubkey=? WHERE id=?",
        params![h, client_pubkey, wid],
    )
    .map_err(|_| InvalidRegisterCode::UnknownWorker)?;
    Ok((token, wid))
}

pub fn lookup_by_token(conn: &Connection, token: &str) -> Result<Option<Worker>> {
    let h = sha256_hex(token);
    let row = conn
        .query_row(
            "SELECT id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at
             FROM workers WHERE token_hash=?",
            [h],
            |r| {
                Ok((
                    r.get::<_, i64>(0)?,
                    r.get::<_, String>(1)?,
                    r.get::<_, Option<String>>(2)?,
                    r.get::<_, Option<String>>(3)?,
                    r.get::<_, Option<String>>(4)?,
                    r.get::<_, Option<String>>(5)?,
                    r.get::<_, Option<String>>(6)?,
                ))
            },
        )
        .optional()?;
    let Some((id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at)) = row
    else {
        return Ok(None);
    };
    if revoked_at.is_some() {
        return Ok(None);
    }
    Ok(Some(Worker {
        id,
        name,
        status,
        current_pwd,
        current_task_id,
        last_poll_at,
    }))
}

pub fn lookup_by_name(conn: &Connection, name: &str) -> Result<Option<Worker>> {
    let row = conn
        .query_row(
            "SELECT id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at
             FROM workers WHERE name=?",
            [name],
            |r| {
                Ok((
                    r.get::<_, i64>(0)?,
                    r.get::<_, String>(1)?,
                    r.get::<_, Option<String>>(2)?,
                    r.get::<_, Option<String>>(3)?,
                    r.get::<_, Option<String>>(4)?,
                    r.get::<_, Option<String>>(5)?,
                    r.get::<_, Option<String>>(6)?,
                ))
            },
        )
        .optional()?;
    let Some((id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at)) = row
    else {
        return Ok(None);
    };
    if revoked_at.is_some() {
        return Ok(None);
    }
    Ok(Some(Worker {
        id,
        name,
        status,
        current_pwd,
        current_task_id,
        last_poll_at,
    }))
}

pub fn revoke(conn: &Connection, name: &str) -> Result<()> {
    conn.execute(
        "UPDATE workers SET revoked_at=?, token_hash=NULL WHERE name=?",
        params![iso_now(), name],
    )?;
    Ok(())
}

pub fn update_status(
    conn: &Connection,
    name: &str,
    status: &str,
    current_pwd: Option<&str>,
    current_task_id: Option<&str>,
) -> Result<()> {
    conn.execute(
        "UPDATE workers SET status=?, current_pwd=COALESCE(?, current_pwd),
         current_task_id=?, last_poll_at=? WHERE name=?",
        params![status, current_pwd, current_task_id, iso_now(), name],
    )?;
    Ok(())
}

pub fn list_all(conn: &Connection) -> Result<Vec<Worker>> {
    let mut stmt = conn.prepare(
        "SELECT id, name, status, current_pwd, current_task_id, last_poll_at
         FROM workers WHERE revoked_at IS NULL ORDER BY name ASC",
    )?;
    let rows = stmt
        .query_map([], |r| {
            Ok(Worker {
                id: r.get(0)?,
                name: r.get(1)?,
                status: r.get(2)?,
                current_pwd: r.get(3)?,
                current_task_id: r.get(4)?,
                last_poll_at: r.get(5)?,
            })
        })?
        .collect::<rusqlite::Result<Vec<_>>>()?;
    Ok(rows)
}

pub fn get_client_pubkey(conn: &Connection, name: &str) -> Result<Option<Vec<u8>>> {
    let row: Option<Option<Vec<u8>>> = conn
        .query_row(
            "SELECT client_pubkey FROM workers WHERE name=? AND revoked_at IS NULL",
            [name],
            |r| r.get(0),
        )
        .optional()?;
    Ok(row.flatten())
}

#[cfg(test)]
mod tests {
    use super::*;

    fn fresh() -> Connection {
        let c = crate::db::connect_in_memory().unwrap();
        crate::db::bootstrap(&c).unwrap();
        c
    }

    #[test]
    fn issue_then_finalize() {
        let c = fresh();
        let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
        let (token, wid) = finalize_register(&c, "w1", &code, None).unwrap();
        assert!(wid > 0);
        let w = lookup_by_token(&c, &token).unwrap().unwrap();
        assert_eq!(w.name, "w1");
        // re-using the code must fail
        let err = finalize_register(&c, "w1", &code, None).unwrap_err();
        assert!(matches!(err, InvalidRegisterCode::NoPending));
    }

    #[test]
    fn finalize_wrong_code_rejected() {
        let c = fresh();
        let _ = create_or_replace_register_code(&c, "w1", 600).unwrap();
        let err = finalize_register(&c, "w1", "xxxxxxxxxxxxx", None).unwrap_err();
        assert!(matches!(err, InvalidRegisterCode::BadCode));
    }

    #[test]
    fn finalize_expired_rejected() {
        let c = fresh();
        let code = create_or_replace_register_code(&c, "w1", -1).unwrap();
        let err = finalize_register(&c, "w1", &code, None).unwrap_err();
        assert!(matches!(err, InvalidRegisterCode::Expired));
    }

    #[test]
    fn revoke_invalidates_token() {
        let c = fresh();
        let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
        let (token, _) = finalize_register(&c, "w1", &code, None).unwrap();
        revoke(&c, "w1").unwrap();
        assert!(lookup_by_token(&c, &token).unwrap().is_none());
    }

    #[test]
    fn update_status_and_list() {
        let c = fresh();
        let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
        let _ = finalize_register(&c, "w1", &code, None).unwrap();
        update_status(&c, "w1", "IDLE", Some("/tmp"), None).unwrap();
        let listed = list_all(&c).unwrap();
        assert_eq!(listed.len(), 1);
        assert_eq!(listed[0].status.as_deref(), Some("IDLE"));
        assert_eq!(listed[0].current_pwd.as_deref(), Some("/tmp"));
    }

    #[test]
    fn finalize_records_client_pubkey() {
        let c = fresh();
        let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
        let blob = b"hello-pubkey";
        let _ = finalize_register(&c, "w1", &code, Some(blob)).unwrap();
        assert_eq!(
            get_client_pubkey(&c, "w1").unwrap().as_deref(),
            Some(&blob[..])
        );
    }
}