use crate::{
bytes::{ByteArray, ByteCursor, ByteCursorMut},
error::{Error, ErrorKind},
key::{ParentKey, ParentKeyId},
};
use aes_gcm::{
aead::{
generic_array::typenum::{Unsigned, U12},
AeadInPlace,
},
aes::Aes256,
AesGcm, KeyInit as AesKeyInit, Nonce as AesNonce, Tag,
};
use alloc::vec::Vec;
use core::{fmt, ops::Range};
use hkdf::Hkdf;
use sha2::Sha256;
#[cfg(feature = "std")]
use std::io::{Read, Seek, SeekFrom, Write};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
const KEY_LEN_256: usize = 32;
pub type HeaderNonce = ByteArray<16>;
type HeaderMac = ByteArray<32>;
const HEADER_LEN: usize = MAGIC_NUM_LEN
+ VERSION_LEN
+ ALGO_LEN
+ FRAME_LEN_LEN
+ RESERVED_LEN
+ HeaderNonce::LEN
+ ParentKeyId::LEN
+ HeaderMac::LEN;
pub type HeaderBytes = ByteArray<HEADER_LEN>;
const MAGIC_NUM_LEN: usize = 4;
const VERSION_LEN: usize = 1;
const ALGO_LEN: usize = 2;
const FRAME_LEN_LEN: usize = 1;
const RESERVED_LEN: usize = 8;
const MAGIC_NUM_OFFSET: usize = 0;
const VERSION_OFFSET: usize = MAGIC_NUM_OFFSET + MAGIC_NUM_LEN;
const ALGO_OFFSET: usize = VERSION_OFFSET + VERSION_LEN;
const FRAME_LEN_OFFSET: usize = ALGO_OFFSET + ALGO_LEN;
const RESERVED_OFFSET: usize = FRAME_LEN_OFFSET + FRAME_LEN_LEN;
const NONCE_OFFSET: usize = RESERVED_OFFSET + RESERVED_LEN;
const KEY_ID_OFFSET: usize = NONCE_OFFSET + HeaderNonce::LEN;
const HEADER_MAC_OFFSET: usize = KEY_ID_OFFSET + ParentKeyId::LEN;
const HEADER_KEY_ID_RANGE: Range<usize> = KEY_ID_OFFSET..KEY_ID_OFFSET + ParentKeyId::LEN;
const HEADER_MAC_RANGE: Range<usize> = HEADER_MAC_OFFSET..HEADER_MAC_OFFSET + HeaderMac::LEN;
const HKDF_INFO_RANGE: Range<usize> = 0..NONCE_OFFSET;
const HKDF_SALT_RANGE: Range<usize> =
NONCE_OFFSET..NONCE_OFFSET + HeaderNonce::LEN + ParentKeyId::LEN;
const SEQ_NUM_LEN: usize = 4;
const INVOCATION_LEN: usize = 8;
const END_LEN: usize = 4;
const FRAME_TAG_LEN: usize = 16;
const FRAME_META_LEN: usize = FRAME_TAG_LEN + SEQ_NUM_LEN + END_LEN + INVOCATION_LEN;
const FRAME_HEADER_LEN: usize = SEQ_NUM_LEN + END_LEN + INVOCATION_LEN;
type FrameNonceLen = U12;
const FRAME_NONCE_LEN: usize = FrameNonceLen::USIZE;
const SEQ_NUM_OFFSET: usize = 0;
const INVOCATION_OFFSET: usize = SEQ_NUM_OFFSET + SEQ_NUM_LEN;
const END_LEN_OFFSET: usize = INVOCATION_OFFSET + INVOCATION_LEN;
const PAYLOAD_OFFSET: usize = END_LEN_OFFSET + END_LEN;
const DATA_KEY_LEN: usize = KEY_LEN_256;
const MAGIC_NUM: u32 = 0x6d797a2e;
const VERSION: u8 = 1;
type Aes256Gcm = AesGcm<Aes256, FrameNonceLen>;
#[repr(u16)]
#[derive(Debug, PartialEq)]
enum CryptoAlgorithm {
Aes256GcmHkdfSha256 = 0,
}
#[repr(u8)]
#[derive(Debug, PartialEq, Clone, Copy, Default)]
pub enum FrameLength {
Len4KiB = 12,
Len8KiB = 13,
#[default]
Len16KiB = 14,
Len32KiB = 15,
Len64KiB = 16,
}
#[derive(PartialEq, Clone, Debug)]
pub struct Header {
frame_len: FrameLength,
data_key: aes_gcm::Key<Aes256Gcm>,
bytes: HeaderBytes,
}
pub struct HeaderBuilder<'a> {
parent_key: &'a ParentKey,
nonce: &'a HeaderNonce,
frame_len: FrameLength,
}
#[derive(Default)]
pub struct FrameHeader {
seq_num: u32,
invocation: u64,
is_end: bool,
}
pub struct FrameHeaderBuilder {
seq_num: u32,
invocation: u64,
is_end: bool,
}
pub struct FrameBuf {
buf: Vec<u8>,
frame_len: usize,
max_payload_len: usize,
max_payload_pos: usize,
payload_len: usize,
cipher: Aes256Gcm,
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub struct ZymicStream<T> {
seq_num: u32,
start_seq_num: u32,
invocation: u64,
payload_pos: usize,
end_len: Option<usize>,
frame_buf: FrameBuf,
inner: T,
}
fn derive_data_key(
parent_key: &ParentKey,
salt: &[u8],
info: &[u8],
) -> (HeaderMac, aes_gcm::Key<Aes256>) {
let mut hkdf_out = [0u8; HeaderMac::LEN + DATA_KEY_LEN];
let hkdf = Hkdf::<Sha256>::new(Some(salt), parent_key.secret());
hkdf.expand(info, &mut hkdf_out).expect("hdkf expansion");
let digest = HeaderMac::from(&hkdf_out[..HeaderMac::LEN]);
let mut data_key = aes_gcm::Key::<Aes256Gcm>::default();
data_key.copy_from_slice(&hkdf_out[HeaderMac::LEN..]);
(digest, data_key)
}
impl TryFrom<u8> for FrameLength {
type Error = Error;
fn try_from(val: u8) -> Result<Self, Error> {
match val {
12 => Ok(FrameLength::Len4KiB),
13 => Ok(FrameLength::Len8KiB),
14 => Ok(FrameLength::Len16KiB),
15 => Ok(FrameLength::Len32KiB),
_ => Err(Error::new(ErrorKind::InvalidFrameLength(val))),
}
}
}
impl From<FrameLength> for u8 {
fn from(value: FrameLength) -> Self {
value as u8
}
}
impl fmt::Display for FrameLength {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.as_usize())
}
}
impl FrameLength {
pub fn as_usize(self) -> usize {
1 << (self as u8)
}
}
impl FrameBuf {
pub fn new(header: &Header) -> Self {
let frame_len = header.frame_len.as_usize();
let max_payload_len = frame_len - FRAME_META_LEN;
let cipher = AesKeyInit::new(&header.data_key);
Self {
buf: Vec::with_capacity(frame_len),
frame_len,
max_payload_len,
max_payload_pos: PAYLOAD_OFFSET + max_payload_len,
payload_len: 0,
cipher,
}
}
pub fn write_payload(&mut self, payload_off: usize, payload: &[u8]) -> Result<usize, Error> {
if payload_off > self.payload_len {
return Err(Error::new(ErrorKind::InvalidArgument));
}
let abs_payload_off = PAYLOAD_OFFSET + payload_off;
let buf_len = usize::min(self.max_payload_pos, abs_payload_off + payload.len());
if buf_len > self.buf.len() {
self.buf.resize(buf_len, 0);
}
let copy_len = buf_len - abs_payload_off;
self.buf[abs_payload_off..abs_payload_off + copy_len].copy_from_slice(&payload[..copy_len]);
self.payload_len = usize::max(self.payload_len, payload_off + copy_len);
Ok(copy_len)
}
pub fn payload(&self) -> &[u8] {
if self.buf.len() < PAYLOAD_OFFSET {
&self.buf[..0]
} else {
&self.buf[PAYLOAD_OFFSET..PAYLOAD_OFFSET + self.payload_len]
}
}
pub fn has_payload_capacity(&self) -> bool {
self.payload_capacity() > 0
}
pub fn payload_capacity(&self) -> usize {
self.max_payload_len - self.payload_len
}
pub fn encrypt(&mut self, frame_header: &FrameHeader) {
if self.buf.len() < FRAME_HEADER_LEN {
self.buf.resize(FRAME_HEADER_LEN, 0);
}
debug_assert!(self.payload_len <= self.buf.len() - FRAME_HEADER_LEN);
let seq_num_bytes = frame_header.seq_num().to_le_bytes();
self.set_bytes(seq_num_bytes.as_slice(), SEQ_NUM_OFFSET);
let invocation_bytes = frame_header.invocation().to_le_bytes();
self.set_bytes(invocation_bytes.as_slice(), INVOCATION_OFFSET);
let eof_len_bytes = if frame_header.is_end() {
u32::try_from(self.payload_len)
.expect("payload len should be 4 bytes")
.to_le_bytes()
} else {
u32::MAX.to_le_bytes()
};
self.set_bytes(eof_len_bytes.as_slice(), END_LEN_OFFSET);
let (nonce, frame) = self.buf.split_at_mut(FRAME_NONCE_LEN);
let (eof_len, payload) = frame.split_at_mut(END_LEN);
let nonce = AesNonce::<FrameNonceLen>::from_slice(nonce);
let tag = self
.cipher
.encrypt_in_place_detached(nonce, eof_len, &mut payload[..self.payload_len])
.expect("buffer of sufficient size");
self.buf.truncate(self.payload_len + FRAME_HEADER_LEN);
self.buf.extend_from_slice(&tag);
}
pub fn decrypt(&mut self, seq_num: u32) -> Result<FrameHeader, Error> {
if self.buf.len() < FRAME_META_LEN {
return Err(Error::new(ErrorKind::InvalidBufLength));
}
let (nonce, frame) = self.buf.split_at_mut(FRAME_NONCE_LEN);
let (eof_len_bytes, frame) = frame.split_at_mut(END_LEN);
let eof_len =
u32::from_le_bytes(eof_len_bytes.try_into().expect("eof len should be 4 bytes"));
let (payload_len, is_end) = if eof_len != u32::MAX {
if eof_len as usize > self.max_payload_len {
return Err(Error::new(ErrorKind::InvalidEndLength(eof_len)));
}
(eof_len as usize, true)
} else {
(self.frame_len - FRAME_META_LEN, false)
};
let body_len = payload_len + FRAME_TAG_LEN;
if frame.len() < body_len {
return Err(Error::new(ErrorKind::InvalidEndLength(eof_len)));
}
let (payload, mac) = frame.split_at_mut(payload_len);
let tag = Tag::from_slice(&mac[..FRAME_TAG_LEN]);
let nonce = AesNonce::from_slice(nonce);
self.cipher
.decrypt_in_place_detached(nonce, eof_len_bytes, payload, tag)?;
let seq_num_decoded = u32::from_le_bytes(
nonce[..SEQ_NUM_LEN]
.try_into()
.expect("seq num should be 4 bytes"),
);
if seq_num != seq_num_decoded {
return Err(Error::new(ErrorKind::UnexpectedSeqNum(
seq_num,
seq_num_decoded,
)));
}
let invocation = u64::from_le_bytes(
nonce[SEQ_NUM_LEN..]
.try_into()
.expect("invocation should be 8 bytes"),
);
self.payload_len = payload_len;
Ok(FrameHeader::new(seq_num, invocation, is_end))
}
pub fn clear(&mut self) {
self.buf.clear();
self.payload_len = 0;
}
pub fn is_empty(&self) -> bool {
self.buf.is_empty()
}
pub fn len(&self) -> usize {
self.buf.len()
}
pub fn copy_from_encrypted_bytes(&mut self, src: &[u8]) -> usize {
let len = usize::min(src.len(), self.frame_len);
self.buf.resize(len, 0);
self.payload_len = 0;
self.buf[..len].copy_from_slice(&src[..len]);
len
}
pub fn chunk_mut(&mut self) -> &mut [u8] {
self.clear_resize_to_full();
&mut self.buf
}
pub fn commit_chunk_mut(&mut self, len: usize) -> Result<(), Error> {
if len > self.buf.len() {
return Err(Error::new(ErrorKind::InvalidBufLength));
}
self.buf.truncate(len);
self.payload_len = 0;
Ok(())
}
#[cfg(any(feature = "std", test))]
fn is_partial(&self) -> bool {
self.buf.len() < FRAME_HEADER_LEN
}
fn set_bytes(&mut self, bytes: &[u8], offset: usize) {
self.buf[offset..offset + bytes.len()].copy_from_slice(bytes)
}
fn clear_resize_to_full(&mut self) {
self.buf.clear();
self.buf.resize(self.frame_len, 0);
self.payload_len = 0;
}
}
impl AsRef<[u8]> for FrameBuf {
fn as_ref(&self) -> &[u8] {
&self.buf
}
}
impl core::ops::Deref for FrameBuf {
type Target = [u8];
fn deref(&self) -> &[u8] {
self.as_ref()
}
}
impl Header {
pub fn from_bytes(parent_key: &ParentKey, bytes: HeaderBytes) -> Result<Self, Error> {
let mut byte_buf = ByteCursor::new(&bytes);
let magic_num = byte_buf.get_u32_le();
if magic_num != MAGIC_NUM {
return Err(Error::new(ErrorKind::InvalidMagicNumber(magic_num)));
}
let version = byte_buf.get_u8();
if version != VERSION {
return Err(Error::new(ErrorKind::UnsupportedVersion(version)));
}
let algo = byte_buf.get_u16_le();
if algo != CryptoAlgorithm::Aes256GcmHkdfSha256 as u16 {
return Err(Error::new(ErrorKind::UnsupportedCrypto(algo)));
}
if &bytes[HEADER_KEY_ID_RANGE] != parent_key.id().as_slice() {
return Err(Error::new(ErrorKind::ParentKeyIdMismatch));
}
let frame_len = FrameLength::try_from(byte_buf.get_u8())?;
let info = &bytes.as_slice()[HKDF_INFO_RANGE];
let salt = &bytes.as_slice()[HKDF_SALT_RANGE];
let expected_mac = &bytes.as_slice()[HEADER_MAC_RANGE];
let (header_mac, data_key) = derive_data_key(parent_key, salt, info);
if header_mac.as_ref() != expected_mac {
return Err(Error::new(ErrorKind::Authentication));
};
Ok(Self {
frame_len,
data_key,
bytes,
})
}
pub fn bytes(&self) -> &HeaderBytes {
&self.bytes
}
}
#[cfg(feature = "zeroize")]
impl Drop for Header {
fn drop(&mut self) {
self.data_key.zeroize();
}
}
impl<'a> HeaderBuilder<'a> {
pub fn new(parent_key: &'a ParentKey, nonce: &'a HeaderNonce) -> Self {
Self {
parent_key,
nonce,
frame_len: Default::default(),
}
}
pub fn with_frame_len(mut self, len: FrameLength) -> Self {
self.frame_len = len;
self
}
pub fn build(self) -> Header {
let bytes = HeaderBytes::default();
let mut cur = ByteCursorMut::new(bytes);
cur.push_u32_le(MAGIC_NUM);
cur.push_u8(VERSION);
cur.push_u16_le(CryptoAlgorithm::Aes256GcmHkdfSha256 as u16);
cur.push_u8(self.frame_len.into());
cur.push_bytes(&[0u8; RESERVED_LEN]);
cur.push_bytes(self.nonce);
cur.push_bytes(self.parent_key.id());
let mut bytes = cur.into_inner();
let info = &bytes[HKDF_INFO_RANGE];
let salt = &bytes[HKDF_SALT_RANGE];
let (header_mac, data_key) = derive_data_key(self.parent_key, salt, info);
bytes.as_mut()[HEADER_MAC_OFFSET..].copy_from_slice(&header_mac);
Header {
frame_len: self.frame_len,
data_key,
bytes,
}
}
}
impl FrameHeader {
fn new(seq_num: u32, invocation: u64, is_end: bool) -> Self {
Self {
seq_num,
invocation,
is_end,
}
}
pub fn seq_num(&self) -> u32 {
self.seq_num
}
pub fn invocation(&self) -> u64 {
self.invocation
}
pub fn is_end(&self) -> bool {
self.is_end
}
}
impl FrameHeaderBuilder {
pub fn new(seq_num: u32) -> Self {
Self {
seq_num,
invocation: 0,
is_end: false,
}
}
pub fn end(mut self) -> Self {
self.is_end = true;
self
}
pub fn invocation(mut self, invocation: u64) -> Self {
self.invocation = invocation;
self
}
pub fn build(self) -> FrameHeader {
FrameHeader::new(self.seq_num, self.invocation, self.is_end)
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<T> ZymicStream<T> {
pub fn new(inner: T, header: &Header) -> Self {
Self::new_with_seq_num(inner, header, 0)
}
pub fn new_with_seq_num(inner: T, header: &Header, seq_num: u32) -> Self {
let frame_buf = FrameBuf::new(header);
Self {
seq_num,
start_seq_num: seq_num,
invocation: 0,
payload_pos: 0,
end_len: None,
frame_buf,
inner,
}
}
pub fn into_inner(self) -> T {
self.inner
}
pub fn is_eof(&self) -> bool {
self.end_len
.map_or_else(|| false, |end_len| self.payload_pos == end_len)
}
pub fn is_eof_or_err(&self) -> Result<(), Error> {
if self.is_eof() {
Ok(())
} else {
Err(Error::new(ErrorKind::Truncation))
}
}
#[inline]
fn frame_idx_to_frame_off(&self, frame_idx: u32) -> Result<u64, Error> {
let frame_off = (frame_idx as u64)
.checked_mul(self.frame_buf.frame_len as u64)
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
Ok(frame_off)
}
#[inline]
fn byte_off_to_frame_idx(&self, abs_off: u64) -> Result<u32, Error> {
let frame_idx = abs_off / self.frame_buf.frame_len as u64;
Ok(u32::try_from(frame_idx)?)
}
#[inline]
fn payload_off_to_frame_idx(&self, payload_offset: u64) -> Result<u32, Error> {
let frame_idx = payload_offset / self.frame_buf.max_payload_len as u64;
Ok(u32::try_from(frame_idx)?)
}
#[inline]
fn payload_off_to_frame_off(&self, payload_offset: u64) -> Result<u64, Error> {
let frame_idx = payload_offset / self.frame_buf.max_payload_len as u64;
let frame_off = frame_idx
.checked_mul(self.frame_buf.frame_len as u64)
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
Ok(frame_off)
}
#[inline]
fn current_payload_off(&self) -> Result<u64, Error> {
let frame_off = self.current_frame_idx();
let abs_payload_off = (frame_off as usize)
.checked_mul(self.frame_buf.max_payload_len)
.and_then(|v| v.checked_add(self.payload_pos))
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
Ok(abs_payload_off as u64)
}
#[inline]
fn payload_end_off(&self) -> Result<u64, Error> {
let frame_off = self.current_frame_idx();
let abs_payload_len = (frame_off as usize)
.checked_mul(self.frame_buf.max_payload_len)
.and_then(|v| v.checked_add(self.frame_buf.payload_len.saturating_sub(1)))
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
Ok(abs_payload_len as u64)
}
#[inline]
fn current_frame_idx(&self) -> u32 {
self.seq_num - self.start_seq_num
}
#[inline]
fn frame_payload_remaining(&self) -> usize {
self.frame_buf.payload_len - self.payload_pos
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<T> ZymicStream<T> {}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<T: Write> ZymicStream<T> {
pub fn eof(&mut self) -> Result<(), Error> {
self.frame_buf.encrypt(
&FrameHeaderBuilder::new(self.seq_num)
.invocation(self.invocation)
.end()
.build(),
);
self.inner.write_all(self.frame_buf.as_ref())?;
self.inner.flush()?;
let len = self.frame_buf.payload_len;
self.end_len = Some(len);
self.payload_pos = len;
Ok(())
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<T: Read> ZymicStream<T> {
fn read_next_frame(&mut self) -> Result<bool, Error> {
self.frame_buf.clear_resize_to_full();
let mut buf = self.frame_buf.chunk_mut();
let mut total_len = 0;
while !buf.is_empty() {
let len = self.inner.read(buf)?;
if len == 0 {
break;
}
buf = &mut buf[len..];
total_len += len;
}
self.frame_buf.commit_chunk_mut(total_len)?;
if total_len == 0 {
return Ok(false);
}
if self.frame_buf.is_partial() {
return Err(Error::new(ErrorKind::UnexpectedEof));
}
let frame_header = self.frame_buf.decrypt(self.seq_num)?;
self.invocation = frame_header
.invocation()
.checked_add(1)
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
self.end_len = frame_header.is_end().then_some(self.frame_buf.payload_len);
self.payload_pos = 0;
Ok(true)
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<T: Read> Read for ZymicStream<T> {
fn read(&mut self, mut buf: &mut [u8]) -> Result<usize, std::io::Error> {
let mut total_len = 0;
while !buf.is_empty() && !self.is_eof() {
if self.frame_payload_remaining() == 0 && self.read_next_frame()? {
self.seq_num = self
.seq_num
.checked_add(1)
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
}
let remaining = self.frame_payload_remaining();
if remaining == 0 {
break;
}
let payload = self.frame_buf.payload();
if payload.is_empty() {
break;
}
let len = usize::min(remaining, buf.len());
buf[..len].copy_from_slice(&payload[self.payload_pos..self.payload_pos + len]);
buf = &mut buf[len..];
self.payload_pos += len;
total_len += len;
}
Ok(total_len)
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<T: Write> Write for ZymicStream<T> {
fn write(&mut self, mut buf: &[u8]) -> Result<usize, std::io::Error> {
let mut total_len = 0;
while !buf.is_empty() {
if !self.frame_buf.has_payload_capacity() {
self.frame_buf.encrypt(
&FrameHeaderBuilder::new(self.seq_num)
.invocation(self.invocation)
.build(),
);
self.inner.write_all(self.frame_buf.as_ref())?;
self.frame_buf.clear();
self.seq_num = self
.seq_num
.checked_add(1)
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
self.invocation = 0;
self.payload_pos = 0;
}
let len = self.frame_buf.write_payload(self.payload_pos, buf)?;
buf = &buf[len..];
self.payload_pos += len;
total_len += len;
}
Ok(total_len)
}
fn flush(&mut self) -> Result<(), std::io::Error> {
Ok(())
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<T: Seek + Read> ZymicStream<T> {
fn seek_to_payload_off(&mut self, payload_off: u64) -> Result<(), Error> {
let frame_off = self.payload_off_to_frame_off(payload_off)?;
self.inner.seek(SeekFrom::Start(frame_off))?;
self.seq_num = self
.payload_off_to_frame_idx(payload_off)?
.checked_add(self.start_seq_num)
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
if !self.read_next_frame()? {
return Err(Error::new(ErrorKind::UnexpectedEof));
}
self.inner.seek(SeekFrom::Start(frame_off))?;
let payload_off = payload_off as usize % self.frame_buf.max_payload_len;
if payload_off < self.frame_buf.payload_len || payload_off == 0 {
self.payload_pos = payload_off;
} else {
return Err(Error::new(ErrorKind::UnexpectedEof));
}
Ok(())
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<T: Seek + Read + Write> Seek for ZymicStream<T> {
fn seek(&mut self, pos: SeekFrom) -> Result<u64, std::io::Error> {
let payload_off = match pos {
SeekFrom::Start(payload_off) => {
self.seek_to_payload_off(payload_off)?;
payload_off
}
SeekFrom::End(payload_off) => {
if payload_off > 0 {
return Err(Error::new(ErrorKind::UnexpectedEof).into());
}
let abs_end = self.inner.seek(SeekFrom::End(0))?.saturating_sub(1);
let end_frame_idx = self.byte_off_to_frame_idx(abs_end)?;
let end_frame_off = self.frame_idx_to_frame_off(end_frame_idx)?;
self.inner.seek(SeekFrom::Start(end_frame_off))?;
self.seq_num = end_frame_idx
.checked_add(self.start_seq_num)
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
if !self.read_next_frame()? {
return Err(Error::new(ErrorKind::UnexpectedEof).into());
}
let payload_end_off = self.payload_end_off()?;
let abs_payload_len = payload_end_off
.checked_add(u64::from(payload_end_off > 0))
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
if payload_off == 0 {
let inner_seek_off = abs_payload_len.saturating_sub(1);
self.seek_to_payload_off(inner_seek_off)?;
abs_payload_len
} else {
let abs_payload_len = i64::try_from(abs_payload_len)
.map_err(|e| Error::new(ErrorKind::TryFromInt(e)))?;
let inner_seek_off = abs_payload_len
.checked_add(payload_off)
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
if inner_seek_off < 0 {
return Err(std::io::Error::from(std::io::ErrorKind::InvalidInput));
}
self.seek_to_payload_off(inner_seek_off as u64)?;
inner_seek_off as u64
}
}
SeekFrom::Current(payload_off) => {
let current_abs_payload_off = i64::try_from(self.current_payload_off()?)
.map_err(|e| Error::new(ErrorKind::TryFromInt(e)))?;
let new_abs_payload_off = payload_off
.checked_add(current_abs_payload_off)
.ok_or(Error::new(ErrorKind::IntegerOverflow))?;
if new_abs_payload_off < 0 {
return Err(std::io::Error::from(std::io::ErrorKind::InvalidInput));
}
self.seek_to_payload_off(new_abs_payload_off as u64)?;
new_abs_payload_off as u64
}
};
Ok(payload_off)
}
}
#[cfg(test)]
mod tests {
use super::{
Aes256Gcm, CryptoAlgorithm, FrameBuf, FrameHeader, FrameHeaderBuilder, FrameLength, Header,
HeaderBuilder, HeaderNonce, ALGO_OFFSET, END_LEN_OFFSET, FRAME_HEADER_LEN, FRAME_LEN_LEN,
FRAME_LEN_OFFSET, FRAME_META_LEN, FRAME_TAG_LEN, KEY_ID_OFFSET, MAGIC_NUM, NONCE_OFFSET,
PAYLOAD_OFFSET, RESERVED_LEN, RESERVED_OFFSET, VERSION, VERSION_OFFSET,
};
use crate::{
byte_array,
bytes::ByteCursor,
error::ErrorKind,
key::{ParentKey, ParentKeyId, ParentKeySecret},
};
use alloc::{vec, vec::Vec};
#[cfg(feature = "std")]
use super::ZymicStream;
#[cfg(feature = "std")]
use crate::error::Error;
#[cfg(feature = "std")]
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
const TEST_NONCE: HeaderNonce = byte_array![3u8; {HeaderNonce::LEN}];
fn mock_parent_key() -> ParentKey {
const ID: ParentKeyId = byte_array![1u8; {ParentKeyId::LEN}];
const SECRET: ParentKeySecret = byte_array![2u8; {ParentKeySecret::LEN}];
let id = ParentKeyId::from(ID);
let secret = ParentKeySecret::from(SECRET);
ParentKey::new(id, secret)
}
fn entropy(bytes: &[u8]) -> f64 {
let mut hist = [0u32; 256];
for b in bytes.iter() {
hist[*b as usize] += 1;
}
hist.iter()
.filter(|v| **v > 0)
.map(|v| {
let p = *v as f64 / bytes.len() as f64;
-p * p.log2()
})
.sum()
}
fn validate_frame_bytes(frame: &[u8], metadata: &FrameHeader) {
let mut frame_buf = ByteCursor::new(frame);
let seq = frame_buf.get_u32_le();
assert_eq!(seq, metadata.seq_num());
let invocation = frame_buf.get_u64_le();
assert_eq!(invocation, metadata.invocation());
let eof_len = frame_buf.get_u32_le();
assert!(
(!metadata.is_end() && eof_len == u32::MAX)
|| (metadata.is_end() && eof_len < u32::MAX)
);
let payload_len = frame.len() - FRAME_META_LEN;
if eof_len != u32::MAX {
assert_eq!(payload_len, eof_len as usize);
}
assert_eq!(FRAME_TAG_LEN + payload_len, frame_buf.remaining());
}
fn validate_header(header: &[u8], algo: CryptoAlgorithm, frame_len: FrameLength) {
let mut header_cur = ByteCursor::new(header);
let magic = header_cur.get_u32_le();
assert_eq!(MAGIC_NUM, magic);
let version = header_cur.get_u8();
assert_eq!(VERSION, version);
let algo_val = header_cur.get_u16_le();
assert_eq!(algo as u16, algo_val);
let len_val = header_cur.get_u8();
assert_eq!(frame_len as u8, len_val);
for val in &header[RESERVED_OFFSET..RESERVED_OFFSET + RESERVED_LEN] {
assert_eq!(0, *val)
}
assert_eq!(
TEST_NONCE,
(&header[NONCE_OFFSET..NONCE_OFFSET + HeaderNonce::LEN]).into()
);
let parent_key = mock_parent_key();
assert_eq!(
parent_key.id().as_array(),
&header[KEY_ID_OFFSET..KEY_ID_OFFSET + ParentKeyId::LEN]
);
}
fn validate_framebuf(
frame_buf: &FrameBuf,
expected_payload_len: usize,
expected_frame_len: usize,
) {
if expected_payload_len > 0 {
assert!(!frame_buf.is_empty());
assert!(!frame_buf.is_partial())
} else {
assert!(frame_buf.is_empty());
assert!(frame_buf.is_partial());
}
let payload = frame_buf.payload();
assert_eq!(payload.len(), expected_payload_len);
assert_eq!(frame_buf.payload_len, expected_payload_len);
assert_eq!(frame_buf.frame_len, expected_frame_len);
assert_eq!(
frame_buf.max_payload_pos,
PAYLOAD_OFFSET + frame_buf.max_payload_len
);
assert_eq!(
frame_buf.payload_capacity(),
frame_buf.max_payload_len - expected_payload_len,
);
}
#[cfg(feature = "std")]
fn validate_stream_body(stream_body: &[u8], plain_txt_len: usize, frame_len: FrameLength) {
let payload_chunk_len = frame_len.as_usize() - FRAME_META_LEN;
let frame_count = plain_txt_len.div_ceil(payload_chunk_len);
let expected_len = plain_txt_len + FRAME_META_LEN * frame_count;
assert_eq!(expected_len, stream_body.len());
let max_seq_num = frame_count - 1;
for (seq_num, frame) in stream_body.chunks(frame_len.as_usize()).enumerate() {
let is_end = seq_num == max_seq_num;
let metadata = FrameHeader::new(seq_num.try_into().unwrap(), 0, is_end);
validate_frame_bytes(frame, &metadata);
}
}
#[cfg(feature = "std")]
fn swap_frames(
stream_body: &mut [u8],
frame_len: FrameLength,
frame_idx_1: usize,
frame_idx_2: usize,
) {
let frame_1 = stream_body
.chunks(frame_len.as_usize())
.nth(frame_idx_1)
.unwrap()
.to_vec();
let frame_2 = stream_body
.chunks(frame_len.as_usize())
.nth(frame_idx_2)
.unwrap()
.to_vec();
let frame = stream_body
.chunks_mut(frame_len.as_usize())
.nth(frame_idx_2)
.unwrap();
frame.copy_from_slice(&frame_1);
let frame = stream_body
.chunks_mut(frame_len.as_usize())
.nth(frame_idx_1)
.unwrap();
frame.copy_from_slice(&frame_2);
}
#[cfg(feature = "std")]
fn payload_from_frame_count(frame_count: u32, frame_len: FrameLength) -> Vec<u8> {
let plain_txt_len =
frame_count as usize * frame_len.as_usize() - FRAME_META_LEN * frame_count as usize;
vec![0u8; plain_txt_len]
}
#[cfg(feature = "std")]
fn stream_io_copy(alignment: usize) {
use std::io::Cursor;
let frame_len = FrameLength::Len4KiB;
let max_plain_txt_len = frame_len.as_usize() * 4;
let mut plain_txt_len = alignment;
let parent_key = mock_parent_key();
while plain_txt_len < max_plain_txt_len {
let expected_plain_txt = vec![0xffu8; plain_txt_len];
let mut plain_txt_reader = Cursor::new(expected_plain_txt);
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE)
.with_frame_len(frame_len)
.build();
let mut zym_writer = ZymicStream::new(Vec::default(), &header);
std::io::copy(&mut plain_txt_reader, &mut zym_writer).unwrap();
zym_writer.eof().unwrap();
let cipher_txt = zym_writer.into_inner();
validate_stream_body(&cipher_txt, plain_txt_len, frame_len);
let mut zym_reader = ZymicStream::new(Cursor::new(cipher_txt), &header);
let mut plain_txt = Vec::default();
std::io::copy(&mut zym_reader, &mut plain_txt).unwrap();
assert!(zym_reader.is_eof());
let expected_plain_txt = plain_txt_reader.into_inner();
assert_eq!(expected_plain_txt, plain_txt);
plain_txt_len += alignment;
}
}
#[test]
fn header_format() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let bytes = header.bytes();
validate_header(
bytes,
CryptoAlgorithm::Aes256GcmHkdfSha256,
FrameLength::default(),
);
}
#[test]
fn header_default_frame_len() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
assert_eq!(FrameLength::default(), header.frame_len);
let empty_data_key = aes_gcm::Key::<Aes256Gcm>::default();
assert_ne!(empty_data_key, header.data_key);
}
#[test]
fn header_explicit_frame_len() {
let frame_len = FrameLength::Len32KiB;
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE)
.with_frame_len(frame_len)
.build();
assert_eq!(frame_len, header.frame_len);
let empty_data_key = aes_gcm::Key::<Aes256Gcm>::default();
assert_ne!(empty_data_key, header.data_key);
}
#[test]
fn header_from_bytes() {
let parent_key = mock_parent_key();
let expected_header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let bytes = expected_header.bytes();
let header = Header::from_bytes(&parent_key, bytes.clone()).unwrap();
assert_eq!(expected_header, header);
}
#[test]
fn header_from_bytes_err() {
let parent_key = mock_parent_key();
let expected_header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let bytes = expected_header.bytes();
let bad_parent_key = ParentKey::new(parent_key.id().clone(), ParentKeySecret::default());
if let Err(e) = Header::from_bytes(&bad_parent_key, bytes.clone()) {
assert_eq!(*e.kind(), ErrorKind::Authentication)
} else {
panic!("expected an error")
}
}
#[test]
fn header_key_id_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let wrong_key = ParentKey::default();
if let Err(e) = Header::from_bytes(&wrong_key, header.bytes().clone()) {
assert_eq!(*e.kind(), ErrorKind::ParentKeyIdMismatch)
} else {
panic!("expected an error")
}
}
#[test]
fn header_magic_num_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut header_bytes = header.bytes().clone();
header_bytes[0] = 0;
if let Err(e) = Header::from_bytes(&parent_key, header_bytes) {
assert!(matches!(e.kind(), ErrorKind::InvalidMagicNumber(_)))
} else {
panic!("expected an error")
}
}
#[test]
fn header_version_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut header_bytes = header.bytes().clone();
header_bytes[VERSION_OFFSET] = 0xff;
if let Err(e) = Header::from_bytes(&parent_key, header_bytes) {
assert!(matches!(e.kind(), ErrorKind::UnsupportedVersion(0xff)))
} else {
panic!("expected an error")
}
}
#[test]
fn header_algo_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut header_bytes = header.bytes().clone();
header_bytes[ALGO_OFFSET] = 0xff;
header_bytes[ALGO_OFFSET + 1] = 0xff;
if let Err(e) = Header::from_bytes(&parent_key, header_bytes) {
assert!(matches!(e.kind(), ErrorKind::UnsupportedCrypto(0xffff)))
} else {
panic!("expected an error")
}
}
#[test]
fn header_frame_len_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut header_bytes = header.bytes().clone();
for i in FRAME_LEN_OFFSET..FRAME_LEN_OFFSET + FRAME_LEN_LEN {
header_bytes[i] = 0xff;
}
if let Err(e) = Header::from_bytes(&parent_key, header_bytes) {
assert!(matches!(e.kind(), ErrorKind::InvalidFrameLength(_)))
} else {
panic!("expected an error")
}
}
#[test]
fn header_nonce_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut header_bytes = header.bytes().clone();
for i in NONCE_OFFSET..NONCE_OFFSET + HeaderNonce::LEN {
header_bytes[i] = !header_bytes[i]
}
if let Err(e) = Header::from_bytes(&parent_key, header_bytes) {
assert!(matches!(e.kind(), ErrorKind::Authentication))
} else {
panic!("expected an error")
}
}
#[test]
fn framebuf_new() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let frame_buf = FrameBuf::new(&header);
validate_framebuf(&frame_buf, 0, header.frame_len.as_usize());
}
#[test]
fn framebuf_write_payload() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt).unwrap();
assert_eq!(len, plain_txt.len());
validate_framebuf(&frame_buf, plain_txt.len(), header.frame_len.as_usize());
}
#[test]
fn framebuf_write_payload_inline() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt_1 = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt_1).unwrap();
assert_eq!(len, plain_txt_1.len());
validate_framebuf(&frame_buf, plain_txt_1.len(), header.frame_len.as_usize());
let plain_txt_2 = vec![6, 7];
let len = frame_buf.write_payload(2, &plain_txt_2).unwrap();
assert_eq!(len, plain_txt_2.len());
validate_framebuf(&frame_buf, plain_txt_1.len(), header.frame_len.as_usize());
let payload = frame_buf.payload();
assert_eq!(payload, vec![1, 2, 6, 7, 5]);
}
#[test]
fn framebuf_write_payload_extend() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt_1 = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt_1).unwrap();
assert_eq!(len, plain_txt_1.len());
validate_framebuf(&frame_buf, plain_txt_1.len(), header.frame_len.as_usize());
let plain_txt_2 = vec![6, 7, 8, 9, 10, 11, 12];
let len = frame_buf.write_payload(2, &plain_txt_2).unwrap();
assert_eq!(len, plain_txt_2.len());
validate_framebuf(&frame_buf, 9, header.frame_len.as_usize());
let payload = frame_buf.payload();
assert_eq!(payload, vec![1, 2, 6, 7, 8, 9, 10, 11, 12]);
}
#[test]
fn framebuf_write_payload_append() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt_1 = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt_1).unwrap();
assert_eq!(len, plain_txt_1.len());
validate_framebuf(&frame_buf, plain_txt_1.len(), header.frame_len.as_usize());
let plain_txt_2 = vec![6, 7, 8, 9, 10];
let len = frame_buf.write_payload(5, &plain_txt_2).unwrap();
assert_eq!(len, plain_txt_2.len());
validate_framebuf(
&frame_buf,
plain_txt_1.len() + plain_txt_2.len(),
header.frame_len.as_usize(),
);
let payload = frame_buf.payload();
assert_eq!(payload, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
}
#[test]
#[should_panic]
fn framebuf_write_payload_panic() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
if let Err(e) = frame_buf.write_payload(100, &plain_txt) {
assert!(matches!(e.kind(), ErrorKind::InvalidBufLength))
} else {
panic!("expecting an error")
}
}
#[test]
fn framebuf_encrypt_lt_capacity() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt).unwrap();
assert_eq!(len, plain_txt.len());
validate_framebuf(&frame_buf, plain_txt.len(), header.frame_len.as_usize());
let frame_header = FrameHeader::new(1, 2, true);
frame_buf.encrypt(&frame_header);
validate_frame_bytes(frame_buf.as_ref(), &frame_header);
}
#[test]
fn framebuf_encrypt_eq_capacity() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt_len = header.frame_len.as_usize() - FRAME_META_LEN;
let plain_txt = vec![0u8; plain_txt_len];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt).unwrap();
assert_eq!(len, plain_txt.len());
validate_framebuf(&frame_buf, plain_txt.len(), header.frame_len.as_usize());
let frame_header = FrameHeader::new(1, 2, true);
frame_buf.encrypt(&frame_header);
validate_frame_bytes(frame_buf.as_ref(), &frame_header);
}
#[test]
fn framebuf_encrypt_gt_capacity() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt_frame_len = header.frame_len.as_usize() - FRAME_META_LEN;
let plain_txt = vec![0u8; plain_txt_frame_len * 2];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt).unwrap();
assert_eq!(len, plain_txt_frame_len);
validate_framebuf(&frame_buf, plain_txt_frame_len, header.frame_len.as_usize());
let frame_header = FrameHeader::new(1, 2, true);
frame_buf.encrypt(&frame_header);
validate_frame_bytes(frame_buf.as_ref(), &frame_header);
}
#[test]
fn framebuf_encrypt_empty_payload() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut frame_buf = FrameBuf::new(&header);
let frame_header = FrameHeader::new(1, 2, true);
frame_buf.encrypt(&frame_header);
validate_frame_bytes(frame_buf.as_ref(), &frame_header);
let payload = frame_buf.payload();
assert!(payload.is_empty());
}
#[test]
#[should_panic]
fn framebuf_encrypt_panic() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let frame_header = FrameHeader::new(1, 2, true);
let mut frame_buf = FrameBuf::new(&header);
frame_buf.payload_len = 1 << 31;
frame_buf.encrypt(&frame_header);
}
#[test]
fn framebuf_clear() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
frame_buf.write_payload(0, &plain_txt).unwrap();
frame_buf.clear();
validate_framebuf(&frame_buf, 0, header.frame_len.as_usize());
}
#[test]
fn framebuf_clear_resize_to_full() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut frame_buf = FrameBuf::new(&header);
frame_buf.clear_resize_to_full();
assert!(!frame_buf.is_empty());
assert!(!frame_buf.is_partial());
let payload = frame_buf.payload();
assert!(payload.is_empty());
assert_eq!(0, frame_buf.payload_len);
assert_eq!(
header.frame_len.as_usize() - FRAME_META_LEN,
frame_buf.payload_capacity()
);
}
#[test]
fn framebuf_decrypt_in_place() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt).unwrap();
assert_eq!(len, plain_txt.len());
let frame_header = FrameHeader::new(1, 2, true);
frame_buf.encrypt(&frame_header);
validate_frame_bytes(frame_buf.as_ref(), &frame_header);
frame_buf.decrypt(1).unwrap();
let payload = frame_buf.payload();
assert_eq!(payload, plain_txt);
}
#[test]
fn framebuf_decrypt_from_copy() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt).unwrap();
assert_eq!(len, plain_txt.len());
let frame_header = FrameHeader::new(1, 2, true);
frame_buf.encrypt(&frame_header);
validate_frame_bytes(frame_buf.as_ref(), &frame_header);
let mut frame_buf_2 = FrameBuf::new(&header);
let len = frame_buf_2.copy_from_encrypted_bytes(frame_buf.as_ref());
assert_eq!(len, frame_buf.as_ref().len());
frame_buf_2.decrypt(1).unwrap();
let payload = frame_buf_2.payload();
assert_eq!(payload, plain_txt);
}
#[test]
fn framebuf_copy_from_encrypted_bytes() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let data = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.copy_from_encrypted_bytes(&data);
assert_eq!(len, data.len());
let data = vec![0u8; header.frame_len.as_usize() + 1];
let len = frame_buf.copy_from_encrypted_bytes(&data);
assert_eq!(len, header.frame_len.as_usize());
}
#[test]
fn framebuf_decrypt_empty_payload() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut frame_buf = FrameBuf::new(&header);
let frame_header = FrameHeader::new(1, 2, true);
frame_buf.encrypt(&frame_header);
frame_buf.decrypt(1).unwrap();
let payload = frame_buf.payload();
assert!(payload.is_empty());
}
#[test]
fn framebuf_entropy() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut frame_buf = FrameBuf::new(&header);
let payload_len: usize = 1 << 22;
let payload_chunk_len = header.frame_len.as_usize() - FRAME_META_LEN;
let frame_count = payload_len.div_ceil(payload_chunk_len);
let plain_txt = vec![0u8; payload_chunk_len];
let mut payload = Vec::with_capacity(payload_len);
let seq_num = 0;
for _ in 0..frame_count - 1 {
frame_buf.write_payload(0, &plain_txt).unwrap();
assert!(!frame_buf.has_payload_capacity());
let metadata = FrameHeaderBuilder::new(seq_num).build();
frame_buf.encrypt(&metadata);
payload.extend_from_slice(frame_buf.payload());
}
let entropy = entropy(&payload);
assert_eq!(f64::round(entropy), 8.0);
}
#[test]
fn framebuf_decrypt_empty_buf_panic() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut frame_buf = FrameBuf::new(&header);
if let Err(e) = frame_buf.decrypt(0) {
assert!(matches!(e.kind(), ErrorKind::InvalidBufLength));
} else {
panic!("expected an error");
}
}
#[test]
fn framebuf_decrypt_end_len_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt).unwrap();
assert_eq!(len, plain_txt.len());
let frame_header = FrameHeader::new(1, 2, true);
frame_buf.encrypt(&frame_header);
let bad_len: u32 = 1 << 31;
let bad_len_bytes = bad_len.to_le_bytes();
frame_buf.buf[END_LEN_OFFSET..END_LEN_OFFSET + bad_len_bytes.len()]
.copy_from_slice(&bad_len_bytes);
if let Err(e) = frame_buf.decrypt(1) {
assert!(matches!(e.kind(), ErrorKind::InvalidEndLength(_)));
} else {
panic!("expected an error");
}
}
#[test]
fn framebuf_decrypt_truncate() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut frame_buf = FrameBuf::new(&header);
let header = FrameHeader::new(1, 2, true);
frame_buf.write_payload(0, &[0u8; 16]).unwrap();
frame_buf.encrypt(&header);
let keep = FRAME_HEADER_LEN + FRAME_TAG_LEN;
frame_buf.buf.truncate(keep);
if let Err(e) = frame_buf.decrypt(1) {
assert!(matches!(e.kind(), ErrorKind::InvalidEndLength(_)))
} else {
panic!("expected an error")
}
}
#[test]
fn framebuf_decrypt_seq_num_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt).unwrap();
assert_eq!(len, plain_txt.len());
let frame_header = FrameHeader::new(1, 2, true);
frame_buf.encrypt(&frame_header);
if let Err(e) = frame_buf.decrypt(2) {
assert!(matches!(e.kind(), ErrorKind::UnexpectedSeqNum(2, 1)));
} else {
panic!("expected an error");
}
}
#[test]
fn framebuf_chunk_mut_commit() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let frame_data = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let chunk = frame_buf.chunk_mut();
assert_eq!(chunk.len(), header.frame_len.as_usize());
chunk[..frame_data.len()].copy_from_slice(&frame_data);
assert_eq!(frame_buf.payload_len, 0);
frame_buf.commit_chunk_mut(frame_data.len()).unwrap();
assert_eq!(frame_buf.payload_len, 0);
assert_eq!(frame_buf.as_ref(), &frame_data);
}
#[test]
fn framebuf_chunk_mut_commit_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let mut frame_buf = FrameBuf::new(&header);
if let Err(e) = frame_buf.commit_chunk_mut(1 << 32) {
assert!(matches!(e.kind(), ErrorKind::InvalidBufLength));
} else {
panic!("expected an error");
}
}
#[test]
fn framebuf_integrity_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let mut frame_buf = FrameBuf::new(&header);
let len = frame_buf.write_payload(0, &plain_txt).unwrap();
assert_eq!(len, plain_txt.len());
let frame_header = FrameHeader::new(1, 2, true);
frame_buf.encrypt(&frame_header);
for i in 0..frame_buf.buf.len() {
let mut buf_copy = frame_buf.buf.clone();
buf_copy[i] = !buf_copy[i];
let mut frame_buf_reader = FrameBuf::new(&header);
frame_buf_reader.buf = buf_copy;
let result = frame_buf.decrypt(0);
assert!(result.is_err());
}
}
#[cfg(feature = "std")]
#[test]
fn stream_write() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let cipher_txt: Vec<u8> = Vec::default();
let mut stream = ZymicStream::new(cipher_txt, &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
assert!(stream.is_eof());
let cipher_txt = stream.into_inner();
let expected_frame_header = FrameHeaderBuilder::new(0).end().build();
validate_frame_bytes(&cipher_txt, &expected_frame_header);
assert_ne!(plain_txt, cipher_txt);
}
#[cfg(feature = "std")]
#[test]
fn stream_write_read_eof() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let cursor = Cursor::new(Vec::default());
let mut stream = ZymicStream::new(cursor, &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
assert!(stream.is_eof());
let mut buf = vec![0u8; 5];
let len = stream.read(&mut buf).unwrap();
assert_eq!(len, 0);
assert_eq!(buf, vec![0u8; 5]);
}
#[cfg(feature = "std")]
#[test]
fn stream_write_invocation() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let cursor = Cursor::new(Vec::default());
let mut stream = ZymicStream::new(cursor, &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
assert!(stream.is_eof());
assert_eq!(stream.invocation, 0);
stream.seek(SeekFrom::Start(0)).unwrap();
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
assert!(stream.is_eof());
assert_eq!(stream.invocation, 1)
}
#[cfg(feature = "std")]
#[test]
fn stream_seek_read() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let cursor = Cursor::new(Vec::default());
let mut stream = ZymicStream::new(cursor, &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
stream.seek(SeekFrom::Start(0)).unwrap();
let mut buf = vec![0u8; 5];
stream.read_exact(&mut buf).unwrap();
assert_eq!(plain_txt, buf);
}
#[cfg(feature = "std")]
#[test]
fn stream_read_eof() {
let parent_key = mock_parent_key();
let frame_len = FrameLength::Len4KiB;
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE)
.with_frame_len(frame_len)
.build();
let plain_txt = payload_from_frame_count(4, frame_len);
let mut stream = ZymicStream::new(Vec::default(), &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
let cipher_txt = stream.into_inner();
let mut stream = ZymicStream::new(Cursor::new(cipher_txt), &header);
let mut buf = vec![0u8; plain_txt.len()];
stream.read_exact(&mut buf).unwrap();
assert!(stream.is_eof());
}
#[cfg(feature = "std")]
#[test]
fn stream_seek_write_read() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let cursor = Cursor::new(Vec::default());
let mut stream = ZymicStream::new(cursor, &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
stream.seek(SeekFrom::Start(2)).unwrap();
assert_eq!(stream.payload_pos, 2);
let plain_txt = vec![6, 7, 8];
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
stream.rewind().unwrap();
let mut buf = vec![0u8; 5];
stream.read_exact(&mut buf).unwrap();
assert_eq!(buf, vec![1, 2, 6, 7, 8]);
}
#[cfg(feature = "std")]
#[test]
fn stream_seek_write_read_2() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let cursor = Cursor::new(Vec::default());
let mut stream = ZymicStream::new(cursor, &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
stream.seek(SeekFrom::Start(2)).unwrap();
assert_eq!(stream.payload_pos, 2);
let plain_txt = vec![6, 7, 8];
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
stream.rewind().unwrap();
let mut buf = vec![0u8; 5];
stream.read_exact(&mut buf).unwrap();
assert_eq!(buf, vec![1, 2, 6, 7, 8]);
}
#[cfg(feature = "std")]
#[test]
fn stream_seek_end() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let cursor = Cursor::new(Vec::default());
let mut stream = ZymicStream::new(cursor, &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
stream.seek(SeekFrom::End(-3)).unwrap();
assert_eq!(stream.payload_pos, 2);
let plain_txt = vec![6, 7, 8];
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
stream.rewind().unwrap();
let mut buf = vec![0u8; 5];
stream.read_exact(&mut buf).unwrap();
assert_eq!(buf, vec![1, 2, 6, 7, 8]);
}
#[cfg(feature = "std")]
#[test]
fn stream_seek_end_len() {
let parent_key = mock_parent_key();
let frame_len = FrameLength::Len4KiB;
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE)
.with_frame_len(frame_len)
.build();
let plain_txt = payload_from_frame_count(4, frame_len);
let cursor = Cursor::new(Vec::default());
let mut stream = ZymicStream::new(cursor, &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
let off = stream.seek(SeekFrom::End(0)).unwrap();
assert_eq!(off as usize, plain_txt.len());
}
#[cfg(feature = "std")]
#[test]
fn stream_seek_current() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let cursor = Cursor::new(Vec::default());
let mut stream = ZymicStream::new(cursor, &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
stream.rewind().unwrap();
stream.seek(SeekFrom::Current(2)).unwrap();
assert_eq!(stream.payload_pos, 2);
let plain_txt = vec![6, 7, 8];
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
stream.rewind().unwrap();
let mut buf = vec![0u8; 5];
stream.read_exact(&mut buf).unwrap();
assert_eq!(buf, vec![1, 2, 6, 7, 8]);
}
#[cfg(feature = "std")]
#[test]
fn stream_seek_empty_payload() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let cursor = Cursor::new(Vec::default());
let mut stream = ZymicStream::new(cursor, &header);
stream.write_all(&[]).unwrap();
stream.eof().unwrap();
let off = stream.seek(SeekFrom::Start(0)).unwrap();
assert_eq!(off, 0);
let off = stream.seek(SeekFrom::End(0)).unwrap();
assert_eq!(off, 0);
let off = stream.stream_position().unwrap();
assert_eq!(off, 0);
}
#[cfg(feature = "std")]
#[test]
fn stream_seek_multi_frame() {
let parent_key = mock_parent_key();
let frame_len = FrameLength::Len4KiB;
let payload_len_per_frame = frame_len.as_usize() - FRAME_META_LEN;
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE)
.with_frame_len(frame_len)
.build();
let mut plain_txt = payload_from_frame_count(2, frame_len);
plain_txt[payload_len_per_frame..].fill(0xff);
let mut stream = ZymicStream::new(Cursor::new(Vec::default()), &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
stream.rewind().unwrap();
let expected_off = payload_len_per_frame as u64;
let off = stream.seek(SeekFrom::Start(expected_off)).unwrap();
assert_eq!(off, expected_off);
assert_eq!(stream.seq_num, 1);
let mut buf = vec![0u8; payload_len_per_frame];
stream.read_exact(&mut buf).unwrap();
assert!(buf.iter().all(|&v| v == 0xff));
stream.rewind().unwrap();
let expected_off = payload_len_per_frame as i64;
let off = stream.seek(SeekFrom::Current(expected_off)).unwrap();
assert_eq!(off, expected_off as u64);
assert_eq!(stream.seq_num, 1);
let mut buf = vec![0u8; payload_len_per_frame];
stream.read_exact(&mut buf).unwrap();
assert!(buf.iter().all(|&v| v == 0xff));
stream.rewind().unwrap();
let expected_off = payload_len_per_frame as i64;
let off = stream.seek(SeekFrom::End(-expected_off)).unwrap();
assert_eq!(off, expected_off as u64);
assert_eq!(stream.seq_num, 1);
let mut buf = vec![0u8; payload_len_per_frame];
stream.read_exact(&mut buf).unwrap();
assert!(buf.iter().all(|&v| v == 0xff));
}
#[cfg(feature = "std")]
#[test]
fn stream_seek_unexpected_eof_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let cursor = Cursor::new(Vec::default());
let mut stream = ZymicStream::new(cursor, &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
if let Err(e) = stream.seek(SeekFrom::Start(1 << 21)) {
let inner = e.get_ref().unwrap().downcast_ref::<Error>().unwrap();
assert!(matches!(inner.kind(), ErrorKind::UnexpectedEof))
} else {
panic!("expecting an error")
}
if let Err(e) = stream.seek(SeekFrom::End(1)) {
let inner = e.get_ref().unwrap().downcast_ref::<Error>().unwrap();
assert!(matches!(inner.kind(), ErrorKind::UnexpectedEof))
} else {
panic!("expecting an error")
}
if let Err(e) = stream.seek(SeekFrom::Current(32)) {
let inner = e.get_ref().unwrap().downcast_ref::<Error>().unwrap();
assert!(matches!(inner.kind(), ErrorKind::UnexpectedEof))
} else {
panic!("expecting an error")
}
}
#[cfg(feature = "std")]
#[test]
fn stream_seek_invalid_err() {
let parent_key = mock_parent_key();
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE).build();
let plain_txt = vec![1, 2, 3, 4, 5];
let cursor = Cursor::new(Vec::default());
let mut stream = ZymicStream::new(cursor, &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
if let Err(e) = stream.seek(SeekFrom::End(-32)) {
assert!(matches!(e.kind(), std::io::ErrorKind::InvalidInput))
} else {
panic!("expecting an error")
}
if let Err(e) = stream.seek(SeekFrom::Current(-32)) {
assert!(matches!(e.kind(), std::io::ErrorKind::InvalidInput))
} else {
panic!("expecting an error")
}
}
#[cfg(feature = "std")]
#[test]
fn stream_seq_num_err() {
let parent_key = mock_parent_key();
let frame_len = FrameLength::Len4KiB;
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE)
.with_frame_len(frame_len)
.build();
let plain_txt = payload_from_frame_count(4, frame_len);
let mut stream = ZymicStream::new(Vec::default(), &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
let mut cipher_txt = stream.into_inner();
swap_frames(&mut cipher_txt, frame_len, 2, 3);
let mut stream = ZymicStream::new(Cursor::new(cipher_txt), &header);
let mut buf = vec![0u8; plain_txt.len()];
if let Err(e) = stream.read_exact(&mut buf) {
let inner = e.get_ref().unwrap().downcast_ref::<Error>().unwrap();
assert!(matches!(inner.kind(), ErrorKind::UnexpectedSeqNum(2, 3)))
} else {
panic!("expecting an error")
}
}
#[cfg(feature = "std")]
#[test]
fn stream_seq_num_err_2() {
let parent_key = mock_parent_key();
let frame_len = FrameLength::Len4KiB;
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE)
.with_frame_len(frame_len)
.build();
let plain_txt = payload_from_frame_count(4, frame_len);
let mut stream = ZymicStream::new(Vec::default(), &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
let mut cipher_txt = stream.into_inner();
swap_frames(&mut cipher_txt, frame_len, 1, 2);
let mut stream = ZymicStream::new(Cursor::new(cipher_txt), &header);
let mut buf = vec![0u8; plain_txt.len()];
if let Err(e) = stream.read_exact(&mut buf) {
let inner = e.get_ref().unwrap().downcast_ref::<Error>().unwrap();
assert!(matches!(inner.kind(), ErrorKind::UnexpectedSeqNum(1, 2)))
} else {
panic!("expecting an error")
}
}
#[cfg(feature = "std")]
#[test]
fn stream_truncated_err() {
let parent_key = mock_parent_key();
let frame_len = FrameLength::Len4KiB;
let header = HeaderBuilder::new(&parent_key, &TEST_NONCE)
.with_frame_len(frame_len)
.build();
let plain_txt = payload_from_frame_count(4, frame_len);
let mut stream = ZymicStream::new(Vec::default(), &header);
stream.write_all(&plain_txt).unwrap();
stream.eof().unwrap();
let mut cipher_txt = stream.into_inner();
cipher_txt.truncate(frame_len.as_usize() * 3);
let mut stream = ZymicStream::new(Cursor::new(cipher_txt), &header);
let mut buf = vec![0u8; plain_txt.len() - frame_len.as_usize()];
stream.read_exact(&mut buf).unwrap();
assert!(!stream.is_eof());
}
#[cfg(feature = "std")]
#[test]
fn stream_io_copy_aligned() {
stream_io_copy(128);
}
#[cfg(feature = "std")]
#[test]
fn stream_io_copy_unaligned() {
stream_io_copy(317);
}
}