use crate::crypto::aes_gcm::{aes256_gcm_decrypt, aes256_gcm_encrypt};
use crate::crypto::hmac::hmac_sha256;
use crate::crypto::os_random;
use crate::storage::encryption::argon2id::{derive_key, Argon2Params};
use crate::storage::encryption::key::SecureKey;
use crate::storage::engine::page::{Page, PageType, CONTENT_SIZE, HEADER_SIZE};
use crate::storage::engine::pager::Pager;
use super::{ApiKey, AuthError, Role, User, UserId};
const VAULT_MAGIC: &[u8; 4] = b"RDVT";
const VAULT_DATA_MAGIC: &[u8; 4] = b"RDVD";
const VAULT_VERSION: u8 = 2;
const VAULT_LEGACY_VERSION: u8 = 1;
const VAULT_AAD: &[u8] = b"reddb-vault";
const VAULT_LOGICAL_EXPORT_MAGIC: &[u8; 4] = b"RDVX";
const VAULT_LOGICAL_EXPORT_VERSION: u8 = 1;
const VAULT_LOGICAL_EXPORT_AAD: &[u8] = b"reddb-vault-logical-export-v1";
const VAULT_MAGIC_SIZE: usize = 4;
const VAULT_VERSION_SIZE: usize = 1;
const VAULT_SALT_SIZE: usize = 16;
const VAULT_PAYLOAD_LEN_SIZE: usize = 4;
const VAULT_CHAIN_COUNT_SIZE: usize = 4;
const VAULT_FIRST_PAGE_ID_SIZE: usize = 4;
const NONCE_SIZE: usize = 12;
const VAULT_HEADER_PREAMBLE_SIZE: usize =
VAULT_MAGIC_SIZE + VAULT_VERSION_SIZE + VAULT_SALT_SIZE + VAULT_PAYLOAD_LEN_SIZE;
const VAULT_HEADER_META_SIZE: usize =
VAULT_HEADER_PREAMBLE_SIZE + NONCE_SIZE + VAULT_CHAIN_COUNT_SIZE + VAULT_FIRST_PAGE_ID_SIZE;
const VAULT_DATA_PREFIX_SIZE: usize = VAULT_MAGIC_SIZE + 4;
const VAULT_HEADER_PAGE: u32 = 2;
const VAULT_HEADER_CIPHER_CAPACITY: usize = CONTENT_SIZE - VAULT_HEADER_META_SIZE;
const VAULT_DATA_CIPHER_CAPACITY: usize = CONTENT_SIZE - VAULT_DATA_PREFIX_SIZE;
pub struct KeyPair {
pub master_secret: Vec<u8>,
pub certificate: Vec<u8>,
}
impl KeyPair {
pub fn generate() -> Self {
let mut master_secret = vec![0u8; 32];
os_random::fill_bytes(&mut master_secret).expect("CSPRNG failed during keypair generation");
let certificate = hmac_sha256(&master_secret, b"reddb-certificate-v1");
Self {
master_secret,
certificate: certificate.to_vec(),
}
}
pub fn from_master_secret(master_secret: Vec<u8>) -> Self {
let certificate = hmac_sha256(&master_secret, b"reddb-certificate-v1");
Self {
master_secret,
certificate: certificate.to_vec(),
}
}
pub fn vault_key_from_certificate(certificate: &[u8]) -> SecureKey {
let key_bytes = derive_key(certificate, b"reddb-vault-seal", &vault_argon2_params());
SecureKey::new(&key_bytes)
}
pub fn sign(&self, data: &[u8]) -> Vec<u8> {
hmac_sha256(&self.master_secret, data).to_vec()
}
pub fn verify(&self, data: &[u8], signature: &[u8]) -> bool {
let expected = self.sign(data);
constant_time_eq(&expected, signature)
}
pub fn certificate_hex(&self) -> String {
hex::encode(&self.certificate)
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[derive(Debug)]
pub enum VaultError {
NoKey,
Encryption,
Decryption,
Io(std::io::Error),
Corrupt(String),
Pager(String),
}
impl std::fmt::Display for VaultError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoKey => write!(
f,
"no vault key: set REDDB_CERTIFICATE (or REDDB_VAULT_KEY) or provide a certificate"
),
Self::Encryption => write!(f, "vault encryption failed"),
Self::Decryption => write!(f, "vault decryption failed (wrong key or corrupt data)"),
Self::Io(err) => write!(f, "vault I/O error: {err}"),
Self::Corrupt(msg) => write!(f, "vault corrupt: {msg}"),
Self::Pager(msg) => write!(f, "vault pager error: {msg}"),
}
}
}
impl std::error::Error for VaultError {}
impl From<VaultError> for AuthError {
fn from(err: VaultError) -> Self {
AuthError::Internal(err.to_string())
}
}
#[derive(Debug, Default)]
pub struct VaultState {
pub users: Vec<User>,
pub api_keys: Vec<(UserId, ApiKey)>,
pub bootstrapped: bool,
pub master_secret: Option<Vec<u8>>,
pub kv: std::collections::HashMap<String, String>,
}
impl VaultState {
pub fn serialize(&self) -> Vec<u8> {
let mut out = String::new();
if let Some(ref secret) = self.master_secret {
out.push_str(&format!("MASTER_SECRET:{}\n", hex::encode(secret)));
}
out.push_str(&format!("SEALED:{}\n", self.bootstrapped));
for user in &self.users {
let scram_field = match &user.scram_verifier {
Some(v) => format!(
"{}:{}:{}:{}",
hex::encode(&v.salt),
v.iter,
hex::encode(v.stored_key),
hex::encode(v.server_key),
),
None => String::new(),
};
let tenant_field = user.tenant_id.clone().unwrap_or_default();
out.push_str(&format!(
"USER:{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n",
user.username,
user.password_hash,
user.role.as_str(),
user.enabled,
user.created_at,
user.updated_at,
scram_field,
tenant_field,
));
}
for (owner, key) in &self.api_keys {
let tenant_field = owner.tenant.clone().unwrap_or_default();
out.push_str(&format!(
"KEY:{}\t{}\t{}\t{}\t{}\t{}\n",
owner.username,
key.key,
key.name,
key.role.as_str(),
key.created_at,
tenant_field,
));
}
for (k, v) in &self.kv {
out.push_str(&format!("KV:{}\t{}\n", k, hex::encode(v.as_bytes())));
}
out.into_bytes()
}
pub fn deserialize(data: &[u8]) -> Result<Self, VaultError> {
let text = std::str::from_utf8(data)
.map_err(|_| VaultError::Corrupt("payload is not valid UTF-8".into()))?;
let mut users = Vec::new();
let mut api_keys: Vec<(UserId, ApiKey)> = Vec::new();
let mut bootstrapped = false;
let mut master_secret: Option<Vec<u8>> = None;
let mut kv: std::collections::HashMap<String, String> = std::collections::HashMap::new();
for line in text.lines() {
if line.is_empty() {
continue;
}
if let Some(rest) = line.strip_prefix("MASTER_SECRET:") {
master_secret = Some(
hex::decode(rest)
.map_err(|_| VaultError::Corrupt("invalid MASTER_SECRET hex".into()))?,
);
} else if let Some(rest) = line.strip_prefix("SEALED:") {
bootstrapped = rest == "true";
} else if let Some(rest) = line.strip_prefix("USER:") {
let parts: Vec<&str> = rest.split('\t').collect();
if parts.len() != 7 && parts.len() != 8 {
return Err(VaultError::Corrupt(format!(
"USER line has {} fields, expected 7 or 8",
parts.len()
)));
}
let role = Role::from_str(parts[2])
.ok_or_else(|| VaultError::Corrupt(format!("unknown role: {}", parts[2])))?;
let enabled = parts[3] == "true";
let created_at: u128 = parts[4]
.parse()
.map_err(|_| VaultError::Corrupt("invalid created_at".into()))?;
let updated_at: u128 = parts[5]
.parse()
.map_err(|_| VaultError::Corrupt("invalid updated_at".into()))?;
let scram_verifier = parts
.get(6)
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(parse_scram_field)
.transpose()?;
let tenant_id = parts
.get(7)
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.to_string());
users.push(User {
username: parts[0].to_string(),
tenant_id,
password_hash: parts[1].to_string(),
scram_verifier,
role,
api_keys: Vec::new(), created_at,
updated_at,
enabled,
});
} else if let Some(rest) = line.strip_prefix("KEY:") {
let parts: Vec<&str> = rest.split('\t').collect();
if parts.len() != 5 && parts.len() != 6 {
return Err(VaultError::Corrupt(format!(
"KEY line has {} fields, expected 5 or 6",
parts.len()
)));
}
let role = Role::from_str(parts[3])
.ok_or_else(|| VaultError::Corrupt(format!("unknown role: {}", parts[3])))?;
let created_at: u128 = parts[4]
.parse()
.map_err(|_| VaultError::Corrupt("invalid key created_at".into()))?;
let tenant_id = parts
.get(5)
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.to_string());
api_keys.push((
UserId {
tenant: tenant_id,
username: parts[0].to_string(),
},
ApiKey {
key: parts[1].to_string(),
name: parts[2].to_string(),
role,
created_at,
},
));
} else if let Some(rest) = line.strip_prefix("KV:") {
let parts: Vec<&str> = rest.splitn(2, '\t').collect();
if parts.len() == 2 {
if let Ok(bytes) = hex::decode(parts[1]) {
if let Ok(value) = String::from_utf8(bytes) {
kv.insert(parts[0].to_string(), value);
}
}
}
} else {
}
}
for (owner, key) in &api_keys {
if let Some(user) = users
.iter_mut()
.find(|u| u.username == owner.username && u.tenant_id == owner.tenant)
{
user.api_keys.push(key.clone());
}
}
Ok(Self {
users,
api_keys,
bootstrapped,
master_secret,
kv,
})
}
}
pub struct Vault {
key: SecureKey,
salt: [u8; 16],
}
fn vault_argon2_params() -> Argon2Params {
Argon2Params {
m_cost: 16 * 1024, t_cost: 3,
p: 1,
tag_len: 32,
}
}
impl Vault {
pub fn has_saved_state(pager: &Pager) -> bool {
pager
.read_page_no_checksum(VAULT_HEADER_PAGE)
.ok()
.map(|page| {
let content = page.content();
content.len() >= VAULT_MAGIC_SIZE && &content[0..VAULT_MAGIC_SIZE] == VAULT_MAGIC
})
.unwrap_or(false)
}
pub fn open(pager: &Pager, passphrase: Option<&str>) -> Result<Self, VaultError> {
if let Ok(cert_hex) = std::env::var("REDDB_CERTIFICATE") {
return Self::with_certificate(pager, &cert_hex);
}
let passphrase_str = std::env::var("REDDB_VAULT_KEY")
.ok()
.or_else(|| passphrase.map(|s| s.to_string()))
.ok_or(VaultError::NoKey)?;
let salt = match read_vault_salt_from_pager(pager) {
Ok(s) => s,
Err(_) => {
let mut salt = [0u8; 16];
let mut buf = [0u8; 16];
os_random::fill_bytes(&mut buf)
.map_err(|e| VaultError::Corrupt(format!("CSPRNG failed: {e}")))?;
salt.copy_from_slice(&buf);
salt
}
};
let key_bytes = derive_key(passphrase_str.as_bytes(), &salt, &vault_argon2_params());
let key = SecureKey::new(&key_bytes);
Ok(Self { key, salt })
}
pub fn with_certificate(pager: &Pager, certificate_hex: &str) -> Result<Self, VaultError> {
let certificate = hex::decode(certificate_hex).map_err(|_| VaultError::NoKey)?;
let key = KeyPair::vault_key_from_certificate(&certificate);
let salt = match read_vault_salt_from_pager(pager) {
Ok(s) => s,
Err(_) => {
let mut s = [0u8; 16];
os_random::fill_bytes(&mut s)
.map_err(|e| VaultError::Corrupt(format!("CSPRNG failed: {e}")))?;
s
}
};
Ok(Self { key, salt })
}
pub fn from_env(pager: &Pager) -> Result<Self, VaultError> {
if let Ok(cert_hex) = std::env::var("REDDB_CERTIFICATE") {
return Self::with_certificate(pager, &cert_hex);
}
if let Ok(passphrase) = std::env::var("REDDB_VAULT_KEY") {
return Self::open_with_passphrase(pager, &passphrase);
}
Err(VaultError::NoKey)
}
fn open_with_passphrase(pager: &Pager, passphrase: &str) -> Result<Self, VaultError> {
let salt = match read_vault_salt_from_pager(pager) {
Ok(s) => s,
Err(_) => {
let mut s = [0u8; 16];
os_random::fill_bytes(&mut s)
.map_err(|e| VaultError::Corrupt(format!("CSPRNG failed: {e}")))?;
s
}
};
let key_bytes = derive_key(passphrase.as_bytes(), &salt, &vault_argon2_params());
let key = SecureKey::new(&key_bytes);
Ok(Self { key, salt })
}
pub fn with_certificate_bytes(pager: &Pager, certificate: &[u8]) -> Result<Self, VaultError> {
let key = KeyPair::vault_key_from_certificate(certificate);
let salt = match read_vault_salt_from_pager(pager) {
Ok(s) => s,
Err(_) => {
let mut s = [0u8; 16];
os_random::fill_bytes(&mut s)
.map_err(|e| VaultError::Corrupt(format!("CSPRNG failed: {e}")))?;
s
}
};
Ok(Self { key, salt })
}
pub fn seal_logical_export(&self, state: &VaultState) -> Result<String, VaultError> {
let plaintext = state.serialize();
let mut nonce = [0u8; NONCE_SIZE];
os_random::fill_bytes(&mut nonce)
.map_err(|e| VaultError::Corrupt(format!("CSPRNG failed: {e}")))?;
let key_bytes: &[u8] = self.key.as_bytes();
let key_arr: &[u8; 32] = key_bytes.try_into().map_err(|_| VaultError::Encryption)?;
let ciphertext = aes256_gcm_encrypt(key_arr, &nonce, VAULT_LOGICAL_EXPORT_AAD, &plaintext);
let mut out = Vec::with_capacity(
VAULT_MAGIC_SIZE + VAULT_VERSION_SIZE + VAULT_SALT_SIZE + NONCE_SIZE + ciphertext.len(),
);
out.extend_from_slice(VAULT_LOGICAL_EXPORT_MAGIC);
out.push(VAULT_LOGICAL_EXPORT_VERSION);
out.extend_from_slice(&self.salt);
out.extend_from_slice(&nonce);
out.extend_from_slice(&ciphertext);
Ok(hex::encode(out))
}
pub fn unseal_logical_export(
blob_hex: &str,
passphrase: Option<&str>,
) -> Result<VaultState, VaultError> {
let (salt, nonce, ciphertext) = Self::decode_logical_export(blob_hex)?;
let key = if let Ok(cert_hex) = std::env::var("REDDB_CERTIFICATE") {
let certificate = hex::decode(cert_hex).map_err(|_| VaultError::NoKey)?;
KeyPair::vault_key_from_certificate(&certificate)
} else {
let passphrase_str = std::env::var("REDDB_VAULT_KEY")
.ok()
.or_else(|| passphrase.map(|s| s.to_string()))
.ok_or(VaultError::NoKey)?;
let key_bytes = derive_key(passphrase_str.as_bytes(), &salt, &vault_argon2_params());
SecureKey::new(&key_bytes)
};
Self::decrypt_logical_export(&key, &nonce, &ciphertext)
}
pub fn unseal_logical_export_with_passphrase(
blob_hex: &str,
passphrase: &str,
) -> Result<VaultState, VaultError> {
let (salt, nonce, ciphertext) = Self::decode_logical_export(blob_hex)?;
let key_bytes = derive_key(passphrase.as_bytes(), &salt, &vault_argon2_params());
let key = SecureKey::new(&key_bytes);
Self::decrypt_logical_export(&key, &nonce, &ciphertext)
}
fn decode_logical_export(
blob_hex: &str,
) -> Result<([u8; VAULT_SALT_SIZE], [u8; NONCE_SIZE], Vec<u8>), VaultError> {
let blob = hex::decode(blob_hex).map_err(|_| VaultError::Corrupt("bad hex".into()))?;
let min_len = VAULT_MAGIC_SIZE + VAULT_VERSION_SIZE + VAULT_SALT_SIZE + NONCE_SIZE + 16;
if blob.len() < min_len {
return Err(VaultError::Corrupt("logical vault export too short".into()));
}
if &blob[0..VAULT_MAGIC_SIZE] != VAULT_LOGICAL_EXPORT_MAGIC {
return Err(VaultError::Corrupt(
"bad logical vault export magic".to_string(),
));
}
let version = blob[VAULT_MAGIC_SIZE];
if version != VAULT_LOGICAL_EXPORT_VERSION {
return Err(VaultError::Corrupt(format!(
"unsupported logical vault export version: {version}"
)));
}
let mut off = VAULT_MAGIC_SIZE + VAULT_VERSION_SIZE;
let salt: [u8; VAULT_SALT_SIZE] = blob[off..off + VAULT_SALT_SIZE]
.try_into()
.map_err(|_| VaultError::Corrupt("bad logical export salt".into()))?;
off += VAULT_SALT_SIZE;
let nonce: [u8; NONCE_SIZE] = blob[off..off + NONCE_SIZE]
.try_into()
.map_err(|_| VaultError::Corrupt("bad logical export nonce".into()))?;
off += NONCE_SIZE;
Ok((salt, nonce, blob[off..].to_vec()))
}
fn decrypt_logical_export(
key: &SecureKey,
nonce: &[u8; NONCE_SIZE],
ciphertext: &[u8],
) -> Result<VaultState, VaultError> {
let key_bytes: &[u8] = key.as_bytes();
let key_arr: &[u8; 32] = key_bytes.try_into().map_err(|_| VaultError::Decryption)?;
let plaintext = aes256_gcm_decrypt(key_arr, nonce, VAULT_LOGICAL_EXPORT_AAD, ciphertext)
.map_err(|_| VaultError::Decryption)?;
VaultState::deserialize(&plaintext)
}
pub fn save(&self, pager: &Pager, state: &VaultState) -> Result<(), VaultError> {
let plaintext = state.serialize();
let mut nonce = [0u8; NONCE_SIZE];
os_random::fill_bytes(&mut nonce)
.map_err(|e| VaultError::Corrupt(format!("CSPRNG failed: {e}")))?;
let key_bytes: &[u8] = self.key.as_bytes();
let key_arr: &[u8; 32] = key_bytes.try_into().map_err(|_| VaultError::Encryption)?;
let ciphertext = aes256_gcm_encrypt(key_arr, &nonce, VAULT_AAD, &plaintext);
let cipher_total = ciphertext.len();
let payload_len = (NONCE_SIZE + cipher_total) as u32;
let header_chunk_len = cipher_total.min(VAULT_HEADER_CIPHER_CAPACITY);
let overflow = cipher_total.saturating_sub(header_chunk_len);
let chain_count = overflow.div_ceil(VAULT_DATA_CIPHER_CAPACITY);
while pager
.page_count()
.map_err(|e| VaultError::Pager(e.to_string()))?
<= VAULT_HEADER_PAGE
{
pager
.allocate_page(PageType::Vault)
.map_err(|e| VaultError::Pager(format!("reserve vault slot: {e}")))?;
}
let old_chain = self.read_existing_chain_ids(pager).unwrap_or_default();
let mut new_chain: Vec<u32> = Vec::with_capacity(chain_count);
for _ in 0..chain_count {
let page = pager
.allocate_page(PageType::Vault)
.map_err(|e| VaultError::Pager(format!("allocate vault data page: {e}")))?;
new_chain.push(page.page_id());
}
let mut cursor = header_chunk_len;
for i in 0..chain_count {
let next_id = if i + 1 < chain_count {
new_chain[i + 1]
} else {
0
};
let take = (cipher_total - cursor).min(VAULT_DATA_CIPHER_CAPACITY);
let frag = &ciphertext[cursor..cursor + take];
self.write_data_page(pager, new_chain[i], next_id, frag)?;
cursor += take;
}
debug_assert_eq!(cursor, cipher_total, "ciphertext spill accounting mismatch");
let first_data_page = new_chain.first().copied().unwrap_or(0);
self.write_header_page(
pager,
&nonce,
payload_len,
chain_count as u32,
first_data_page,
&ciphertext[..header_chunk_len],
)?;
pager
.flush()
.map_err(|e| VaultError::Pager(e.to_string()))?;
for &id in old_chain.iter() {
pager
.free_page(id)
.map_err(|e| VaultError::Pager(format!("free old vault page {id}: {e}")))?;
}
Ok(())
}
pub fn load(&self, pager: &Pager) -> Result<Option<VaultState>, VaultError> {
let page = match pager.read_page_no_checksum(VAULT_HEADER_PAGE) {
Ok(p) => p,
Err(_) => return Ok(None),
};
let page_content = page.content();
if page_content.len() < VAULT_HEADER_META_SIZE {
return Ok(None);
}
if &page_content[0..VAULT_MAGIC_SIZE] != VAULT_MAGIC {
return Ok(None); }
let version = page_content[4];
if version == VAULT_LEGACY_VERSION {
return Err(VaultError::Corrupt(
"vault was bootstrapped with the legacy 2-page format \
(pre-RedDB v0.3); re-bootstrap with `red bootstrap` to upgrade"
.to_string(),
));
}
if version != VAULT_VERSION {
return Err(VaultError::Corrupt(format!(
"unsupported vault version: {} (expected {})",
version, VAULT_VERSION
)));
}
let payload_len = u32::from_le_bytes(
page_content[21..25]
.try_into()
.map_err(|_| VaultError::Corrupt("bad payload length bytes".into()))?,
) as usize;
let nonce_start = VAULT_HEADER_PREAMBLE_SIZE;
let nonce: [u8; NONCE_SIZE] = page_content[nonce_start..nonce_start + NONCE_SIZE]
.try_into()
.map_err(|_| VaultError::Corrupt("bad nonce".into()))?;
let chain_count_off = nonce_start + NONCE_SIZE;
let chain_count = u32::from_le_bytes(
page_content[chain_count_off..chain_count_off + 4]
.try_into()
.map_err(|_| VaultError::Corrupt("bad chain_count bytes".into()))?,
) as usize;
let first_id_off = chain_count_off + 4;
let mut next_id = u32::from_le_bytes(
page_content[first_id_off..first_id_off + 4]
.try_into()
.map_err(|_| VaultError::Corrupt("bad first_data_page_id bytes".into()))?,
);
if payload_len < NONCE_SIZE {
return Err(VaultError::Corrupt("payload too short for nonce".into()));
}
let cipher_total = payload_len - NONCE_SIZE;
let mut cipher = Vec::with_capacity(cipher_total);
let header_chunk_len = cipher_total.min(VAULT_HEADER_CIPHER_CAPACITY);
let header_cipher_start = VAULT_HEADER_META_SIZE;
cipher.extend_from_slice(
&page_content[header_cipher_start..header_cipher_start + header_chunk_len],
);
let mut hops = 0usize;
while cipher.len() < cipher_total {
if hops >= chain_count {
return Err(VaultError::Corrupt(format!(
"vault chain shorter than declared: {} hops, expected {}",
hops, chain_count
)));
}
if next_id == 0 {
return Err(VaultError::Corrupt(
"vault chain ends prematurely (next_id == 0)".to_string(),
));
}
let dp = pager
.read_page_no_checksum(next_id)
.map_err(|e| VaultError::Pager(format!("vault data page {next_id}: {e}")))?;
let dp_content = dp.content();
if dp_content.len() < VAULT_DATA_PREFIX_SIZE {
return Err(VaultError::Corrupt(format!(
"vault data page {next_id} truncated"
)));
}
if &dp_content[0..VAULT_MAGIC_SIZE] != VAULT_DATA_MAGIC {
return Err(VaultError::Corrupt(format!(
"vault data page {next_id} has bad magic"
)));
}
let np = u32::from_le_bytes(
dp_content[VAULT_MAGIC_SIZE..VAULT_MAGIC_SIZE + 4]
.try_into()
.map_err(|_| VaultError::Corrupt("bad next_page_id bytes".into()))?,
);
let take = (cipher_total - cipher.len()).min(VAULT_DATA_CIPHER_CAPACITY);
let frag_start = VAULT_DATA_PREFIX_SIZE;
cipher.extend_from_slice(&dp_content[frag_start..frag_start + take]);
next_id = np;
hops += 1;
}
if cipher.len() != cipher_total {
return Err(VaultError::Corrupt(format!(
"vault truncated: expected {} cipher bytes, got {}",
cipher_total,
cipher.len()
)));
}
if hops != chain_count {
return Err(VaultError::Corrupt(format!(
"vault chain length mismatch: walked {} pages, header says {}",
hops, chain_count
)));
}
let key_bytes: &[u8] = self.key.as_bytes();
let key_arr: &[u8; 32] = key_bytes.try_into().map_err(|_| VaultError::Decryption)?;
let plaintext = aes256_gcm_decrypt(key_arr, &nonce, VAULT_AAD, &cipher)
.map_err(|_| VaultError::Decryption)?;
let state = VaultState::deserialize(&plaintext)?;
Ok(Some(state))
}
fn read_existing_chain_ids(&self, pager: &Pager) -> Result<Vec<u32>, VaultError> {
let header = pager
.read_page_no_checksum(VAULT_HEADER_PAGE)
.map_err(|e| VaultError::Pager(e.to_string()))?;
let content = header.content();
if content.len() < VAULT_HEADER_META_SIZE {
return Ok(Vec::new());
}
if &content[0..VAULT_MAGIC_SIZE] != VAULT_MAGIC {
return Ok(Vec::new());
}
let version = content[4];
if version != VAULT_VERSION {
return Ok(Vec::new());
}
let nonce_start = VAULT_HEADER_PREAMBLE_SIZE;
let chain_count_off = nonce_start + NONCE_SIZE;
let chain_count = u32::from_le_bytes(
content[chain_count_off..chain_count_off + 4]
.try_into()
.map_err(|_| VaultError::Corrupt("bad chain_count bytes".into()))?,
) as usize;
let first_id_off = chain_count_off + 4;
let mut id = u32::from_le_bytes(
content[first_id_off..first_id_off + 4]
.try_into()
.map_err(|_| VaultError::Corrupt("bad first_data_page_id bytes".into()))?,
);
let mut out = Vec::with_capacity(chain_count);
let mut hops = 0usize;
while id != 0 && hops < chain_count {
out.push(id);
match pager.read_page_no_checksum(id) {
Ok(dp) => {
let dc = dp.content();
if dc.len() < VAULT_DATA_PREFIX_SIZE
|| &dc[0..VAULT_MAGIC_SIZE] != VAULT_DATA_MAGIC
{
break;
}
id = u32::from_le_bytes(
dc[VAULT_MAGIC_SIZE..VAULT_MAGIC_SIZE + 4]
.try_into()
.map_err(|_| VaultError::Corrupt("bad next_id".into()))?,
);
}
Err(_) => break,
}
hops += 1;
}
Ok(out)
}
fn write_header_page(
&self,
pager: &Pager,
nonce: &[u8; NONCE_SIZE],
payload_len: u32,
chain_count: u32,
first_data_page_id: u32,
cipher_fragment: &[u8],
) -> Result<(), VaultError> {
debug_assert!(cipher_fragment.len() <= VAULT_HEADER_CIPHER_CAPACITY);
let mut page = Page::new(PageType::Vault, VAULT_HEADER_PAGE);
let bytes = page.as_bytes_mut();
let mut off = HEADER_SIZE;
bytes[off..off + VAULT_MAGIC_SIZE].copy_from_slice(VAULT_MAGIC);
off += VAULT_MAGIC_SIZE;
bytes[off] = VAULT_VERSION;
off += VAULT_VERSION_SIZE;
bytes[off..off + VAULT_SALT_SIZE].copy_from_slice(&self.salt);
off += VAULT_SALT_SIZE;
bytes[off..off + 4].copy_from_slice(&payload_len.to_le_bytes());
off += VAULT_PAYLOAD_LEN_SIZE;
bytes[off..off + NONCE_SIZE].copy_from_slice(nonce);
off += NONCE_SIZE;
bytes[off..off + 4].copy_from_slice(&chain_count.to_le_bytes());
off += VAULT_CHAIN_COUNT_SIZE;
bytes[off..off + 4].copy_from_slice(&first_data_page_id.to_le_bytes());
off += VAULT_FIRST_PAGE_ID_SIZE;
debug_assert_eq!(off, HEADER_SIZE + VAULT_HEADER_META_SIZE);
bytes[off..off + cipher_fragment.len()].copy_from_slice(cipher_fragment);
pager
.write_page_no_checksum(VAULT_HEADER_PAGE, page)
.map_err(|e| VaultError::Pager(e.to_string()))?;
Ok(())
}
fn write_data_page(
&self,
pager: &Pager,
page_id: u32,
next_page_id: u32,
cipher_fragment: &[u8],
) -> Result<(), VaultError> {
debug_assert!(cipher_fragment.len() <= VAULT_DATA_CIPHER_CAPACITY);
let mut page = Page::new(PageType::Vault, page_id);
let bytes = page.as_bytes_mut();
let mut off = HEADER_SIZE;
bytes[off..off + VAULT_MAGIC_SIZE].copy_from_slice(VAULT_DATA_MAGIC);
off += VAULT_MAGIC_SIZE;
bytes[off..off + 4].copy_from_slice(&next_page_id.to_le_bytes());
off += 4;
bytes[off..off + cipher_fragment.len()].copy_from_slice(cipher_fragment);
pager
.write_page_no_checksum(page_id, page)
.map_err(|e| VaultError::Pager(e.to_string()))?;
Ok(())
}
}
fn parse_scram_field(field: &str) -> Result<crate::auth::scram::ScramVerifier, VaultError> {
let parts: Vec<&str> = field.split(':').collect();
if parts.len() != 4 {
return Err(VaultError::Corrupt(format!(
"SCRAM verifier has {} segments, expected 4",
parts.len()
)));
}
let salt =
hex::decode(parts[0]).map_err(|_| VaultError::Corrupt("invalid SCRAM salt hex".into()))?;
let iter: u32 = parts[1]
.parse()
.map_err(|_| VaultError::Corrupt("invalid SCRAM iter".into()))?;
if iter < crate::auth::scram::MIN_ITER {
return Err(VaultError::Corrupt(format!(
"SCRAM iter {} below minimum {}",
iter,
crate::auth::scram::MIN_ITER
)));
}
let stored_vec = hex::decode(parts[2])
.map_err(|_| VaultError::Corrupt("invalid SCRAM stored_key hex".into()))?;
let server_vec = hex::decode(parts[3])
.map_err(|_| VaultError::Corrupt("invalid SCRAM server_key hex".into()))?;
let stored_key: [u8; 32] = stored_vec
.try_into()
.map_err(|_| VaultError::Corrupt("SCRAM stored_key must be 32 bytes".into()))?;
let server_key: [u8; 32] = server_vec
.try_into()
.map_err(|_| VaultError::Corrupt("SCRAM server_key must be 32 bytes".into()))?;
Ok(crate::auth::scram::ScramVerifier {
salt,
iter,
stored_key,
server_key,
})
}
fn read_vault_salt_from_pager(pager: &Pager) -> Result<[u8; 16], VaultError> {
let page = pager
.read_page_no_checksum(VAULT_HEADER_PAGE)
.map_err(|e| VaultError::Pager(format!("vault page read: {e}")))?;
let content = page.content();
if content.len() < VAULT_HEADER_PREAMBLE_SIZE {
return Err(VaultError::Corrupt("vault page too short".into()));
}
if &content[0..VAULT_MAGIC_SIZE] != VAULT_MAGIC {
return Err(VaultError::Corrupt("bad magic bytes".into()));
}
let mut salt = [0u8; VAULT_SALT_SIZE];
salt.copy_from_slice(&content[5..21]);
Ok(salt)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::{now_ms, ApiKey, Role, User};
use crate::storage::engine::pager::PagerConfig;
fn sample_state() -> VaultState {
let now = now_ms();
VaultState {
users: vec![
User {
username: "alice".into(),
tenant_id: None,
password_hash: "argon2id$aabbccdd$eeff0011".into(),
scram_verifier: None,
role: Role::Admin,
api_keys: vec![ApiKey {
key: "rk_abc123".into(),
name: "ci-token".into(),
role: Role::Write,
created_at: now,
}],
created_at: now,
updated_at: now,
enabled: true,
},
User {
username: "bob".into(),
tenant_id: None,
password_hash: "argon2id$11223344$55667788".into(),
scram_verifier: None,
role: Role::Read,
api_keys: vec![],
created_at: now,
updated_at: now,
enabled: false,
},
],
api_keys: vec![(
UserId::platform("alice"),
ApiKey {
key: "rk_abc123".into(),
name: "ci-token".into(),
role: Role::Write,
created_at: now,
},
)],
bootstrapped: true,
master_secret: None,
kv: std::collections::HashMap::new(),
}
}
fn temp_pager() -> (Pager, std::path::PathBuf) {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
let tmp_dir =
std::env::temp_dir().join(format!("reddb_vault_test_{}_{}", std::process::id(), id));
std::fs::create_dir_all(&tmp_dir).unwrap();
let db_path = tmp_dir.join("test.rdb");
let pager = Pager::open(&db_path, PagerConfig::default()).unwrap();
(pager, tmp_dir)
}
#[test]
fn test_vault_state_serialize_deserialize_roundtrip() {
let state = sample_state();
let serialized = state.serialize();
let text = std::str::from_utf8(&serialized).unwrap();
assert!(text.contains("SEALED:true"));
assert!(text.contains("USER:alice\t"));
assert!(text.contains("USER:bob\t"));
assert!(text.contains("KEY:alice\trk_abc123\t"));
let restored = VaultState::deserialize(&serialized).unwrap();
assert!(restored.bootstrapped);
assert_eq!(restored.users.len(), 2);
let alice = restored
.users
.iter()
.find(|u| u.username == "alice")
.unwrap();
assert_eq!(alice.role, Role::Admin);
assert!(alice.enabled);
assert_eq!(alice.password_hash, "argon2id$aabbccdd$eeff0011");
assert_eq!(alice.api_keys.len(), 1);
assert_eq!(alice.api_keys[0].key, "rk_abc123");
assert_eq!(alice.api_keys[0].name, "ci-token");
assert_eq!(alice.api_keys[0].role, Role::Write);
let bob = restored.users.iter().find(|u| u.username == "bob").unwrap();
assert_eq!(bob.role, Role::Read);
assert!(!bob.enabled);
assert!(bob.api_keys.is_empty());
assert_eq!(restored.api_keys.len(), 1);
assert_eq!(restored.api_keys[0].0.username, "alice");
assert!(restored.api_keys[0].0.tenant.is_none());
}
#[test]
fn test_vault_state_empty() {
let state = VaultState {
users: vec![],
api_keys: vec![],
bootstrapped: false,
master_secret: None,
kv: std::collections::HashMap::new(),
};
let serialized = state.serialize();
let restored = VaultState::deserialize(&serialized).unwrap();
assert!(!restored.bootstrapped);
assert!(restored.users.is_empty());
assert!(restored.api_keys.is_empty());
}
#[test]
fn test_vault_state_deserialize_invalid_utf8() {
let bad_data = vec![0xFF, 0xFE, 0xFD];
let result = VaultState::deserialize(&bad_data);
assert!(result.is_err());
}
#[test]
fn test_vault_state_deserialize_bad_user_line() {
let bad = b"USER:only_two\tfields\n";
let result = VaultState::deserialize(bad);
assert!(result.is_err());
}
#[test]
fn test_vault_state_deserialize_bad_key_line() {
let bad = b"KEY:too\tfew\n";
let result = VaultState::deserialize(bad);
assert!(result.is_err());
}
#[test]
fn test_vault_state_deserialize_unknown_line_skipped() {
let data = b"SEALED:false\nFUTURE:some_data\n";
let result = VaultState::deserialize(data).unwrap();
assert!(!result.bootstrapped);
}
#[test]
fn test_vault_pager_save_load_roundtrip() {
let (pager, tmp_dir) = temp_pager();
let vault = Vault::open(&pager, Some("test-passphrase-42")).unwrap();
let loaded = vault.load(&pager).unwrap();
assert!(loaded.is_none());
let state = sample_state();
vault.save(&pager, &state).unwrap();
let restored = vault.load(&pager).unwrap().unwrap();
assert!(restored.bootstrapped);
assert_eq!(restored.users.len(), 2);
assert_eq!(restored.api_keys.len(), 1);
let alice = restored
.users
.iter()
.find(|u| u.username == "alice")
.unwrap();
assert_eq!(alice.role, Role::Admin);
assert_eq!(alice.api_keys.len(), 1);
let vault2 = Vault::open(&pager, Some("test-passphrase-42")).unwrap();
let restored2 = vault2.load(&pager).unwrap().unwrap();
assert!(restored2.bootstrapped);
assert_eq!(restored2.users.len(), 2);
drop(pager);
let _ = std::fs::remove_dir_all(&tmp_dir);
}
#[test]
fn test_vault_wrong_key_fails_decryption() {
let (pager, tmp_dir) = temp_pager();
let vault = Vault::open(&pager, Some("correct-key")).unwrap();
let state = VaultState {
users: vec![],
api_keys: vec![],
bootstrapped: true,
master_secret: None,
kv: std::collections::HashMap::new(),
};
vault.save(&pager, &state).unwrap();
let vault2 = Vault::open(&pager, Some("wrong-key")).unwrap();
let result = vault2.load(&pager);
assert!(result.is_err());
drop(pager);
let _ = std::fs::remove_dir_all(&tmp_dir);
}
#[test]
fn test_vault_no_key_error() {
let (pager, tmp_dir) = temp_pager();
let result = Vault::open(&pager, None);
let has_env_key =
std::env::var("REDDB_VAULT_KEY").is_ok() || std::env::var("REDDB_CERTIFICATE").is_ok();
match has_env_key {
true => {
assert!(result.is_ok());
}
false => {
assert!(matches!(result, Err(VaultError::NoKey)));
}
}
drop(pager);
let _ = std::fs::remove_dir_all(&tmp_dir);
}
#[test]
fn test_vault_passphrase_argument() {
let (pager, tmp_dir) = temp_pager();
let vault = Vault::open(&pager, Some("my-passphrase")).unwrap();
let state = VaultState {
users: vec![],
api_keys: vec![],
bootstrapped: false,
master_secret: None,
kv: std::collections::HashMap::new(),
};
vault.save(&pager, &state).unwrap();
let vault2 = Vault::open(&pager, Some("my-passphrase")).unwrap();
let loaded = vault2.load(&pager).unwrap().unwrap();
assert!(!loaded.bootstrapped);
drop(pager);
let _ = std::fs::remove_dir_all(&tmp_dir);
}
#[test]
fn test_keypair_generate_deterministic_certificate() {
let kp = KeyPair::generate();
assert_eq!(kp.master_secret.len(), 32);
assert_eq!(kp.certificate.len(), 32);
let kp2 = KeyPair::from_master_secret(kp.master_secret.clone());
assert_eq!(kp.certificate, kp2.certificate);
}
#[test]
fn test_keypair_sign_verify() {
let kp = KeyPair::generate();
let data = b"session:abc123";
let sig = kp.sign(data);
assert!(kp.verify(data, &sig));
assert!(!kp.verify(b"session:wrong", &sig));
let mut bad_sig = sig.clone();
bad_sig[0] ^= 0xFF;
assert!(!kp.verify(data, &bad_sig));
}
#[test]
fn test_keypair_certificate_hex() {
let kp = KeyPair::generate();
let hex_str = kp.certificate_hex();
assert_eq!(hex_str.len(), 64); let decoded = hex::decode(&hex_str).unwrap();
assert_eq!(decoded, kp.certificate);
}
#[test]
fn test_vault_certificate_seal_roundtrip() {
let (pager, tmp_dir) = temp_pager();
let kp = KeyPair::generate();
let vault = Vault::with_certificate_bytes(&pager, &kp.certificate).unwrap();
let state = VaultState {
users: vec![],
api_keys: vec![],
bootstrapped: true,
master_secret: Some(kp.master_secret.clone()),
kv: std::collections::HashMap::new(),
};
vault.save(&pager, &state).unwrap();
let vault2 = Vault::with_certificate(&pager, &kp.certificate_hex()).unwrap();
let loaded = vault2.load(&pager).unwrap().unwrap();
assert!(loaded.bootstrapped);
assert_eq!(loaded.master_secret, Some(kp.master_secret.clone()));
let kp2 = KeyPair::from_master_secret(loaded.master_secret.unwrap());
assert_eq!(kp.certificate, kp2.certificate);
drop(pager);
let _ = std::fs::remove_dir_all(&tmp_dir);
}
#[test]
fn test_vault_certificate_wrong_cert_fails() {
let (pager, tmp_dir) = temp_pager();
let kp = KeyPair::generate();
let vault = Vault::with_certificate_bytes(&pager, &kp.certificate).unwrap();
let state = VaultState {
users: vec![],
api_keys: vec![],
bootstrapped: true,
master_secret: Some(kp.master_secret.clone()),
kv: std::collections::HashMap::new(),
};
vault.save(&pager, &state).unwrap();
let kp2 = KeyPair::generate();
let vault2 = Vault::with_certificate_bytes(&pager, &kp2.certificate).unwrap();
let result = vault2.load(&pager);
assert!(result.is_err());
drop(pager);
let _ = std::fs::remove_dir_all(&tmp_dir);
}
#[test]
fn test_vault_state_master_secret_serialization() {
let secret = vec![0xAA; 32];
let state = VaultState {
users: vec![],
api_keys: vec![],
bootstrapped: true,
master_secret: Some(secret.clone()),
kv: std::collections::HashMap::new(),
};
let serialized = state.serialize();
let text = std::str::from_utf8(&serialized).unwrap();
assert!(text.contains("MASTER_SECRET:"));
assert!(text.contains(&hex::encode(&secret)));
let restored = VaultState::deserialize(&serialized).unwrap();
assert_eq!(restored.master_secret, Some(secret));
assert!(restored.bootstrapped);
}
#[test]
fn test_vault_state_no_master_secret_backward_compat() {
let data = b"SEALED:true\n";
let restored = VaultState::deserialize(data).unwrap();
assert!(restored.master_secret.is_none());
assert!(restored.bootstrapped);
}
#[test]
fn test_vault_state_scram_verifier_roundtrip() {
use crate::auth::scram::ScramVerifier;
let verifier = ScramVerifier::from_password(
"hunter2",
b"reddb-vault-test-salt".to_vec(),
crate::auth::scram::DEFAULT_ITER,
);
let now = now_ms();
let state = VaultState {
users: vec![User {
username: "carol".into(),
tenant_id: None,
password_hash: "argon2id$abc$def".into(),
scram_verifier: Some(verifier.clone()),
role: Role::Admin,
api_keys: vec![],
created_at: now,
updated_at: now,
enabled: true,
}],
api_keys: vec![],
bootstrapped: true,
master_secret: None,
kv: std::collections::HashMap::new(),
};
let bytes = state.serialize();
let restored = VaultState::deserialize(&bytes).unwrap();
let carol = restored
.users
.iter()
.find(|u| u.username == "carol")
.unwrap();
let v = carol.scram_verifier.as_ref().expect("verifier round-trips");
assert_eq!(v.salt, verifier.salt);
assert_eq!(v.iter, verifier.iter);
assert_eq!(v.stored_key, verifier.stored_key);
assert_eq!(v.server_key, verifier.server_key);
}
#[test]
fn test_vault_state_pre_tenant_user_line_still_parses() {
let now = now_ms();
let line = format!(
"USER:dave\targon2id$x$y\tread\ttrue\t{}\t{}\t\nSEALED:false\n",
now, now
);
let restored = VaultState::deserialize(line.as_bytes()).unwrap();
let dave = restored
.users
.iter()
.find(|u| u.username == "dave")
.unwrap();
assert!(dave.scram_verifier.is_none());
assert!(dave.tenant_id.is_none());
}
#[test]
fn test_vault_state_user_line_with_tenant_roundtrip() {
let now = now_ms();
let state = VaultState {
users: vec![User {
username: "alice".into(),
tenant_id: Some("acme".into()),
password_hash: "argon2id$x$y".into(),
scram_verifier: None,
role: Role::Write,
api_keys: vec![],
created_at: now,
updated_at: now,
enabled: true,
}],
api_keys: vec![],
bootstrapped: true,
master_secret: None,
kv: std::collections::HashMap::new(),
};
let bytes = state.serialize();
let text = std::str::from_utf8(&bytes).unwrap();
assert!(text.contains("\tacme\n"));
let restored = VaultState::deserialize(&bytes).unwrap();
let alice = restored
.users
.iter()
.find(|u| u.username == "alice")
.unwrap();
assert_eq!(alice.tenant_id.as_deref(), Some("acme"));
}
#[test]
fn test_vault_state_key_line_with_tenant_reattaches_correctly() {
let now = now_ms();
let state = VaultState {
users: vec![
User {
username: "alice".into(),
tenant_id: Some("acme".into()),
password_hash: "argon2id$x$y".into(),
scram_verifier: None,
role: Role::Write,
api_keys: vec![],
created_at: now,
updated_at: now,
enabled: true,
},
User {
username: "alice".into(),
tenant_id: Some("globex".into()),
password_hash: "argon2id$a$b".into(),
scram_verifier: None,
role: Role::Read,
api_keys: vec![],
created_at: now,
updated_at: now,
enabled: true,
},
],
api_keys: vec![
(
UserId::scoped("acme", "alice"),
ApiKey {
key: "rk_acme_key".into(),
name: "deploy".into(),
role: Role::Write,
created_at: now,
},
),
(
UserId::scoped("globex", "alice"),
ApiKey {
key: "rk_globex_key".into(),
name: "ci".into(),
role: Role::Read,
created_at: now,
},
),
],
bootstrapped: true,
master_secret: None,
kv: std::collections::HashMap::new(),
};
let bytes = state.serialize();
let restored = VaultState::deserialize(&bytes).unwrap();
assert_eq!(restored.api_keys.len(), 2);
let acme_key = restored
.api_keys
.iter()
.find(|(o, _)| o.tenant.as_deref() == Some("acme"))
.unwrap();
assert_eq!(acme_key.1.key, "rk_acme_key");
let globex_key = restored
.api_keys
.iter()
.find(|(o, _)| o.tenant.as_deref() == Some("globex"))
.unwrap();
assert_eq!(globex_key.1.key, "rk_globex_key");
}
#[test]
fn test_vault_state_scram_iter_below_min_rejected() {
let now = now_ms();
let stored_hex = "00".repeat(32);
let server_hex = "11".repeat(32);
let line = format!(
"USER:eve\targon2id$x$y\tread\ttrue\t{}\t{}\tdeadbeef:1024:{}:{}\n",
now, now, stored_hex, server_hex
);
match VaultState::deserialize(line.as_bytes()) {
Err(VaultError::Corrupt(msg)) => assert!(msg.contains("below minimum")),
Err(other) => panic!("expected Corrupt iter-floor error, got {other:?}"),
Ok(_) => panic!("expected Corrupt iter-floor error, got Ok"),
}
}
#[test]
fn test_constant_time_eq_function() {
assert!(constant_time_eq(b"hello", b"hello"));
assert!(!constant_time_eq(b"hello", b"world"));
assert!(!constant_time_eq(b"short", b"longer"));
assert!(constant_time_eq(b"", b""));
}
}