use std::io::{self, Read, Write};
use std::mem;
use chacha20poly1305::aead::stream::{DecryptorBE32, EncryptorBE32};
use chacha20poly1305::aead::{AeadInPlace, KeyInit};
use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce};
use zeroize::{Zeroize, Zeroizing};
fn scrub_tail_and_truncate(out: &mut Vec<u8>, start: usize, n: usize) {
let grown_len = start + n;
assert!(
out.len() >= grown_len,
"scrub precondition violated: len={} < start+n={}",
out.len(),
grown_len
);
out[start..grown_len].zeroize();
out.truncate(start);
}
pub const VERSION: u8 = 0x01;
const ACCEPTED_VERSIONS: &[u8] = &[VERSION];
pub const NONCE_LEN: usize = 12;
pub const TAG_LEN: usize = 16;
pub const STREAM_NONCE_LEN: usize = 7;
pub const STREAM_CHUNK: usize = 64 * 1024;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("data too short to be valid ciphertext")]
TooShort,
#[error("authentication failed: wrong key, wrong AAD, or tampered data")]
AuthenticationFailed,
#[error("unsupported wire-format version: {0:#04x}")]
UnsupportedVersion(u8),
#[error("plaintext exceeds ChaCha20-Poly1305 message-length limit")]
PlaintextTooLarge,
#[error("OS RNG failure: {0}")]
Rng(getrandom::Error),
}
impl From<getrandom::Error> for Error {
fn from(e: getrandom::Error) -> Self {
Self::Rng(e)
}
}
impl PartialEq for Error {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Io(_), Self::Io(_))
| (Self::TooShort, Self::TooShort)
| (Self::AuthenticationFailed, Self::AuthenticationFailed)
| (Self::PlaintextTooLarge, Self::PlaintextTooLarge)
| (Self::Rng(_), Self::Rng(_)) => true,
(Self::UnsupportedVersion(a), Self::UnsupportedVersion(b)) => a == b,
_ => false,
}
}
}
#[must_use]
pub fn hash_bytes(data: &[u8]) -> blake3::Hash {
blake3::hash(data)
}
pub fn hash_stream(reader: &mut impl Read) -> Result<blake3::Hash, Error> {
let mut hasher = blake3::Hasher::new();
hasher.update_reader(reader)?;
Ok(hasher.finalize())
}
pub struct Vault {
cipher: ChaCha20Poly1305,
}
impl std::fmt::Debug for Vault {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Vault").finish_non_exhaustive()
}
}
const _: fn() = || {
const fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Vault>();
};
impl Vault {
#[must_use]
pub fn new(master_key: [u8; 32]) -> Self {
let key = Zeroizing::new(master_key);
let cipher = ChaCha20Poly1305::new(Key::from_slice(key.as_ref()));
Self { cipher }
}
pub fn generate() -> Result<Self, Error> {
let mut key = Zeroizing::new([0u8; 32]);
getrandom::getrandom(key.as_mut())?;
Ok(Self::new(*key))
}
fn cipher(&self) -> ChaCha20Poly1305 {
self.cipher.clone()
}
pub fn encrypt(&self, plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>, Error> {
let mut out = Vec::with_capacity(1 + NONCE_LEN + plaintext.len() + TAG_LEN);
self.encrypt_into(plaintext, aad, &mut out)?;
Ok(out)
}
pub fn encrypt_into(
&self,
plaintext: &[u8],
aad: &[u8],
out: &mut Vec<u8>,
) -> Result<(), Error> {
let mut nonce_bytes = [0u8; NONCE_LEN];
getrandom::getrandom(&mut nonce_bytes)?;
let nonce = Nonce::from_slice(&nonce_bytes);
let bound_aad = BoundAad::new(VERSION, aad);
let start = out.len();
out.reserve(1 + NONCE_LEN + plaintext.len() + TAG_LEN);
let ct_start = start + 1 + NONCE_LEN;
out.push(VERSION);
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(plaintext);
if let Ok(tag) =
self.cipher
.encrypt_in_place_detached(nonce, bound_aad.as_slice(), &mut out[ct_start..])
{
out.extend_from_slice(&tag);
Ok(())
} else {
let written = out.len() - start;
scrub_tail_and_truncate(out, start, written);
Err(Error::PlaintextTooLarge)
}
}
pub fn decrypt(&self, data: &[u8], aad: &[u8]) -> Result<Vec<u8>, Error> {
let mut out = Vec::new();
self.decrypt_into(data, aad, &mut out)?;
Ok(out)
}
pub fn decrypt_into(&self, data: &[u8], aad: &[u8], out: &mut Vec<u8>) -> Result<(), Error> {
if data.len() < 1 + NONCE_LEN + TAG_LEN {
return Err(Error::TooShort);
}
let version = data[0];
if !ACCEPTED_VERSIONS.contains(&version) {
return Err(Error::UnsupportedVersion(version));
}
#[allow(clippy::range_plus_one)]
let nonce = Nonce::from_slice(&data[1..1 + NONCE_LEN]);
let ct_and_tag = &data[1 + NONCE_LEN..];
let ct_len = ct_and_tag.len() - TAG_LEN;
let (ct, tag) = ct_and_tag.split_at(ct_len);
let tag = chacha20poly1305::Tag::from_slice(tag);
let bound_aad = BoundAad::new(version, aad);
let start = out.len();
out.reserve(ct_len);
out.extend_from_slice(ct);
if self
.cipher
.decrypt_in_place_detached(nonce, bound_aad.as_slice(), &mut out[start..], tag)
.is_err()
{
scrub_tail_and_truncate(out, start, ct_len);
return Err(Error::AuthenticationFailed);
}
Ok(())
}
pub fn encrypt_stream(
&self,
reader: &mut impl Read,
writer: &mut impl Write,
aad: &[u8],
) -> Result<(), Error> {
self.encrypt_stream_inner(reader, writer, aad, None)?;
Ok(())
}
pub fn hash_and_encrypt_stream(
&self,
reader: &mut impl Read,
writer: &mut impl Write,
aad: &[u8],
) -> Result<blake3::Hash, Error> {
let mut hasher = blake3::Hasher::new();
self.encrypt_stream_inner(reader, writer, aad, Some(&mut hasher))?;
Ok(hasher.finalize())
}
fn encrypt_stream_inner(
&self,
reader: &mut impl Read,
writer: &mut impl Write,
aad: &[u8],
mut plaintext_hasher: Option<&mut blake3::Hasher>,
) -> Result<(), Error> {
let mut header_nonce = [0u8; STREAM_NONCE_LEN];
getrandom::getrandom(&mut header_nonce)?;
writer.write_all(&[VERSION])?;
writer.write_all(&header_nonce)?;
let bound_aad = BoundAad::new(VERSION, aad);
let mut encryptor = EncryptorBE32::from_aead(self.cipher(), header_nonce.as_ref().into());
let mut buf: Vec<u8> = Vec::with_capacity(STREAM_CHUNK + TAG_LEN);
let mut next: Vec<u8> = Vec::with_capacity(STREAM_CHUNK + TAG_LEN);
buf.resize(STREAM_CHUNK, 0);
next.resize(STREAM_CHUNK, 0);
let mut pending = read_up_to(reader, &mut buf)?;
loop {
next.resize(STREAM_CHUNK, 0);
let next_len = read_up_to(reader, &mut next)?;
if let Some(h) = plaintext_hasher.as_deref_mut() {
h.update(&buf[..pending]);
}
buf.truncate(pending);
if next_len == 0 {
encryptor
.encrypt_last_in_place(bound_aad.as_slice(), &mut buf)
.map_err(|_| io::Error::other("STREAM encrypt_last failed"))?;
writer.write_all(&buf)?;
return Ok(());
}
encryptor
.encrypt_next_in_place(bound_aad.as_slice(), &mut buf)
.map_err(|_| io::Error::other("STREAM encrypt_next failed"))?;
writer.write_all(&buf)?;
mem::swap(&mut buf, &mut next);
pending = next_len;
}
}
pub fn decrypt_stream(
&self,
reader: &mut impl Read,
writer: &mut impl Write,
aad: &[u8],
) -> Result<(), Error> {
let mut version = [0u8; 1];
match reader.read_exact(&mut version) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Err(Error::TooShort),
Err(e) => return Err(Error::Io(e)),
}
if !ACCEPTED_VERSIONS.contains(&version[0]) {
return Err(Error::UnsupportedVersion(version[0]));
}
let mut header_nonce = [0u8; STREAM_NONCE_LEN];
match reader.read_exact(&mut header_nonce) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Err(Error::TooShort),
Err(e) => return Err(Error::Io(e)),
}
let bound_aad = BoundAad::new(version[0], aad);
let mut decryptor = DecryptorBE32::from_aead(self.cipher(), header_nonce.as_ref().into());
let frame = STREAM_CHUNK + TAG_LEN;
let mut current: Vec<u8> = Vec::with_capacity(frame);
let mut next: Vec<u8> = Vec::with_capacity(frame);
current.resize(frame, 0);
let mut current_len = read_up_to(reader, &mut current)?;
if current_len < TAG_LEN {
return Err(Error::AuthenticationFailed);
}
loop {
next.resize(frame, 0);
let next_len = read_up_to(reader, &mut next)?;
current.truncate(current_len);
if next_len == 0 {
decryptor
.decrypt_last_in_place(bound_aad.as_slice(), &mut current)
.map_err(|_| Error::AuthenticationFailed)?;
writer.write_all(¤t)?;
return Ok(());
}
if next_len < TAG_LEN {
return Err(Error::AuthenticationFailed);
}
decryptor
.decrypt_next_in_place(bound_aad.as_slice(), &mut current)
.map_err(|_| Error::AuthenticationFailed)?;
writer.write_all(¤t)?;
mem::swap(&mut current, &mut next);
current_len = next_len;
}
}
}
const BOUND_AAD_INLINE: usize = 64;
enum BoundAad {
Inline {
buf: [u8; BOUND_AAD_INLINE],
len: usize,
},
Heap(Vec<u8>),
}
impl BoundAad {
fn new(version: u8, user_aad: &[u8]) -> Self {
let total = 1 + user_aad.len();
if total <= BOUND_AAD_INLINE {
let mut buf = [0u8; BOUND_AAD_INLINE];
buf[0] = version;
buf[1..total].copy_from_slice(user_aad);
Self::Inline { buf, len: total }
} else {
let mut v = Vec::with_capacity(total);
v.push(version);
v.extend_from_slice(user_aad);
Self::Heap(v)
}
}
fn as_slice(&self) -> &[u8] {
match self {
Self::Inline { buf, len } => &buf[..*len],
Self::Heap(v) => v.as_slice(),
}
}
}
fn read_up_to(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
let mut total = 0;
while total < buf.len() {
match reader.read(&mut buf[total..]) {
Ok(0) => return Ok(total),
Ok(n) => total += n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(total)
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use super::*;
const KEY: [u8; 32] = [0x42; 32];
const AAD: &[u8] = b"test-context";
fn vault() -> Vault {
Vault::new(KEY)
}
#[test]
fn into_variants_roundtrip() {
let v = vault();
let mut ct = Vec::new();
v.encrypt_into(b"hello", AAD, &mut ct).unwrap();
let mut pt = Vec::new();
v.decrypt_into(&ct, AAD, &mut pt).unwrap();
assert_eq!(pt, b"hello");
let ct2 = v.encrypt(b"hello", AAD).unwrap();
assert_eq!(v.decrypt(&ct2, AAD).unwrap(), b"hello");
}
#[test]
fn into_variants_append() {
let v = vault();
let mut ct = vec![0xAA, 0xBB, 0xCC];
v.encrypt_into(b"payload", AAD, &mut ct).unwrap();
assert_eq!(&ct[..3], &[0xAA, 0xBB, 0xCC]);
assert_eq!(v.decrypt(&ct[3..], AAD).unwrap(), b"payload");
}
#[test]
fn decrypt_into_reuses_buffer() {
let v = vault();
let msgs: [&[u8]; 3] = [b"one", b"two-two", b"three-three-three"];
let mut pt = Vec::with_capacity(64);
let initial_cap = pt.capacity();
for msg in msgs {
let ct = v.encrypt(msg, AAD).unwrap();
pt.clear();
v.decrypt_into(&ct, AAD, &mut pt).unwrap();
assert_eq!(pt, msg);
}
assert_eq!(pt.capacity(), initial_cap, "no realloc expected");
}
#[test]
fn decrypt_into_leaves_buffer_unchanged_on_failure() {
let v = vault();
let mut ct = v.encrypt(b"secret", AAD).unwrap();
*ct.last_mut().unwrap() ^= 0xff;
let mut out = vec![0x11, 0x22, 0x33];
let before = out.clone();
assert!(matches!(
v.decrypt_into(&ct, AAD, &mut out),
Err(Error::AuthenticationFailed)
));
assert_eq!(out, before);
}
#[test]
fn roundtrip_empty() {
let v = vault();
let ct = v.encrypt(b"", AAD).unwrap();
assert_eq!(v.decrypt(&ct, AAD).unwrap(), b"");
}
#[test]
fn roundtrip_small() {
let v = vault();
let ct = v.encrypt(b"hello, nahui!", AAD).unwrap();
assert_eq!(v.decrypt(&ct, AAD).unwrap(), b"hello, nahui!");
}
#[test]
fn roundtrip_large() {
let v = vault();
let msg: Vec<u8> = (0u32..)
.map(|i| u8::try_from(i & 0xFF).unwrap())
.take(10 * 1024 * 1024)
.collect();
let ct = v.encrypt(&msg, AAD).unwrap();
assert_eq!(v.decrypt(&ct, AAD).unwrap(), msg);
}
#[test]
fn output_length() {
let v = vault();
let ct = v.encrypt(b"four", AAD).unwrap();
assert_eq!(ct.len(), 1 + NONCE_LEN + 4 + TAG_LEN);
}
#[test]
fn nonce_is_random() {
let v = vault();
let ct1 = v.encrypt(b"same", AAD).unwrap();
let ct2 = v.encrypt(b"same", AAD).unwrap();
assert_ne!(ct1, ct2);
}
#[test]
fn wrong_key_fails() {
let v = vault();
let ct = v.encrypt(b"secret", AAD).unwrap();
let bad = Vault::new([0u8; 32]);
assert!(matches!(
bad.decrypt(&ct, AAD),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn wrong_aad_fails() {
let v = vault();
let ct = v.encrypt(b"secret", b"context-A").unwrap();
assert!(matches!(
v.decrypt(&ct, b"context-B"),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn tampered_ciphertext_fails() {
let v = vault();
let mut ct = v.encrypt(b"secret", AAD).unwrap();
ct[1 + NONCE_LEN] ^= 0xff;
assert!(matches!(
v.decrypt(&ct, AAD),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn tampered_tag_fails() {
let v = vault();
let mut ct = v.encrypt(b"secret", AAD).unwrap();
let last = ct.len() - 1;
ct[last] ^= 0xff;
assert!(matches!(
v.decrypt(&ct, AAD),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn tampered_nonce_fails() {
let v = vault();
let mut ct = v.encrypt(b"secret", AAD).unwrap();
ct[1] ^= 0xff;
assert!(matches!(
v.decrypt(&ct, AAD),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn unknown_version_fails() {
let v = vault();
let mut ct = v.encrypt(b"secret", AAD).unwrap();
ct[0] = 0xFE;
assert!(matches!(
v.decrypt(&ct, AAD),
Err(Error::UnsupportedVersion(0xFE))
));
}
#[test]
fn too_short_fails() {
let v = vault();
assert!(matches!(v.decrypt(&[], AAD), Err(Error::TooShort)));
assert!(matches!(v.decrypt(&[0u8; 10], AAD), Err(Error::TooShort)));
let ct = v.encrypt(b"", AAD).unwrap();
assert_eq!(ct.len(), 1 + NONCE_LEN + TAG_LEN);
assert!(v.decrypt(&ct, AAD).is_ok());
}
#[test]
fn truncated_ciphertext_fails() {
let v = vault();
let ct = v.encrypt(b"truncation test", AAD).unwrap();
assert!(v.decrypt(&ct[..ct.len() - 1], AAD).is_err());
assert!(v.decrypt(&ct[..ct.len() - TAG_LEN], AAD).is_err());
assert!(v.decrypt(&ct[..=NONCE_LEN], AAD).is_err());
}
#[test]
fn appended_bytes_fails() {
let v = vault();
let mut ct = v.encrypt(b"append test", AAD).unwrap();
ct.push(0x00);
assert!(matches!(
v.decrypt(&ct, AAD),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn stream_roundtrip_empty() {
let v = vault();
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(b"".as_slice()), &mut ct, AAD)
.unwrap();
let mut pt = Vec::new();
v.decrypt_stream(&mut Cursor::new(ct), &mut pt, AAD)
.unwrap();
assert!(pt.is_empty());
}
#[test]
fn stream_roundtrip_small() {
let v = vault();
let mut ct = Vec::new();
v.encrypt_stream(
&mut Cursor::new(b"streaming hello".as_slice()),
&mut ct,
AAD,
)
.unwrap();
let mut pt = Vec::new();
v.decrypt_stream(&mut Cursor::new(ct), &mut pt, AAD)
.unwrap();
assert_eq!(pt, b"streaming hello");
}
#[test]
fn stream_roundtrip_large() {
let v = vault();
let msg: Vec<u8> = (0u32..)
.map(|i| u8::try_from(i & 0xFF).unwrap())
.take(5 * 1024 * 1024)
.collect();
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(&msg), &mut ct, AAD)
.unwrap();
let mut pt = Vec::new();
v.decrypt_stream(&mut Cursor::new(ct), &mut pt, AAD)
.unwrap();
assert_eq!(pt, msg);
}
#[test]
fn stream_roundtrip_at_chunk_boundaries() {
let v = vault();
for size in [
STREAM_CHUNK - 1,
STREAM_CHUNK,
STREAM_CHUNK + 1,
2 * STREAM_CHUNK,
2 * STREAM_CHUNK + 1,
5 * STREAM_CHUNK,
] {
let msg: Vec<u8> = (0u8..=255).cycle().take(size).collect();
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(&msg), &mut ct, AAD)
.unwrap();
let mut pt = Vec::new();
v.decrypt_stream(&mut Cursor::new(ct), &mut pt, AAD)
.unwrap();
assert_eq!(pt, msg, "failed at size {size}");
}
}
#[test]
fn stream_wrong_key_fails() {
let v = vault();
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(b"secret".as_slice()), &mut ct, AAD)
.unwrap();
let bad = Vault::new([0u8; 32]);
let mut pt = Vec::new();
assert!(matches!(
bad.decrypt_stream(&mut Cursor::new(ct), &mut pt, AAD),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn stream_wrong_aad_fails() {
let v = vault();
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(b"secret".as_slice()), &mut ct, b"A")
.unwrap();
let mut pt = Vec::new();
assert!(matches!(
v.decrypt_stream(&mut Cursor::new(ct), &mut pt, b"B"),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn stream_tampered_ciphertext_fails() {
let v = vault();
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(b"secret".as_slice()), &mut ct, AAD)
.unwrap();
ct[1 + STREAM_NONCE_LEN] ^= 0xff;
let mut pt = Vec::new();
assert!(matches!(
v.decrypt_stream(&mut Cursor::new(ct), &mut pt, AAD),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn stream_tampered_tag_fails() {
let v = vault();
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(b"mac test".as_slice()), &mut ct, AAD)
.unwrap();
let last = ct.len() - 1;
ct[last] ^= 0xff;
let mut pt = Vec::new();
assert!(matches!(
v.decrypt_stream(&mut Cursor::new(ct), &mut pt, AAD),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn stream_tampered_header_nonce_fails() {
let v = vault();
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(b"hello".as_slice()), &mut ct, AAD)
.unwrap();
ct[1] ^= 0xff;
let mut pt = Vec::new();
assert!(matches!(
v.decrypt_stream(&mut Cursor::new(ct), &mut pt, AAD),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn stream_unknown_version_fails() {
let v = vault();
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(b"hi".as_slice()), &mut ct, AAD)
.unwrap();
ct[0] = 0xFE;
let mut pt = Vec::new();
assert!(matches!(
v.decrypt_stream(&mut Cursor::new(ct), &mut pt, AAD),
Err(Error::UnsupportedVersion(0xFE))
));
}
#[test]
fn stream_truncated_at_boundary_fails() {
let v = vault();
let msg: Vec<u8> = (0u8..=255).cycle().take(STREAM_CHUNK + 1024).collect();
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(&msg), &mut ct, AAD)
.unwrap();
let truncated = &ct[..1 + STREAM_NONCE_LEN + STREAM_CHUNK + TAG_LEN];
let mut pt = Vec::new();
assert!(matches!(
v.decrypt_stream(&mut Cursor::new(truncated), &mut pt, AAD),
Err(Error::AuthenticationFailed)
));
}
#[test]
fn stream_too_short_fails() {
let v = vault();
let mut pt = Vec::new();
assert!(matches!(
v.decrypt_stream(&mut Cursor::new(Vec::new()), &mut pt, AAD),
Err(Error::TooShort)
));
let short = vec![0u8; STREAM_NONCE_LEN]; assert!(matches!(
v.decrypt_stream(&mut Cursor::new(short), &mut pt, AAD),
Err(Error::TooShort | Error::UnsupportedVersion(_))
));
}
#[test]
fn hash_bytes_is_deterministic() {
assert_eq!(hash_bytes(b"hello"), hash_bytes(b"hello"));
}
#[test]
fn hash_bytes_and_stream_agree() {
let data: Vec<u8> = (0u8..=255).cycle().take(200_000).collect();
let h_mem = hash_bytes(&data);
let h_stream = hash_stream(&mut Cursor::new(&data)).unwrap();
assert_eq!(h_mem, h_stream);
}
#[test]
fn hash_and_encrypt_stream_hash_matches_plaintext() {
let v = vault();
let msg: Vec<u8> = (0u8..=255).cycle().take(300_000).collect();
let expected = hash_bytes(&msg);
let mut ct = Vec::new();
let got = v
.hash_and_encrypt_stream(&mut Cursor::new(&msg), &mut ct, AAD)
.unwrap();
assert_eq!(expected, got);
}
#[test]
fn hash_and_encrypt_stream_output_is_decryptable() {
let v = vault();
let mut ct = Vec::new();
v.hash_and_encrypt_stream(&mut Cursor::new(b"attest".as_slice()), &mut ct, AAD)
.unwrap();
let mut pt = Vec::new();
v.decrypt_stream(&mut Cursor::new(ct), &mut pt, AAD)
.unwrap();
assert_eq!(pt, b"attest");
}
#[test]
fn hash_to_hex_is_64_chars() {
assert_eq!(hash_bytes(b"test").to_hex().len(), 64);
}
#[test]
fn wrong_key_does_not_reveal_plaintext() {
let v = vault();
let ct = v.encrypt(b"super secret", AAD).unwrap();
for i in 0u8..=255 {
let mut raw = [i; 32];
raw[0] = i;
if raw == KEY {
continue;
}
assert!(
Vault::new(raw).decrypt(&ct, AAD).is_err(),
"key {i} must fail"
);
}
}
#[test]
fn single_bit_key_difference_fails() {
let v = vault();
let ct = v.encrypt(b"bit-flip", AAD).unwrap();
for byte_idx in 0..32 {
for bit in 0..8 {
let mut raw = KEY;
raw[byte_idx] ^= 1 << bit;
assert!(
Vault::new(raw).decrypt(&ct, AAD).is_err(),
"bit {bit} of byte {byte_idx} must fail"
);
}
}
}
#[test]
fn all_zero_key_works() {
let v = Vault::new([0u8; 32]);
let ct = v.encrypt(b"zero-key", AAD).unwrap();
assert_eq!(v.decrypt(&ct, AAD).unwrap(), b"zero-key");
}
#[test]
fn all_ff_key_works() {
let v = Vault::new([0xFFu8; 32]);
let ct = v.encrypt(b"ff-key", AAD).unwrap();
assert_eq!(v.decrypt(&ct, AAD).unwrap(), b"ff-key");
}
#[test]
fn empty_aad_works() {
let v = vault();
let ct = v.encrypt(b"hello", b"").unwrap();
assert_eq!(v.decrypt(&ct, b"").unwrap(), b"hello");
}
#[test]
fn long_aad_works() {
let v = vault();
let aad = vec![0xAB; 4096];
let ct = v.encrypt(b"hello", &aad).unwrap();
assert_eq!(v.decrypt(&ct, &aad).unwrap(), b"hello");
}
struct BoundedReader {
remaining: usize,
}
impl Read for BoundedReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.remaining == 0 {
return Ok(0);
}
let n = buf.len().min(self.remaining);
for byte in &mut buf[..n] {
*byte = 0xAB;
}
self.remaining -= n;
Ok(n)
}
}
struct MeasuringWriter {
total: usize,
max_single: usize,
}
impl Write for MeasuringWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.total += buf.len();
self.max_single = self.max_single.max(buf.len());
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[test]
fn encrypt_stream_writes_in_bounded_chunks() {
let v = vault();
let total = 10 * 1024 * 1024;
let mut r = BoundedReader { remaining: total };
let mut w = MeasuringWriter {
total: 0,
max_single: 0,
};
v.encrypt_stream(&mut r, &mut w, AAD).unwrap();
assert!(
w.max_single <= STREAM_CHUNK + TAG_LEN,
"max single write {} exceeded chunk+tag {}",
w.max_single,
STREAM_CHUNK + TAG_LEN
);
let chunks = total.div_ceil(STREAM_CHUNK).max(1);
assert_eq!(w.total, 1 + STREAM_NONCE_LEN + total + chunks * TAG_LEN);
}
#[test]
fn decrypt_stream_writes_in_bounded_chunks() {
let v = vault();
let total = 5 * 1024 * 1024;
let msg: Vec<u8> = (0u8..=255).cycle().take(total).collect();
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(&msg), &mut ct, AAD)
.unwrap();
let mut w = MeasuringWriter {
total: 0,
max_single: 0,
};
v.decrypt_stream(&mut Cursor::new(ct), &mut w, AAD).unwrap();
assert!(w.max_single <= STREAM_CHUNK);
assert_eq!(w.total, total);
}
#[test]
fn file_roundtrip() {
use std::fs;
use std::io::BufWriter;
let v = vault();
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let input_path = format!("{manifest_dir}/test.pdf");
let out_dir = format!("{manifest_dir}/target/test-artifacts");
fs::create_dir_all(&out_dir).unwrap();
let encrypted_path = format!("{out_dir}/test.pdf.enc");
let decrypted_path = format!("{out_dir}/test.pdf.dec");
let original = fs::read(&input_path).expect("test.pdf must exist");
{
let mut src = fs::File::open(&input_path).unwrap();
let dst = fs::File::create(&encrypted_path).unwrap();
v.encrypt_stream(&mut src, &mut BufWriter::new(dst), b"file:test.pdf")
.unwrap();
}
{
let mut src = fs::File::open(&encrypted_path).unwrap();
let dst = fs::File::create(&decrypted_path).unwrap();
v.decrypt_stream(&mut src, &mut BufWriter::new(dst), b"file:test.pdf")
.unwrap();
}
assert_eq!(original, fs::read(&decrypted_path).unwrap());
let _ = fs::remove_file(&decrypted_path);
let _ = fs::remove_file(&encrypted_path);
}
#[test]
fn kat_single_shot_decrypts() {
let v = Vault::new([0x42; 32]);
let pt = b"nahui-kat-v1";
let aad = b"kat-aad";
let ct = v.encrypt(pt, aad).unwrap();
assert_eq!(ct[0], VERSION);
assert_eq!(ct.len(), 1 + NONCE_LEN + pt.len() + TAG_LEN);
assert_eq!(v.decrypt(&ct, aad).unwrap(), pt);
assert!(v.decrypt(&ct, b"different").is_err());
}
#[test]
fn kat_stream_format_invariants() {
let v = Vault::new([0x42; 32]);
let mut ct = Vec::new();
v.encrypt_stream(&mut Cursor::new(b"abc".as_slice()), &mut ct, b"kat")
.unwrap();
assert_eq!(ct[0], VERSION);
assert_eq!(ct.len(), 1 + STREAM_NONCE_LEN + 3 + TAG_LEN);
}
#[test]
fn bound_aad_binds_received_version_not_build_version() {
use chacha20poly1305::aead::AeadInPlace;
let v = Vault::new(KEY);
let ct = v.encrypt(b"payload", AAD).unwrap();
#[allow(clippy::range_plus_one)]
let nonce = Nonce::from_slice(&ct[1..1 + NONCE_LEN]);
let ct_and_tag = &ct[1 + NONCE_LEN..];
let ct_len = ct_and_tag.len() - TAG_LEN;
let (cipher_bytes, tag_bytes) = ct_and_tag.split_at(ct_len);
let tag = chacha20poly1305::Tag::from_slice(tag_bytes);
let wrong_version_aad = BoundAad::new(0x99, AAD);
let mut buf = cipher_bytes.to_vec();
let cipher = ChaCha20Poly1305::new(Key::from_slice(&KEY));
assert!(
cipher
.decrypt_in_place_detached(nonce, wrong_version_aad.as_slice(), &mut buf, tag)
.is_err(),
"AEAD must reject ciphertext authenticated under a different version"
);
let mut buf = cipher_bytes.to_vec();
let right_version_aad = BoundAad::new(VERSION, AAD);
cipher
.decrypt_in_place_detached(nonce, right_version_aad.as_slice(), &mut buf, tag)
.expect("AEAD must accept ciphertext with the correct bound version");
}
#[test]
fn generate_produces_distinct_working_vaults() {
let v1 = Vault::generate().unwrap();
let v2 = Vault::generate().unwrap();
let ct = v1.encrypt(b"hi", AAD).unwrap();
assert_eq!(v1.decrypt(&ct, AAD).unwrap(), b"hi");
assert!(v2.decrypt(&ct, AAD).is_err());
}
#[test]
fn version_constant_is_one() {
assert_eq!(VERSION, 0x01);
}
}