Skip to main content

spg_engine/
users.rs

1//! User table + RBAC types for v4.1.
2//!
3//! Three roles, narrow on purpose:
4//!
5//! - `Admin` — full read+write + can manage other users
6//! - `ReadWrite` — full read+write, no user-mgmt
7//! - `ReadOnly` — SELECT / SHOW only
8//!
9//! Passwords stored as BLAKE3(salt || password) — the salt is a
10//! random 16-byte value per user, kept inline with the record so we
11//! never need to hash twice. The hash is not designed to resist a
12//! determined offline attack on the snapshot file (that's what file
13//! perms are for in the docker-compose deployment shape); it's
14//! enough that the snapshot itself doesn't leak plaintext, and that
15//! an in-memory dump can't trivially reverse a typed password.
16
17use alloc::collections::BTreeMap;
18use alloc::string::{String, ToString};
19use alloc::vec::Vec;
20
21const SALT_LEN: usize = 16;
22const HASH_LEN: usize = 32;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum Role {
26    Admin,
27    ReadWrite,
28    ReadOnly,
29}
30
31impl Role {
32    pub const fn as_str(self) -> &'static str {
33        match self {
34            Self::Admin => "admin",
35            Self::ReadWrite => "readwrite",
36            Self::ReadOnly => "readonly",
37        }
38    }
39
40    pub fn parse(s: &str) -> Option<Self> {
41        match s.to_ascii_lowercase().as_str() {
42            "admin" => Some(Self::Admin),
43            "readwrite" | "rw" => Some(Self::ReadWrite),
44            "readonly" | "ro" => Some(Self::ReadOnly),
45            _ => None,
46        }
47    }
48
49    /// Read access — every role qualifies.
50    pub const fn can_read(self) -> bool {
51        true
52    }
53
54    /// Write access (INSERT / DDL on user tables).
55    pub const fn can_write(self) -> bool {
56        matches!(self, Self::Admin | Self::ReadWrite)
57    }
58
59    /// User-management DDL (`CREATE USER`, `DROP USER`).
60    pub const fn can_manage_users(self) -> bool {
61        matches!(self, Self::Admin)
62    }
63}
64
65#[derive(Debug, Clone)]
66pub struct UserRecord {
67    pub role: Role,
68    salt: [u8; SALT_LEN],
69    hash: [u8; HASH_LEN],
70    /// v4.8: SCRAM-SHA-256 verifier. Computed alongside the
71    /// BLAKE3 hash at user creation so PG-wire SASL auth can
72    /// verify without re-running PBKDF2 per attempt. `None`
73    /// means the user predates v4.8 (loaded from an older
74    /// snapshot); the PG-wire layer falls back to
75    /// `CleartextPassword` for those users.
76    scram: Option<ScramSecrets>,
77}
78
79/// SCRAM-SHA-256 stored credentials per RFC 5802 §5.
80/// `salt` and `iters` are sent to the client in server-first;
81/// `stored_key` and `server_key` are kept secret and used in the
82/// final-message verification.
83#[derive(Debug, Clone)]
84pub struct ScramSecrets {
85    pub iters: u32,
86    pub salt: [u8; SCRAM_SALT_LEN],
87    pub stored_key: [u8; HASH_LEN],
88    pub server_key: [u8; HASH_LEN],
89}
90
91pub const SCRAM_SALT_LEN: usize = 16;
92pub const SCRAM_DEFAULT_ITERS: u32 = 4096;
93
94impl UserRecord {
95    pub fn verify(&self, password: &str) -> bool {
96        let candidate = derive_hash(&self.salt, password);
97        constant_time_eq(&candidate, &self.hash)
98    }
99
100    pub const fn scram(&self) -> Option<&ScramSecrets> {
101        self.scram.as_ref()
102    }
103}
104
105#[derive(Debug, Clone, Default)]
106pub struct UserStore {
107    users: BTreeMap<String, UserRecord>,
108}
109
110#[derive(Debug, PartialEq, Eq)]
111pub enum UserError {
112    Exists,
113    NotFound,
114    InvalidRole,
115    EmptyName,
116    EmptyPassword,
117}
118
119impl core::fmt::Display for UserError {
120    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
121        match self {
122            Self::Exists => f.write_str("user already exists"),
123            Self::NotFound => f.write_str("user not found"),
124            Self::InvalidRole => {
125                f.write_str("invalid role (expected admin / readwrite / readonly)")
126            }
127            Self::EmptyName => f.write_str("username must be non-empty"),
128            Self::EmptyPassword => f.write_str("password must be non-empty"),
129        }
130    }
131}
132
133impl UserStore {
134    pub fn new() -> Self {
135        Self::default()
136    }
137
138    pub fn len(&self) -> usize {
139        self.users.len()
140    }
141
142    pub fn is_empty(&self) -> bool {
143        self.users.is_empty()
144    }
145
146    pub fn contains(&self, name: &str) -> bool {
147        self.users.contains_key(name)
148    }
149
150    /// Stable iteration in name order — used by SHOW USERS and the
151    /// snapshot writer.
152    pub fn iter(&self) -> impl Iterator<Item = (&str, &UserRecord)> {
153        self.users.iter().map(|(k, v)| (k.as_str(), v))
154    }
155
156    pub fn create(
157        &mut self,
158        name: &str,
159        password: &str,
160        role: Role,
161        salt: [u8; SALT_LEN],
162    ) -> Result<(), UserError> {
163        if name.is_empty() {
164            return Err(UserError::EmptyName);
165        }
166        if password.is_empty() {
167            return Err(UserError::EmptyPassword);
168        }
169        if self.users.contains_key(name) {
170            return Err(UserError::Exists);
171        }
172        let hash = derive_hash(&salt, password);
173        self.users.insert(
174            name.to_string(),
175            UserRecord {
176                role,
177                salt,
178                hash,
179                scram: None,
180            },
181        );
182        Ok(())
183    }
184
185    pub fn drop(&mut self, name: &str) -> Result<(), UserError> {
186        self.users
187            .remove(name)
188            .map(|_| ())
189            .ok_or(UserError::NotFound)
190    }
191
192    /// v4.8: attach SCRAM-SHA-256 verifier to an existing user.
193    /// Called by the engine right after `create` so new users have
194    /// both auth paths (legacy BLAKE3 + SCRAM) available. The salt
195    /// here is independent of the BLAKE3 hash salt — they serve
196    /// different purposes.
197    pub fn enable_scram(
198        &mut self,
199        name: &str,
200        password: &str,
201        salt: [u8; SCRAM_SALT_LEN],
202        iters: u32,
203    ) -> Result<(), UserError> {
204        let rec = self.users.get_mut(name).ok_or(UserError::NotFound)?;
205        rec.scram = Some(compute_scram_secrets(password, salt, iters));
206        Ok(())
207    }
208
209    pub fn verify(&self, name: &str, password: &str) -> Option<Role> {
210        let rec = self.users.get(name)?;
211        if rec.verify(password) {
212            Some(rec.role)
213        } else {
214            None
215        }
216    }
217}
218
219fn derive_hash(salt: &[u8; SALT_LEN], password: &str) -> [u8; HASH_LEN] {
220    let mut buf = Vec::with_capacity(SALT_LEN + password.len());
221    buf.extend_from_slice(salt);
222    buf.extend_from_slice(password.as_bytes());
223    spg_crypto::hash(&buf)
224}
225
226/// v4.8: derive SCRAM-SHA-256 stored credentials per RFC 5802 §3.
227///
228/// `SaltedPassword` = `PBKDF2(password, salt, iters)`
229/// `ClientKey`      = `HMAC(SaltedPassword, "Client Key")`
230/// `StoredKey`      = `SHA-256(ClientKey)`
231/// `ServerKey`      = `HMAC(SaltedPassword, "Server Key")`
232///
233/// PG-wire keeps the `StoredKey` + `ServerKey` on disk; verifying a
234/// client SCRAM proof needs only the `StoredKey` (no plaintext
235/// password ever stored).
236pub fn compute_scram_secrets(
237    password: &str,
238    salt: [u8; SCRAM_SALT_LEN],
239    iters: u32,
240) -> ScramSecrets {
241    let salted = spg_crypto::pbkdf2::pbkdf2_sha256_32(password.as_bytes(), &salt, iters);
242    let client_key = spg_crypto::hmac::hmac_sha256(&salted, b"Client Key");
243    let stored_key = spg_crypto::sha256::hash(&client_key);
244    let server_key = spg_crypto::hmac::hmac_sha256(&salted, b"Server Key");
245    ScramSecrets {
246        iters,
247        salt,
248        stored_key,
249        server_key,
250    }
251}
252
253/// Branch-free byte compare so verify timing doesn't leak whether
254/// a prefix matched.
255fn constant_time_eq(a: &[u8; HASH_LEN], b: &[u8; HASH_LEN]) -> bool {
256    let mut diff: u8 = 0;
257    for i in 0..HASH_LEN {
258        diff |= a[i] ^ b[i];
259    }
260    diff == 0
261}
262
263// ---- snapshot encoding ----
264//
265// Layout (after a magic + version envelope handled at Engine level):
266//
267// v1 (v4.1.0 — original):
268//   [u32 user_count]
269//   for each user:
270//     [u16 name_len][name][u8 role][16 salt][32 hash]
271//
272// v2 (v4.8.0 — adds SCRAM):
273//   [u8 format_version = 2]    // distinguishes from v1 (where the
274//                                 first byte is the LO of user_count
275//                                 u32, never 0xff)
276//   [u32 user_count]
277//   for each user:
278//     [u16 name_len][name][u8 role][16 salt][32 hash]
279//     [u8 scram_present]       // 0 or 1
280//     if scram_present:
281//       [u32 iters][16 scram_salt][32 stored_key][32 server_key]
282//
283// We use byte 0xff as the v2 marker — v1 would have to have ≥
284// 4 billion users for its first byte to be 0xff, so the version
285// switch is unambiguous.
286
287const SCRAM_FORMAT_MARKER: u8 = 0xff;
288
289pub(crate) fn serialize_users(store: &UserStore) -> Vec<u8> {
290    let per_user_floor = 2 + 16 + 1 + SALT_LEN + HASH_LEN + 1;
291    let mut out = Vec::with_capacity(1 + 4 + store.len() * per_user_floor);
292    out.push(SCRAM_FORMAT_MARKER);
293    out.extend_from_slice(
294        &u32::try_from(store.users.len())
295            .expect("≤ 4G users")
296            .to_le_bytes(),
297    );
298    for (name, rec) in &store.users {
299        let nl = u16::try_from(name.len()).expect("≤ 65k name");
300        out.extend_from_slice(&nl.to_le_bytes());
301        out.extend_from_slice(name.as_bytes());
302        out.push(match rec.role {
303            Role::Admin => 0,
304            Role::ReadWrite => 1,
305            Role::ReadOnly => 2,
306        });
307        out.extend_from_slice(&rec.salt);
308        out.extend_from_slice(&rec.hash);
309        match &rec.scram {
310            None => out.push(0),
311            Some(s) => {
312                out.push(1);
313                out.extend_from_slice(&s.iters.to_le_bytes());
314                out.extend_from_slice(&s.salt);
315                out.extend_from_slice(&s.stored_key);
316                out.extend_from_slice(&s.server_key);
317            }
318        }
319    }
320    out
321}
322
323#[derive(Debug)]
324pub enum UserDeserializeError {
325    Truncated,
326    BadRole(u8),
327    InvalidUtf8,
328}
329
330impl core::fmt::Display for UserDeserializeError {
331    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
332        match self {
333            Self::Truncated => f.write_str("user blob truncated"),
334            Self::BadRole(b) => write!(f, "unknown role byte: {b}"),
335            Self::InvalidUtf8 => f.write_str("username not valid UTF-8"),
336        }
337    }
338}
339
340fn take<'a>(p: &mut usize, n: usize, buf: &'a [u8]) -> Result<&'a [u8], UserDeserializeError> {
341    if *p + n > buf.len() {
342        return Err(UserDeserializeError::Truncated);
343    }
344    let s = &buf[*p..*p + n];
345    *p += n;
346    Ok(s)
347}
348
349pub(crate) fn deserialize_users(buf: &[u8]) -> Result<UserStore, UserDeserializeError> {
350    let mut p = 0usize;
351    // v1 → starts with a u32 user_count (LO byte rarely 0xff in
352    // practice). v2 → starts with the 0xff marker, then u32 count.
353    // Probing the first byte before advancing keeps the v1 path
354    // intact for old snapshots.
355    let scram_present_inline = if !buf.is_empty() && buf[0] == SCRAM_FORMAT_MARKER {
356        p += 1;
357        true
358    } else {
359        false
360    };
361    let count_bytes = take(&mut p, 4, buf)?;
362    let count = u32::from_le_bytes(count_bytes.try_into().unwrap()) as usize;
363    let mut store = UserStore::new();
364    for _ in 0..count {
365        let nl_bytes = take(&mut p, 2, buf)?;
366        let nl = u16::from_le_bytes(nl_bytes.try_into().unwrap()) as usize;
367        let name_bytes = take(&mut p, nl, buf)?;
368        let name = core::str::from_utf8(name_bytes)
369            .map_err(|_| UserDeserializeError::InvalidUtf8)?
370            .to_string();
371        let role_byte = take(&mut p, 1, buf)?[0];
372        let role = match role_byte {
373            0 => Role::Admin,
374            1 => Role::ReadWrite,
375            2 => Role::ReadOnly,
376            b => return Err(UserDeserializeError::BadRole(b)),
377        };
378        let mut salt = [0u8; SALT_LEN];
379        salt.copy_from_slice(take(&mut p, SALT_LEN, buf)?);
380        let mut hash = [0u8; HASH_LEN];
381        hash.copy_from_slice(take(&mut p, HASH_LEN, buf)?);
382        let scram = if scram_present_inline {
383            let flag = take(&mut p, 1, buf)?[0];
384            if flag == 1 {
385                let iters_bytes = take(&mut p, 4, buf)?;
386                let iters = u32::from_le_bytes(iters_bytes.try_into().unwrap());
387                let mut s_salt = [0u8; SCRAM_SALT_LEN];
388                s_salt.copy_from_slice(take(&mut p, SCRAM_SALT_LEN, buf)?);
389                let mut stored_key = [0u8; HASH_LEN];
390                stored_key.copy_from_slice(take(&mut p, HASH_LEN, buf)?);
391                let mut server_key = [0u8; HASH_LEN];
392                server_key.copy_from_slice(take(&mut p, HASH_LEN, buf)?);
393                Some(ScramSecrets {
394                    iters,
395                    salt: s_salt,
396                    stored_key,
397                    server_key,
398                })
399            } else {
400                None
401            }
402        } else {
403            None
404        };
405        store.users.insert(
406            name,
407            UserRecord {
408                role,
409                salt,
410                hash,
411                scram,
412            },
413        );
414    }
415    if p != buf.len() {
416        return Err(UserDeserializeError::Truncated);
417    }
418    Ok(store)
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn create_then_verify_succeeds_with_right_password_only() {
427        let mut s = UserStore::new();
428        s.create("alice", "hunter2", Role::Admin, [1; SALT_LEN])
429            .unwrap();
430        assert_eq!(s.verify("alice", "hunter2"), Some(Role::Admin));
431        assert_eq!(s.verify("alice", "wrong"), None);
432        assert_eq!(s.verify("bob", "hunter2"), None);
433    }
434
435    #[test]
436    fn create_duplicate_user_is_rejected() {
437        let mut s = UserStore::new();
438        s.create("a", "p", Role::ReadOnly, [0; SALT_LEN]).unwrap();
439        assert_eq!(
440            s.create("a", "p2", Role::Admin, [0; SALT_LEN]),
441            Err(UserError::Exists)
442        );
443    }
444
445    #[test]
446    fn drop_user_removes_them() {
447        let mut s = UserStore::new();
448        s.create("a", "p", Role::Admin, [0; SALT_LEN]).unwrap();
449        s.drop("a").unwrap();
450        assert!(s.is_empty());
451        assert_eq!(s.drop("a"), Err(UserError::NotFound));
452    }
453
454    #[test]
455    fn role_parse_accepts_aliases() {
456        assert_eq!(Role::parse("ADMIN"), Some(Role::Admin));
457        assert_eq!(Role::parse("rw"), Some(Role::ReadWrite));
458        assert_eq!(Role::parse("ro"), Some(Role::ReadOnly));
459        assert_eq!(Role::parse("god"), None);
460    }
461
462    #[test]
463    fn snapshot_round_trip_preserves_users_and_verify() {
464        let mut s = UserStore::new();
465        s.create("alice", "pw1", Role::Admin, [7; SALT_LEN])
466            .unwrap();
467        s.create("bob", "pw2", Role::ReadOnly, [13; SALT_LEN])
468            .unwrap();
469        let bytes = serialize_users(&s);
470        let s2 = deserialize_users(&bytes).unwrap();
471        assert_eq!(s2.len(), 2);
472        assert_eq!(s2.verify("alice", "pw1"), Some(Role::Admin));
473        assert_eq!(s2.verify("bob", "pw2"), Some(Role::ReadOnly));
474        assert_eq!(s2.verify("bob", "wrong"), None);
475    }
476
477    #[test]
478    fn empty_store_round_trip() {
479        // v4.8: format prefix byte (0xff) + zero u32 count.
480        let s = UserStore::new();
481        let bytes = serialize_users(&s);
482        assert_eq!(bytes, [0xff, 0, 0, 0, 0]);
483        let s2 = deserialize_users(&bytes).unwrap();
484        assert!(s2.is_empty());
485    }
486
487    #[test]
488    fn old_v1_user_blob_still_loads() {
489        // Hand-constructed v1 blob: 1 user, no SCRAM byte.
490        // [u32 count=1][u16 name_len=3]["bob"][u8 role=0][16 salt][32 hash]
491        let mut buf = Vec::new();
492        buf.extend_from_slice(&1u32.to_le_bytes());
493        buf.extend_from_slice(&3u16.to_le_bytes());
494        buf.extend_from_slice(b"bob");
495        buf.push(0); // role = admin
496        buf.extend_from_slice(&[7u8; SALT_LEN]);
497        buf.extend_from_slice(&[42u8; HASH_LEN]);
498        let s = deserialize_users(&buf).expect("v1 blob must still load");
499        assert_eq!(s.len(), 1);
500        let (n, rec) = s.iter().next().unwrap();
501        assert_eq!(n, "bob");
502        assert_eq!(rec.role, Role::Admin);
503        assert!(rec.scram().is_none(), "v1 users have no SCRAM secrets");
504    }
505
506    #[test]
507    fn scram_round_trip_preserves_iters_salt_keys() {
508        let mut s = UserStore::new();
509        s.create("alice", "pw", Role::Admin, [3; SALT_LEN]).unwrap();
510        s.enable_scram("alice", "pw", [9; SCRAM_SALT_LEN], 4096)
511            .unwrap();
512        let bytes = serialize_users(&s);
513        let s2 = deserialize_users(&bytes).unwrap();
514        let (_, rec) = s2.iter().next().unwrap();
515        let scram = rec.scram().expect("scram must round-trip");
516        assert_eq!(scram.iters, 4096);
517        assert_eq!(scram.salt, [9u8; SCRAM_SALT_LEN]);
518        // StoredKey and ServerKey are deterministic given (password,
519        // salt, iters); verify by recomputing.
520        let expected = compute_scram_secrets("pw", [9; SCRAM_SALT_LEN], 4096);
521        assert_eq!(scram.stored_key, expected.stored_key);
522        assert_eq!(scram.server_key, expected.server_key);
523    }
524
525    #[test]
526    fn deserialize_truncation_is_caught() {
527        assert!(deserialize_users(&[]).is_err());
528        assert!(deserialize_users(&[0, 0, 0]).is_err());
529    }
530}