use crate::aliases::HmacSha256;
use crate::aliases::{Block16, EncryptedSessionBlock48, Iv16, RingBuffer64};
use crate::error::AescryptError;
use crate::utilities::{read_until_full, xor_blocks}; use aes::cipher::BlockDecrypt;
use aes::{Aes256Dec, Block as AesBlock};
use hmac::Mac;
use secure_gate::{RevealSecret, RevealSecretMut};
use std::io::{Read, Write};
pub(crate) struct DecryptionContext {
pub(crate) ring_buffer: RingBuffer64,
pub(crate) tail_index: usize,
pub(crate) current_index: usize,
pub(crate) head_index: usize,
pub(crate) plaintext_block: Block16,
pub(crate) need_write_plaintext: bool,
}
impl DecryptionContext {
#[inline(always)]
pub(crate) fn new_with_iv(iv: &Iv16) -> Self {
let mut this = Self {
ring_buffer: RingBuffer64::new([0u8; 64]),
tail_index: 0,
current_index: 16,
head_index: 16,
plaintext_block: Block16::new([0u8; 16]),
need_write_plaintext: false,
};
iv.with_secret(|iv_bytes| {
this.ring_buffer
.with_secret_mut(|rb| rb[0..16].copy_from_slice(iv_bytes))
});
this
}
#[inline(always)]
fn write_at_head(&mut self, src: &[u8]) {
self.ring_buffer.with_secret_mut(|rb| {
rb[self.head_index..self.head_index + src.len()].copy_from_slice(src);
});
self.head_index += src.len();
}
#[inline(always)]
pub(crate) fn decrypt_cbc_loop<R, W>(
&mut self,
input: &mut R,
output: &mut W,
cipher: &Aes256Dec,
hmac: &mut HmacSha256,
) -> Result<(), AescryptError>
where
R: Read,
W: Write,
{
let mut initial_buffer = EncryptedSessionBlock48::new([0u8; 48]);
let bytes_read = initial_buffer
.with_secret_mut(|ib| read_until_full(input, ib))
.map_err(AescryptError::Io)?;
initial_buffer.with_secret(|ib| self.write_at_head(&ib[..bytes_read]));
if bytes_read == 48 {
loop {
if self.need_write_plaintext {
self.plaintext_block
.with_secret(|pb| output.write_all(pb))?;
}
self.ring_buffer.with_secret(|rb| {
hmac.update(&rb[self.current_index..self.current_index + 16])
});
let mut block_bytes = Block16::new([0u8; 16]);
block_bytes.with_secret_mut(|bb| {
self.ring_buffer.with_secret(|rb| {
bb.copy_from_slice(&rb[self.current_index..self.current_index + 16]);
});
});
let mut aes_block = block_bytes.with_secret(|bb| AesBlock::from(*bb));
cipher.decrypt_block(&mut aes_block);
self.ring_buffer.with_secret(|rb| {
self.plaintext_block.with_secret_mut(|pb| {
xor_blocks(
aes_block.as_ref(),
&rb[self.tail_index..self.tail_index + 16],
pb,
);
});
});
self.need_write_plaintext = true;
self.tail_index = (self.tail_index + 16) % 64;
self.current_index = (self.current_index + 16) % 64;
if self.head_index == 64 {
self.head_index = 0;
}
let mut next_block = Block16::new([0u8; 16]);
let n = next_block
.with_secret_mut(|nb| read_until_full(input, nb))
.map_err(AescryptError::Io)?;
next_block.with_secret(|nb| self.write_at_head(&nb[..n]));
if n < 16 {
break;
}
}
}
Ok(())
}
#[inline(always)]
pub(crate) fn advance_tail(&mut self) {
self.tail_index = (self.tail_index + 16) % 64;
}
#[inline(always)]
pub(crate) fn remaining(&self) -> usize {
if self.head_index >= self.tail_index {
self.head_index - self.tail_index
} else {
64 - self.tail_index + self.head_index
}
}
}