use alloc::collections::BTreeMap;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
const SALT_LEN: usize = 16;
const HASH_LEN: usize = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
Admin,
ReadWrite,
ReadOnly,
}
impl Role {
pub const fn as_str(self) -> &'static str {
match self {
Self::Admin => "admin",
Self::ReadWrite => "readwrite",
Self::ReadOnly => "readonly",
}
}
pub fn parse(s: &str) -> Option<Self> {
match s.to_ascii_lowercase().as_str() {
"admin" => Some(Self::Admin),
"readwrite" | "rw" => Some(Self::ReadWrite),
"readonly" | "ro" => Some(Self::ReadOnly),
_ => None,
}
}
pub const fn can_read(self) -> bool {
true
}
pub const fn can_write(self) -> bool {
matches!(self, Self::Admin | Self::ReadWrite)
}
pub const fn can_manage_users(self) -> bool {
matches!(self, Self::Admin)
}
}
#[derive(Debug, Clone)]
pub struct UserRecord {
pub role: Role,
salt: [u8; SALT_LEN],
hash: [u8; HASH_LEN],
scram: Option<ScramSecrets>,
}
#[derive(Debug, Clone)]
pub struct ScramSecrets {
pub iters: u32,
pub salt: [u8; SCRAM_SALT_LEN],
pub stored_key: [u8; HASH_LEN],
pub server_key: [u8; HASH_LEN],
}
pub const SCRAM_SALT_LEN: usize = 16;
pub const SCRAM_DEFAULT_ITERS: u32 = 4096;
impl UserRecord {
pub fn verify(&self, password: &str) -> bool {
let candidate = derive_hash(&self.salt, password);
constant_time_eq(&candidate, &self.hash)
}
pub const fn scram(&self) -> Option<&ScramSecrets> {
self.scram.as_ref()
}
}
#[derive(Debug, Clone, Default)]
pub struct UserStore {
users: BTreeMap<String, UserRecord>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum UserError {
Exists,
NotFound,
InvalidRole,
EmptyName,
EmptyPassword,
}
impl core::fmt::Display for UserError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Exists => f.write_str("user already exists"),
Self::NotFound => f.write_str("user not found"),
Self::InvalidRole => {
f.write_str("invalid role (expected admin / readwrite / readonly)")
}
Self::EmptyName => f.write_str("username must be non-empty"),
Self::EmptyPassword => f.write_str("password must be non-empty"),
}
}
}
impl UserStore {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.users.len()
}
pub fn is_empty(&self) -> bool {
self.users.is_empty()
}
pub fn contains(&self, name: &str) -> bool {
self.users.contains_key(name)
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &UserRecord)> {
self.users.iter().map(|(k, v)| (k.as_str(), v))
}
pub fn create(
&mut self,
name: &str,
password: &str,
role: Role,
salt: [u8; SALT_LEN],
) -> Result<(), UserError> {
if name.is_empty() {
return Err(UserError::EmptyName);
}
if password.is_empty() {
return Err(UserError::EmptyPassword);
}
if self.users.contains_key(name) {
return Err(UserError::Exists);
}
let hash = derive_hash(&salt, password);
self.users.insert(
name.to_string(),
UserRecord {
role,
salt,
hash,
scram: None,
},
);
Ok(())
}
pub fn drop(&mut self, name: &str) -> Result<(), UserError> {
self.users
.remove(name)
.map(|_| ())
.ok_or(UserError::NotFound)
}
pub fn enable_scram(
&mut self,
name: &str,
password: &str,
salt: [u8; SCRAM_SALT_LEN],
iters: u32,
) -> Result<(), UserError> {
let rec = self.users.get_mut(name).ok_or(UserError::NotFound)?;
rec.scram = Some(compute_scram_secrets(password, salt, iters));
Ok(())
}
pub fn verify(&self, name: &str, password: &str) -> Option<Role> {
let rec = self.users.get(name)?;
if rec.verify(password) {
Some(rec.role)
} else {
None
}
}
}
fn derive_hash(salt: &[u8; SALT_LEN], password: &str) -> [u8; HASH_LEN] {
let mut buf = Vec::with_capacity(SALT_LEN + password.len());
buf.extend_from_slice(salt);
buf.extend_from_slice(password.as_bytes());
spg_crypto::hash(&buf)
}
pub fn compute_scram_secrets(
password: &str,
salt: [u8; SCRAM_SALT_LEN],
iters: u32,
) -> ScramSecrets {
let salted = spg_crypto::pbkdf2::pbkdf2_sha256_32(password.as_bytes(), &salt, iters);
let client_key = spg_crypto::hmac::hmac_sha256(&salted, b"Client Key");
let stored_key = spg_crypto::sha256::hash(&client_key);
let server_key = spg_crypto::hmac::hmac_sha256(&salted, b"Server Key");
ScramSecrets {
iters,
salt,
stored_key,
server_key,
}
}
fn constant_time_eq(a: &[u8; HASH_LEN], b: &[u8; HASH_LEN]) -> bool {
let mut diff: u8 = 0;
for i in 0..HASH_LEN {
diff |= a[i] ^ b[i];
}
diff == 0
}
const SCRAM_FORMAT_MARKER: u8 = 0xff;
pub(crate) fn serialize_users(store: &UserStore) -> Vec<u8> {
let per_user_floor = 2 + 16 + 1 + SALT_LEN + HASH_LEN + 1;
let mut out = Vec::with_capacity(1 + 4 + store.len() * per_user_floor);
out.push(SCRAM_FORMAT_MARKER);
out.extend_from_slice(
&u32::try_from(store.users.len())
.expect("≤ 4G users")
.to_le_bytes(),
);
for (name, rec) in &store.users {
let nl = u16::try_from(name.len()).expect("≤ 65k name");
out.extend_from_slice(&nl.to_le_bytes());
out.extend_from_slice(name.as_bytes());
out.push(match rec.role {
Role::Admin => 0,
Role::ReadWrite => 1,
Role::ReadOnly => 2,
});
out.extend_from_slice(&rec.salt);
out.extend_from_slice(&rec.hash);
match &rec.scram {
None => out.push(0),
Some(s) => {
out.push(1);
out.extend_from_slice(&s.iters.to_le_bytes());
out.extend_from_slice(&s.salt);
out.extend_from_slice(&s.stored_key);
out.extend_from_slice(&s.server_key);
}
}
}
out
}
#[derive(Debug)]
pub enum UserDeserializeError {
Truncated,
BadRole(u8),
InvalidUtf8,
}
impl core::fmt::Display for UserDeserializeError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Truncated => f.write_str("user blob truncated"),
Self::BadRole(b) => write!(f, "unknown role byte: {b}"),
Self::InvalidUtf8 => f.write_str("username not valid UTF-8"),
}
}
}
fn take<'a>(p: &mut usize, n: usize, buf: &'a [u8]) -> Result<&'a [u8], UserDeserializeError> {
if *p + n > buf.len() {
return Err(UserDeserializeError::Truncated);
}
let s = &buf[*p..*p + n];
*p += n;
Ok(s)
}
pub(crate) fn deserialize_users(buf: &[u8]) -> Result<UserStore, UserDeserializeError> {
let mut p = 0usize;
let scram_present_inline = if !buf.is_empty() && buf[0] == SCRAM_FORMAT_MARKER {
p += 1;
true
} else {
false
};
let count_bytes = take(&mut p, 4, buf)?;
let count = u32::from_le_bytes(count_bytes.try_into().unwrap()) as usize;
let mut store = UserStore::new();
for _ in 0..count {
let nl_bytes = take(&mut p, 2, buf)?;
let nl = u16::from_le_bytes(nl_bytes.try_into().unwrap()) as usize;
let name_bytes = take(&mut p, nl, buf)?;
let name = core::str::from_utf8(name_bytes)
.map_err(|_| UserDeserializeError::InvalidUtf8)?
.to_string();
let role_byte = take(&mut p, 1, buf)?[0];
let role = match role_byte {
0 => Role::Admin,
1 => Role::ReadWrite,
2 => Role::ReadOnly,
b => return Err(UserDeserializeError::BadRole(b)),
};
let mut salt = [0u8; SALT_LEN];
salt.copy_from_slice(take(&mut p, SALT_LEN, buf)?);
let mut hash = [0u8; HASH_LEN];
hash.copy_from_slice(take(&mut p, HASH_LEN, buf)?);
let scram = if scram_present_inline {
let flag = take(&mut p, 1, buf)?[0];
if flag == 1 {
let iters_bytes = take(&mut p, 4, buf)?;
let iters = u32::from_le_bytes(iters_bytes.try_into().unwrap());
let mut s_salt = [0u8; SCRAM_SALT_LEN];
s_salt.copy_from_slice(take(&mut p, SCRAM_SALT_LEN, buf)?);
let mut stored_key = [0u8; HASH_LEN];
stored_key.copy_from_slice(take(&mut p, HASH_LEN, buf)?);
let mut server_key = [0u8; HASH_LEN];
server_key.copy_from_slice(take(&mut p, HASH_LEN, buf)?);
Some(ScramSecrets {
iters,
salt: s_salt,
stored_key,
server_key,
})
} else {
None
}
} else {
None
};
store.users.insert(
name,
UserRecord {
role,
salt,
hash,
scram,
},
);
}
if p != buf.len() {
return Err(UserDeserializeError::Truncated);
}
Ok(store)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_then_verify_succeeds_with_right_password_only() {
let mut s = UserStore::new();
s.create("alice", "hunter2", Role::Admin, [1; SALT_LEN])
.unwrap();
assert_eq!(s.verify("alice", "hunter2"), Some(Role::Admin));
assert_eq!(s.verify("alice", "wrong"), None);
assert_eq!(s.verify("bob", "hunter2"), None);
}
#[test]
fn create_duplicate_user_is_rejected() {
let mut s = UserStore::new();
s.create("a", "p", Role::ReadOnly, [0; SALT_LEN]).unwrap();
assert_eq!(
s.create("a", "p2", Role::Admin, [0; SALT_LEN]),
Err(UserError::Exists)
);
}
#[test]
fn drop_user_removes_them() {
let mut s = UserStore::new();
s.create("a", "p", Role::Admin, [0; SALT_LEN]).unwrap();
s.drop("a").unwrap();
assert!(s.is_empty());
assert_eq!(s.drop("a"), Err(UserError::NotFound));
}
#[test]
fn role_parse_accepts_aliases() {
assert_eq!(Role::parse("ADMIN"), Some(Role::Admin));
assert_eq!(Role::parse("rw"), Some(Role::ReadWrite));
assert_eq!(Role::parse("ro"), Some(Role::ReadOnly));
assert_eq!(Role::parse("god"), None);
}
#[test]
fn snapshot_round_trip_preserves_users_and_verify() {
let mut s = UserStore::new();
s.create("alice", "pw1", Role::Admin, [7; SALT_LEN])
.unwrap();
s.create("bob", "pw2", Role::ReadOnly, [13; SALT_LEN])
.unwrap();
let bytes = serialize_users(&s);
let s2 = deserialize_users(&bytes).unwrap();
assert_eq!(s2.len(), 2);
assert_eq!(s2.verify("alice", "pw1"), Some(Role::Admin));
assert_eq!(s2.verify("bob", "pw2"), Some(Role::ReadOnly));
assert_eq!(s2.verify("bob", "wrong"), None);
}
#[test]
fn empty_store_round_trip() {
let s = UserStore::new();
let bytes = serialize_users(&s);
assert_eq!(bytes, [0xff, 0, 0, 0, 0]);
let s2 = deserialize_users(&bytes).unwrap();
assert!(s2.is_empty());
}
#[test]
fn old_v1_user_blob_still_loads() {
let mut buf = Vec::new();
buf.extend_from_slice(&1u32.to_le_bytes());
buf.extend_from_slice(&3u16.to_le_bytes());
buf.extend_from_slice(b"bob");
buf.push(0); buf.extend_from_slice(&[7u8; SALT_LEN]);
buf.extend_from_slice(&[42u8; HASH_LEN]);
let s = deserialize_users(&buf).expect("v1 blob must still load");
assert_eq!(s.len(), 1);
let (n, rec) = s.iter().next().unwrap();
assert_eq!(n, "bob");
assert_eq!(rec.role, Role::Admin);
assert!(rec.scram().is_none(), "v1 users have no SCRAM secrets");
}
#[test]
fn scram_round_trip_preserves_iters_salt_keys() {
let mut s = UserStore::new();
s.create("alice", "pw", Role::Admin, [3; SALT_LEN]).unwrap();
s.enable_scram("alice", "pw", [9; SCRAM_SALT_LEN], 4096)
.unwrap();
let bytes = serialize_users(&s);
let s2 = deserialize_users(&bytes).unwrap();
let (_, rec) = s2.iter().next().unwrap();
let scram = rec.scram().expect("scram must round-trip");
assert_eq!(scram.iters, 4096);
assert_eq!(scram.salt, [9u8; SCRAM_SALT_LEN]);
let expected = compute_scram_secrets("pw", [9; SCRAM_SALT_LEN], 4096);
assert_eq!(scram.stored_key, expected.stored_key);
assert_eq!(scram.server_key, expected.server_key);
}
#[test]
fn deserialize_truncation_is_caught() {
assert!(deserialize_users(&[]).is_err());
assert!(deserialize_users(&[0, 0, 0]).is_err());
}
}