use std::io;
use std::cmp;
use std::fmt;
use crate::{Error, Result};
use crate::SymmetricAlgorithm;
use crate::vec_resize;
use crate::{
crypto::SessionKey,
parse::Cookie,
};
use buffered_reader::BufferedReader;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum BlockCipherMode {
CFB,
CBC,
ECB,
}
impl BlockCipherMode {
pub fn requires_padding(&self) -> bool {
match self {
BlockCipherMode::CFB => false,
BlockCipherMode::CBC => true,
BlockCipherMode::ECB => true,
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum PaddingMode {
None,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum UnpaddingMode {
None,
}
pub(crate) trait Context: Send + Sync {
fn encrypt(
&mut self,
dst: &mut [u8],
src: &[u8],
) -> Result<()>;
fn decrypt(
&mut self,
dst: &mut [u8],
src: &[u8],
) -> Result<()>;
}
pub(crate) struct InternalDecryptor<'a> {
source: Box<dyn BufferedReader<Cookie> + 'a>,
mode: BlockCipherMode,
padding: UnpaddingMode,
dec: Box<dyn Context>,
block_size: usize,
buffer: Vec<u8>,
}
assert_send_and_sync!(InternalDecryptor<'_>);
impl<'a> InternalDecryptor<'a> {
pub fn new<R>(algo: SymmetricAlgorithm,
mode: BlockCipherMode,
padding: UnpaddingMode,
key: &SessionKey,
iv: Option<&[u8]>,
source: R)
-> Result<Self>
where
R: BufferedReader<Cookie> + 'a,
{
use crate::crypto::backend::{Backend, interface::Symmetric};
let block_size = algo.block_size()?;
let dec = Backend::decryptor(algo, mode, key.as_protected(), iv)?;
Ok(InternalDecryptor {
source: source.into_boxed(),
mode,
padding,
dec,
block_size,
buffer: Vec::with_capacity(block_size),
})
}
}
impl<'a> io::Read for InternalDecryptor<'a> {
fn read(&mut self, plaintext: &mut [u8]) -> io::Result<usize> {
let mut pos = 0;
if !self.buffer.is_empty() {
let to_copy = cmp::min(self.buffer.len(), plaintext.len());
plaintext[..to_copy].copy_from_slice(&self.buffer[..to_copy]);
crate::vec_drain_prefix(&mut self.buffer, to_copy);
pos = to_copy;
}
if pos == plaintext.len() {
return Ok(pos);
}
let mut to_copy
= ((plaintext.len() - pos) / self.block_size) * self.block_size;
let result = self.source.data_consume(to_copy);
let short_read;
let ciphertext = match result {
Ok(data) => {
short_read = data.len() < to_copy;
to_copy = data.len().min(to_copy);
&data[..to_copy]
},
Err(_) if pos > 0 => return Ok(pos),
Err(e) => return Err(e),
};
if ! ciphertext.is_empty() {
match self.padding {
UnpaddingMode::None => if self.mode.requires_padding()
&& ciphertext.len() % self.block_size > 0
{
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
Error::InvalidOperation(
"incomplete last block".into())));
},
}
self.dec.decrypt(&mut plaintext[pos..pos + to_copy],
ciphertext)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput,
format!("{}", e)))?;
match self.padding {
UnpaddingMode::None => (),
}
pos += to_copy;
}
if short_read || pos == plaintext.len() {
return Ok(pos);
}
let mut to_copy = plaintext.len() - pos;
assert!(0 < to_copy);
assert!(to_copy < self.block_size);
let to_read = self.block_size;
let result = self.source.data_consume(to_read);
let ciphertext = match result {
Ok(data) => {
to_copy = cmp::min(to_copy, data.len());
&data[..data.len().min(to_read)]
},
Err(_) if pos > 0 => return Ok(pos),
Err(e) => return Err(e),
};
assert!(ciphertext.len() <= self.block_size);
if ciphertext.is_empty() {
return Ok(pos);
}
vec_resize(&mut self.buffer, ciphertext.len());
match self.padding {
UnpaddingMode::None => if self.mode.requires_padding()
&& ciphertext.len() % self.block_size > 0
{
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
Error::InvalidOperation(
"incomplete last block".into())));
},
}
self.dec.decrypt(&mut self.buffer, ciphertext)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput,
format!("{}", e)))?;
match self.padding {
UnpaddingMode::None => (),
}
plaintext[pos..pos + to_copy].copy_from_slice(&self.buffer[..to_copy]);
crate::vec_drain_prefix(&mut self.buffer, to_copy);
pos += to_copy;
Ok(pos)
}
}
pub struct Decryptor<'a> {
reader: buffered_reader::Generic<InternalDecryptor<'a>, Cookie>,
}
impl<'a> Decryptor<'a> {
pub fn new<R>(algo: SymmetricAlgorithm,
mode: BlockCipherMode,
padding: UnpaddingMode,
key: &SessionKey,
iv: Option<&[u8]>,
source: R)
-> Result<Self>
where
R: BufferedReader<Cookie> + 'a,
{
Self::with_cookie(
algo, mode, padding, key, iv, source, Default::default())
}
pub fn with_cookie<R>(algo: SymmetricAlgorithm,
mode: BlockCipherMode,
padding: UnpaddingMode,
key: &SessionKey,
iv: Option<&[u8]>,
reader: R,
cookie: Cookie)
-> Result<Self>
where
R: BufferedReader<Cookie> + 'a,
{
Ok(Decryptor {
reader: buffered_reader::Generic::with_cookie(
InternalDecryptor::new(algo, mode, padding, key, iv, reader)?,
None, cookie),
})
}
}
impl<'a> io::Read for Decryptor<'a> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.reader.read(buf)
}
}
impl<'a> fmt::Display for Decryptor<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Decryptor")
}
}
impl<'a> fmt::Debug for Decryptor<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Decryptor")
.field("reader", &self.get_ref().unwrap())
.finish()
}
}
impl<'a> BufferedReader<Cookie> for Decryptor<'a> {
fn buffer(&self) -> &[u8] {
self.reader.buffer()
}
fn data(&mut self, amount: usize) -> io::Result<&[u8]> {
self.reader.data(amount)
}
fn data_hard(&mut self, amount: usize) -> io::Result<&[u8]> {
self.reader.data_hard(amount)
}
fn data_eof(&mut self) -> io::Result<&[u8]> {
self.reader.data_eof()
}
fn consume(&mut self, amount: usize) -> &[u8] {
self.reader.consume(amount)
}
fn data_consume(&mut self, amount: usize)
-> io::Result<&[u8]> {
self.reader.data_consume(amount)
}
fn data_consume_hard(&mut self, amount: usize) -> io::Result<&[u8]> {
self.reader.data_consume_hard(amount)
}
fn read_be_u16(&mut self) -> io::Result<u16> {
self.reader.read_be_u16()
}
fn read_be_u32(&mut self) -> io::Result<u32> {
self.reader.read_be_u32()
}
fn steal(&mut self, amount: usize) -> io::Result<Vec<u8>> {
self.reader.steal(amount)
}
fn steal_eof(&mut self) -> io::Result<Vec<u8>> {
self.reader.steal_eof()
}
fn get_mut(&mut self) -> Option<&mut dyn BufferedReader<Cookie>> {
Some(&mut self.reader.reader_mut().source)
}
fn get_ref(&self) -> Option<&dyn BufferedReader<Cookie>> {
Some(&self.reader.reader_ref().source)
}
fn into_inner<'b>(self: Box<Self>)
-> Option<Box<dyn BufferedReader<Cookie> + 'b>> where Self: 'b {
Some(self.reader.into_reader().source.into_boxed())
}
fn cookie_set(&mut self, cookie: Cookie) -> Cookie {
self.reader.cookie_set(cookie)
}
fn cookie_ref(&self) -> &Cookie {
self.reader.cookie_ref()
}
fn cookie_mut(&mut self) -> &mut Cookie {
self.reader.cookie_mut()
}
}
pub struct Encryptor<W: io::Write> {
inner: Option<W>,
mode: BlockCipherMode,
padding: PaddingMode,
cipher: Box<dyn Context>,
block_size: usize,
buffer: Vec<u8>,
scratch: Vec<u8>,
}
assert_send_and_sync!(Encryptor<W> where W: io::Write);
impl<W: io::Write> Encryptor<W> {
pub fn new(algo: SymmetricAlgorithm,
mode: BlockCipherMode,
padding: PaddingMode,
key: &SessionKey,
iv: Option<&[u8]>,
sink: W) -> Result<Self> {
use crate::crypto::backend::{Backend, interface::Symmetric};
let block_size = algo.block_size()?;
let cipher =
Backend::encryptor(algo, mode, key.as_protected(), iv)?;
Ok(Encryptor {
inner: Some(sink),
mode,
padding,
cipher,
block_size,
buffer: Vec::with_capacity(block_size),
scratch: vec![0; 4096],
})
}
pub fn finalize(mut self) -> Result<W> {
self.finalize_intern()
}
fn finalize_intern(&mut self) -> Result<W> {
if let Some(mut inner) = self.inner.take() {
if !self.buffer.is_empty() {
let n = self.buffer.len();
assert!(n < self.block_size);
match self.padding {
PaddingMode::None => if self.mode.requires_padding()
{
return Err(Error::InvalidOperation(
"incomplete last block".into())
.into());
},
}
self.cipher.encrypt(&mut self.scratch[..n], &self.buffer)?;
match self.padding {
PaddingMode::None => (),
}
crate::vec_truncate(&mut self.buffer, 0);
inner.write_all(&self.scratch[..n])?;
crate::vec_truncate(&mut self.scratch, 0);
}
Ok(inner)
} else {
Err(io::Error::new(io::ErrorKind::BrokenPipe,
"Inner writer was taken").into())
}
}
pub(crate) fn get_ref(&self) -> Option<&W> {
self.inner.as_ref()
}
#[allow(dead_code)]
pub(crate) fn get_mut(&mut self) -> Option<&mut W> {
self.inner.as_mut()
}
}
impl<W: io::Write> io::Write for Encryptor<W> {
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
if self.inner.is_none() {
return Err(io::Error::new(io::ErrorKind::BrokenPipe,
"Inner writer was taken"));
}
let inner = self.inner.as_mut().unwrap();
let amount = buf.len();
if !self.buffer.is_empty() {
let n = cmp::min(buf.len(), self.block_size - self.buffer.len());
self.buffer.extend_from_slice(&buf[..n]);
assert!(self.buffer.len() <= self.block_size);
buf = &buf[n..];
if self.buffer.len() == self.block_size {
self.cipher.encrypt(&mut self.scratch[..self.block_size],
&self.buffer)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput,
format!("{}", e)))?;
crate::vec_truncate(&mut self.buffer, 0);
inner.write_all(&self.scratch[..self.block_size])?;
}
}
let whole_blocks = (buf.len() / self.block_size) * self.block_size;
if whole_blocks > 0 {
if self.scratch.len() < whole_blocks {
vec_resize(&mut self.scratch, whole_blocks);
}
self.cipher.encrypt(&mut self.scratch[..whole_blocks],
&buf[..whole_blocks])
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput,
format!("{}", e)))?;
inner.write_all(&self.scratch[..whole_blocks])?;
}
assert!(buf.is_empty() || self.buffer.is_empty());
self.buffer.extend_from_slice(&buf[whole_blocks..]);
assert!(self.buffer.len() < self.block_size);
Ok(amount)
}
fn flush(&mut self) -> io::Result<()> {
if let Some(ref mut inner) = self.inner {
inner.flush()
} else {
Err(io::Error::new(io::ErrorKind::BrokenPipe,
"Inner writer was taken"))
}
}
}
impl<W: io::Write> Drop for Encryptor<W> {
fn drop(&mut self) {
let _ = self.finalize_intern();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Cursor, Read, Write};
#[test]
fn smoke_test() {
use crate::crypto::mem::Protected;
use crate::crypto::symmetric::BlockCipherMode;
use crate::crypto::backend::{Backend, interface::Symmetric};
use crate::fmt::hex;
let algo = SymmetricAlgorithm::AES128;
let key: Protected =
hex::decode("2b7e151628aed2a6abf7158809cf4f3c").unwrap().into();
assert_eq!(key.len(), 16);
let iv = hex::decode("000102030405060708090A0B0C0D0E0F").unwrap();
let mut cfb =
Backend::encryptor(algo, BlockCipherMode::CFB, &key, Some(&iv)).unwrap();
let msg = hex::decode("6bc1bee22e409f96e93d7e117393172a").unwrap();
let mut dst = vec![0; msg.len()];
cfb.encrypt(&mut dst, &*msg).unwrap();
assert_eq!(&dst[..16], &*hex::decode("3b3fd92eb72dad20333449f8e83cfb4a").unwrap());
let iv = hex::decode("000102030405060708090A0B0C0D0E0F").unwrap();
let mut cfb =
Backend::encryptor(algo, BlockCipherMode::CFB, &key, Some(&iv)).unwrap();
let msg = b"This is a very important message";
let mut dst = vec![0; msg.len()];
cfb.encrypt(&mut dst, &*msg).unwrap();
assert_eq!(&dst, &hex::decode(
"04960ebfb9044196bb29418ce9d6cc0939d5ccb1d0712fa8e45fe5673456fded"
).unwrap());
let iv = hex::decode("000102030405060708090A0B0C0D0E0F").unwrap();
let mut cfb =
Backend::encryptor(algo, BlockCipherMode::CFB, &key, Some(&iv)).unwrap();
let msg = b"This is a very important message!";
let mut dst = vec![0; msg.len()];
cfb.encrypt(&mut dst, &*msg).unwrap();
assert_eq!(&dst, &hex::decode(
"04960ebfb9044196bb29418ce9d6cc0939d5ccb1d0712fa8e45fe5673456fded0b"
).unwrap());
let iv = hex::decode("000102030405060708090A0B0C0D0E0F").unwrap();
let mut cfb =
Backend::encryptor(algo, BlockCipherMode::CFB, &key, Some(&iv)).unwrap();
let mut dst = vec![0; msg.len()];
for (mut dst, msg) in dst.chunks_mut(16).zip(msg.chunks(16)) {
cfb.encrypt(&mut dst, msg).unwrap();
}
assert_eq!(&dst, &hex::decode(
"04960ebfb9044196bb29418ce9d6cc0939d5ccb1d0712fa8e45fe5673456fded0b"
).unwrap());
}
#[test]
fn decryptor() {
for algo in [SymmetricAlgorithm::AES128,
SymmetricAlgorithm::AES192,
SymmetricAlgorithm::AES256].iter() {
let mut key = vec![0u8; algo.key_size().unwrap()];
key[0] = key.len() as u8 - 1;
let key = key.into();
let filename = &format!(
"raw/a-cypherpunks-manifesto.aes{}.key_is_key_len_dec1_as_le",
algo.key_size().unwrap() * 8);
let ciphertext = buffered_reader::Memory::with_cookie(
crate::tests::file(filename), Default::default());
let decryptor = InternalDecryptor::new(
*algo, BlockCipherMode::CFB, UnpaddingMode::None,
&key, None, ciphertext).unwrap();
let mut plaintext = Vec::new();
for b in decryptor.bytes() {
plaintext.push(b.unwrap());
}
assert_eq!(crate::tests::manifesto(), &plaintext[..]);
}
}
#[test]
fn encryptor() {
for algo in [SymmetricAlgorithm::AES128,
SymmetricAlgorithm::AES192,
SymmetricAlgorithm::AES256].iter() {
let mut key = vec![0u8; algo.key_size().unwrap()];
key[0] = key.len() as u8 - 1;
let key = key.into();
let mut ciphertext = Vec::new();
{
let mut encryptor = Encryptor::new(
*algo, BlockCipherMode::CFB, PaddingMode::None,
&key, None, &mut ciphertext).unwrap();
for b in crate::tests::manifesto().chunks(1) {
encryptor.write_all(b).unwrap();
}
}
let filename = format!(
"raw/a-cypherpunks-manifesto.aes{}.key_is_key_len_dec1_as_le",
algo.key_size().unwrap() * 8);
let mut cipherfile = Cursor::new(crate::tests::file(&filename));
let mut reference = Vec::new();
cipherfile.read_to_end(&mut reference).unwrap();
assert_eq!(&reference[..], &ciphertext[..]);
}
}
#[test]
fn roundtrip() {
for algo in SymmetricAlgorithm::variants()
.filter(|x| x.is_supported()) {
for mode in [BlockCipherMode::CFB,
BlockCipherMode::CBC,
BlockCipherMode::ECB] {
eprintln!("Testing {:?}/{:?}", algo, mode);
let bs = algo.block_size().unwrap();
let text = if mode.requires_padding() {
let l = (crate::tests::manifesto().len() / bs) * bs;
&crate::tests::manifesto()[..l]
} else {
crate::tests::manifesto()
};
let key = SessionKey::new(algo.key_size().unwrap()).unwrap();
let mut ciphertext = Vec::new();
let mut encryptor = Encryptor::new(
algo, mode, PaddingMode::None,
&key, None, &mut ciphertext).unwrap();
encryptor.write_all(text).unwrap();
encryptor.finalize().unwrap();
let mut plaintext = Vec::new();
let reader = buffered_reader::Memory::with_cookie(
&ciphertext, Default::default());
let mut decryptor = InternalDecryptor::new(
algo, mode, UnpaddingMode::None,
&key, None, reader).unwrap();
decryptor.read_to_end(&mut plaintext).unwrap();
assert_eq!(&plaintext[..], text);
}
}
}
}