use crate::aliases::HmacSha256;
use crate::aliases::{Block16, EncryptedSessionBlock48, Iv16, RingBuffer64};
use crate::error::AescryptError;
use crate::utilities::xor_blocks; use aes::cipher::BlockDecrypt;
use aes::{Aes256Dec, Block as AesBlock};
use hmac::Mac;
use secure_gate::{RevealSecret, RevealSecretMut};
use std::io::{Read, Write};
pub struct DecryptionContext {
pub ring_buffer: RingBuffer64,
pub tail_index: usize,
pub current_index: usize,
pub head_index: usize,
pub plaintext_block: Block16,
pub need_write_plaintext: bool,
}
impl DecryptionContext {
#[inline(always)]
pub fn new_with_iv(iv: &Iv16) -> Self {
let mut this = Self {
ring_buffer: RingBuffer64::new([0u8; 64]),
tail_index: 0,
current_index: 16,
head_index: 16,
plaintext_block: Block16::new([0u8; 16]),
need_write_plaintext: false,
};
iv.with_secret(|iv_bytes| {
this.ring_buffer
.with_secret_mut(|rb| rb[0..16].copy_from_slice(iv_bytes))
});
this
}
#[inline(always)]
pub fn decrypt_cbc_loop<R, W>(
&mut self,
input: &mut R,
output: &mut W,
cipher: &Aes256Dec,
hmac: &mut HmacSha256,
) -> Result<(), AescryptError>
where
R: Read,
W: Write,
{
let mut initial_buffer = EncryptedSessionBlock48::new([0u8; 48]);
let bytes_read = initial_buffer
.with_secret_mut(|ib| -> Result<usize, std::io::Error> {
let mut total = 0usize;
while total < 48 {
match input.read(&mut ib[total..]) {
Ok(0) => break,
Ok(k) => total += k,
Err(e) => return Err(e),
}
}
Ok(total)
})
.map_err(AescryptError::Io)?;
self.ring_buffer.with_secret_mut(|rb| {
initial_buffer.with_secret(|ib| {
rb[self.head_index..self.head_index + bytes_read]
.copy_from_slice(&ib[..bytes_read]);
});
});
self.head_index += bytes_read;
if bytes_read == 48 {
loop {
if self.need_write_plaintext {
self.plaintext_block
.with_secret(|pb| output.write_all(pb))?;
}
self.ring_buffer.with_secret(|rb| {
hmac.update(&rb[self.current_index..self.current_index + 16])
});
let mut block_bytes = Block16::new([0u8; 16]);
block_bytes.with_secret_mut(|bb| {
self.ring_buffer.with_secret(|rb| {
bb.copy_from_slice(&rb[self.current_index..self.current_index + 16]);
});
});
let mut aes_block = block_bytes.with_secret(|bb| AesBlock::from(*bb));
cipher.decrypt_block(&mut aes_block);
self.ring_buffer.with_secret(|rb| {
self.plaintext_block.with_secret_mut(|pb| {
xor_blocks(
aes_block.as_ref(),
&rb[self.tail_index..self.tail_index + 16],
pb,
);
});
});
self.need_write_plaintext = true;
self.tail_index = (self.tail_index + 16) % 64;
self.current_index = (self.current_index + 16) % 64;
if self.head_index == 64 {
self.head_index = 0;
}
let mut next_block = Block16::new([0u8; 16]);
let n = next_block
.with_secret_mut(|nb| -> Result<usize, std::io::Error> {
let mut total = 0usize;
while total < 16 {
match input.read(&mut nb[total..]) {
Ok(0) => break,
Ok(k) => total += k,
Err(e) => return Err(e),
}
}
Ok(total)
})
.map_err(AescryptError::Io)?;
if n < 16 {
self.ring_buffer.with_secret_mut(|rb| {
next_block.with_secret(|nb| {
rb[self.head_index..self.head_index + n].copy_from_slice(&nb[..n]);
});
});
self.head_index += n;
break;
}
self.ring_buffer.with_secret_mut(|rb| {
next_block.with_secret(|nb| {
rb[self.head_index..self.head_index + 16].copy_from_slice(nb);
});
});
self.head_index += 16;
}
}
Ok(())
}
#[inline(always)]
pub fn advance_tail(&mut self) {
self.tail_index = (self.tail_index + 16) % 64;
}
#[inline(always)]
pub fn remaining(&self) -> usize {
if self.head_index >= self.tail_index {
self.head_index - self.tail_index
} else {
64 - self.tail_index + self.head_index
}
}
}