use chacha20poly1305::aead::{Aead, KeyInit};
use chacha20poly1305::XChaCha20Poly1305;
pub const ENCRYPT_DATA_OVERHEAD: usize = 40;
const DOMAIN_MAX_SIZE: usize = 24;
const NONCE_KEY_PREFIX: &str = "libsessionutil-config-encrypted-";
#[derive(Debug, thiserror::Error)]
pub enum DecryptError {
#[error("Decryption failed: ciphertext is too short")]
CiphertextTooShort,
#[error("Message decryption failed")]
DecryptionFailed,
}
fn make_encrypt_key(
key_base: &[u8; 32],
message_size: u64,
domain: &str,
) -> Result<[u8; 32], &'static str> {
if domain.is_empty() || domain.len() > DOMAIN_MAX_SIZE {
return Err("encrypt called with domain size not in [1, 24]");
}
let mut state = blake2b_simd::Params::new().hash_length(32).to_state();
state.update(key_base);
state.update(&message_size.to_be_bytes());
state.update(domain.as_bytes());
let hash = state.finalize();
let mut key = [0u8; 32];
key.copy_from_slice(hash.as_bytes());
Ok(key)
}
fn make_nonce(message: &[u8], domain: &str) -> [u8; 24] {
let mut nonce_key = String::with_capacity(NONCE_KEY_PREFIX.len() + domain.len());
nonce_key.push_str(NONCE_KEY_PREFIX);
nonce_key.push_str(domain);
let hash = blake2b_simd::Params::new()
.hash_length(24)
.key(nonce_key.as_bytes())
.hash(message);
let mut nonce = [0u8; 24];
nonce.copy_from_slice(hash.as_bytes());
nonce
}
pub fn config_encrypt(message: &[u8], key_base: &[u8; 32], domain: &str) -> Vec<u8> {
let key = make_encrypt_key(key_base, message.len() as u64, domain)
.expect("invalid domain for config_encrypt");
let nonce_bytes = make_nonce(message, domain);
let nonce = chacha20poly1305::XNonce::from(nonce_bytes);
let cipher = XChaCha20Poly1305::new((&key).into());
let mut ciphertext = cipher
.encrypt(&nonce, message)
.expect("XChaCha20-Poly1305 encryption should not fail");
ciphertext.extend_from_slice(&nonce_bytes);
ciphertext
}
pub fn config_decrypt(
ciphertext: &[u8],
key_base: &[u8; 32],
domain: &str,
) -> Result<Vec<u8>, DecryptError> {
if ciphertext.len() < ENCRYPT_DATA_OVERHEAD {
return Err(DecryptError::CiphertextTooShort);
}
let message_len = ciphertext.len() - ENCRYPT_DATA_OVERHEAD;
let nonce_bytes: [u8; 24] = ciphertext[ciphertext.len() - 24..]
.try_into()
.unwrap();
let nonce = chacha20poly1305::XNonce::from(nonce_bytes);
let ct_with_tag = &ciphertext[..ciphertext.len() - 24];
let key = make_encrypt_key(key_base, message_len as u64, domain)
.map_err(|_| DecryptError::DecryptionFailed)?;
let cipher = XChaCha20Poly1305::new((&key).into());
cipher
.decrypt(&nonce, ct_with_tag)
.map_err(|_| DecryptError::DecryptionFailed)
}
pub const fn config_padded_size(s: usize, overhead: usize) -> usize {
let s2 = s + overhead;
let chunk = if s2 < 5120 {
256
} else if s2 < 20480 {
1024
} else if s2 < 40960 {
2048
} else {
5120
};
s2.div_ceil(chunk) * chunk - overhead
}
pub fn pad_message(data: &mut Vec<u8>, overhead: usize) {
let target_size = config_padded_size(data.len(), overhead);
if target_size > data.len() {
let pad_count = target_size - data.len();
data.splice(0..0, std::iter::repeat_n(0u8, pad_count));
}
}
#[cfg(test)]
mod tests {
use super::*;
use hex_literal::hex;
#[test]
fn test_encrypt_known_vector() {
let message = b"some message 1";
let key1 = hex!("abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789");
let enc = config_encrypt(message, &key1, "test-suite1");
assert_eq!(
hex::encode(&enc),
"f14f242a26638f3305707d1035e734577f943cd7d28af58e32637e\
0966dcaf2f4860cb4d0f8ba7e09d29e31f5e4a18f65847287a54a0"
);
}
#[test]
fn test_nonce_from_known_vector() {
let message = b"some message 1";
let key1 = hex!("abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789");
let enc = config_encrypt(message, &key1, "test-suite1");
let nonce = &enc[enc.len() - 24..];
assert_eq!(
hex::encode(nonce),
"af2f4860cb4d0f8ba7e09d29e31f5e4a18f65847287a54a0"
);
}
#[test]
fn test_different_domain_different_ciphertext() {
let message = b"some message 1";
let key1 = hex!("abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789");
let enc1 = config_encrypt(message, &key1, "test-suite1");
let enc2 = config_encrypt(message, &key1, "test-suite2");
assert_ne!(hex::encode(&enc1), hex::encode(&enc2));
let nonce2 = &enc2[enc2.len() - 24..];
assert_eq!(
hex::encode(nonce2),
"277e639d36ba46470dfff509a68cb73d9a96386c51739bdd"
);
}
#[test]
fn test_different_key_same_nonce() {
let message = b"some message 1";
let key1 = hex!("abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789");
let key2 = hex!("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef");
let enc1 = config_encrypt(message, &key1, "test-suite1");
let enc3 = config_encrypt(message, &key2, "test-suite1");
assert_ne!(hex::encode(&enc1), hex::encode(&enc3));
let nonce1 = &enc1[enc1.len() - 24..];
let nonce3 = &enc3[enc3.len() - 24..];
assert_eq!(hex::encode(nonce1), hex::encode(nonce3));
}
#[test]
fn test_roundtrip() {
let message = b"some message 1";
let key1 = hex!("abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789");
let enc = config_encrypt(message, &key1, "test-suite1");
let dec = config_decrypt(&enc, &key1, "test-suite1").unwrap();
assert_eq!(dec, message);
}
#[test]
fn test_decrypt_wrong_domain() {
let message = b"some message 1";
let key1 = hex!("abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789");
let enc = config_encrypt(message, &key1, "test-suite1");
assert!(config_decrypt(&enc, &key1, "test-suite2").is_err());
}
#[test]
fn test_decrypt_wrong_key() {
let message = b"some message 1";
let key1 = hex!("abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789");
let key2 = hex!("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef");
let enc = config_encrypt(message, &key1, "test-suite1");
assert!(config_decrypt(&enc, &key2, "test-suite1").is_err());
}
#[test]
fn test_decrypt_corrupted() {
let message = b"some message 1";
let key1 = hex!("abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789");
let mut enc = config_encrypt(message, &key1, "test-suite1");
enc[3] = 0x42;
assert!(config_decrypt(&enc, &key1, "test-suite1").is_err());
}
#[test]
fn test_decrypt_too_short() {
let key = hex!("abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789");
let short = vec![0u8; ENCRYPT_DATA_OVERHEAD - 1];
assert!(matches!(
config_decrypt(&short, &key, "test"),
Err(DecryptError::CiphertextTooShort)
));
}
#[test]
fn test_padded_size() {
assert_eq!(config_padded_size(1, 0), 256);
assert_eq!(config_padded_size(1, 10), 256 - 10);
assert_eq!(config_padded_size(246, 10), 256 - 10);
assert_eq!(config_padded_size(247, 10), 512 - 10);
assert_eq!(config_padded_size(247, 256), 256);
assert_eq!(config_padded_size(3839, 96), 4000);
assert_eq!(config_padded_size(3744, 96), 3744);
assert_eq!(config_padded_size(3745, 96), 4000);
assert_eq!(config_padded_size(4864, 0), 4864);
assert_eq!(config_padded_size(4865, 0), 5120);
assert_eq!(config_padded_size(5120 + 1, 0), 6144);
assert_eq!(config_padded_size(9 * 1024, 0), 9 * 1024);
assert_eq!(config_padded_size(9 * 1024 + 1, 0), 10 * 1024);
assert_eq!(config_padded_size(10 * 1024 + 1, 0), 11 * 1024);
assert_eq!(config_padded_size(20 * 1024, 0), 20 * 1024);
assert_eq!(config_padded_size(20 * 1024 + 1, 0), 22 * 1024);
assert_eq!(config_padded_size(38 * 1024, 0), 38 * 1024);
assert_eq!(config_padded_size(38 * 1024 + 1, 0), 40 * 1024);
assert_eq!(config_padded_size(40 * 1024 + 1, 0), 45 * 1024);
assert_eq!(config_padded_size(45 * 1024 + 1, 0), 50 * 1024);
assert_eq!(config_padded_size(70 * 1024, 0), 70 * 1024);
assert_eq!(config_padded_size(70 * 1024 + 1, 0), 75 * 1024);
assert_eq!(config_padded_size(75 * 1024, 0), 75 * 1024);
assert_eq!(config_padded_size(75 * 1024 - 24, 24), 75 * 1024 - 24);
}
#[test]
fn test_pad_message() {
let mut data = vec![1u8, 2, 3];
pad_message(&mut data, ENCRYPT_DATA_OVERHEAD);
assert_eq!(data.len(), 216);
assert!(data[..213].iter().all(|&b| b == 0));
assert_eq!(&data[213..], &[1, 2, 3]);
}
#[test]
fn test_overhead_constant() {
assert_eq!(ENCRYPT_DATA_OVERHEAD, 40);
}
}