use super::*;
use crate::error::KeyRingError;
use aes_gcm::{
Aes256Gcm,
aead::{Aead, KeyInit, Payload},
};
use sha2::{Digest, Sha256};
use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use std::{fs, io};
use zeroize::Zeroizing;
const CONTEXT: &[u8] = b"cryptex-keyring";
const NONCE_DST: &[u8] = b"cryptex-nonce";
const ENTRY_VERSION: u8 = 1;
pub trait KmsBackend: Send + Sync {
fn backend_name(&self) -> &'static str;
fn key_id(&self) -> &str;
fn device_id(&self) -> [u8; 16];
fn get_random(&self, n: usize) -> Result<Vec<u8>>;
fn hmac_sha256(&self, msg: Vec<u8>) -> Result<[u8; 32]>;
}
#[derive(Clone)]
pub struct Entry {
pub version: u8,
pub key_id: String,
pub device_id: [u8; 16],
pub nonce: [u8; 12],
pub ciphertext: Vec<u8>,
}
impl Entry {
pub fn to_bytes(&self) -> Vec<u8> {
let key_id_bytes = self.key_id.as_bytes();
let key_id_len = key_id_bytes.len() as u16;
let ct_len = self.ciphertext.len() as u32;
let mut buf =
Vec::with_capacity(1 + 2 + key_id_bytes.len() + 16 + 12 + 4 + self.ciphertext.len());
buf.push(self.version);
buf.extend_from_slice(&key_id_len.to_le_bytes());
buf.extend_from_slice(key_id_bytes);
buf.extend_from_slice(&self.device_id);
buf.extend_from_slice(&self.nonce);
buf.extend_from_slice(&ct_len.to_le_bytes());
buf.extend_from_slice(&self.ciphertext);
buf
}
pub fn from_bytes(b: &[u8]) -> Result<Self> {
if b.len() < 3 {
return Err(corrupt());
}
let version = b[0];
let key_id_len = u16::from_le_bytes([b[1], b[2]]) as usize;
let key_id_end = 3 + key_id_len;
if b.len() < key_id_end + 32 {
return Err(corrupt());
}
let key_id = String::from_utf8(b[3..key_id_end].to_vec()).map_err(|_| corrupt())?;
let mut device_id = [0u8; 16];
device_id.copy_from_slice(&b[key_id_end..key_id_end + 16]);
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&b[key_id_end + 16..key_id_end + 28]);
let ct_len_off = key_id_end + 28;
let ct_len = u32::from_le_bytes([
b[ct_len_off],
b[ct_len_off + 1],
b[ct_len_off + 2],
b[ct_len_off + 3],
]) as usize;
let ct_start = ct_len_off + 4;
if b.len() < ct_start + ct_len {
return Err(corrupt());
}
let ciphertext = b[ct_start..ct_start + ct_len].to_vec();
Ok(Entry {
version,
key_id,
device_id,
nonce,
ciphertext,
})
}
}
pub struct KmsKeyRing<B: KmsBackend> {
pub(crate) backend: B,
pub(crate) storage_dir: PathBuf,
}
impl<B: KmsBackend> KmsKeyRing<B> {
pub fn open(backend: B, service: &str) -> Result<Self> {
let storage_dir = entry_dir(backend.backend_name(), service)?;
Ok(Self {
backend,
storage_dir,
})
}
pub fn list_secrets(&self) -> Result<Vec<BTreeMap<String, String>>> {
let mut results = Vec::new();
let entries = fs::read_dir(&self.storage_dir).map_err(io_err)?;
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("bin") {
continue;
}
if let Ok((id, e)) = read_entry_file(&path) {
let mut map = BTreeMap::new();
map.insert("id".to_string(), id);
map.insert("key_id".to_string(), e.key_id);
map.insert("device_id".to_string(), hex::encode(e.device_id));
results.push(map);
}
}
Ok(results)
}
fn entry_path(&self, id: &str) -> PathBuf {
self.storage_dir.join(entry_filename(id))
}
pub(crate) fn generate_nonce(&self) -> Result<[u8; 12]> {
let mut os_rand = Zeroizing::new([0u8; 32]);
getrandom::getrandom(os_rand.as_mut()).map_err(|e| KeyRingError::GeneralError {
msg: format!("OS RNG failed: {e}"),
})?;
let backend_rand = self.backend.get_random(32)?;
let mut hasher = Sha256::new();
hasher.update(NONCE_DST);
hasher.update(os_rand.as_ref());
if !backend_rand.is_empty() {
hasher.update(&backend_rand);
}
let digest = hasher.finalize();
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&digest[..12]);
Ok(nonce)
}
fn derive_key(&self, entry: &Entry) -> Result<Zeroizing<[u8; 32]>> {
let key_id_bytes = entry.key_id.as_bytes();
let mut hmac_input = Vec::with_capacity(CONTEXT.len() + 1 + key_id_bytes.len() + 16 + 12);
hmac_input.extend_from_slice(CONTEXT);
hmac_input.push(entry.version);
hmac_input.extend_from_slice(key_id_bytes);
hmac_input.extend_from_slice(&entry.device_id);
hmac_input.extend_from_slice(&entry.nonce);
let raw = self.backend.hmac_sha256(hmac_input)?;
let mut k_enc = Zeroizing::new([0u8; 32]);
*k_enc = raw;
Ok(k_enc)
}
fn build_aad(entry: &Entry) -> Vec<u8> {
let key_id_bytes = entry.key_id.as_bytes();
let mut aad = Vec::with_capacity(1 + key_id_bytes.len() + 16 + 12);
aad.push(entry.version);
aad.extend_from_slice(key_id_bytes);
aad.extend_from_slice(&entry.device_id);
aad.extend_from_slice(&entry.nonce);
aad
}
fn encrypt_entry(&self, plaintext: &[u8], nonce: [u8; 12]) -> Result<Entry> {
let entry = Entry {
version: ENTRY_VERSION,
key_id: self.backend.key_id().to_string(),
device_id: self.backend.device_id(),
nonce,
ciphertext: Vec::new(),
};
let k_enc = self.derive_key(&entry)?;
let cipher =
Aes256Gcm::new_from_slice(k_enc.as_ref()).map_err(|_| KeyRingError::GeneralError {
msg: "invalid key length for AES-256-GCM".to_string(),
})?;
let aad = Self::build_aad(&entry);
let gcm_nonce = aes_gcm::Nonce::from_slice(&entry.nonce);
let ciphertext = cipher
.encrypt(
gcm_nonce,
Payload {
msg: plaintext,
aad: &aad,
},
)
.map_err(|_| KeyRingError::GeneralError {
msg: "AES-256-GCM encryption failed".to_string(),
})?;
Ok(Entry {
ciphertext,
..entry
})
}
fn decrypt_entry(&self, entry: &Entry) -> Result<Vec<u8>> {
let k_enc = self.derive_key(entry)?;
let cipher =
Aes256Gcm::new_from_slice(k_enc.as_ref()).map_err(|_| KeyRingError::GeneralError {
msg: "invalid key length for AES-256-GCM".to_string(),
})?;
let aad = Self::build_aad(entry);
let gcm_nonce = aes_gcm::Nonce::from_slice(&entry.nonce);
cipher
.decrypt(
gcm_nonce,
Payload {
msg: &entry.ciphertext,
aad: &aad,
},
)
.map_err(|_| KeyRingError::GeneralError {
msg: "AES-256-GCM decryption failed (wrong key or corrupted data)".to_string(),
})
}
}
impl<B: KmsBackend> DynKeyRing for KmsKeyRing<B> {
fn get_secret(&mut self, id: &str) -> Result<KeyRingSecret> {
let path = self.entry_path(id);
if !path.exists() {
return Err(KeyRingError::ItemNotFound);
}
let (_stored_id, entry) = read_entry_file(&path)?;
let plaintext = self.decrypt_entry(&entry)?;
Ok(KeyRingSecret(plaintext))
}
fn set_secret(&mut self, id: &str, secret: &[u8]) -> Result<()> {
let nonce = self.generate_nonce()?;
let entry = self.encrypt_entry(secret, nonce)?;
let path = self.entry_path(id);
write_entry_file(&path, id, &entry)
}
fn delete_secret(&mut self, id: &str) -> Result<()> {
let path = self.entry_path(id);
if !path.exists() {
return Err(KeyRingError::ItemNotFound);
}
fs::remove_file(&path).map_err(io_err)
}
}
pub(crate) fn entry_filename(id: &str) -> String {
let hash = Sha256::digest(id.as_bytes());
format!("{}.bin", hex::encode(hash))
}
pub(crate) fn write_entry_file(path: &Path, id: &str, entry: &Entry) -> Result<()> {
let id_bytes = id.as_bytes();
let id_len = id_bytes.len() as u16;
let mut data = Vec::new();
data.extend_from_slice(&id_len.to_le_bytes());
data.extend_from_slice(id_bytes);
data.extend_from_slice(&entry.to_bytes());
let tmp = path.with_extension("tmp");
fs::write(&tmp, &data).map_err(io_err)?;
fs::rename(&tmp, path).map_err(io_err)
}
pub(crate) fn read_entry_file(path: &Path) -> Result<(String, Entry)> {
let data = fs::read(path).map_err(io_err)?;
if data.len() < 2 {
return Err(corrupt());
}
let id_len = u16::from_le_bytes([data[0], data[1]]) as usize;
let header = 2 + id_len;
if data.len() < header {
return Err(corrupt());
}
let id = String::from_utf8(data[2..header].to_vec()).map_err(|_| corrupt())?;
let entry = Entry::from_bytes(&data[header..])?;
Ok((id, entry))
}
pub(crate) fn entry_dir(backend_name: &str, service: &str) -> Result<PathBuf> {
let base = dirs::home_dir().ok_or_else(|| KeyRingError::GeneralError {
msg: "could not determine home directory".to_string(),
})?;
let dir = base
.join(".cryptex")
.join(backend_name)
.join(sanitize_name(service));
fs::create_dir_all(&dir).map_err(io_err)?;
Ok(dir)
}
pub(crate) fn sanitize_name(s: &str) -> String {
s.chars()
.map(|c| {
if c.is_alphanumeric() || c == '-' || c == '_' || c == '.' {
c
} else {
'_'
}
})
.collect()
}
pub(crate) fn corrupt() -> KeyRingError {
KeyRingError::GeneralError {
msg: "corrupted KMS entry file".to_string(),
}
}
pub(crate) fn io_err(e: io::Error) -> KeyRingError {
KeyRingError::GeneralError { msg: e.to_string() }
}
#[cfg(test)]
mod tests {
use super::Entry;
#[test]
fn test_entry_round_trip() {
let entry = Entry {
version: 1,
key_id: "mrk-1234abcd-12ab-34cd-56ef-1234567890ab".to_string(),
device_id: [0xEFu8; 16],
nonce: [0xCDu8; 12],
ciphertext: vec![1, 2, 3, 4, 5],
};
let bytes = entry.to_bytes();
let decoded = Entry::from_bytes(&bytes).expect("decode entry");
assert_eq!(decoded.version, entry.version);
assert_eq!(decoded.key_id, entry.key_id);
assert_eq!(decoded.device_id, entry.device_id);
assert_eq!(decoded.nonce, entry.nonce);
assert_eq!(decoded.ciphertext, entry.ciphertext);
}
#[test]
fn test_entry_rejects_short_input() {
assert!(Entry::from_bytes(&[]).is_err());
assert!(Entry::from_bytes(&[0u8; 2]).is_err());
assert!(Entry::from_bytes(&[1u8, 0u8, 0u8, 0u8]).is_err());
}
}