use crate::CompressionMethod;
use crate::aes_ctr::AesCipher;
use crate::result::ZipResult;
use crate::types::{AesMode, AesVendorVersion};
use crate::{aes_ctr, result::ZipError};
use constant_time_eq::constant_time_eq;
use hmac::{KeyInit, Mac, SimpleHmacReset};
use sha1::Sha1;
use std::io::{self, Error, ErrorKind, Read, Write};
use zeroize::{Zeroize, Zeroizing};
pub const PWD_VERIFY_LENGTH: usize = 2;
const AUTH_CODE_LENGTH: usize = 10;
const ITERATION_COUNT: u32 = 1000;
#[derive(Debug)]
enum Cipher {
Aes128(Box<aes_ctr::AesCtrZipKeyStream<aes_ctr::Aes128>>),
Aes192(Box<aes_ctr::AesCtrZipKeyStream<aes_ctr::Aes192>>),
Aes256(Box<aes_ctr::AesCtrZipKeyStream<aes_ctr::Aes256>>),
}
#[derive(Debug)]
pub struct AesInfo {
pub aes_mode: AesMode,
pub verification_value: [u8; crate::aes::PWD_VERIFY_LENGTH],
pub salt: Vec<u8>,
}
#[non_exhaustive]
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
pub(crate) struct AesModeOptions {
pub(crate) mode: AesMode,
pub(crate) vendor_version: AesVendorVersion,
pub(crate) actual_compression_method: CompressionMethod,
pub(crate) custom_salt: Option<AesSalt>,
}
impl AesModeOptions {
pub(crate) fn new(
mode: AesMode,
vendor_version: AesVendorVersion,
actual_compression_method: CompressionMethod,
custom_salt: Option<AesSalt>,
) -> Self {
Self {
mode,
vendor_version,
actual_compression_method,
custom_salt,
}
}
pub(crate) fn to_tuple(self) -> (AesMode, AesVendorVersion, CompressionMethod) {
(
self.mode,
self.vendor_version,
self.actual_compression_method,
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AesSalt {
Aes128([u8; AesMode::Aes128.salt_length()]),
Aes192([u8; AesMode::Aes192.salt_length()]),
Aes256([u8; AesMode::Aes256.salt_length()]),
}
impl AesSalt {
pub(crate) fn mode(&self) -> AesMode {
match self {
Self::Aes128(_) => AesMode::Aes128,
Self::Aes192(_) => AesMode::Aes192,
Self::Aes256(_) => AesMode::Aes256,
}
}
pub(crate) fn into_inner(self) -> Vec<u8> {
match self {
Self::Aes128(salt) => salt.to_vec(),
Self::Aes192(salt) => salt.to_vec(),
Self::Aes256(salt) => salt.to_vec(),
}
}
pub(crate) fn salt_error(mode: AesMode, err: std::array::TryFromSliceError) -> std::io::Error {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"Salt for {mode} must be {} bytes long: {err}",
mode.salt_length(),
),
)
}
pub fn try_new(mode: AesMode, salt: &[u8]) -> Result<Self, std::io::Error> {
let custom_salt = match mode {
AesMode::Aes128 => {
AesSalt::Aes128(salt.try_into().map_err(|e| Self::salt_error(mode, e))?)
}
AesMode::Aes192 => {
AesSalt::Aes192(salt.try_into().map_err(|e| Self::salt_error(mode, e))?)
}
AesMode::Aes256 => {
AesSalt::Aes256(salt.try_into().map_err(|e| Self::salt_error(mode, e))?)
}
};
Ok(custom_salt)
}
}
impl Cipher {
fn from_mode(aes_mode: AesMode, key: &[u8]) -> ZipResult<Self> {
Ok(match aes_mode {
AesMode::Aes128 => Cipher::Aes128(Box::new(aes_ctr::AesCtrZipKeyStream::<
aes_ctr::Aes128,
>::new(key)?)),
AesMode::Aes192 => Cipher::Aes192(Box::new(aes_ctr::AesCtrZipKeyStream::<
aes_ctr::Aes192,
>::new(key)?)),
AesMode::Aes256 => Cipher::Aes256(Box::new(aes_ctr::AesCtrZipKeyStream::<
aes_ctr::Aes256,
>::new(key)?)),
})
}
fn crypt_in_place(&mut self, target: &mut [u8]) {
match self {
Self::Aes128(cipher) => cipher.crypt_in_place(target),
Self::Aes192(cipher) => cipher.crypt_in_place(target),
Self::Aes256(cipher) => cipher.crypt_in_place(target),
}
}
}
pub struct AesReader<R> {
reader: R,
aes_mode: AesMode,
data_length: u64,
}
impl<R: Read> AesReader<R> {
pub const fn new(reader: R, aes_mode: AesMode, compressed_size: u64) -> AesReader<R> {
let data_length = compressed_size
- (PWD_VERIFY_LENGTH + AUTH_CODE_LENGTH + aes_mode.salt_length()) as u64;
Self {
reader,
aes_mode,
data_length,
}
}
pub fn validate(mut self, password: &[u8]) -> Result<AesReaderValid<R>, ZipError> {
let salt_length = self.aes_mode.salt_length();
let key_length = self.aes_mode.key_length();
let mut salt = vec![0; salt_length];
self.reader.read_exact(&mut salt)?;
let mut pwd_verification_value = vec![0; PWD_VERIFY_LENGTH];
self.reader.read_exact(&mut pwd_verification_value)?;
let derived_key_len = 2 * key_length + PWD_VERIFY_LENGTH;
let mut derived_key: Box<[u8]> = vec![0; derived_key_len].into_boxed_slice();
pbkdf2::pbkdf2::<SimpleHmacReset<Sha1>>(password, &salt, ITERATION_COUNT, &mut derived_key)
.map_err(|e| Error::new(ErrorKind::InvalidInput, e))?;
let decrypt_key = &derived_key[0..key_length];
let hmac_key = &derived_key[key_length..key_length * 2];
let pwd_verify = &derived_key[derived_key_len - 2..];
if pwd_verification_value != pwd_verify {
return Err(ZipError::InvalidPassword);
}
let cipher = Cipher::from_mode(self.aes_mode, decrypt_key)?;
let hmac = SimpleHmacReset::<Sha1>::new_from_slice(hmac_key).map_err(|e| {
ZipError::Io(std::io::Error::other(format!(
"Cannot create hmac with key: {e}"
)))
})?;
Ok(AesReaderValid {
reader: self.reader,
data_remaining: self.data_length,
cipher,
hmac,
finalized: false,
})
}
pub fn get_verification_value_and_salt(
mut self,
) -> io::Result<([u8; PWD_VERIFY_LENGTH], Vec<u8>)> {
let salt_length = self.aes_mode.salt_length();
let mut salt = vec![0; salt_length];
self.reader.read_exact(&mut salt)?;
let mut pwd_verification_value = [0; PWD_VERIFY_LENGTH];
self.reader.read_exact(&mut pwd_verification_value)?;
Ok((pwd_verification_value, salt))
}
}
#[derive(Debug)]
pub struct AesReaderValid<R: Read> {
reader: R,
data_remaining: u64,
cipher: Cipher,
hmac: SimpleHmacReset<Sha1>,
finalized: bool,
}
impl<R: Read> Read for AesReaderValid<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.data_remaining == 0 {
return Ok(0);
}
let bytes_to_read = self.data_remaining.min(buf.len() as u64) as usize;
let read = self.reader.read(&mut buf[0..bytes_to_read])?;
self.data_remaining -= read as u64;
self.hmac.update(&buf[0..read]);
self.cipher.crypt_in_place(&mut buf[0..read]);
if self.data_remaining == 0 {
assert!(
!self.finalized,
"Tried to use an already finalized HMAC. This is a bug!"
);
self.finalized = true;
let mut read_auth_code = [0; AUTH_CODE_LENGTH];
self.reader.read_exact(&mut read_auth_code)?;
let computed_auth_code = &self.hmac.finalize_reset().into_bytes()[0..AUTH_CODE_LENGTH];
if !constant_time_eq(computed_auth_code, &read_auth_code) {
return Err(Error::new(
ErrorKind::InvalidData,
"Invalid authentication code, this could be due to an invalid password or errors in the data",
));
}
}
Ok(read)
}
}
impl<R: Read> AesReaderValid<R> {
pub fn into_inner(self) -> R {
self.reader
}
}
pub struct AesWriter<W> {
writer: W,
cipher: Cipher,
hmac: SimpleHmacReset<Sha1>,
buffer: Zeroizing<Vec<u8>>,
encrypted_file_header: Option<Vec<u8>>,
}
impl<W: Write> AesWriter<W> {
pub(crate) fn new_with_options(
writer: W,
aes_mode: AesMode,
password: &[u8],
custom_salt: Option<AesSalt>,
) -> ZipResult<Self> {
let salt_length = aes_mode.salt_length();
let key_length = aes_mode.key_length();
let mut encrypted_file_header = Vec::with_capacity(salt_length + 2);
let salt = if let Some(customized_salt) = custom_salt {
customized_salt.into_inner()
} else {
let mut salt = vec![0; salt_length];
getrandom::fill(&mut salt).map_err(|e| ZipError::Io(e.into()))?;
salt
};
encrypted_file_header.write_all(&salt)?;
let derived_key_len = 2 * key_length + PWD_VERIFY_LENGTH;
let mut derived_key: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0; derived_key_len]);
pbkdf2::pbkdf2::<SimpleHmacReset<Sha1>>(password, &salt, ITERATION_COUNT, &mut derived_key)
.map_err(|e| Error::new(ErrorKind::InvalidInput, e))?;
let encryption_key = &derived_key[0..key_length];
let hmac_key = &derived_key[key_length..key_length * 2];
let pwd_verify = derived_key[derived_key_len - 2..].to_vec();
encrypted_file_header.write_all(&pwd_verify)?;
let cipher = Cipher::from_mode(aes_mode, encryption_key)?;
let hmac = SimpleHmacReset::<Sha1>::new_from_slice(hmac_key)
.map_err(|e| std::io::Error::other(format!("Cannot create hmac with key: {e}")))?;
Ok(Self {
writer,
cipher,
hmac,
buffer: Zeroizing::default(),
encrypted_file_header: Some(encrypted_file_header),
})
}
pub fn get_ref(&self) -> &W {
&self.writer
}
pub unsafe fn get_mut(&mut self) -> &mut W {
&mut self.writer
}
pub fn finish(mut self) -> io::Result<W> {
self.write_encrypted_file_header()?;
let computed_auth_code = &self.hmac.finalize_reset().into_bytes()[0..AUTH_CODE_LENGTH];
self.writer.write_all(computed_auth_code)?;
Ok(self.writer)
}
fn write_encrypted_file_header(&mut self) -> io::Result<()> {
if let Some(header) = self.encrypted_file_header.take() {
self.writer.write_all(&header)?;
}
Ok(())
}
}
impl<W: Write> Write for AesWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_encrypted_file_header()?;
self.buffer.extend_from_slice(buf);
self.cipher.crypt_in_place(&mut self.buffer[..]);
self.hmac.update(&self.buffer[..]);
self.writer.write_all(&self.buffer[..])?;
self.buffer.zeroize();
self.buffer.clear();
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.writer.flush()
}
}
#[cfg(all(test, feature = "aes-crypto"))]
mod tests {
use std::io::{self, Read, Write};
use crate::{
aes::{AesReader, AesWriter},
result::ZipError,
types::AesMode,
};
fn roundtrip(aes_mode: AesMode, password: &[u8], plaintext: &[u8]) -> Result<bool, ZipError> {
let mut buf = io::Cursor::new(vec![]);
let mut read_buffer = vec![];
{
let mut writer = AesWriter::new_with_options(&mut buf, aes_mode, password, None)?;
writer.write_all(plaintext)?;
writer.finish()?;
}
buf.set_position(0);
{
let compressed_length = buf.get_ref().len() as u64;
let mut reader =
AesReader::new(&mut buf, aes_mode, compressed_length).validate(password)?;
reader.read_to_end(&mut read_buffer)?;
}
Ok(plaintext == read_buffer)
}
#[test]
fn crypt_aes_256_0_byte() {
let plaintext = &[];
let password = b"some super secret password";
assert!(
roundtrip(AesMode::Aes256, password, plaintext).expect("could encrypt and decrypt")
);
}
#[test]
fn crypt_aes_128_5_byte() {
let plaintext = b"asdf\n";
let password = b"some super secret password";
assert!(
roundtrip(AesMode::Aes128, password, plaintext).expect("could encrypt and decrypt")
);
}
#[test]
fn crypt_aes_192_5_byte() {
let plaintext = b"asdf\n";
let password = b"some super secret password";
assert!(
roundtrip(AesMode::Aes192, password, plaintext).expect("could encrypt and decrypt")
);
}
#[test]
fn crypt_aes_256_5_byte() {
let plaintext = b"asdf\n";
let password = b"some super secret password";
assert!(
roundtrip(AesMode::Aes256, password, plaintext).expect("could encrypt and decrypt")
);
}
#[test]
fn crypt_aes_128_40_byte() {
let plaintext = b"Lorem ipsum dolor sit amet, consectetur\n";
let password = b"some super secret password";
assert!(
roundtrip(AesMode::Aes128, password, plaintext).expect("could encrypt and decrypt")
);
}
#[test]
fn crypt_aes_192_40_byte() {
let plaintext = b"Lorem ipsum dolor sit amet, consectetur\n";
let password = b"some super secret password";
assert!(
roundtrip(AesMode::Aes192, password, plaintext).expect("could encrypt and decrypt")
);
}
#[test]
fn crypt_aes_256_40_byte() {
let plaintext = b"Lorem ipsum dolor sit amet, consectetur\n";
let password = b"some super secret password";
assert!(
roundtrip(AesMode::Aes256, password, plaintext).expect("could encrypt and decrypt")
);
}
}