use aes_gcm::aead::{Aead, KeyInit, Payload};
use aes_gcm::{Aes256Gcm, Nonce};
use base64::Engine;
use base64::engine::general_purpose::STANDARD as B64_STANDARD;
use rand::RngCore;
use crate::error::Error;
#[derive(Debug, Clone)]
pub struct VaultEncryptResult {
pub encrypted_data: String,
pub context: std::collections::HashMap<String, String>,
pub encrypted_keys: String,
}
pub fn local_encrypt(
data: &str,
data_key_b64: &str,
encrypted_keys_b64: &str,
associated_data: &str,
) -> Result<String, Error> {
let raw_key = B64_STANDARD
.decode(data_key_b64)
.map_err(|e| Error::VaultCrypto(format!("decode data key: {e}")))?;
if raw_key.len() != 32 {
return Err(Error::VaultCrypto(format!(
"data key must be 32 bytes; got {}",
raw_key.len()
)));
}
let encrypted_keys = B64_STANDARD
.decode(encrypted_keys_b64)
.map_err(|e| Error::VaultCrypto(format!("decode encrypted keys: {e}")))?;
let cipher = Aes256Gcm::new_from_slice(&raw_key)
.map_err(|e| Error::VaultCrypto(format!("init AES-GCM: {e}")))?;
let mut nonce_bytes = [0u8; 12];
rand::rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(
nonce,
Payload {
msg: data.as_bytes(),
aad: associated_data.as_bytes(),
},
)
.map_err(|e| Error::VaultCrypto(format!("encrypt: {e}")))?;
let prefix = encode_leb128(encrypted_keys.len() as u32);
let mut buf = Vec::with_capacity(
prefix.len() + encrypted_keys.len() + nonce_bytes.len() + ciphertext.len(),
);
buf.extend_from_slice(&prefix);
buf.extend_from_slice(&encrypted_keys);
buf.extend_from_slice(&nonce_bytes);
buf.extend_from_slice(&ciphertext);
Ok(B64_STANDARD.encode(buf))
}
pub fn local_decrypt(
encrypted_data: &str,
data_key_b64: &str,
associated_data: &str,
) -> Result<String, Error> {
let raw = B64_STANDARD
.decode(encrypted_data)
.map_err(|e| Error::VaultCrypto(format!("base64 decode: {e}")))?;
let (keys_len, bytes_read) = decode_leb128(&raw)?;
let offset = bytes_read + keys_len as usize;
if offset + 12 > raw.len() {
return Err(Error::VaultCrypto(
"encrypted data too short: missing nonce".to_string(),
));
}
let nonce = &raw[offset..offset + 12];
let ciphertext = &raw[offset + 12..];
if ciphertext.is_empty() {
return Err(Error::VaultCrypto(
"encrypted data too short: missing ciphertext".to_string(),
));
}
let raw_key = B64_STANDARD
.decode(data_key_b64)
.map_err(|e| Error::VaultCrypto(format!("decode data key: {e}")))?;
let cipher = Aes256Gcm::new_from_slice(&raw_key)
.map_err(|e| Error::VaultCrypto(format!("init AES-GCM: {e}")))?;
let plaintext = cipher
.decrypt(
Nonce::from_slice(nonce),
Payload {
msg: ciphertext,
aad: associated_data.as_bytes(),
},
)
.map_err(|e| Error::VaultCrypto(format!("decrypt: {e}")))?;
String::from_utf8(plaintext).map_err(|e| Error::VaultCrypto(format!("utf-8: {e}")))
}
pub fn extract_encrypted_keys(encrypted_data: &str) -> Result<String, Error> {
let raw = B64_STANDARD
.decode(encrypted_data)
.map_err(|e| Error::VaultCrypto(format!("base64 decode: {e}")))?;
let (keys_len, bytes_read) = decode_leb128(&raw)?;
let total = bytes_read + keys_len as usize;
if raw.len() < total {
return Err(Error::VaultCrypto(
"encrypted data too short for declared key length".to_string(),
));
}
Ok(B64_STANDARD.encode(&raw[bytes_read..total]))
}
fn encode_leb128(mut n: u32) -> Vec<u8> {
if n == 0 {
return vec![0];
}
let mut out = Vec::new();
while n > 0 {
let mut b = (n & 0x7f) as u8;
n >>= 7;
if n > 0 {
b |= 0x80;
}
out.push(b);
}
out
}
fn decode_leb128(buf: &[u8]) -> Result<(u32, usize), Error> {
let mut result: u32 = 0;
let mut shift: u32 = 0;
for (i, &b) in buf.iter().enumerate() {
let chunk = (b & 0x7f) as u32;
if shift == 28 && (chunk >> 4) != 0 {
return Err(Error::VaultCrypto(
"LEB128 value too large for uint32".to_string(),
));
}
result |= chunk << shift;
if b & 0x80 == 0 {
return Ok((result, i + 1));
}
shift += 7;
if shift >= 35 {
return Err(Error::VaultCrypto(
"LEB128 value too large for uint32".to_string(),
));
}
}
Err(Error::VaultCrypto(
"unexpected end of LEB128 data".to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn leb128_round_trip() {
for n in [0u32, 1, 127, 128, 16383, 16384, 1_234_567] {
let bytes = encode_leb128(n);
let (decoded, consumed) = decode_leb128(&bytes).unwrap();
assert_eq!(decoded, n);
assert_eq!(consumed, bytes.len());
}
}
fn make_key_material() -> (String, String) {
let key = [9u8; 32];
(
B64_STANDARD.encode(key),
B64_STANDARD.encode([1u8, 2, 3, 4, 5]),
)
}
#[test]
fn local_encrypt_decrypt_round_trip() {
let (data_key, encrypted_keys) = make_key_material();
let plaintext = "hello world";
let aad = "ctx:env_1";
let sealed = local_encrypt(plaintext, &data_key, &encrypted_keys, aad).unwrap();
let opened = local_decrypt(&sealed, &data_key, aad).unwrap();
assert_eq!(opened, plaintext);
}
#[test]
fn local_decrypt_rejects_wrong_aad() {
let (data_key, encrypted_keys) = make_key_material();
let sealed = local_encrypt("secret", &data_key, &encrypted_keys, "ctx-a").unwrap();
assert!(local_decrypt(&sealed, &data_key, "ctx-b").is_err());
}
#[test]
fn extract_encrypted_keys_round_trip() {
let (data_key, encrypted_keys) = make_key_material();
let sealed = local_encrypt("data", &data_key, &encrypted_keys, "aad").unwrap();
let extracted = extract_encrypted_keys(&sealed).unwrap();
assert_eq!(extracted, encrypted_keys);
}
}