#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use byteorder::{BigEndian, ByteOrder};
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
use super::super::BlockCipher;
use crate::error::{validate, Result};
use crate::types::nonce::AesCtrCompatible;
use crate::types::Nonce;
use dcrypt_common::security::barrier;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CounterPosition {
Prefix,
Postfix,
Custom(usize),
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct Ctr<B: BlockCipher + Zeroize> {
cipher: B,
counter_block: Zeroizing<Vec<u8>>,
counter_position: usize,
counter_size: usize,
keystream: Zeroizing<Vec<u8>>,
keystream_pos: usize,
}
impl<B: BlockCipher + Zeroize> Ctr<B> {
pub fn new<const N: usize>(cipher: B, nonce: &Nonce<N>) -> Result<Self>
where
Nonce<N>: AesCtrCompatible,
{
Self::with_counter_params(cipher, nonce, CounterPosition::Postfix, 4)
}
pub fn with_counter_params<const N: usize>(
cipher: B,
nonce: &Nonce<N>,
counter_pos: CounterPosition,
counter_size: usize,
) -> Result<Self>
where
Nonce<N>: AesCtrCompatible,
{
let block_size = B::block_size();
validate::parameter(
counter_size > 0 && counter_size <= 8,
"counter_size",
"Counter size must be between 1 and 8 bytes",
)?;
let position = match counter_pos {
CounterPosition::Prefix => 0,
CounterPosition::Postfix => block_size - counter_size,
CounterPosition::Custom(offset) => {
validate::parameter(
offset + counter_size <= block_size,
"counter_position",
"Counter with specified size doesn't fit at offset in block",
)?;
offset
}
};
let mut counter_block = Zeroizing::new(vec![0u8; block_size]);
let max_nonce_size = block_size - counter_size;
let effective_nonce = if N > max_nonce_size {
&nonce.as_ref()[0..max_nonce_size]
} else {
nonce.as_ref()
};
if position == 0 {
counter_block[counter_size..counter_size + effective_nonce.len()]
.copy_from_slice(effective_nonce);
} else {
counter_block[0..effective_nonce.len()].copy_from_slice(effective_nonce);
}
Ok(Self {
cipher,
counter_block,
counter_position: position,
counter_size,
keystream: Zeroizing::new(Vec::new()),
keystream_pos: 0,
})
}
fn generate_keystream(&mut self) -> Result<()> {
let block_size = B::block_size();
self.keystream = Zeroizing::new(vec![0u8; block_size]);
barrier::compiler_fence_seq_cst();
self.keystream.copy_from_slice(&self.counter_block);
self.cipher.encrypt_block(&mut self.keystream)?;
self.increment_counter();
self.keystream_pos = 0;
barrier::compiler_fence_seq_cst();
Ok(())
}
fn increment_counter(&mut self) {
match self.counter_size {
8 => {
let mut counter = [0u8; 8];
counter.copy_from_slice(
&self.counter_block[self.counter_position..self.counter_position + 8],
);
let value = BigEndian::read_u64(&counter);
BigEndian::write_u64(&mut counter, value.wrapping_add(1));
self.counter_block[self.counter_position..self.counter_position + 8]
.copy_from_slice(&counter);
counter.zeroize();
}
4 => {
let mut counter = [0u8; 4];
counter.copy_from_slice(
&self.counter_block[self.counter_position..self.counter_position + 4],
);
let value = BigEndian::read_u32(&counter);
BigEndian::write_u32(&mut counter, value.wrapping_add(1));
self.counter_block[self.counter_position..self.counter_position + 4]
.copy_from_slice(&counter);
counter.zeroize();
}
size => {
let mut value: u64 = 0;
for i in 0..size {
value = (value << 8) | (self.counter_block[self.counter_position + i] as u64);
}
value = value.wrapping_add(1);
for i in 0..size {
self.counter_block[self.counter_position + size - 1 - i] = (value & 0xff) as u8;
value >>= 8;
}
}
}
}
pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
let mut ciphertext = Vec::with_capacity(plaintext.len());
barrier::compiler_fence_seq_cst();
for &byte in plaintext {
if self.keystream_pos >= self.keystream.len() {
self.generate_keystream()?;
}
ciphertext.push(byte ^ self.keystream[self.keystream_pos]);
self.keystream_pos += 1;
}
barrier::compiler_fence_seq_cst();
Ok(ciphertext)
}
pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> {
self.encrypt(ciphertext)
}
pub fn process(&mut self, data: &mut [u8]) -> Result<()> {
barrier::compiler_fence_seq_cst();
for byte in data.iter_mut() {
if self.keystream_pos >= self.keystream.len() {
self.generate_keystream()?;
}
*byte ^= self.keystream[self.keystream_pos];
self.keystream_pos += 1;
}
barrier::compiler_fence_seq_cst();
Ok(())
}
pub fn keystream(&mut self, output: &mut [u8]) -> Result<()> {
for byte in output.iter_mut() {
*byte = 0;
}
self.keystream_pos = self.keystream.len();
self.process(output)
}
pub fn seek(&mut self, block_offset: u32) {
let mut counter_value = [0u8; 8];
BigEndian::write_u32(&mut counter_value[4..], block_offset.wrapping_add(1));
for i in 0..self.counter_size {
let idx = self.counter_position + self.counter_size - 1 - i;
self.counter_block[idx] = counter_value[7 - i];
}
self.keystream_pos = self.keystream.len();
self.keystream = Zeroizing::new(Vec::new());
counter_value.zeroize();
}
pub fn set_counter(&mut self, counter: u32) {
let counter_pos = self.counter_position;
let counter_bytes = counter.to_be_bytes();
let start_idx = 4 - self.counter_size;
for i in 0..self.counter_size {
if start_idx + i < 4 {
self.counter_block[counter_pos + i] = counter_bytes[start_idx + i];
}
}
self.keystream_pos = self.keystream.len();
}
pub fn reset<const N: usize>(&mut self, nonce: Option<&Nonce<N>>, counter: u32) -> Result<()>
where
Nonce<N>: AesCtrCompatible,
{
barrier::compiler_fence_seq_cst();
if let Some(new_nonce) = nonce {
let block_size = B::block_size();
let max_nonce_size = block_size - self.counter_size;
let effective_nonce = if N > max_nonce_size {
&new_nonce.as_ref()[0..max_nonce_size]
} else {
new_nonce.as_ref()
};
for b in &mut *self.counter_block {
*b = 0;
}
let counter_pos = match self.counter_position {
0 => self.counter_size, _ => 0, };
self.counter_block[counter_pos..counter_pos + effective_nonce.len()]
.copy_from_slice(effective_nonce);
}
self.set_counter(counter);
self.keystream = Zeroizing::new(Vec::new());
self.keystream_pos = 0;
barrier::compiler_fence_seq_cst();
Ok(())
}
}
#[cfg(test)]
mod tests;