mod password;
mod properties;
use aes::Aes256;
use cbc::cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit};
use sha2::{Digest, Sha256};
use std::io::{self, Read, Write};
use std::num::NonZeroUsize;
use std::sync::Mutex;
use crate::Result;
use crate::s3fifo::S3FifoCache;
pub use password::Password;
pub use properties::{AesProperties, NoncePolicy};
type Aes256CbcDec = cbc::Decryptor<Aes256>;
type Aes256CbcEnc = cbc::Encryptor<Aes256>;
fn lock_or_recover<T>(mutex: &Mutex<T>) -> std::sync::MutexGuard<'_, T> {
mutex.lock().unwrap_or_else(|poisoned| {
log::warn!("KeyCache mutex was poisoned, recovering");
poisoned.into_inner()
})
}
const BLOCK_SIZE: usize = 16;
pub const MAX_NUM_CYCLES_POWER: u8 = 30;
pub fn derive_key(password: &Password, salt: &[u8], num_cycles_power: u8) -> Result<[u8; 32]> {
if num_cycles_power > MAX_NUM_CYCLES_POWER {
log::warn!(
"Key derivation cycles_power {} exceeds maximum {}, rejecting",
num_cycles_power,
MAX_NUM_CYCLES_POWER
);
return Err(crate::Error::ResourceLimitExceeded(format!(
"key derivation cycles_power {} exceeds maximum {} (would require {} iterations)",
num_cycles_power,
MAX_NUM_CYCLES_POWER,
1u64.checked_shl(num_cycles_power as u32)
.unwrap_or(u64::MAX)
)));
}
let iterations = 1u64 << num_cycles_power;
let password_bytes = password.as_utf16_le();
let mut hash_input = Vec::with_capacity(salt.len() + password_bytes.len() + 8);
let mut sha = Sha256::new();
for i in 0..iterations {
hash_input.clear();
hash_input.extend_from_slice(salt);
hash_input.extend_from_slice(&password_bytes);
hash_input.extend_from_slice(&i.to_le_bytes());
sha.update(&hash_input);
}
Ok(sha.finalize().into())
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct CacheKey {
password_hash: [u8; 32],
salt: Vec<u8>,
num_cycles_power: u8,
}
impl CacheKey {
fn new(password: &Password, salt: &[u8], num_cycles_power: u8) -> Self {
let password_bytes = password.as_utf16_le();
let password_hash: [u8; 32] = Sha256::digest(&password_bytes).into();
Self {
password_hash,
salt: salt.to_vec(),
num_cycles_power,
}
}
}
pub struct KeyCache {
cache: Mutex<S3FifoCache<CacheKey, [u8; 32]>>,
stats: Mutex<CacheStats>,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub iterations_saved: u64,
}
impl CacheStats {
pub fn hit_ratio(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
impl KeyCache {
pub fn new(capacity: usize) -> Self {
let cap = NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::MIN);
Self {
cache: Mutex::new(S3FifoCache::new(cap)),
stats: Mutex::new(CacheStats::default()),
}
}
pub fn derive_key(
&self,
password: &Password,
salt: &[u8],
num_cycles_power: u8,
) -> Result<[u8; 32]> {
let cache_key = CacheKey::new(password, salt, num_cycles_power);
{
let mut cache = lock_or_recover(&self.cache);
if let Some(&key) = cache.get(&cache_key) {
let mut stats = lock_or_recover(&self.stats);
stats.hits += 1;
stats.iterations_saved += 1u64 << num_cycles_power;
return Ok(key);
}
}
let key = derive_key(password, salt, num_cycles_power)?;
{
let mut cache = lock_or_recover(&self.cache);
cache.insert(cache_key, key);
let mut stats = lock_or_recover(&self.stats);
stats.misses += 1;
}
Ok(key)
}
pub fn stats(&self) -> CacheStats {
lock_or_recover(&self.stats).clone()
}
pub fn reset_stats(&self) {
*lock_or_recover(&self.stats) = CacheStats::default();
}
pub fn clear(&self) {
lock_or_recover(&self.cache).clear();
}
pub fn len(&self) -> usize {
lock_or_recover(&self.cache).len()
}
pub fn is_empty(&self) -> bool {
lock_or_recover(&self.cache).is_empty()
}
}
pub fn validate_decrypted_header(decrypted_data: &[u8], compression_method: &[u8]) -> bool {
if decrypted_data.is_empty() {
return false;
}
const LZMA: &[u8] = &[0x03, 0x01, 0x01];
const LZMA2: &[u8] = &[0x21];
const DEFLATE: &[u8] = &[0x04, 0x01, 0x08];
const BZIP2: &[u8] = &[0x04, 0x02, 0x02];
const PPMD: &[u8] = &[0x03, 0x04, 0x01];
const COPY: &[u8] = &[0x00];
match compression_method {
LZMA => validate_lzma_header(decrypted_data),
LZMA2 => validate_lzma2_header(decrypted_data),
DEFLATE => validate_deflate_header(decrypted_data),
BZIP2 => validate_bzip2_header(decrypted_data),
PPMD => validate_ppmd_header(decrypted_data),
COPY => true, _ => true, }
}
fn validate_lzma_header(data: &[u8]) -> bool {
if data.is_empty() {
return false;
}
let props_byte = data[0];
let pb = props_byte / 45;
let remainder = props_byte % 45;
let lp = remainder / 9;
let lc = remainder % 9;
if pb >= 5 || lp >= 5 || lc >= 9 {
return false;
}
if data.len() >= 5 {
let dict_size = u32::from_le_bytes([data[1], data[2], data[3], data[4]]);
if dict_size > 1 << 30 {
return false;
}
}
true
}
fn validate_lzma2_header(data: &[u8]) -> bool {
if data.is_empty() {
return false;
}
let control = data[0];
if (0x03..0x80).contains(&control) {
return false;
}
true
}
fn validate_deflate_header(data: &[u8]) -> bool {
if data.is_empty() {
return false;
}
let first_byte = data[0];
let btype = (first_byte >> 1) & 0x03;
if btype == 3 {
return false;
}
true
}
fn validate_bzip2_header(data: &[u8]) -> bool {
if data.len() < 2 {
return false;
}
data[0] == b'B' && data[1] == b'Z'
}
fn validate_ppmd_header(_data: &[u8]) -> bool {
true
}
pub struct Aes256Decoder<R> {
inner: R,
buffer: Vec<u8>,
pos: usize,
key: [u8; 32],
iv: [u8; 16],
finished: bool,
}
impl<R> std::fmt::Debug for Aes256Decoder<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Aes256Decoder").finish_non_exhaustive()
}
}
impl<R: Read + Send> Aes256Decoder<R> {
pub fn new(input: R, properties: &[u8], password: &Password) -> Result<Self> {
let props = AesProperties::parse(properties)?;
let key = derive_key(password, &props.salt, props.num_cycles_power)?;
let mut iv = [0u8; 16];
let iv_len = props.iv.len().min(16);
iv[..iv_len].copy_from_slice(&props.iv[..iv_len]);
Ok(Self {
inner: input,
buffer: Vec::new(),
pos: 0,
key,
iv,
finished: false,
})
}
pub fn with_key_iv(input: R, key: [u8; 32], iv: [u8; 16]) -> Self {
Self {
inner: input,
buffer: Vec::new(),
pos: 0,
key,
iv,
finished: false,
}
}
pub fn validate_first_block(&mut self, compression_method: &[u8]) -> io::Result<bool> {
if self.buffer.is_empty() && !self.finished {
self.decrypt_buffer()?;
}
if self.buffer.is_empty() {
return Ok(true);
}
Ok(validate_decrypted_header(&self.buffer, compression_method))
}
pub fn buffered_data(&self) -> &[u8] {
&self.buffer[self.pos..]
}
fn decrypt_buffer(&mut self) -> io::Result<()> {
let mut encrypted = vec![0u8; 4096];
let n = self.inner.read(&mut encrypted)?;
if n == 0 {
self.finished = true;
return Ok(());
}
let aligned_len = (n / BLOCK_SIZE) * BLOCK_SIZE;
if aligned_len == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"encrypted data not block-aligned",
));
}
encrypted.truncate(aligned_len);
let next_iv: [u8; 16] = if encrypted.len() >= BLOCK_SIZE {
encrypted[encrypted.len() - BLOCK_SIZE..]
.try_into()
.expect("slice is exactly BLOCK_SIZE bytes after length check")
} else {
self.iv
};
let decryptor = Aes256CbcDec::new(&self.key.into(), &self.iv.into());
let decrypted = decryptor
.decrypt_padded_mut::<cbc::cipher::block_padding::NoPadding>(&mut encrypted)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
self.iv = next_iv;
self.buffer = decrypted.to_vec();
self.pos = 0;
Ok(())
}
}
impl<R: Read + Send> Read for Aes256Decoder<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.pos >= self.buffer.len() && !self.finished {
self.decrypt_buffer()?;
}
if self.pos >= self.buffer.len() {
return Ok(0);
}
let available = &self.buffer[self.pos..];
let to_copy = available.len().min(buf.len());
buf[..to_copy].copy_from_slice(&available[..to_copy]);
self.pos += to_copy;
Ok(to_copy)
}
}
pub struct Aes256Encoder<W> {
inner: W,
buffer: Vec<u8>,
key: [u8; 32],
iv: [u8; 16],
}
impl<W> std::fmt::Debug for Aes256Encoder<W> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Aes256Encoder").finish_non_exhaustive()
}
}
impl<W: Write + Send> Aes256Encoder<W> {
pub fn new(output: W, password: &Password, nonce_policy: &NoncePolicy) -> Result<Self> {
let (salt, iv) = nonce_policy.generate()?;
let key = derive_key(password, &salt, nonce_policy.num_cycles_power())?;
Ok(Self {
inner: output,
buffer: Vec::new(),
key,
iv,
})
}
pub fn with_key_iv(output: W, key: [u8; 32], iv: [u8; 16]) -> Self {
Self {
inner: output,
buffer: Vec::new(),
key,
iv,
}
}
pub fn properties(&self, salt: &[u8], num_cycles_power: u8) -> Vec<u8> {
AesProperties::encode(num_cycles_power, salt, &self.iv)
}
fn flush_buffer(&mut self) -> io::Result<()> {
if self.buffer.is_empty() {
return Ok(());
}
let complete_blocks = (self.buffer.len() / BLOCK_SIZE) * BLOCK_SIZE;
if complete_blocks == 0 {
return Ok(());
}
let mut to_encrypt = self.buffer[..complete_blocks].to_vec();
let encryptor = Aes256CbcEnc::new(&self.key.into(), &self.iv.into());
let encrypted = encryptor
.encrypt_padded_mut::<cbc::cipher::block_padding::NoPadding>(
&mut to_encrypt,
complete_blocks,
)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
self.inner.write_all(encrypted)?;
if encrypted.len() >= BLOCK_SIZE {
self.iv
.copy_from_slice(&encrypted[encrypted.len() - BLOCK_SIZE..]);
}
self.buffer = self.buffer[complete_blocks..].to_vec();
Ok(())
}
pub fn finish(mut self) -> io::Result<W> {
self.flush_buffer()?;
let pad_len = BLOCK_SIZE - (self.buffer.len() % BLOCK_SIZE);
self.buffer
.extend(std::iter::repeat_n(pad_len as u8, pad_len));
let buffer_len = self.buffer.len();
let encryptor = Aes256CbcEnc::new(&self.key.into(), &self.iv.into());
let encrypted = encryptor
.encrypt_padded_mut::<cbc::cipher::block_padding::NoPadding>(
&mut self.buffer,
buffer_len,
)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
self.inner.write_all(encrypted)?;
self.inner.flush()?;
Ok(self.inner)
}
}
impl<W: Write + Send> Write for Aes256Encoder<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.buffer.extend_from_slice(buf);
if self.buffer.len() >= 4096 {
self.flush_buffer()?;
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.flush_buffer()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_derive_key() {
let password = Password::new("test");
let salt = b"saltsalt";
let key = derive_key(&password, salt, 10).unwrap();
assert_eq!(key.len(), 32);
let key2 = derive_key(&password, salt, 10).unwrap();
assert_eq!(key, key2);
let password2 = Password::new("test2");
let key3 = derive_key(&password2, salt, 10).unwrap();
assert_ne!(key, key3);
}
#[test]
fn test_derive_key_max_cycles_power() {
let password = Password::new("test");
let salt = b"saltsalt";
let key = derive_key(&password, salt, 10).unwrap();
assert_eq!(key.len(), 32);
let result = derive_key(&password, salt, MAX_NUM_CYCLES_POWER + 1);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, crate::Error::ResourceLimitExceeded(_)));
}
#[test]
fn test_aes_roundtrip() {
let data = b"Hello, World! This is test data for AES encryption.";
let key = [0u8; 32];
let iv = [0u8; 16];
let mut encrypted = Vec::new();
{
let mut encoder = Aes256Encoder::with_key_iv(Cursor::new(&mut encrypted), key, iv);
encoder.write_all(data).unwrap();
encoder.finish().unwrap();
}
let mut decoder = Aes256Decoder::with_key_iv(Cursor::new(&encrypted), key, iv);
let mut decrypted = Vec::new();
decoder.read_to_end(&mut decrypted).unwrap();
if let Some(&pad_len) = decrypted.last() {
if (pad_len as usize) <= BLOCK_SIZE {
decrypted.truncate(decrypted.len() - pad_len as usize);
}
}
assert_eq!(&decrypted[..], &data[..]);
}
#[test]
fn test_password_utf16le() {
let password = Password::new("test");
let bytes = password.as_utf16_le();
assert_eq!(bytes, vec![0x74, 0x00, 0x65, 0x00, 0x73, 0x00, 0x74, 0x00]);
}
#[test]
fn test_key_cache_basic() {
let cache = KeyCache::new(4);
let password = Password::new("test");
let salt = b"saltsalt";
let key1 = cache.derive_key(&password, salt, 5).unwrap();
let stats1 = cache.stats();
assert_eq!(stats1.misses, 1);
assert_eq!(stats1.hits, 0);
assert_eq!(cache.len(), 1);
let key2 = cache.derive_key(&password, salt, 5).unwrap();
let stats2 = cache.stats();
assert_eq!(stats2.misses, 1);
assert_eq!(stats2.hits, 1);
assert_eq!(stats2.iterations_saved, 32);
assert_eq!(key1, key2);
}
#[test]
fn test_key_cache_different_params() {
let cache = KeyCache::new(4);
let password = Password::new("test");
let salt1 = b"salt1111";
let salt2 = b"salt2222";
let key1 = cache.derive_key(&password, salt1, 5).unwrap();
let key2 = cache.derive_key(&password, salt2, 5).unwrap();
assert_ne!(key1, key2);
let stats = cache.stats();
assert_eq!(stats.misses, 2);
assert_eq!(stats.hits, 0);
}
#[test]
fn test_key_cache_clear() {
let cache = KeyCache::new(4);
let password = Password::new("test");
let salt = b"saltsalt";
cache.derive_key(&password, salt, 5).unwrap();
assert_eq!(cache.len(), 1);
cache.clear();
assert!(cache.is_empty());
cache.derive_key(&password, salt, 5).unwrap();
let stats = cache.stats();
assert_eq!(stats.misses, 2); }
#[test]
fn test_key_cache_stats_reset() {
let cache = KeyCache::new(4);
let password = Password::new("test");
let salt = b"saltsalt";
cache.derive_key(&password, salt, 5).unwrap();
cache.derive_key(&password, salt, 5).unwrap();
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
cache.reset_stats();
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
}
#[test]
fn test_cache_stats_hit_ratio() {
let stats = CacheStats {
hits: 3,
misses: 1,
iterations_saved: 1000,
};
assert!((stats.hit_ratio() - 0.75).abs() < f64::EPSILON);
let empty = CacheStats::default();
assert!((empty.hit_ratio() - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_validate_lzma_header() {
assert!(validate_lzma_header(&[0x5D, 0x00, 0x00, 0x10, 0x00]));
assert!(validate_lzma_header(&[0x00, 0x00, 0x00, 0x01, 0x00]));
assert!(!validate_lzma_header(&[0xE1]));
assert!(!validate_lzma_header(&[]));
}
#[test]
fn test_validate_lzma2_header() {
assert!(validate_lzma2_header(&[0x00]));
assert!(validate_lzma2_header(&[0x01]));
assert!(validate_lzma2_header(&[0x02]));
assert!(validate_lzma2_header(&[0x80]));
assert!(validate_lzma2_header(&[0xFF]));
assert!(!validate_lzma2_header(&[0x03]));
assert!(!validate_lzma2_header(&[0x50]));
assert!(!validate_lzma2_header(&[0x7F]));
assert!(!validate_lzma2_header(&[]));
}
#[test]
fn test_validate_deflate_header() {
assert!(validate_deflate_header(&[0b00000000])); assert!(validate_deflate_header(&[0b00000001]));
assert!(validate_deflate_header(&[0b00000010])); assert!(validate_deflate_header(&[0b00000011]));
assert!(validate_deflate_header(&[0b00000100])); assert!(validate_deflate_header(&[0b00000101]));
assert!(!validate_deflate_header(&[0b00000110])); assert!(!validate_deflate_header(&[0b00000111]));
assert!(!validate_deflate_header(&[]));
}
#[test]
fn test_validate_bzip2_header() {
assert!(validate_bzip2_header(b"BZh9"));
assert!(!validate_bzip2_header(b"PK"));
assert!(!validate_bzip2_header(b"7z"));
assert!(!validate_bzip2_header(b"B"));
assert!(!validate_bzip2_header(&[]));
}
#[test]
fn test_validate_decrypted_header() {
const LZMA: &[u8] = &[0x03, 0x01, 0x01];
const LZMA2: &[u8] = &[0x21];
const DEFLATE: &[u8] = &[0x04, 0x01, 0x08];
const BZIP2: &[u8] = &[0x04, 0x02, 0x02];
const COPY: &[u8] = &[0x00];
assert!(validate_decrypted_header(
&[0x5D, 0x00, 0x00, 0x10, 0x00],
LZMA
));
assert!(validate_decrypted_header(&[0x80], LZMA2));
assert!(!validate_decrypted_header(&[0x50], LZMA2));
assert!(validate_decrypted_header(&[0x00], DEFLATE));
assert!(!validate_decrypted_header(&[0x06], DEFLATE));
assert!(validate_decrypted_header(b"BZh9data", BZIP2));
assert!(validate_decrypted_header(&[0xFF, 0xFF, 0xFF], COPY));
assert!(validate_decrypted_header(&[0xFF], &[0x99, 0x99]));
}
#[test]
fn test_derive_key_with_varied_salts() {
let password = Password::new("test_password");
let cycles_power = 10;
let salt_patterns: [([u8; 16], &str); 5] = [
([0u8; 16], "all zeros"),
([0xFFu8; 16], "all ones"),
(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
"sequential",
),
(
[
0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE, 0xBA, 0xBE, 0x12, 0x34, 0x56, 0x78, 0x9A,
0xBC, 0xDE, 0xF0,
],
"mixed bytes",
),
(
[
0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x01,
],
"boundary values",
),
];
let mut derived_keys = Vec::new();
for (salt, pattern_name) in &salt_patterns {
let result = derive_key(&password, salt, cycles_power);
assert!(
result.is_ok(),
"Key derivation should succeed with {} salt",
pattern_name
);
let key = result.unwrap();
assert_eq!(
key.len(),
32,
"Derived key should be 32 bytes for {} salt",
pattern_name
);
derived_keys.push((key, pattern_name));
}
for i in 0..derived_keys.len() {
for j in (i + 1)..derived_keys.len() {
assert_ne!(
derived_keys[i].0, derived_keys[j].0,
"Salt '{}' and '{}' should produce different keys",
derived_keys[i].1, derived_keys[j].1
);
}
}
}
#[test]
fn test_derive_key_deterministic() {
let password = Password::new("determinism_test");
let salt = [0x42u8; 16];
let cycles_power = 10;
let key1 = derive_key(&password, &salt, cycles_power).unwrap();
let key2 = derive_key(&password, &salt, cycles_power).unwrap();
assert_eq!(key1, key2, "Same inputs should produce same key");
}
#[test]
fn test_derive_key_extreme_values_rejected() {
let password = Password::new("test");
let salt = [0u8; 16];
let result = derive_key(&password, &salt, 62);
assert!(result.is_err(), "cycles_power=62 should be rejected");
let result = derive_key(&password, &salt, 63);
assert!(result.is_err(), "cycles_power=63 should be rejected");
}
}