use arrayvec::ArrayVec;
use cipher::block_padding::Padding;
use cipher::{Block, BlockCipherEncrypt, BlockModeEncrypt, BlockSizeUser, KeyIvInit};
use std::io::{self, Write};
use std::marker::PhantomData;
pub(crate) struct CbcBlockCipherEncryptWriter<W, C, P>
where
C: BlockCipherEncrypt,
cbc::Encryptor<C>: BlockModeEncrypt,
P: Padding,
{
w: W,
c: cbc::Encryptor<C>,
padding: PhantomData<P>,
buf: ArrayVec<u8, 16>,
}
impl<W, C, P> CbcBlockCipherEncryptWriter<W, C, P>
where
W: Write,
C: BlockCipherEncrypt,
cbc::Encryptor<C>: BlockModeEncrypt,
P: Padding,
cbc::Encryptor<C>: KeyIvInit,
{
pub(crate) fn new(w: W, key: &[u8], iv: &[u8]) -> io::Result<Self> {
debug_assert_eq!(cbc::Encryptor::<C>::block_size(), 16);
Ok(Self {
w,
c: cbc::Encryptor::<C>::new_from_slices(key, iv)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?,
padding: PhantomData,
buf: ArrayVec::new(),
})
}
}
impl<W, C, P> CbcBlockCipherEncryptWriter<W, C, P>
where
W: Write,
C: BlockCipherEncrypt,
cbc::Encryptor<C>: BlockModeEncrypt,
P: Padding,
{
fn encrypt_write_with_padding(mut self) -> io::Result<W> {
let pos = self.buf.len();
unsafe { self.buf.set_len(cbc::Encryptor::<C>::block_size()) };
let block = <&mut Block<cbc::Encryptor<C>>>::try_from(self.buf.as_mut_slice())
.expect("buf length equals block size");
P::pad(block, pos);
self.c.encrypt_block_inout(block.into());
self.w.write_all(block.as_slice())?;
Ok(self.w)
}
}
impl<W, C, P> Write for CbcBlockCipherEncryptWriter<W, C, P>
where
W: Write,
C: BlockCipherEncrypt,
cbc::Encryptor<C>: BlockModeEncrypt,
P: Padding,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let block_size = cbc::Encryptor::<C>::block_size();
if buf.len() + self.buf.len() < block_size {
self.buf
.try_extend_from_slice(buf)
.expect("buffer capacity exceeded: crypto block size invariant violated");
return Ok(buf.len());
}
let mut total_written = 0;
if !self.buf.is_empty() {
let remaining = block_size - self.buf.len();
self.buf
.try_extend_from_slice(&buf[..remaining])
.expect("buffer capacity exceeded: crypto block size invariant violated");
total_written += remaining;
let inout_block = <&mut Block<cbc::Encryptor<C>>>::try_from(self.buf.as_mut_slice())
.expect("buf length equals block size");
self.c.encrypt_block_inout(inout_block.into());
self.w.write_all(inout_block.as_slice())?;
self.buf.clear();
}
let mut out_block = Block::<cbc::Encryptor<C>>::default();
let chunks = buf[total_written..].chunks_exact(block_size);
let remainder = chunks.remainder();
for b in chunks {
let in_block = <&Block<cbc::Encryptor<C>>>::try_from(b)
.expect("chunks_exact yields exact block-sized slices");
self.c.encrypt_block_b2b(in_block, &mut out_block);
self.w.write_all(out_block.as_slice())?;
total_written += b.len();
}
self.buf
.try_extend_from_slice(remainder)
.expect("buffer capacity exceeded: crypto block size invariant violated");
total_written += remainder.len();
Ok(total_written)
}
fn flush(&mut self) -> io::Result<()> {
self.w.flush()
}
}
impl<W, C, P> CbcBlockCipherEncryptWriter<W, C, P>
where
W: Write,
C: BlockCipherEncrypt,
cbc::Encryptor<C>: BlockModeEncrypt,
P: Padding,
{
#[inline]
pub(crate) fn get_mut(&mut self) -> &mut W {
&mut self.w
}
pub(crate) fn finish(self) -> io::Result<W> {
self.encrypt_write_with_padding()
}
}
#[cfg(test)]
mod tests {
use super::*;
use cipher::block_padding::Pkcs7;
#[cfg(all(target_family = "wasm", target_os = "unknown"))]
use wasm_bindgen_test::wasm_bindgen_test as test;
#[test]
fn write_encrypt() {
let key = [0x42; 16];
let iv = [0x24; 16];
let plaintext = *b"hello world! this is my plaintext.";
let ciphertext = [
199u8, 254, 36, 126, 249, 123, 33, 240, 124, 189, 210, 108, 181, 211, 70, 191, 210,
120, 103, 203, 0, 217, 72, 103, 35, 225, 89, 151, 143, 185, 165, 249, 20, 207, 178, 40,
167, 16, 222, 65, 113, 227, 150, 231, 182, 207, 133, 158,
];
let mut writer =
CbcBlockCipherEncryptWriter::<_, aes::Aes128, Pkcs7>::new(Vec::new(), &key, &iv)
.unwrap();
for p in plaintext.chunks(7) {
writer.write_all(p).unwrap();
}
let ct = writer.finish().unwrap();
assert_eq!(&ct[..], &ciphertext[..]);
}
#[test]
fn write_len() {
let key = [0x42; 16];
let iv = [0x24; 16];
let plaintext = b"hello world! this is my plaintext.".repeat(1024);
let mut writer =
CbcBlockCipherEncryptWriter::<_, aes::Aes128, Pkcs7>::new(Vec::new(), &key, &iv)
.unwrap();
for p in plaintext.chunks(13) {
assert_eq!(writer.write(p).unwrap(), p.len());
}
writer.finish().unwrap();
}
}