1use anyhow::Result;
5use chrono::{Duration, Utc};
6use rand::{rngs::OsRng, RngCore};
7use rusqlite::{params, Connection, OptionalExtension};
8use sha2::{Digest, Sha256};
9use thiserror::Error;
10
11#[derive(Debug, Clone)]
12pub struct Worker {
13 pub id: i64,
14 pub name: String,
15 pub status: Option<String>,
16 pub current_pwd: Option<String>,
17 pub current_task_id: Option<String>,
18 pub last_poll_at: Option<String>,
19}
20
21#[derive(Debug, Error)]
22pub enum InvalidRegisterCode {
23 #[error("unknown worker")]
24 UnknownWorker,
25 #[error("no pending registration")]
26 NoPending,
27 #[error("bad register code")]
28 BadCode,
29 #[error("register code expired")]
30 Expired,
31}
32
33fn iso_now() -> String {
34 Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
35}
36
37fn iso_at(dt: chrono::DateTime<Utc>) -> String {
38 dt.to_rfc3339_opts(chrono::SecondsFormat::Micros, true)
39}
40
41fn token_urlsafe(byte_len: usize) -> String {
42 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
43 let mut bytes = vec![0u8; byte_len];
44 OsRng.fill_bytes(&mut bytes);
45 URL_SAFE_NO_PAD.encode(&bytes)
46}
47
48fn sha256_hex(s: &str) -> String {
49 hex::encode(Sha256::digest(s.as_bytes()))
50}
51
52pub fn create_or_replace_register_code(
55 conn: &Connection,
56 name: &str,
57 ttl_sec: i64,
58) -> Result<String> {
59 let code = token_urlsafe(16);
60 let h = sha256_hex(&code);
61 let expires = iso_at(Utc::now() + Duration::seconds(ttl_sec));
62 let exists: Option<i64> = conn
63 .query_row("SELECT id FROM workers WHERE name=?", [name], |r| r.get(0))
64 .optional()?;
65 if exists.is_none() {
66 conn.execute(
67 "INSERT INTO workers (name, register_code_hash, register_code_expires, created_at)
68 VALUES (?,?,?,?)",
69 params![name, h, expires, iso_now()],
70 )?;
71 } else {
72 conn.execute(
73 "UPDATE workers SET register_code_hash=?, register_code_expires=?,
74 token_hash=NULL, revoked_at=NULL WHERE name=?",
75 params![h, expires, name],
76 )?;
77 }
78 Ok(code)
79}
80
81pub fn finalize_register(
84 conn: &Connection,
85 name: &str,
86 register_code: &str,
87 client_pubkey: Option<&[u8]>,
88) -> std::result::Result<(String, i64), InvalidRegisterCode> {
89 let row: Option<(i64, Option<String>, Option<String>, Option<String>)> = conn
90 .query_row(
91 "SELECT id, register_code_hash, register_code_expires, token_hash
92 FROM workers WHERE name=?",
93 [name],
94 |r| Ok((r.get(0)?, r.get(1)?, r.get(2)?, r.get(3)?)),
95 )
96 .optional()
97 .map_err(|_| InvalidRegisterCode::UnknownWorker)?;
98
99 let Some((wid, register_code_hash, register_code_expires, _token_hash)) = row else {
100 return Err(InvalidRegisterCode::UnknownWorker);
101 };
102
103 let Some(stored_hash) = register_code_hash else {
104 return Err(InvalidRegisterCode::NoPending);
105 };
106 if stored_hash != sha256_hex(register_code) {
107 return Err(InvalidRegisterCode::BadCode);
108 }
109 let expires = register_code_expires.unwrap_or_else(|| iso_now());
110 let Ok(parsed) = chrono::DateTime::parse_from_rfc3339(
111 &expires.replace('Z', "+00:00"),
112 ) else {
113 return Err(InvalidRegisterCode::Expired);
114 };
115 if parsed.with_timezone(&Utc) <= Utc::now() {
116 return Err(InvalidRegisterCode::Expired);
117 }
118
119 let token = token_urlsafe(32);
120 let h = sha256_hex(&token);
121 conn.execute(
122 "UPDATE workers SET token_hash=?, register_code_hash=NULL,
123 register_code_expires=NULL, client_pubkey=? WHERE id=?",
124 params![h, client_pubkey, wid],
125 )
126 .map_err(|_| InvalidRegisterCode::UnknownWorker)?;
127 Ok((token, wid))
128}
129
130pub fn lookup_by_token(conn: &Connection, token: &str) -> Result<Option<Worker>> {
131 let h = sha256_hex(token);
132 let row = conn
133 .query_row(
134 "SELECT id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at
135 FROM workers WHERE token_hash=?",
136 [h],
137 |r| {
138 Ok((
139 r.get::<_, i64>(0)?,
140 r.get::<_, String>(1)?,
141 r.get::<_, Option<String>>(2)?,
142 r.get::<_, Option<String>>(3)?,
143 r.get::<_, Option<String>>(4)?,
144 r.get::<_, Option<String>>(5)?,
145 r.get::<_, Option<String>>(6)?,
146 ))
147 },
148 )
149 .optional()?;
150 let Some((id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at)) = row
151 else {
152 return Ok(None);
153 };
154 if revoked_at.is_some() {
155 return Ok(None);
156 }
157 Ok(Some(Worker {
158 id,
159 name,
160 status,
161 current_pwd,
162 current_task_id,
163 last_poll_at,
164 }))
165}
166
167pub fn lookup_by_name(conn: &Connection, name: &str) -> Result<Option<Worker>> {
168 let row = conn
169 .query_row(
170 "SELECT id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at
171 FROM workers WHERE name=?",
172 [name],
173 |r| {
174 Ok((
175 r.get::<_, i64>(0)?,
176 r.get::<_, String>(1)?,
177 r.get::<_, Option<String>>(2)?,
178 r.get::<_, Option<String>>(3)?,
179 r.get::<_, Option<String>>(4)?,
180 r.get::<_, Option<String>>(5)?,
181 r.get::<_, Option<String>>(6)?,
182 ))
183 },
184 )
185 .optional()?;
186 let Some((id, name, status, current_pwd, current_task_id, last_poll_at, revoked_at)) = row
187 else {
188 return Ok(None);
189 };
190 if revoked_at.is_some() {
191 return Ok(None);
192 }
193 Ok(Some(Worker {
194 id,
195 name,
196 status,
197 current_pwd,
198 current_task_id,
199 last_poll_at,
200 }))
201}
202
203pub fn revoke(conn: &Connection, name: &str) -> Result<()> {
204 conn.execute(
205 "UPDATE workers SET revoked_at=?, token_hash=NULL WHERE name=?",
206 params![iso_now(), name],
207 )?;
208 Ok(())
209}
210
211pub fn update_status(
212 conn: &Connection,
213 name: &str,
214 status: &str,
215 current_pwd: Option<&str>,
216 current_task_id: Option<&str>,
217) -> Result<()> {
218 conn.execute(
219 "UPDATE workers SET status=?, current_pwd=COALESCE(?, current_pwd),
220 current_task_id=?, last_poll_at=? WHERE name=?",
221 params![status, current_pwd, current_task_id, iso_now(), name],
222 )?;
223 Ok(())
224}
225
226pub fn list_all(conn: &Connection) -> Result<Vec<Worker>> {
227 let mut stmt = conn.prepare(
228 "SELECT id, name, status, current_pwd, current_task_id, last_poll_at
229 FROM workers WHERE revoked_at IS NULL ORDER BY name ASC",
230 )?;
231 let rows = stmt
232 .query_map([], |r| {
233 Ok(Worker {
234 id: r.get(0)?,
235 name: r.get(1)?,
236 status: r.get(2)?,
237 current_pwd: r.get(3)?,
238 current_task_id: r.get(4)?,
239 last_poll_at: r.get(5)?,
240 })
241 })?
242 .collect::<rusqlite::Result<Vec<_>>>()?;
243 Ok(rows)
244}
245
246pub fn get_client_pubkey(conn: &Connection, name: &str) -> Result<Option<Vec<u8>>> {
247 let row: Option<Option<Vec<u8>>> = conn
248 .query_row(
249 "SELECT client_pubkey FROM workers WHERE name=? AND revoked_at IS NULL",
250 [name],
251 |r| r.get(0),
252 )
253 .optional()?;
254 Ok(row.flatten())
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 fn fresh() -> Connection {
262 let c = crate::db::connect_in_memory().unwrap();
263 crate::db::bootstrap(&c).unwrap();
264 c
265 }
266
267 #[test]
268 fn issue_then_finalize() {
269 let c = fresh();
270 let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
271 let (token, wid) = finalize_register(&c, "w1", &code, None).unwrap();
272 assert!(wid > 0);
273 let w = lookup_by_token(&c, &token).unwrap().unwrap();
274 assert_eq!(w.name, "w1");
275 let err = finalize_register(&c, "w1", &code, None).unwrap_err();
277 assert!(matches!(err, InvalidRegisterCode::NoPending));
278 }
279
280 #[test]
281 fn finalize_wrong_code_rejected() {
282 let c = fresh();
283 let _ = create_or_replace_register_code(&c, "w1", 600).unwrap();
284 let err = finalize_register(&c, "w1", "xxxxxxxxxxxxx", None).unwrap_err();
285 assert!(matches!(err, InvalidRegisterCode::BadCode));
286 }
287
288 #[test]
289 fn finalize_expired_rejected() {
290 let c = fresh();
291 let code = create_or_replace_register_code(&c, "w1", -1).unwrap();
292 let err = finalize_register(&c, "w1", &code, None).unwrap_err();
293 assert!(matches!(err, InvalidRegisterCode::Expired));
294 }
295
296 #[test]
297 fn revoke_invalidates_token() {
298 let c = fresh();
299 let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
300 let (token, _) = finalize_register(&c, "w1", &code, None).unwrap();
301 revoke(&c, "w1").unwrap();
302 assert!(lookup_by_token(&c, &token).unwrap().is_none());
303 }
304
305 #[test]
306 fn update_status_and_list() {
307 let c = fresh();
308 let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
309 let _ = finalize_register(&c, "w1", &code, None).unwrap();
310 update_status(&c, "w1", "IDLE", Some("/tmp"), None).unwrap();
311 let listed = list_all(&c).unwrap();
312 assert_eq!(listed.len(), 1);
313 assert_eq!(listed[0].status.as_deref(), Some("IDLE"));
314 assert_eq!(listed[0].current_pwd.as_deref(), Some("/tmp"));
315 }
316
317 #[test]
318 fn finalize_records_client_pubkey() {
319 let c = fresh();
320 let code = create_or_replace_register_code(&c, "w1", 600).unwrap();
321 let blob = b"hello-pubkey";
322 let _ = finalize_register(&c, "w1", &code, Some(blob)).unwrap();
323 assert_eq!(
324 get_client_pubkey(&c, "w1").unwrap().as_deref(),
325 Some(&blob[..])
326 );
327 }
328}