use std::collections::BTreeMap;
use std::io::Write;
use aes_gcm::aead::{AeadCore, AeadInPlace, KeyInit, OsRng, rand_core::RngCore};
use aes_gcm::{Aes256Gcm, Key, Nonce, Tag};
use zeroize::Zeroize;
const MAGIC: [u8; 4] = *b"QVLT";
const VERSION: u8 = 0x01;
const SALT_LEN: usize = 16;
const NONCE_LEN: usize = 12;
const TAG_LEN: usize = 16;
const HEADER_LEN: usize = 4 + 1 + SALT_LEN + NONCE_LEN + TAG_LEN;
pub const ITERATIONS: u32 = 600_000;
pub const MAX_KEY_LEN: usize = 256;
pub const MAX_VALUE_LEN: usize = 65536;
#[derive(Debug)]
pub enum VaultError {
TooSmall,
BadMagic,
UnsupportedVersion(u8),
DecryptionFailed,
EncryptionFailed,
MalformedData,
Io(std::io::Error),
}
impl std::fmt::Display for VaultError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooSmall => write!(f, "vault file too small"),
Self::BadMagic => write!(f, "invalid vault file (bad magic)"),
Self::UnsupportedVersion(v) => write!(f, "unsupported vault version: {v}"),
Self::DecryptionFailed => write!(f, "decryption failed (wrong passphrase?)"),
Self::EncryptionFailed => write!(f, "encryption failed"),
Self::MalformedData => write!(f, "malformed vault data"),
Self::Io(e) => write!(f, "I/O error: {e}"),
}
}
}
impl std::error::Error for VaultError {}
impl From<std::io::Error> for VaultError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
#[derive(Debug, Clone)]
pub struct Vault {
entries: BTreeMap<String, String>,
}
impl Vault {
pub fn new() -> Self {
Self {
entries: BTreeMap::new(),
}
}
pub fn from_map(entries: BTreeMap<String, String>) -> Self {
Self { entries }
}
pub fn get(&self, key: &str) -> Option<&str> {
self.entries.get(key).map(|s| s.as_str())
}
pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) -> Option<String> {
self.entries.insert(key.into(), value.into())
}
pub fn delete(&mut self, key: &str) -> Option<String> {
self.entries.remove(key)
}
pub fn keys(&self) -> impl Iterator<Item = &str> {
self.entries.keys().map(|s| s.as_str())
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.entries.iter().map(|(k, v)| (k.as_str(), v.as_str()))
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn entries_mut(&mut self) -> &mut BTreeMap<String, String> {
&mut self.entries
}
pub fn to_map(&self) -> BTreeMap<String, String> {
self.entries.clone()
}
pub fn encrypt(&self, passphrase: &str) -> Result<Vec<u8>, VaultError> {
let mut plaintext = serialize(&self.entries);
let mut salt = [0u8; SALT_LEN];
OsRng.fill_bytes(&mut salt);
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let key = derive_key(passphrase, &salt);
let cipher = Aes256Gcm::new(&key);
let tag = cipher
.encrypt_in_place_detached(&nonce, b"", &mut plaintext)
.map_err(|_| VaultError::EncryptionFailed)?;
drop(cipher);
let mut out = Vec::with_capacity(HEADER_LEN + plaintext.len());
out.write_all(&MAGIC)?;
out.write_all(&[VERSION])?;
out.write_all(&salt)?;
out.write_all(nonce.as_slice())?;
out.write_all(tag.as_slice())?;
out.write_all(&plaintext)?;
Ok(out)
}
pub fn decrypt(data: &[u8], passphrase: &str) -> Result<Self, VaultError> {
if data.len() < HEADER_LEN {
return Err(VaultError::TooSmall);
}
if data[0..4] != MAGIC {
return Err(VaultError::BadMagic);
}
if data[4] != VERSION {
return Err(VaultError::UnsupportedVersion(data[4]));
}
let salt = &data[5..5 + SALT_LEN];
let nonce_bytes = &data[5 + SALT_LEN..5 + SALT_LEN + NONCE_LEN];
let tag_bytes = &data[5 + SALT_LEN + NONCE_LEN..HEADER_LEN];
let ciphertext = &data[HEADER_LEN..];
let key = derive_key(passphrase, salt);
let cipher = Aes256Gcm::new(&key);
let nonce = Nonce::from_slice(nonce_bytes);
let tag = Tag::from_slice(tag_bytes);
let mut buf = ciphertext.to_vec();
cipher
.decrypt_in_place_detached(nonce, b"", &mut buf, tag)
.map_err(|_| VaultError::DecryptionFailed)?;
let entries = deserialize(&buf);
buf.zeroize();
Ok(Self { entries })
}
pub fn to_shell_exports(&self) -> String {
let mut out = String::new();
for (key, value) in &self.entries {
let escaped = value.replace('\'', "'\\''");
out.push_str(&format!("export {key}='{escaped}'\n"));
}
out
}
pub fn to_json(&self) -> String {
let mut out = String::from("{\n");
let len = self.entries.len();
for (i, (key, value)) in self.entries.iter().enumerate() {
let escaped = value.replace('\\', "\\\\").replace('"', "\\\"");
out.push_str(&format!(" \"{key}\": \"{escaped}\""));
if i + 1 < len {
out.push(',');
}
out.push('\n');
}
out.push_str("}\n");
out
}
}
impl Default for Vault {
fn default() -> Self {
Self::new()
}
}
impl Drop for Vault {
fn drop(&mut self) {
for value in self.entries.values_mut() {
unsafe {
let bytes = value.as_bytes_mut();
bytes.zeroize();
}
}
}
}
pub fn is_valid_key(key: &str) -> bool {
!key.is_empty()
&& key.len() <= MAX_KEY_LEN
&& key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
}
pub fn parse_env_lines(input: &str) -> Vec<(String, String)> {
let mut pairs = Vec::new();
for line in input.lines() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
let kv = trimmed.strip_prefix("export ").unwrap_or(trimmed);
if let Some((key, value)) = kv.split_once('=') {
let key = key.trim();
let value = value
.trim()
.trim_start_matches(|c| c == '"' || c == '\'')
.trim_end_matches(|c| c == '"' || c == '\'');
if is_valid_key(key) && !value.is_empty() {
pairs.push((key.to_string(), value.to_string()));
}
}
}
pairs
}
fn serialize(entries: &BTreeMap<String, String>) -> Vec<u8> {
let mut buf = Vec::new();
for (key, value) in entries {
let klen = key.len();
buf.push((klen >> 8) as u8);
buf.push((klen & 0xFF) as u8);
buf.extend_from_slice(key.as_bytes());
let vlen = value.len();
buf.push((vlen >> 24) as u8);
buf.push(((vlen >> 16) & 0xFF) as u8);
buf.push(((vlen >> 8) & 0xFF) as u8);
buf.push((vlen & 0xFF) as u8);
buf.extend_from_slice(value.as_bytes());
}
buf.extend_from_slice(&[0x00, 0x00]);
buf
}
fn deserialize(data: &[u8]) -> BTreeMap<String, String> {
let mut entries = BTreeMap::new();
let mut pos = 0;
while pos + 2 <= data.len() {
let klen = ((data[pos] as usize) << 8) | (data[pos + 1] as usize);
pos += 2;
if klen == 0 {
break;
}
if klen > MAX_KEY_LEN || pos + klen > data.len() {
break;
}
let key = String::from_utf8_lossy(&data[pos..pos + klen]).to_string();
pos += klen;
if pos + 4 > data.len() {
break;
}
let vlen = ((data[pos] as usize) << 24)
| ((data[pos + 1] as usize) << 16)
| ((data[pos + 2] as usize) << 8)
| (data[pos + 3] as usize);
pos += 4;
if vlen > MAX_VALUE_LEN || pos + vlen > data.len() {
break;
}
let value = String::from_utf8_lossy(&data[pos..pos + vlen]).to_string();
pos += vlen;
entries.insert(key, value);
}
entries
}
fn derive_key(passphrase: &str, salt: &[u8]) -> Key<Aes256Gcm> {
let key =
pbkdf2::pbkdf2_hmac_array::<sha2::Sha256, 32>(passphrase.as_bytes(), salt, ITERATIONS);
*Key::<Aes256Gcm>::from_slice(&key)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip() {
let mut vault = Vault::new();
vault.set("API_KEY", "sk-secret-123");
vault.set("DB_URL", "postgres://localhost/mydb");
let encrypted = vault.encrypt("test-pass").unwrap();
let decrypted = Vault::decrypt(&encrypted, "test-pass").unwrap();
assert_eq!(decrypted.get("API_KEY"), Some("sk-secret-123"));
assert_eq!(decrypted.get("DB_URL"), Some("postgres://localhost/mydb"));
assert_eq!(decrypted.len(), 2);
}
#[test]
fn wrong_passphrase() {
let vault = Vault::new();
let encrypted = vault.encrypt("correct").unwrap();
assert!(matches!(
Vault::decrypt(&encrypted, "wrong"),
Err(VaultError::DecryptionFailed)
));
}
#[test]
fn tamper_detection() {
let mut vault = Vault::new();
vault.set("KEY", "value");
let mut encrypted = vault.encrypt("pass").unwrap();
if let Some(last) = encrypted.last_mut() {
*last ^= 0xFF;
}
assert!(Vault::decrypt(&encrypted, "pass").is_err());
}
#[test]
fn fresh_nonce_per_encrypt() {
let vault = Vault::new();
let a = vault.encrypt("pass").unwrap();
let b = vault.encrypt("pass").unwrap();
assert_ne!(a, b);
}
#[test]
fn shell_escaping() {
let mut vault = Vault::new();
vault.set("KEY", "it's a \"test\"");
let exports = vault.to_shell_exports();
assert!(exports.contains("'it'\\''s a \"test\"'"));
}
#[test]
fn valid_keys() {
assert!(is_valid_key("API_KEY"));
assert!(is_valid_key("key123"));
assert!(!is_valid_key(""));
assert!(!is_valid_key("has space"));
assert!(!is_valid_key("has-dash"));
}
#[test]
fn parse_env() {
let input = r#"
export API_KEY="sk-123"
DB_URL=postgres://localhost
# comment
export EMPTY=
BARE=value
"#;
let pairs = parse_env_lines(input);
assert_eq!(pairs.len(), 3);
assert_eq!(pairs[0], ("API_KEY".into(), "sk-123".into()));
assert_eq!(pairs[1], ("DB_URL".into(), "postgres://localhost".into()));
assert_eq!(pairs[2], ("BARE".into(), "value".into()));
}
}