use crate::crypto::types::{CryptoError, CryptoResult};
pub const ENCRYPT_CHUNK_SIZE: usize = 32_768;
pub const ENCRYPT_HEADER: usize = 24;
pub const ENCRYPT_CHUNK_OVERHEAD: usize = 17;
pub const ENCRYPT_KEY_SIZE: usize = 32;
pub const ENCRYPTED_CHUNK_TOTAL: usize = ENCRYPT_CHUNK_SIZE + ENCRYPT_CHUNK_OVERHEAD;
pub const MAX_REGULAR_SIZE: usize = 10_218_286;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Domain {
Attachment = 0x00,
ProfilePic = 0x01,
}
const TAG_MESSAGE: u8 = 0;
const TAG_FINAL: u8 = 3;
struct SecretStream {
key: [u8; 32],
nonce: [u8; 12],
}
impl SecretStream {
fn init_push(
header_out: &mut [u8; ENCRYPT_HEADER],
key: &[u8; ENCRYPT_KEY_SIZE],
nonce_bytes: &[u8; ENCRYPT_HEADER],
) -> Self {
use chacha20::hchacha;
use chacha20::R20;
header_out.copy_from_slice(nonce_bytes);
let hchacha_input: [u8; 16] = header_out[..16].try_into().unwrap();
let subkey =
hchacha::<R20>(&(*key).into(), &hchacha_input.into());
let mut k = [0u8; 32];
k.copy_from_slice(subkey.as_slice());
let mut nonce = [0u8; 12];
nonce[0] = 1; nonce[4..12].copy_from_slice(&header_out[16..24]);
SecretStream { key: k, nonce }
}
fn init_pull(header: &[u8; ENCRYPT_HEADER], key: &[u8; ENCRYPT_KEY_SIZE]) -> Self {
use chacha20::hchacha;
use chacha20::R20;
let hchacha_input: [u8; 16] = header[..16].try_into().unwrap();
let subkey =
hchacha::<R20>(&(*key).into(), &hchacha_input.into());
let mut k = [0u8; 32];
k.copy_from_slice(subkey.as_slice());
let mut nonce = [0u8; 12];
nonce[0] = 1;
nonce[4..12].copy_from_slice(&header[16..24]);
SecretStream { key: k, nonce }
}
fn push(&mut self, plaintext: &[u8], tag: u8) -> Vec<u8> {
use chacha20::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use chacha20::ChaCha20;
use poly1305::Poly1305;
use poly1305::universal_hash::{KeyInit, UniversalHash};
let mlen = plaintext.len();
let mut out = vec![0u8; 1 + mlen + 16];
let mut block0 = [0u8; 64];
let mut chacha = ChaCha20::new(&self.key.into(), &self.nonce.into());
chacha.apply_keystream(&mut block0);
let poly_key: [u8; 32] = block0[..32].try_into().unwrap();
out[0] = tag;
out[1..1 + mlen].copy_from_slice(plaintext);
chacha.seek(64u32);
chacha.apply_keystream(&mut out[1..1 + mlen]);
let mut mac = Poly1305::new(&poly_key.into());
mac.update_padded(&out[..1 + mlen]);
let mut lengths = [0u8; 16];
let clen = (1 + mlen) as u64;
lengths[8..16].copy_from_slice(&clen.to_le_bytes());
mac.update_padded(&lengths);
let mac_tag = mac.finalize();
out[0] ^= block0[0];
out[1 + mlen..].copy_from_slice(mac_tag.as_slice());
self.increment_counter();
if tag == TAG_FINAL {
self.rekey();
}
out
}
fn pull(&mut self, ciphertext: &[u8]) -> Result<(Vec<u8>, u8), ()> {
use chacha20::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use chacha20::ChaCha20;
use poly1305::Poly1305;
use poly1305::universal_hash::{KeyInit, UniversalHash};
if ciphertext.len() < ENCRYPT_CHUNK_OVERHEAD {
return Err(());
}
let mlen = ciphertext.len() - ENCRYPT_CHUNK_OVERHEAD;
let mut block0 = [0u8; 64];
let mut chacha = ChaCha20::new(&self.key.into(), &self.nonce.into());
chacha.apply_keystream(&mut block0);
let poly_key: [u8; 32] = block0[..32].try_into().unwrap();
let tag = ciphertext[0] ^ block0[0];
let mut c = vec![0u8; 1 + mlen];
c[0] = tag; c[1..].copy_from_slice(&ciphertext[1..1 + mlen]);
let expected_mac = &ciphertext[1 + mlen..];
let mut mac = Poly1305::new(&poly_key.into());
mac.update_padded(&c);
let mut lengths = [0u8; 16];
let clen = (1 + mlen) as u64;
lengths[8..16].copy_from_slice(&clen.to_le_bytes());
mac.update_padded(&lengths);
let computed_mac = mac.finalize();
if computed_mac.as_slice() != expected_mac {
return Err(());
}
let mut plaintext = ciphertext[1..1 + mlen].to_vec();
chacha.seek(64u32);
chacha.apply_keystream(&mut plaintext);
self.increment_counter();
if tag == TAG_FINAL {
self.rekey();
}
Ok((plaintext, tag))
}
fn increment_counter(&mut self) {
let mut counter = u32::from_le_bytes(self.nonce[..4].try_into().unwrap());
counter = counter.wrapping_add(1);
self.nonce[..4].copy_from_slice(&counter.to_le_bytes());
}
fn rekey(&mut self) {
use chacha20::cipher::{KeyIvInit, StreamCipher};
use chacha20::ChaCha20;
let mut block = [0u8; 64];
let mut chacha = ChaCha20::new(&self.key.into(), &self.nonce.into());
chacha.apply_keystream(&mut block);
self.key.copy_from_slice(&block[..32]);
self.nonce[..4].copy_from_slice(&block[32..36]);
block.fill(0);
}
}
fn bit_floor(x: usize) -> usize {
if x == 0 {
return 0;
}
1usize << (usize::BITS - 1 - x.leading_zeros())
}
pub fn encrypted_padding(data_size: usize) -> usize {
const PREFIX_SIZE: usize = 1 + ENCRYPT_HEADER;
const MIN_PADDING: usize = 1;
let stream_chunks = data_size.div_ceil(ENCRYPT_CHUNK_SIZE);
let enc_size =
data_size + PREFIX_SIZE + MIN_PADDING + stream_chunks * ENCRYPT_CHUNK_OVERHEAD;
let pad_factor = bit_floor(enc_size.max(131072)) >> 5;
let padded_size = enc_size.div_ceil(pad_factor) * pad_factor;
let mut padding = padded_size - enc_size + MIN_PADDING;
let mut implicit_padding = 0usize;
if padding >= ENCRYPTED_CHUNK_TOTAL {
implicit_padding = (padding / ENCRYPTED_CHUNK_TOTAL) * ENCRYPT_CHUNK_OVERHEAD;
}
let free_padding = stream_chunks * ENCRYPT_CHUNK_SIZE - data_size;
if padding % ENCRYPTED_CHUNK_TOTAL > free_padding {
implicit_padding += ENCRYPT_CHUNK_OVERHEAD;
}
padding -= implicit_padding;
padding
}
pub fn encrypted_size(plaintext_size: usize) -> usize {
let padding = encrypted_padding(plaintext_size);
let padded_size = plaintext_size + padding;
let tags_size = padded_size.div_ceil(ENCRYPT_CHUNK_SIZE)
* ENCRYPT_CHUNK_OVERHEAD;
1 + ENCRYPT_HEADER + plaintext_size + padding + tags_size
}
pub fn decrypted_max_size(enc_size: usize) -> Option<usize> {
let min_size = 1 + ENCRYPT_HEADER + 1 + ENCRYPT_CHUNK_OVERHEAD;
if enc_size < min_size {
return None;
}
let sz = enc_size - 1 - 1 - ENCRYPT_HEADER;
let overhead = sz.div_ceil(ENCRYPTED_CHUNK_TOTAL)
* ENCRYPT_CHUNK_OVERHEAD;
let result = sz.checked_sub(overhead)?;
if result > enc_size {
return None;
}
Some(result)
}
fn derive_nonce_key(
seed: &[u8],
data: &[u8],
domain: Domain,
) -> [u8; ENCRYPT_HEADER + ENCRYPT_KEY_SIZE] {
let mut result = [0u8; ENCRYPT_HEADER + ENCRYPT_KEY_SIZE];
let domain_byte = domain as u8;
let mut params = blake2b_simd::Params::new();
params.hash_length(result.len());
params.key(&[domain_byte]);
let mut state = params.to_state();
state.update(&seed[..32]);
state.update(data);
let hash = state.finalize();
result.copy_from_slice(hash.as_bytes());
result
}
pub fn encrypt(
seed: &[u8],
data: &[u8],
domain: Domain,
allow_large: bool,
) -> CryptoResult<(Vec<u8>, [u8; ENCRYPT_KEY_SIZE])> {
if seed.len() < 32 {
return Err(CryptoError::InvalidInput(
"attachment::encrypt requires a 32-byte uploader seed".into(),
));
}
if data.len() > MAX_REGULAR_SIZE && !allow_large {
return Err(CryptoError::InvalidInput(
"data to encrypt is too large".into(),
));
}
let nonce_key = derive_nonce_key(seed, data, domain);
let nonce: [u8; ENCRYPT_HEADER] = nonce_key[..ENCRYPT_HEADER].try_into().unwrap();
let enc_key: [u8; ENCRYPT_KEY_SIZE] =
nonce_key[ENCRYPT_HEADER..].try_into().unwrap();
let padding = encrypted_padding(data.len());
debug_assert!(padding >= 1);
let total_size = encrypted_size(data.len());
let mut out = Vec::with_capacity(total_size);
out.push(b'S');
let mut header_buf = [0u8; ENCRYPT_HEADER];
let mut stream = SecretStream::init_push(&mut header_buf, &enc_key, &nonce);
out.extend_from_slice(&header_buf);
let mut data_pos = 0usize;
let mut padding_remaining = padding;
while padding_remaining > 0 {
let mut buf: Vec<u8>;
if padding_remaining > ENCRYPT_CHUNK_SIZE {
buf = vec![0u8; ENCRYPT_CHUNK_SIZE];
padding_remaining -= ENCRYPT_CHUNK_SIZE;
} else {
buf = vec![0u8; padding_remaining]; buf[padding_remaining - 1] = 0x01;
let space = ENCRYPT_CHUNK_SIZE - padding_remaining;
let data_to_copy = space.min(data.len() - data_pos);
if data_to_copy > 0 {
buf.extend_from_slice(&data[data_pos..data_pos + data_to_copy]);
data_pos += data_to_copy;
}
padding_remaining = 0;
}
let tag = if data_pos >= data.len() && padding_remaining == 0 {
TAG_FINAL
} else {
TAG_MESSAGE
};
let encrypted_chunk = stream.push(&buf, tag);
out.extend_from_slice(&encrypted_chunk);
}
while data_pos < data.len() {
let chunk_end = (data_pos + ENCRYPT_CHUNK_SIZE).min(data.len());
let chunk = &data[data_pos..chunk_end];
data_pos = chunk_end;
let tag = if data_pos >= data.len() {
TAG_FINAL
} else {
TAG_MESSAGE
};
let encrypted_chunk = stream.push(chunk, tag);
out.extend_from_slice(&encrypted_chunk);
}
debug_assert_eq!(out.len(), total_size);
Ok((out, enc_key))
}
pub fn decrypt(
ciphertext: &[u8],
key: &[u8; ENCRYPT_KEY_SIZE],
) -> CryptoResult<Vec<u8>> {
let max_size = decrypted_max_size(ciphertext.len()).ok_or_else(|| {
CryptoError::DecryptionFailed(
"Attachment decryption failed: encrypted data too short".into(),
)
})?;
if ciphertext.is_empty() || ciphertext[0] != b'S' {
return Err(CryptoError::DecryptionFailed(format!(
"Attachment decryption failed: unknown encryption type 0x{:02x}; expected 0x53 (S)",
if ciphertext.is_empty() {
0u8
} else {
ciphertext[0]
}
)));
}
let header: [u8; ENCRYPT_HEADER] =
ciphertext[1..1 + ENCRYPT_HEADER].try_into().unwrap();
let mut stream = SecretStream::init_pull(&header, key);
let enc_data = &ciphertext[1 + ENCRYPT_HEADER..];
let mut pos = 0usize;
let mut result = Vec::with_capacity(max_size);
let mut depadded = false;
let mut done = false;
let fail = || -> CryptoError {
CryptoError::DecryptionFailed(
"Attachment decryption failed: invalid key or corrupted data".into(),
)
};
while !done {
if pos + ENCRYPT_CHUNK_OVERHEAD > enc_data.len() {
return Err(CryptoError::DecryptionFailed(
"Attachment decryption failed: data ended before end of stream".into(),
));
}
let remaining = enc_data.len() - pos;
let chunk_total = remaining.min(ENCRYPTED_CHUNK_TOTAL);
let chunk_data = &enc_data[pos..pos + chunk_total];
pos += chunk_total;
let (plaintext, tag) = stream.pull(chunk_data).map_err(|_| fail())?;
if !depadded {
let pad_end = plaintext
.iter()
.position(|&b| b != 0x00);
if let Some(idx) = pad_end {
if plaintext[idx] != 0x01 {
return Err(CryptoError::DecryptionFailed(
"Attachment decryption failed: invalid padding".into(),
));
}
depadded = true;
if idx + 1 < plaintext.len() {
result.extend_from_slice(&plaintext[idx + 1..]);
}
}
} else {
result.extend_from_slice(&plaintext);
}
if tag == TAG_FINAL {
if pos != enc_data.len() {
return Err(CryptoError::DecryptionFailed(
"Attachment decryption failed: FINAL tag before end of the encrypted data"
.into(),
));
}
done = true;
} else if pos == enc_data.len() {
return Err(CryptoError::DecryptionFailed(
"Attachment decryption failed: end of data without FINAL tag".into(),
));
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use hex_literal::hex;
fn make_data(len: usize) -> Vec<u8> {
(0..len).map(|i| ((i * 7) % 256) as u8).collect()
}
#[test]
fn test_bit_floor() {
assert_eq!(bit_floor(0), 0);
assert_eq!(bit_floor(1), 1);
assert_eq!(bit_floor(2), 2);
assert_eq!(bit_floor(3), 2);
assert_eq!(bit_floor(4), 4);
assert_eq!(bit_floor(5), 4);
assert_eq!(bit_floor(7), 4);
assert_eq!(bit_floor(8), 8);
assert_eq!(bit_floor(131072), 131072);
assert_eq!(bit_floor(1_000_528), 524288);
}
#[test]
fn test_encrypted_padding_sizes() {
let test_cases: Vec<(usize, usize)> = vec![
(0, 4096),
(1, 4096),
(2, 4096),
(10, 4096),
(100, 4096),
(1000, 4096),
(2000, 4096),
(4000, 4096),
(4053, 4096),
(4054, 8192),
(8149, 8192),
(8150, 12288),
(33333, 36864),
(261982, 262144),
(261983, 270336),
(523990, 524288),
(523991, 540672),
(6543210, 6553600),
(10218286, 10223616),
];
for (data_size, expected_enc_size) in test_cases {
let actual = encrypted_size(data_size);
assert_eq!(
actual, expected_enc_size,
"encrypted_size({}) = {} (expected {})",
data_size, actual, expected_enc_size
);
}
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let seed = hex!("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef");
let data_sizes = [
0, 1, 2, 10, 100, 1000, 2000, 4000, 4053, 4054, 8149, 8150,
33333, 261982, 261983, 523990, 523991,
];
for &data_size in &data_sizes {
let data = make_data(data_size);
let (enc, key) =
encrypt(&seed, &data, Domain::Attachment, false).unwrap();
let expected_size_val = encrypted_size(data_size);
assert_eq!(
enc.len(),
expected_size_val,
"data_size={}: enc.len()={} expected={}",
data_size,
enc.len(),
expected_size_val,
);
let decr = decrypt(&enc, &key).unwrap();
assert_eq!(
decr, data,
"data_size={}: roundtrip mismatch",
data_size
);
}
}
#[test]
fn test_encrypt_deterministic() {
let seed = hex!("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef");
for data_size in [0, 1, 100, 1000, 33333] {
let data = make_data(data_size);
let (enc1, key1) =
encrypt(&seed, &data, Domain::Attachment, false).unwrap();
let (enc2, key2) =
encrypt(&seed, &data, Domain::Attachment, false).unwrap();
assert_eq!(key1, key2, "data_size={}: keys differ", data_size);
assert_eq!(enc1, enc2, "data_size={}: ciphertexts differ", data_size);
}
}
#[test]
fn test_key_separation_different_seeds() {
let seed1 = hex!("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef");
let seed2 = hex!("1123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef");
for data_size in [0, 20, 100, 1000, 33333] {
let data = make_data(data_size);
let (enc1, key1) =
encrypt(&seed1, &data, Domain::Attachment, false).unwrap();
let (enc2, key2) =
encrypt(&seed2, &data, Domain::Attachment, false).unwrap();
assert_ne!(key1, key2);
assert_ne!(enc1, enc2);
assert!(decrypt(&enc1, &key2).is_err());
assert!(decrypt(&enc2, &key1).is_err());
}
}
#[test]
fn test_domain_separation() {
let seed = hex!("2123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef");
for data_size in [0, 20, 100, 1000, 33333] {
let data = make_data(data_size);
let (enc1, key1) =
encrypt(&seed, &data, Domain::Attachment, false).unwrap();
let (enc2, key2) =
encrypt(&seed, &data, Domain::ProfilePic, false).unwrap();
assert_ne!(key1, key2);
assert_ne!(enc1, enc2);
assert!(decrypt(&enc1, &key2).is_err());
assert!(decrypt(&enc2, &key1).is_err());
}
}
#[test]
fn test_content_separation() {
let seed = hex!("3123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef");
let data = make_data(50000);
let mut data2 = data.clone();
data2[43210] = 0x42;
let (enc1, key1) =
encrypt(&seed, &data, Domain::Attachment, false).unwrap();
let (enc2, key2) =
encrypt(&seed, &data2, Domain::Attachment, false).unwrap();
assert_ne!(key1, key2);
assert_eq!(enc1.len(), enc2.len());
assert_ne!(enc1, enc2);
assert!(decrypt(&enc1, &key2).is_err());
assert!(decrypt(&enc2, &key1).is_err());
}
#[test]
fn test_too_large_rejected() {
let seed = hex!("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef");
let data = vec![0u8; MAX_REGULAR_SIZE + 1];
assert!(encrypt(&seed, &data, Domain::Attachment, false).is_err());
let result = encrypt(&seed, &data, Domain::Attachment, true);
assert!(result.is_ok());
}
#[test]
fn test_seed_too_short() {
let seed = [0u8; 16];
let data = [0u8; 10];
assert!(encrypt(&seed, &data, Domain::Attachment, false).is_err());
}
#[test]
fn test_decrypted_max_size() {
let enc = encrypted_size(0);
let max = decrypted_max_size(enc);
assert!(max.is_some());
assert!(max.is_some());
assert!(decrypted_max_size(0).is_none());
assert!(decrypted_max_size(1).is_none());
assert!(decrypted_max_size(10).is_none());
}
#[test]
fn test_large_roundtrip() {
let seed = hex!("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef");
for data_size in [6543210] {
let data = make_data(data_size);
let (enc, key) =
encrypt(&seed, &data, Domain::Attachment, true).unwrap();
let expected = encrypted_size(data_size);
assert_eq!(enc.len(), expected);
let decr = decrypt(&enc, &key).unwrap();
assert_eq!(decr, data);
}
}
}