use aes_gcm_siv::aead::{Aead, Payload};
use aes_gcm_siv::{Aes256GcmSiv, KeyInit, Nonce};
use aes_kw::Kek;
#[cfg(not(feature = "test-vectors"))]
use rand::RngCore;
pub const KEY_SIZE: usize = 32;
pub const SALT_SIZE: usize = 32;
pub const NONCE_SIZE: usize = 12;
pub const WRAPPED_KEY_SIZE: usize = 40;
const ARGON2_MEM_KIB: u32 = 128 * 1024;
const ARGON2_TIME: u32 = 4;
const ARGON2_LANES: u32 = 4;
pub fn argon2_mem_kib() -> u32 {
ARGON2_MEM_KIB
}
pub fn argon2_time() -> u32 {
ARGON2_TIME
}
pub fn argon2_lanes() -> u32 {
ARGON2_LANES
}
#[cfg(not(feature = "test-vectors"))]
pub fn random_array<const N: usize>() -> [u8; N] {
let mut out = [0u8; N];
rand::rngs::OsRng.fill_bytes(&mut out);
out
}
#[cfg(feature = "test-vectors")]
const TEST_VECTOR_SEED: &[u8] = b"AEROVAULT3 test-vectors v1";
#[cfg(feature = "test-vectors")]
thread_local! {
static TEST_VECTOR_COUNTER: std::cell::Cell<u64> = const { std::cell::Cell::new(0) };
}
#[cfg(feature = "test-vectors")]
pub fn reset_test_vectors() {
TEST_VECTOR_COUNTER.with(|c| c.set(0));
}
#[cfg(feature = "test-vectors")]
pub fn random_array<const N: usize>() -> [u8; N] {
let mut out = [0u8; N];
let mut off = 0usize;
while off < N {
let ctr = TEST_VECTOR_COUNTER.with(|c| {
let v = c.get();
c.set(v + 1);
v
});
let mut input = TEST_VECTOR_SEED.to_vec();
input.extend_from_slice(&ctr.to_le_bytes());
let block = blake3::hash(&input);
let take = core::cmp::min(32, N - off);
out[off..off + take].copy_from_slice(&block.as_bytes()[..take]);
off += take;
}
out
}
pub fn derive_base_kek(password: &str, salt: &[u8; SALT_SIZE]) -> Result<[u8; KEY_SIZE], String> {
let params = argon2::Params::new(ARGON2_MEM_KIB, ARGON2_TIME, ARGON2_LANES, Some(KEY_SIZE))
.map_err(|e| format!("Argon2 params: {e}"))?;
let argon2 = argon2::Argon2::new(argon2::Algorithm::Argon2id, argon2::Version::V0x13, params);
let mut key = [0u8; KEY_SIZE];
argon2
.hash_password_into(password.as_bytes(), salt, &mut key)
.map_err(|e| format!("Argon2 derive: {e}"))?;
Ok(key)
}
pub fn hkdf_expand<const N: usize>(ikm: &[u8], label: &[u8]) -> Result<[u8; N], String> {
let hk = hkdf::Hkdf::<sha2::Sha256>::new(None, ikm);
let mut out = [0u8; N];
hk.expand(label, &mut out)
.map_err(|_| "HKDF expand failed".to_string())?;
Ok(out)
}
pub fn wrap_key(
kek: &[u8; KEY_SIZE],
key: &[u8; KEY_SIZE],
) -> Result<[u8; WRAPPED_KEY_SIZE], String> {
let kek = Kek::from(*kek);
let mut out = [0u8; WRAPPED_KEY_SIZE];
kek.wrap(key, &mut out)
.map_err(|_| "AES-KW wrap failed".to_string())?;
Ok(out)
}
pub fn unwrap_key(
kek: &[u8; KEY_SIZE],
wrapped: &[u8; WRAPPED_KEY_SIZE],
) -> Result<[u8; KEY_SIZE], String> {
let kek = Kek::from(*kek);
let mut out = [0u8; KEY_SIZE];
kek.unwrap(wrapped, &mut out)
.map_err(|_| "AES-KW unwrap failed".to_string())?;
Ok(out)
}
pub fn encrypt_with_aad(
key: &[u8; KEY_SIZE],
plaintext: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, String> {
let cipher = Aes256GcmSiv::new_from_slice(key).map_err(|e| format!("AES-GCM-SIV init: {e}"))?;
let nonce_bytes = random_array::<NONCE_SIZE>();
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad,
},
)
.map_err(|_| "AES-GCM-SIV encrypt failed".to_string())?;
let mut out = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(&ciphertext);
Ok(out)
}
pub fn decrypt_with_aad(
key: &[u8; KEY_SIZE],
encrypted: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, String> {
if encrypted.len() < NONCE_SIZE + 16 {
return Err("AES-GCM-SIV payload is too short".to_string());
}
let cipher = Aes256GcmSiv::new_from_slice(key).map_err(|e| format!("AES-GCM-SIV init: {e}"))?;
let nonce = Nonce::from_slice(&encrypted[..NONCE_SIZE]);
cipher
.decrypt(
nonce,
Payload {
msg: &encrypted[NONCE_SIZE..],
aad,
},
)
.map_err(|_| "AES-GCM-SIV decrypt failed".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gcm_siv_round_trip_binds_aad() {
let key = [7u8; KEY_SIZE];
let msg = b"AeroCrypt shared codec round trip";
let ct = encrypt_with_aad(&key, msg, b"domain-a").unwrap();
assert_eq!(decrypt_with_aad(&key, &ct, b"domain-a").unwrap(), msg);
assert!(decrypt_with_aad(&key, &ct, b"domain-b").is_err());
assert!(decrypt_with_aad(&[8u8; KEY_SIZE], &ct, b"domain-a").is_err());
}
#[test]
fn gcm_siv_rejects_short_payload() {
assert!(decrypt_with_aad(&[0u8; KEY_SIZE], &[0u8; 4], b"").is_err());
}
#[test]
fn aes_kw_wrap_unwrap_round_trip() {
let kek = [3u8; KEY_SIZE];
let key = [9u8; KEY_SIZE];
let wrapped = wrap_key(&kek, &key).unwrap();
assert_eq!(wrapped.len(), WRAPPED_KEY_SIZE);
assert_eq!(unwrap_key(&kek, &wrapped).unwrap(), key);
assert!(unwrap_key(&[4u8; KEY_SIZE], &wrapped).is_err());
}
#[test]
fn hkdf_and_kw_known_answer() {
let ikm = [0x11u8; KEY_SIZE];
let chunk_id_key =
hkdf_expand::<KEY_SIZE>(&ikm, b"AeroVault v3 keyed BLAKE3 chunk ids").unwrap();
assert_eq!(
hex_lower(&chunk_id_key),
"7773bb7d76fa136062ea5d1a8c3747700298bb19a386edb12dd3db30520f4414",
"HKDF chunk-id key drifted (one-byte label drift breaks cross-open)"
);
let kek = [0x22u8; KEY_SIZE];
let key = [0x33u8; KEY_SIZE];
let wrapped = wrap_key(&kek, &key).unwrap();
assert_eq!(unwrap_key(&kek, &wrapped).unwrap(), key);
}
fn hex_lower(bytes: &[u8]) -> String {
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
s.push_str(&format!("{b:02x}"));
}
s
}
#[test]
fn argon2_and_hkdf_are_deterministic() {
let salt = [1u8; SALT_SIZE];
let a = derive_base_kek("correct horse battery staple", &salt).unwrap();
let b = derive_base_kek("correct horse battery staple", &salt).unwrap();
assert_eq!(a, b);
let c = derive_base_kek("different password", &salt).unwrap();
assert_ne!(a, c);
let k1 = hkdf_expand::<KEY_SIZE>(&a, b"label-1").unwrap();
let k2 = hkdf_expand::<KEY_SIZE>(&a, b"label-1").unwrap();
let k3 = hkdf_expand::<KEY_SIZE>(&a, b"label-2").unwrap();
assert_eq!(k1, k2);
assert_ne!(k1, k3);
}
}