use crate::aes_ctr;
use crate::types::AesMode;
use constant_time_eq::constant_time_eq;
use hmac::{Hmac, Mac};
use sha1::Sha1;
use std::io::{self, Read};
const PWD_VERIFY_LENGTH: usize = 2;
const AUTH_CODE_LENGTH: usize = 10;
const ITERATION_COUNT: u32 = 1000;
fn cipher_from_mode(aes_mode: AesMode, key: &[u8]) -> Box<dyn aes_ctr::AesCipher> {
match aes_mode {
AesMode::Aes128 => Box::new(aes_ctr::AesCtrZipKeyStream::<aes_ctr::Aes128>::new(key))
as Box<dyn aes_ctr::AesCipher>,
AesMode::Aes192 => Box::new(aes_ctr::AesCtrZipKeyStream::<aes_ctr::Aes192>::new(key))
as Box<dyn aes_ctr::AesCipher>,
AesMode::Aes256 => Box::new(aes_ctr::AesCtrZipKeyStream::<aes_ctr::Aes256>::new(key))
as Box<dyn aes_ctr::AesCipher>,
}
}
pub struct AesReader<R> {
reader: R,
aes_mode: AesMode,
data_length: u64,
}
impl<R: Read> AesReader<R> {
pub fn new(reader: R, aes_mode: AesMode, compressed_size: u64) -> AesReader<R> {
let data_length = compressed_size
- (PWD_VERIFY_LENGTH + AUTH_CODE_LENGTH + aes_mode.salt_length()) as u64;
Self {
reader,
aes_mode,
data_length,
}
}
pub fn validate(mut self, password: &[u8]) -> io::Result<Option<AesReaderValid<R>>> {
let salt_length = self.aes_mode.salt_length();
let key_length = self.aes_mode.key_length();
let mut salt = vec![0; salt_length];
self.reader.read_exact(&mut salt)?;
let mut pwd_verification_value = vec![0; PWD_VERIFY_LENGTH];
self.reader.read_exact(&mut pwd_verification_value)?;
let derived_key_len = 2 * key_length + PWD_VERIFY_LENGTH;
let mut derived_key: Vec<u8> = vec![0; derived_key_len];
pbkdf2::pbkdf2::<Hmac<Sha1>>(password, &salt, ITERATION_COUNT, &mut derived_key);
let decrypt_key = &derived_key[0..key_length];
let hmac_key = &derived_key[key_length..key_length * 2];
let pwd_verify = &derived_key[derived_key_len - 2..];
if pwd_verification_value != pwd_verify {
return Ok(None);
}
let cipher = cipher_from_mode(self.aes_mode, decrypt_key);
let hmac = Hmac::<Sha1>::new_from_slice(hmac_key).unwrap();
Ok(Some(AesReaderValid {
reader: self.reader,
data_remaining: self.data_length,
cipher,
hmac,
finalized: false,
}))
}
}
pub struct AesReaderValid<R: Read> {
reader: R,
data_remaining: u64,
cipher: Box<dyn aes_ctr::AesCipher>,
hmac: Hmac<Sha1>,
finalized: bool,
}
impl<R: Read> Read for AesReaderValid<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.data_remaining == 0 {
return Ok(0);
}
let bytes_to_read = self.data_remaining.min(buf.len() as u64) as usize;
let read = self.reader.read(&mut buf[0..bytes_to_read])?;
self.data_remaining -= read as u64;
self.hmac.update(&buf[0..read]);
self.cipher.crypt_in_place(&mut buf[0..read]);
if self.data_remaining == 0 {
assert!(
!self.finalized,
"Tried to use an already finalized HMAC. This is a bug!"
);
self.finalized = true;
let mut read_auth_code = [0; AUTH_CODE_LENGTH];
self.reader.read_exact(&mut read_auth_code)?;
let computed_auth_code = &self.hmac.finalize_reset().into_bytes()[0..AUTH_CODE_LENGTH];
if !constant_time_eq(computed_auth_code, &read_auth_code) {
return Err(
io::Error::new(
io::ErrorKind::InvalidData,
"Invalid authentication code, this could be due to an invalid password or errors in the data"
)
);
}
}
Ok(read)
}
}
impl<R: Read> AesReaderValid<R> {
pub fn into_inner(self) -> R {
self.reader
}
}