use byteorder::{ReadBytesExt, WriteBytesExt, LE};
use cipher::StreamCipherError;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use std::fmt::Debug;
use std::io::{self, ErrorKind, Read, Seek, SeekFrom, Write};
use std::{collections::HashMap, marker::PhantomData};
use crate::{cipher_factory::*, dyn_cipher::*, error::*};
pub const MAGIC: &[u8; 6] = b"\x03ENARD";
pub const DATA_ALIGNMENT: usize = 8;
const HEADER_START: usize = 6 + 2 + 4 + 8;
pub type MetaMap = HashMap<Vec<u8>, Vec<u8>>;
type HmacV1 = Hmac<Sha256>;
pub struct EnardReader<R: Read + Seek, C: DynCipher> {
inner: R,
cipher: C,
data_start: u64,
data_size: u64,
current: u64,
meta: MetaMap,
}
impl<R, C> EnardReader<R, C>
where
R: Read + Seek,
C: DynCipher,
{
pub fn new<Cf: CipherFactory<C>>(
reader: R,
factory: Cf,
key: &[u8],
) -> Result<Self, EnardError> {
EnardBuilder::new(reader, factory, key).build()
}
pub fn meta(&self) -> &HashMap<Vec<u8>, Vec<u8>> {
&self.meta
}
pub fn into_inner(self) -> R {
self.inner
}
}
impl<R> EnardReader<R, BoxDynCipher>
where
R: Read + Seek,
{
pub fn new_boxed(reader: R, key: &[u8]) -> Result<Self, EnardError> {
Self::new(reader, BoxDynCipher::factory(), key)
}
}
impl<R, C> Debug for EnardReader<R, C>
where
R: Read + Seek + Debug,
C: DynCipher,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EnardReader")
.field("inner", &self.inner)
.field("cipher", &self.cipher.get_name())
.field("data_start", &self.data_start)
.field("data_size", &self.data_size)
.field("current", &self.current)
.field("meta", &self.meta)
.finish()
}
}
impl<R, C> Read for EnardReader<R, C>
where
R: Read + Seek,
C: DynCipher,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let limit = buf.len().min((self.data_size - self.current) as usize);
let n = self.inner.read(&mut buf[0..limit])?;
self.current += n as u64;
self.cipher
.try_apply_keystream(&mut buf[0..n])
.map_err(cipher_to_io_error)?;
Ok(n)
}
}
impl<R, C> Seek for EnardReader<R, C>
where
R: Read + Seek,
C: DynCipher,
{
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
let new_pos_raw = match pos {
SeekFrom::Current(rel) => self.current as i64 + rel,
SeekFrom::Start(pos) => pos as i64,
SeekFrom::End(rel) => self.data_size as i64 + rel,
};
if new_pos_raw < 0 || new_pos_raw > self.data_size as i64 {
let msg = format!(
"invalid seek to a negative or overflowing position: {:?}",
pos
);
return Err(io::Error::new(ErrorKind::InvalidInput, msg));
}
let new_pos = new_pos_raw as u64;
self.inner
.seek(SeekFrom::Start(self.data_start + new_pos))?;
self.cipher.try_seek(new_pos).map_err(cipher_to_io_error)?;
self.current = new_pos;
Ok(new_pos)
}
fn stream_position(&mut self) -> io::Result<u64> {
Ok(self.current)
}
}
fn cipher_to_io_error(e: StreamCipherError) -> io::Error {
io::Error::new(ErrorKind::Other, format!("{:?}", e))
}
pub(crate) struct EnardBuilder<R, C, Cf> {
reader: R,
factory: Cf,
key: Vec<u8>,
phantom: PhantomData<C>,
}
impl<R, C, Cf> EnardBuilder<R, C, Cf>
where
R: Read + Seek,
C: DynCipher,
Cf: CipherFactory<C>,
{
pub fn new(reader: R, factory: Cf, key: &[u8]) -> Self {
let key = Vec::from(key);
let phantom = PhantomData;
Self {
reader,
factory,
key,
phantom,
}
}
pub fn build(mut self) -> Result<EnardReader<R, C>, EnardError> {
let mut magic_buf = [0u8; MAGIC.len()];
self.reader.read_exact(&mut magic_buf)?;
if &magic_buf != MAGIC {
return Err(EnardError::new_invalid_magic(MAGIC, &magic_buf).into());
}
let version = self.reader.read_u16::<LE>()?;
match version {
1 => self.read_v1(),
_ => Err(EnardError::UnsupportedVersion { version }.into()),
}
}
fn read_v1(mut self) -> Result<EnardReader<R, C>, EnardError> {
let header_size = self.reader.read_u32::<LE>()?;
let data_size = self.reader.read_u64::<LE>()?;
let header_start = self.reader.stream_position()?;
let data_start = header_start + header_size as u64;
Self::verify_mac(&mut self.reader, &self.key, header_size as u64 + data_size)?;
self.reader.seek(SeekFrom::Start(header_start))?;
let cipher_kind = Self::read_u8_block(&mut self.reader)?;
let cipher_iv = Self::read_u8_block(&mut self.reader)?;
let cipher = self.factory.create(&cipher_kind, &self.key, &cipher_iv)?;
let meta = Self::read_meta_blocks(&mut self.reader, header_size as u64)?;
self.reader.seek(SeekFrom::Start(data_start))?;
Ok(EnardReader {
inner: self.reader,
cipher,
data_start,
data_size,
current: 0,
meta,
})
}
fn verify_mac<R2: Read>(mut reader: R2, key: &[u8], data_size: u64) -> Result<(), EnardError> {
let mut rd = (&mut reader).take(data_size);
let mut mac = HmacV1::new_from_slice(key)?;
io::copy(&mut rd, &mut mac)?;
let mut tag_buf = [0u8; 32];
reader.read_exact(&mut tag_buf)?;
mac.verify_slice(&tag_buf)?;
Ok(())
}
pub fn read_vec<R2: Read>(mut reader: R2, size: usize) -> io::Result<Vec<u8>> {
let mut b_buf = vec![0u8; size];
reader.read_exact(&mut b_buf)?;
Ok(b_buf)
}
pub fn read_u8_block<R2: Read>(mut reader: R2) -> io::Result<Vec<u8>> {
let size = reader.read_u8()? as usize;
Self::read_vec(reader, size)
}
pub fn read_u16_block<R2: Read>(mut reader: R2, limit: usize) -> Result<Vec<u8>, EnardError> {
let size = reader.read_u16::<LE>()? as usize;
if size > limit {
return Err(EnardError::new_block_size(size as u64, limit as u64).into());
}
Ok(Self::read_vec(reader, size)?)
}
pub fn read_meta_blocks<R2: Read>(
mut reader: R2,
max_size: u64,
) -> Result<MetaMap, EnardError> {
let mut result = HashMap::new();
let count = reader.read_u8()? as usize;
for _ in 0..count {
let key = Self::read_u8_block(&mut reader)?;
let value = Self::read_u16_block(&mut reader, max_size as usize)?;
result.insert(key, value);
}
Ok(result)
}
}
pub struct EnardWriter<W, C> {
inner: W,
cipher: C,
iv: Vec<u8>,
mac: Option<HmacV1>,
start_pos: u64,
meta: Option<MetaMap>,
header_size: u32,
crypt_buf: Vec<u8>,
}
impl<'a, W, C> EnardWriter<W, C>
where
W: Write + Seek,
C: DynCipher,
{
pub fn new<Cf: CipherFactory<C>>(
inner: W,
factory: Cf,
name: &[u8],
key: &[u8],
iv: &[u8],
meta: MetaMap,
) -> Result<Self, EnardError> {
let cipher = factory.create(name, key, iv)?;
Ok(Self {
inner,
cipher,
iv: Vec::from(iv),
mac: Some(HmacV1::new_from_slice(key)?),
start_pos: 0,
meta: Some(meta),
header_size: 0,
crypt_buf: vec![0u8; 256],
})
}
pub fn write_complete(&mut self, mut rd: impl Read) -> io::Result<u64> {
let mut n = self.write_header()? as u64;
n += io::copy(&mut rd, self)?;
n += self.finish()? as u64;
Ok(n)
}
pub fn write_header(&mut self) -> io::Result<usize> {
self.write_header_v1()?;
Ok(HEADER_START + self.header_size as usize)
}
pub fn finish(&mut self) -> io::Result<usize> {
self.finish_v1()
}
pub fn into_inner(self) -> W {
self.inner
}
fn write_header_v1(&mut self) -> io::Result<()> {
self.start_pos = self.inner.stream_position()?;
self.inner.write(MAGIC)?;
self.inner.write_u16::<LE>(1)?;
self.inner.write(&[0u8; 4 + 8])?;
let mut hs = 0;
let name = self.cipher.get_name();
let iv = self.iv.clone();
hs += self.write_u8_block(name)?;
hs += self.write_u8_block(&iv)?;
hs += self.write_meta_blocks()?;
let data_start = hs + HEADER_START;
let padding = (DATA_ALIGNMENT - (data_start % DATA_ALIGNMENT)) % DATA_ALIGNMENT;
let pad_buf = [0u8; DATA_ALIGNMENT];
self.mac_write(&pad_buf[0..padding])?;
hs += padding;
self.header_size = hs as u32;
Ok(())
}
fn mac_write(&mut self, b: &[u8]) -> io::Result<()> {
self.inner.write_all(b)?;
self.mac.as_mut().unwrap().update(b);
Ok(())
}
fn finish_v1(&mut self) -> io::Result<usize> {
let data_start = self.start_pos + (self.header_size + HEADER_START as u32) as u64;
let data_len = self.inner.stream_position()? - data_start;
let tag = self.mac.take().unwrap().finalize_reset().into_bytes();
self.inner.write_all(&tag)?;
let end_pos = self.inner.stream_position()?;
self.inner.seek(SeekFrom::Start(self.start_pos + 6 + 2))?;
self.inner.write_u32::<LE>(self.header_size)?;
self.inner.write_u64::<LE>(data_len)?;
self.inner.seek(SeekFrom::Start(end_pos))?;
self.flush()?;
Ok(tag.len())
}
fn write_u8_block(&mut self, block: &[u8]) -> io::Result<usize> {
Self::block_size_check(block, u8::MAX as usize)?;
let blen = block.len() as u8;
self.mac_write(&blen.to_le_bytes())?;
self.mac_write(block)?;
Ok(1 + block.len())
}
fn block_size_check(block: &[u8], size: usize) -> io::Result<()> {
if block.len() >= size {
let msg = format!("block size must be 0-{}, is {}", size - 1, block.len());
Err(io::Error::new(ErrorKind::Other, msg))
} else {
Ok(())
}
}
fn write_meta_blocks(&mut self) -> io::Result<usize> {
let meta = self.meta.take().unwrap();
let mut n = 0;
let count = meta.len() as u8;
self.mac_write(&count.to_le_bytes())?;
n += 1;
for (key, val) in meta.iter() {
Self::block_size_check(&key, u8::MAX as usize)?;
Self::block_size_check(&val, u16::MAX as usize)?;
let k_len = key.len() as u8;
let v_len = val.len() as u16;
self.mac_write(&k_len.to_le_bytes())?;
self.mac_write(&key)?;
self.mac_write(&v_len.to_le_bytes())?;
self.mac_write(&val)?;
n += 1 + key.len() + 2 + val.len();
}
self.meta = Some(meta);
Ok(n)
}
}
impl<W, C> Write for EnardWriter<W, C>
where
W: Write + Seek,
C: DynCipher,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let b_size = self.crypt_buf.len();
for chunk in buf.chunks(b_size) {
let cbuf = &mut self.crypt_buf[0..b_size.min(chunk.len())];
cbuf.clone_from_slice(chunk);
self.cipher
.try_apply_keystream(cbuf)
.map_err(cipher_to_io_error)?;
self.inner.write_all(cbuf)?;
self.mac.as_mut().unwrap().update(cbuf);
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<W, C> Debug for EnardWriter<W, C>
where
W: Write + Seek + Debug,
C: DynCipher,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EnardWriter")
.field("inner", &self.inner)
.field("cipher", &self.cipher.get_name())
.field("iv", &self.iv)
.field("mac", &self.mac)
.field("start_pos", &self.start_pos)
.field("meta", &self.meta)
.field("header_size", &self.header_size)
.field("crypt_buf", &self.crypt_buf)
.finish()
}
}