1use alloc::collections::BTreeMap;
18use alloc::string::{String, ToString};
19use alloc::vec::Vec;
20
21const SALT_LEN: usize = 16;
22const HASH_LEN: usize = 32;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum Role {
26 Admin,
27 ReadWrite,
28 ReadOnly,
29}
30
31impl Role {
32 pub const fn as_str(self) -> &'static str {
33 match self {
34 Self::Admin => "admin",
35 Self::ReadWrite => "readwrite",
36 Self::ReadOnly => "readonly",
37 }
38 }
39
40 pub fn parse(s: &str) -> Option<Self> {
41 match s.to_ascii_lowercase().as_str() {
42 "admin" => Some(Self::Admin),
43 "readwrite" | "rw" => Some(Self::ReadWrite),
44 "readonly" | "ro" => Some(Self::ReadOnly),
45 _ => None,
46 }
47 }
48
49 pub const fn can_read(self) -> bool {
51 true
52 }
53
54 pub const fn can_write(self) -> bool {
56 matches!(self, Self::Admin | Self::ReadWrite)
57 }
58
59 pub const fn can_manage_users(self) -> bool {
61 matches!(self, Self::Admin)
62 }
63}
64
65#[derive(Debug, Clone)]
66pub struct UserRecord {
67 pub role: Role,
68 salt: [u8; SALT_LEN],
69 hash: [u8; HASH_LEN],
70 scram: Option<ScramSecrets>,
77}
78
79#[derive(Debug, Clone)]
84pub struct ScramSecrets {
85 pub iters: u32,
86 pub salt: [u8; SCRAM_SALT_LEN],
87 pub stored_key: [u8; HASH_LEN],
88 pub server_key: [u8; HASH_LEN],
89}
90
91pub const SCRAM_SALT_LEN: usize = 16;
92pub const SCRAM_DEFAULT_ITERS: u32 = 4096;
93
94impl UserRecord {
95 pub fn verify(&self, password: &str) -> bool {
96 let candidate = derive_hash(&self.salt, password);
97 constant_time_eq(&candidate, &self.hash)
98 }
99
100 pub const fn scram(&self) -> Option<&ScramSecrets> {
101 self.scram.as_ref()
102 }
103}
104
105#[derive(Debug, Clone, Default)]
106pub struct UserStore {
107 users: BTreeMap<String, UserRecord>,
108}
109
110#[derive(Debug, PartialEq, Eq)]
111pub enum UserError {
112 Exists,
113 NotFound,
114 InvalidRole,
115 EmptyName,
116 EmptyPassword,
117}
118
119impl core::fmt::Display for UserError {
120 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
121 match self {
122 Self::Exists => f.write_str("user already exists"),
123 Self::NotFound => f.write_str("user not found"),
124 Self::InvalidRole => {
125 f.write_str("invalid role (expected admin / readwrite / readonly)")
126 }
127 Self::EmptyName => f.write_str("username must be non-empty"),
128 Self::EmptyPassword => f.write_str("password must be non-empty"),
129 }
130 }
131}
132
133impl UserStore {
134 pub fn new() -> Self {
135 Self::default()
136 }
137
138 pub fn len(&self) -> usize {
139 self.users.len()
140 }
141
142 pub fn is_empty(&self) -> bool {
143 self.users.is_empty()
144 }
145
146 pub fn contains(&self, name: &str) -> bool {
147 self.users.contains_key(name)
148 }
149
150 pub fn iter(&self) -> impl Iterator<Item = (&str, &UserRecord)> {
153 self.users.iter().map(|(k, v)| (k.as_str(), v))
154 }
155
156 pub fn create(
157 &mut self,
158 name: &str,
159 password: &str,
160 role: Role,
161 salt: [u8; SALT_LEN],
162 ) -> Result<(), UserError> {
163 if name.is_empty() {
164 return Err(UserError::EmptyName);
165 }
166 if password.is_empty() {
167 return Err(UserError::EmptyPassword);
168 }
169 if self.users.contains_key(name) {
170 return Err(UserError::Exists);
171 }
172 let hash = derive_hash(&salt, password);
173 self.users.insert(
174 name.to_string(),
175 UserRecord {
176 role,
177 salt,
178 hash,
179 scram: None,
180 },
181 );
182 Ok(())
183 }
184
185 pub fn drop(&mut self, name: &str) -> Result<(), UserError> {
186 self.users
187 .remove(name)
188 .map(|_| ())
189 .ok_or(UserError::NotFound)
190 }
191
192 pub fn enable_scram(
198 &mut self,
199 name: &str,
200 password: &str,
201 salt: [u8; SCRAM_SALT_LEN],
202 iters: u32,
203 ) -> Result<(), UserError> {
204 let rec = self.users.get_mut(name).ok_or(UserError::NotFound)?;
205 rec.scram = Some(compute_scram_secrets(password, salt, iters));
206 Ok(())
207 }
208
209 pub fn verify(&self, name: &str, password: &str) -> Option<Role> {
210 let rec = self.users.get(name)?;
211 if rec.verify(password) {
212 Some(rec.role)
213 } else {
214 None
215 }
216 }
217}
218
219fn derive_hash(salt: &[u8; SALT_LEN], password: &str) -> [u8; HASH_LEN] {
220 let mut buf = Vec::with_capacity(SALT_LEN + password.len());
221 buf.extend_from_slice(salt);
222 buf.extend_from_slice(password.as_bytes());
223 spg_crypto::hash(&buf)
224}
225
226pub fn compute_scram_secrets(
237 password: &str,
238 salt: [u8; SCRAM_SALT_LEN],
239 iters: u32,
240) -> ScramSecrets {
241 let salted = spg_crypto::pbkdf2::pbkdf2_sha256_32(password.as_bytes(), &salt, iters);
242 let client_key = spg_crypto::hmac::hmac_sha256(&salted, b"Client Key");
243 let stored_key = spg_crypto::sha256::hash(&client_key);
244 let server_key = spg_crypto::hmac::hmac_sha256(&salted, b"Server Key");
245 ScramSecrets {
246 iters,
247 salt,
248 stored_key,
249 server_key,
250 }
251}
252
253fn constant_time_eq(a: &[u8; HASH_LEN], b: &[u8; HASH_LEN]) -> bool {
256 let mut diff: u8 = 0;
257 for i in 0..HASH_LEN {
258 diff |= a[i] ^ b[i];
259 }
260 diff == 0
261}
262
263const SCRAM_FORMAT_MARKER: u8 = 0xff;
288
289pub(crate) fn serialize_users(store: &UserStore) -> Vec<u8> {
290 let per_user_floor = 2 + 16 + 1 + SALT_LEN + HASH_LEN + 1;
291 let mut out = Vec::with_capacity(1 + 4 + store.len() * per_user_floor);
292 out.push(SCRAM_FORMAT_MARKER);
293 out.extend_from_slice(
294 &u32::try_from(store.users.len())
295 .expect("≤ 4G users")
296 .to_le_bytes(),
297 );
298 for (name, rec) in &store.users {
299 let nl = u16::try_from(name.len()).expect("≤ 65k name");
300 out.extend_from_slice(&nl.to_le_bytes());
301 out.extend_from_slice(name.as_bytes());
302 out.push(match rec.role {
303 Role::Admin => 0,
304 Role::ReadWrite => 1,
305 Role::ReadOnly => 2,
306 });
307 out.extend_from_slice(&rec.salt);
308 out.extend_from_slice(&rec.hash);
309 match &rec.scram {
310 None => out.push(0),
311 Some(s) => {
312 out.push(1);
313 out.extend_from_slice(&s.iters.to_le_bytes());
314 out.extend_from_slice(&s.salt);
315 out.extend_from_slice(&s.stored_key);
316 out.extend_from_slice(&s.server_key);
317 }
318 }
319 }
320 out
321}
322
323#[derive(Debug)]
324pub enum UserDeserializeError {
325 Truncated,
326 BadRole(u8),
327 InvalidUtf8,
328}
329
330impl core::fmt::Display for UserDeserializeError {
331 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
332 match self {
333 Self::Truncated => f.write_str("user blob truncated"),
334 Self::BadRole(b) => write!(f, "unknown role byte: {b}"),
335 Self::InvalidUtf8 => f.write_str("username not valid UTF-8"),
336 }
337 }
338}
339
340fn take<'a>(p: &mut usize, n: usize, buf: &'a [u8]) -> Result<&'a [u8], UserDeserializeError> {
341 if *p + n > buf.len() {
342 return Err(UserDeserializeError::Truncated);
343 }
344 let s = &buf[*p..*p + n];
345 *p += n;
346 Ok(s)
347}
348
349pub(crate) fn deserialize_users(buf: &[u8]) -> Result<UserStore, UserDeserializeError> {
350 let mut p = 0usize;
351 let scram_present_inline = if !buf.is_empty() && buf[0] == SCRAM_FORMAT_MARKER {
356 p += 1;
357 true
358 } else {
359 false
360 };
361 let count_bytes = take(&mut p, 4, buf)?;
362 let count = u32::from_le_bytes(count_bytes.try_into().unwrap()) as usize;
363 let mut store = UserStore::new();
364 for _ in 0..count {
365 let nl_bytes = take(&mut p, 2, buf)?;
366 let nl = u16::from_le_bytes(nl_bytes.try_into().unwrap()) as usize;
367 let name_bytes = take(&mut p, nl, buf)?;
368 let name = core::str::from_utf8(name_bytes)
369 .map_err(|_| UserDeserializeError::InvalidUtf8)?
370 .to_string();
371 let role_byte = take(&mut p, 1, buf)?[0];
372 let role = match role_byte {
373 0 => Role::Admin,
374 1 => Role::ReadWrite,
375 2 => Role::ReadOnly,
376 b => return Err(UserDeserializeError::BadRole(b)),
377 };
378 let mut salt = [0u8; SALT_LEN];
379 salt.copy_from_slice(take(&mut p, SALT_LEN, buf)?);
380 let mut hash = [0u8; HASH_LEN];
381 hash.copy_from_slice(take(&mut p, HASH_LEN, buf)?);
382 let scram = if scram_present_inline {
383 let flag = take(&mut p, 1, buf)?[0];
384 if flag == 1 {
385 let iters_bytes = take(&mut p, 4, buf)?;
386 let iters = u32::from_le_bytes(iters_bytes.try_into().unwrap());
387 let mut s_salt = [0u8; SCRAM_SALT_LEN];
388 s_salt.copy_from_slice(take(&mut p, SCRAM_SALT_LEN, buf)?);
389 let mut stored_key = [0u8; HASH_LEN];
390 stored_key.copy_from_slice(take(&mut p, HASH_LEN, buf)?);
391 let mut server_key = [0u8; HASH_LEN];
392 server_key.copy_from_slice(take(&mut p, HASH_LEN, buf)?);
393 Some(ScramSecrets {
394 iters,
395 salt: s_salt,
396 stored_key,
397 server_key,
398 })
399 } else {
400 None
401 }
402 } else {
403 None
404 };
405 store.users.insert(
406 name,
407 UserRecord {
408 role,
409 salt,
410 hash,
411 scram,
412 },
413 );
414 }
415 if p != buf.len() {
416 return Err(UserDeserializeError::Truncated);
417 }
418 Ok(store)
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn create_then_verify_succeeds_with_right_password_only() {
427 let mut s = UserStore::new();
428 s.create("alice", "hunter2", Role::Admin, [1; SALT_LEN])
429 .unwrap();
430 assert_eq!(s.verify("alice", "hunter2"), Some(Role::Admin));
431 assert_eq!(s.verify("alice", "wrong"), None);
432 assert_eq!(s.verify("bob", "hunter2"), None);
433 }
434
435 #[test]
436 fn create_duplicate_user_is_rejected() {
437 let mut s = UserStore::new();
438 s.create("a", "p", Role::ReadOnly, [0; SALT_LEN]).unwrap();
439 assert_eq!(
440 s.create("a", "p2", Role::Admin, [0; SALT_LEN]),
441 Err(UserError::Exists)
442 );
443 }
444
445 #[test]
446 fn drop_user_removes_them() {
447 let mut s = UserStore::new();
448 s.create("a", "p", Role::Admin, [0; SALT_LEN]).unwrap();
449 s.drop("a").unwrap();
450 assert!(s.is_empty());
451 assert_eq!(s.drop("a"), Err(UserError::NotFound));
452 }
453
454 #[test]
455 fn role_parse_accepts_aliases() {
456 assert_eq!(Role::parse("ADMIN"), Some(Role::Admin));
457 assert_eq!(Role::parse("rw"), Some(Role::ReadWrite));
458 assert_eq!(Role::parse("ro"), Some(Role::ReadOnly));
459 assert_eq!(Role::parse("god"), None);
460 }
461
462 #[test]
463 fn snapshot_round_trip_preserves_users_and_verify() {
464 let mut s = UserStore::new();
465 s.create("alice", "pw1", Role::Admin, [7; SALT_LEN])
466 .unwrap();
467 s.create("bob", "pw2", Role::ReadOnly, [13; SALT_LEN])
468 .unwrap();
469 let bytes = serialize_users(&s);
470 let s2 = deserialize_users(&bytes).unwrap();
471 assert_eq!(s2.len(), 2);
472 assert_eq!(s2.verify("alice", "pw1"), Some(Role::Admin));
473 assert_eq!(s2.verify("bob", "pw2"), Some(Role::ReadOnly));
474 assert_eq!(s2.verify("bob", "wrong"), None);
475 }
476
477 #[test]
478 fn empty_store_round_trip() {
479 let s = UserStore::new();
481 let bytes = serialize_users(&s);
482 assert_eq!(bytes, [0xff, 0, 0, 0, 0]);
483 let s2 = deserialize_users(&bytes).unwrap();
484 assert!(s2.is_empty());
485 }
486
487 #[test]
488 fn old_v1_user_blob_still_loads() {
489 let mut buf = Vec::new();
492 buf.extend_from_slice(&1u32.to_le_bytes());
493 buf.extend_from_slice(&3u16.to_le_bytes());
494 buf.extend_from_slice(b"bob");
495 buf.push(0); buf.extend_from_slice(&[7u8; SALT_LEN]);
497 buf.extend_from_slice(&[42u8; HASH_LEN]);
498 let s = deserialize_users(&buf).expect("v1 blob must still load");
499 assert_eq!(s.len(), 1);
500 let (n, rec) = s.iter().next().unwrap();
501 assert_eq!(n, "bob");
502 assert_eq!(rec.role, Role::Admin);
503 assert!(rec.scram().is_none(), "v1 users have no SCRAM secrets");
504 }
505
506 #[test]
507 fn scram_round_trip_preserves_iters_salt_keys() {
508 let mut s = UserStore::new();
509 s.create("alice", "pw", Role::Admin, [3; SALT_LEN]).unwrap();
510 s.enable_scram("alice", "pw", [9; SCRAM_SALT_LEN], 4096)
511 .unwrap();
512 let bytes = serialize_users(&s);
513 let s2 = deserialize_users(&bytes).unwrap();
514 let (_, rec) = s2.iter().next().unwrap();
515 let scram = rec.scram().expect("scram must round-trip");
516 assert_eq!(scram.iters, 4096);
517 assert_eq!(scram.salt, [9u8; SCRAM_SALT_LEN]);
518 let expected = compute_scram_secrets("pw", [9; SCRAM_SALT_LEN], 4096);
521 assert_eq!(scram.stored_key, expected.stored_key);
522 assert_eq!(scram.server_key, expected.server_key);
523 }
524
525 #[test]
526 fn deserialize_truncation_is_caught() {
527 assert!(deserialize_users(&[]).is_err());
528 assert!(deserialize_users(&[0, 0, 0]).is_err());
529 }
530}