Skip to main content

dragoon_server/
workers_repo.rs

1//! Worker registration & token bookkeeping.
2//! Mirrors `python/.../server/workers_repo.py`.
3
4use anyhow::Result;
5use chrono::{Duration, Utc};
6use rand::{rngs::OsRng, RngCore};
7use rusqlite::{params, Connection, OptionalExtension};
8use sha2::{Digest, Sha256};
9use thiserror::Error;
10
11#[derive(Debug, Clone)]
12pub struct Worker {
13    pub id: i64,
14    pub name: String,
15    pub status: Option<String>,
16    pub current_pwd: Option<String>,
17    pub current_task_id: Option<String>,
18    pub last_poll_at: Option<String>,
19}
20
21#[derive(Debug, Error)]
22pub enum InvalidRegisterCode {
23    #[error("unknown worker")]
24    UnknownWorker,
25    #[error("no pending registration")]
26    NoPending,
27    #[error("bad register code")]
28    BadCode,
29    #[error("register code expired")]
30    Expired,
31}
32
33fn iso_now() -> String {
34    Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
35}
36
37fn iso_at(dt: chrono::DateTime<Utc>) -> String {
38    dt.to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
39}
40
41fn token_urlsafe(byte_len: usize) -> String {
42    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
43    let mut bytes = vec![0u8; byte_len];
44    OsRng.fill_bytes(&mut bytes);
45    URL_SAFE_NO_PAD.encode(&bytes)
46}
47
48fn sha256_hex(s: &str) -> String {
49    hex::encode(Sha256::digest(s.as_bytes()))
50}
51
52/// Issue (or rotate) a one-shot register code for `name`. Returns the
53/// **plaintext** code; only the hash is persisted.
54pub fn create_or_replace_register_code(
55    conn: &Connection,
56    name: &str,
57    ttl_sec: i64,
58) -> Result<String> {
59    let code = token_urlsafe(16);
60    let h = sha256_hex(&code);
61    let expires = iso_at(Utc::now() + Duration::seconds(ttl_sec));
62    let exists: Option<i64> = conn
63        .query_row("SELECT id FROM workers WHERE name=?", [name], |r| r.get(0))
64        .optional()?;
65    if exists.is_none() {
66        conn.execute(
67            "INSERT INTO workers (name, register_code_hash, register_code_expires, created_at)
68             VALUES (?,?,?,?)",
69            params![name, h, expires, iso_now()],
70        )?;
71    } else {
72        conn.execute(
73            "UPDATE workers SET register_code_hash=?, register_code_expires=?,
74             token_hash=NULL, revoked_at=NULL WHERE name=?",
75            params![h, expires, name],
76        )?;
77    }
78    Ok(code)
79}
80
81/// Exchange a register code for a persistent worker token. Optionally
82/// records the worker-side public key (Phase 4 design addition).
83pub fn finalize_register(
84    conn: &Connection,
85    name: &str,
86    register_code: &str,
87    client_pubkey: Option<&[u8]>,
88) -> std::result::Result<(String, i64), InvalidRegisterCode> {
89    let row: Option<(i64, Option<String>, Option<String>, Option<String>)> = conn
90        .query_row(
91            "SELECT id, register_code_hash, register_code_expires, token_hash
92             FROM workers WHERE name=?",
93            [name],
94            |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?)),
95        )
96        .optional()
97        .map_err(|_| InvalidRegisterCode::UnknownWorker)?;
98
99    let Some((wid, register_code_hash, register_code_expires, _token_hash)) = row else {
100        return Err(InvalidRegisterCode::UnknownWorker);
101    };
102
103    let Some(stored_hash) = register_code_hash else {
104        return Err(InvalidRegisterCode::NoPending);
105    };
106    if stored_hash != sha256_hex(register_code) {
107        return Err(InvalidRegisterCode::BadCode);
108    }
109    let expires = register_code_expires.unwrap_or_else(|| iso_now());
110    let Ok(parsed) = chrono::DateTime::parse_from_rfc3339(
111        &expires.replace('Z', "+00:00"),
112    ) else {
113        return Err(InvalidRegisterCode::Expired);
114    };
115    if parsed.with_timezone(&Utc) <= Utc::now() {
116        return Err(InvalidRegisterCode::Expired);
117    }
118
119    let token = token_urlsafe(32);
120    let h = sha256_hex(&token);
121    conn.execute(
122        "UPDATE workers SET token_hash=?, register_code_hash=NULL,
123         register_code_expires=NULL, client_pubkey=? WHERE id=?",
124        params![h, client_pubkey, wid],
125    )
126    .map_err(|_| InvalidRegisterCode::UnknownWorker)?;
127    Ok((token, wid))
128}
129
130pub fn lookup_by_token(conn: &Connection, token: &str) -> Result<Option<Worker>> {
131    let h = sha256_hex(token);
132    let row = conn
133        .query_row(
134            "SELECT id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at
135             FROM workers WHERE token_hash=?",
136            [h],
137            |r| {
138                Ok((
139                    r.get::<_, i64>(0)?,
140                    r.get::<_, String>(1)?,
141                    r.get::<_, Option<String>>(2)?,
142                    r.get::<_, Option<String>>(3)?,
143                    r.get::<_, Option<String>>(4)?,
144                    r.get::<_, Option<String>>(5)?,
145                    r.get::<_, Option<String>>(6)?,
146                ))
147            },
148        )
149        .optional()?;
150    let Some((id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at)) = row
151    else {
152        return Ok(None);
153    };
154    if revoked_at.is_some() {
155        return Ok(None);
156    }
157    Ok(Some(Worker {
158        id,
159        name,
160        status,
161        current_pwd,
162        current_task_id,
163        last_poll_at,
164    }))
165}
166
167pub fn lookup_by_name(conn: &Connection, name: &str) -> Result<Option<Worker>> {
168    let row = conn
169        .query_row(
170            "SELECT id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at
171             FROM workers WHERE name=?",
172            [name],
173            |r| {
174                Ok((
175                    r.get::<_, i64>(0)?,
176                    r.get::<_, String>(1)?,
177                    r.get::<_, Option<String>>(2)?,
178                    r.get::<_, Option<String>>(3)?,
179                    r.get::<_, Option<String>>(4)?,
180                    r.get::<_, Option<String>>(5)?,
181                    r.get::<_, Option<String>>(6)?,
182                ))
183            },
184        )
185        .optional()?;
186    let Some((id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at)) = row
187    else {
188        return Ok(None);
189    };
190    if revoked_at.is_some() {
191        return Ok(None);
192    }
193    Ok(Some(Worker {
194        id,
195        name,
196        status,
197        current_pwd,
198        current_task_id,
199        last_poll_at,
200    }))
201}
202
203pub fn revoke(conn: &Connection, name: &str) -> Result<()> {
204    conn.execute(
205        "UPDATE workers SET revoked_at=?, token_hash=NULL WHERE name=?",
206        params![iso_now(), name],
207    )?;
208    Ok(())
209}
210
211pub fn update_status(
212    conn: &Connection,
213    name: &str,
214    status: &str,
215    current_pwd: Option<&str>,
216    current_task_id: Option<&str>,
217) -> Result<()> {
218    conn.execute(
219        "UPDATE workers SET status=?, current_pwd=COALESCE(?, current_pwd),
220         current_task_id=?, last_poll_at=? WHERE name=?",
221        params![status, current_pwd, current_task_id, iso_now(), name],
222    )?;
223    Ok(())
224}
225
226pub fn list_all(conn: &Connection) -> Result<Vec<Worker>> {
227    let mut stmt = conn.prepare(
228        "SELECT id, name, status, current_pwd, current_task_id, last_poll_at
229         FROM workers WHERE revoked_at IS NULL ORDER BY name ASC",
230    )?;
231    let rows = stmt
232        .query_map([], |r| {
233            Ok(Worker {
234                id: r.get(0)?,
235                name: r.get(1)?,
236                status: r.get(2)?,
237                current_pwd: r.get(3)?,
238                current_task_id: r.get(4)?,
239                last_poll_at: r.get(5)?,
240            })
241        })?
242        .collect::<rusqlite::Result<Vec<_>>>()?;
243    Ok(rows)
244}
245
246pub fn get_client_pubkey(conn: &Connection, name: &str) -> Result<Option<Vec<u8>>> {
247    let row: Option<Option<Vec<u8>>> = conn
248        .query_row(
249            "SELECT client_pubkey FROM workers WHERE name=? AND revoked_at IS NULL",
250            [name],
251            |r| r.get(0),
252        )
253        .optional()?;
254    Ok(row.flatten())
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    fn fresh() -> Connection {
262        let c = crate::db::connect_in_memory().unwrap();
263        crate::db::bootstrap(&c).unwrap();
264        c
265    }
266
267    #[test]
268    fn issue_then_finalize() {
269        let c = fresh();
270        let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
271        let (token, wid) = finalize_register(&c, "w1", &code, None).unwrap();
272        assert!(wid > 0);
273        let w = lookup_by_token(&c, &token).unwrap().unwrap();
274        assert_eq!(w.name, "w1");
275        // re-using the code must fail
276        let err = finalize_register(&c, "w1", &code, None).unwrap_err();
277        assert!(matches!(err, InvalidRegisterCode::NoPending));
278    }
279
280    #[test]
281    fn finalize_wrong_code_rejected() {
282        let c = fresh();
283        let _ = create_or_replace_register_code(&c, "w1", 600).unwrap();
284        let err = finalize_register(&c, "w1", "xxxxxxxxxxxxx", None).unwrap_err();
285        assert!(matches!(err, InvalidRegisterCode::BadCode));
286    }
287
288    #[test]
289    fn finalize_expired_rejected() {
290        let c = fresh();
291        let code = create_or_replace_register_code(&c, "w1", -1).unwrap();
292        let err = finalize_register(&c, "w1", &code, None).unwrap_err();
293        assert!(matches!(err, InvalidRegisterCode::Expired));
294    }
295
296    #[test]
297    fn revoke_invalidates_token() {
298        let c = fresh();
299        let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
300        let (token, _) = finalize_register(&c, "w1", &code, None).unwrap();
301        revoke(&c, "w1").unwrap();
302        assert!(lookup_by_token(&c, &token).unwrap().is_none());
303    }
304
305    #[test]
306    fn update_status_and_list() {
307        let c = fresh();
308        let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
309        let _ = finalize_register(&c, "w1", &code, None).unwrap();
310        update_status(&c, "w1", "IDLE", Some("/tmp"), None).unwrap();
311        let listed = list_all(&c).unwrap();
312        assert_eq!(listed.len(), 1);
313        assert_eq!(listed[0].status.as_deref(), Some("IDLE"));
314        assert_eq!(listed[0].current_pwd.as_deref(), Some("/tmp"));
315    }
316
317    #[test]
318    fn finalize_records_client_pubkey() {
319        let c = fresh();
320        let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
321        let blob = b"hello-pubkey";
322        let _ = finalize_register(&c, "w1", &code, Some(blob)).unwrap();
323        assert_eq!(
324            get_client_pubkey(&c, "w1").unwrap().as_deref(),
325            Some(&blob[..])
326        );
327    }
328}