use std::{
cmp,
convert::TryFrom,
io::{Error, ErrorKind, Read, Result, Seek, SeekFrom, Write},
ops::Neg,
};
use hmac::{Mac, NewMac};
use aes::cipher::{NewCipher, StreamCipher};
use rand::{thread_rng, Rng};
const BUFFER_SIZE: usize = 8192;
pub struct AesWriter<E: NewCipher + StreamCipher, M: Mac + NewMac, W: Write> {
writer: W,
enc: E,
mac: M,
finalized: bool,
}
impl<E: NewCipher + StreamCipher, M: Mac + NewMac, W: Write> AesWriter<E, M, W> {
pub fn new(
mut writer: W,
key: &[u8],
mac_key: &[u8],
iv_size: usize,
) -> Result<AesWriter<E, M, W>> {
let mut iv = vec![0u8; iv_size];
let mut rng = thread_rng();
rng.try_fill(&mut iv[0..iv_size / 2])
.map_err(|e| Error::new(ErrorKind::Other, format!("error generating iv: {:?}", e)))?;
let mac = M::new_from_slice(mac_key)
.map_err(|e| Error::new(ErrorKind::Other, format!("error creating mac: {:?}", e)))?;
let enc = E::new_from_slices(key, &iv).map_err(|e| {
Error::new(
ErrorKind::Other,
format!("error initializing cipher: {:?}", e),
)
})?;
writer.write_all(&iv)?;
Ok(AesWriter {
writer,
enc,
mac,
finalized: false,
})
}
fn encrypt_write(&mut self, buf: &mut [u8]) -> Result<usize> {
if self.finalized {
return Err(Error::new(
ErrorKind::Other,
"File has been already finalized",
));
}
self.enc.try_apply_keystream(buf).map_err(|_| {
Error::new(
ErrorKind::Other,
"Encryption error, reached end of the keystream.",
)
})?;
self.writer.write_all(buf)?;
self.mac.update(buf);
Ok(buf.len())
}
pub fn finalize(&mut self) -> Result<()> {
self.encrypt_write(&mut [])?;
let mac_result = self.mac.finalize_reset().into_bytes();
self.writer.write_all(mac_result.as_slice())?;
self.finalized = true;
Ok(())
}
}
impl<E: NewCipher + StreamCipher, M: Mac + NewMac, W: Write> Write for AesWriter<E, M, W> {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
let mut buf = buf.to_owned();
let written = self.encrypt_write(&mut buf)?;
Ok(written)
}
fn flush(&mut self) -> Result<()> {
self.writer.flush()
}
}
impl<E: NewCipher + StreamCipher, M: Mac + NewMac, W: Write> Drop for AesWriter<E, M, W> {
fn drop(&mut self) {
if self.finalized {
return;
}
if std::thread::panicking() {
let _ = self.finalize();
} else {
self.finalize().unwrap();
}
}
}
pub struct AesReader<D: NewCipher + StreamCipher, R: Read + Seek + Clone> {
reader: R,
dec: D,
pub(crate) length: u64,
pub(crate) mac_length: u64,
}
impl<D: NewCipher + StreamCipher, R: Read + Seek + Clone> AesReader<D, R> {
pub fn new<M: Mac + NewMac>(
mut reader: R,
key: &[u8],
mac_key: &[u8],
iv_size: usize,
mac_size: usize,
) -> Result<AesReader<D, R>> {
let iv_length = iv_size;
let mut mac = M::new_from_slice(mac_key)
.map_err(|e| Error::new(ErrorKind::Other, format!("error creating mac: {:?}", e)))?;
let mac_length = mac_size;
let u_iv_length = u64::try_from(iv_length)
.map_err(|_| Error::new(ErrorKind::Other, "IV length is too big"))?;
let u_mac_length = u64::try_from(mac_length)
.map_err(|_| Error::new(ErrorKind::Other, "MAC length is too big"))?;
let i_mac_length = i64::try_from(mac_length)
.map_err(|_| Error::new(ErrorKind::Other, "MAC length is too big"))?;
let mut iv = vec![0u8; iv_length];
let mut expected_mac = vec![0u8; mac_length];
reader.read_exact(&mut iv)?;
let end = reader.seek(SeekFrom::End(0))?;
if end < (u_iv_length + u_mac_length) {
return Err(Error::new(
ErrorKind::Other,
"File doesn't contain a valid IV or MAC",
));
}
let seek_back = i_mac_length.neg();
reader.seek(SeekFrom::End(seek_back))?;
reader.read_exact(&mut expected_mac)?;
reader.seek(SeekFrom::Start(u_iv_length))?;
let mut buffer = [0u8; BUFFER_SIZE];
loop {
let read =
AesReader::<D, R>::read_until_mac(&mut buffer, &mut reader, end, u_mac_length)?;
if read == 0 {
break;
}
mac.update(&buffer[..read]);
}
if mac.verify(&expected_mac).is_err() {
return Err(Error::new(ErrorKind::Other, "Invalid MAC"));
}
reader.seek(SeekFrom::Start(u_iv_length))?;
let dec = D::new_from_slices(key, &iv).map_err(|e| {
Error::new(
ErrorKind::Other,
format!("couldn't initialize cipher {:?}", e),
)
})?;
Ok(AesReader {
reader,
dec,
length: end,
mac_length: u_mac_length,
})
}
fn read_until_mac(
buffer: &mut [u8],
reader: &mut R,
total_length: u64,
mac_length: u64,
) -> Result<usize> {
let current_pos = reader.seek(SeekFrom::Current(0))?;
let mac_start = total_length - mac_length;
if current_pos >= mac_start {
return Ok(0);
}
let max_to_read = cmp::min(buffer.len(), (mac_start - current_pos) as usize);
let read = reader.read(&mut buffer[..max_to_read])?;
Ok(read)
}
fn read_decrypt(&mut self, buf: &mut [u8]) -> Result<usize> {
let read =
AesReader::<D, R>::read_until_mac(buf, &mut self.reader, self.length, self.mac_length)?;
self.dec.try_apply_keystream(buf).map_err(|_| {
Error::new(
ErrorKind::Other,
"Decryption error, reached end of the keystream.",
)
})?;
Ok(read)
}
}
impl<D: NewCipher + StreamCipher, R: Read + Seek + Clone> Read for AesReader<D, R> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let read = self.read_decrypt(buf)?;
Ok(read)
}
}
#[cfg(test)]
mod test {
use aes::Aes128Ctr;
use hmac::Hmac;
use sha2::Sha256;
use std::io::{Cursor, Read, Seek, Write};
use super::{AesReader, AesWriter};
fn encrypt(data: &[u8]) -> Vec<u8> {
let key = [0u8; 16];
let hmac_key = [0u8; 16];
let mut enc = Vec::new();
{
let mut aes =
AesWriter::<Aes128Ctr, Hmac<Sha256>, _>::new(&mut enc, &key, &hmac_key, 16)
.unwrap();
aes.write_all(data).unwrap();
}
enc
}
fn decrypt<R: Read + Seek + Clone>(data: R) -> Vec<u8> {
let key = [0u8; 16];
let mut dec = Vec::new();
let mut aes =
AesReader::<Aes128Ctr, _>::new::<Hmac<Sha256>>(data, &key, &key, 16, 32).unwrap();
aes.read_to_end(&mut dec).unwrap();
dec
}
#[test]
fn enc_unaligned() {
let orig = [0u8; 16];
let key = [0u8; 16];
let hmac_key = [0u8; 16];
let mut enc = Vec::new();
{
let mut aes =
AesWriter::<Aes128Ctr, Hmac<Sha256>, _>::new(&mut enc, &key, &hmac_key, 16)
.unwrap();
for chunk in orig.chunks(3) {
aes.write_all(chunk).unwrap();
}
}
let dec = decrypt(Cursor::new(&enc));
assert_eq!(dec, &orig);
}
#[test]
fn enc_dec_single() {
let orig = [0u8; 16];
let enc = encrypt(&orig);
let dec = decrypt(Cursor::new(&enc));
assert_eq!(dec, &orig);
}
#[test]
fn enc_dec_single_full() {
let orig = [0u8; 16];
let enc = encrypt(&orig);
let dec = decrypt(Cursor::new(&enc));
assert_eq!(dec, &orig);
}
#[test]
fn dec_read_unaligned() {
let orig = [0u8; 16];
let enc = encrypt(&orig);
let key = [0u8; 16];
let mut dec: Vec<u8> = Vec::new();
let mut aes =
AesReader::<Aes128Ctr, _>::new::<Hmac<Sha256>>(Cursor::new(&enc), &key, &key, 16, 32)
.unwrap();
loop {
let mut buf = [0u8; 3];
let read = aes.read(&mut buf).unwrap();
dec.extend(&buf[..read]);
if read == 0 {
break;
}
}
assert_eq!(dec, &orig);
}
}