use std::cmp;
use std::io::{self, Read, Write};
use chacha20poly1305::{
XChaCha20Poly1305,
aead::{KeyInit as AeadKeyInit, stream},
};
use zeroize::Zeroize;
use crate::CryptoError;
use crate::crypto::aead::TAG_SIZE;
use crate::crypto::keys::PayloadKey;
use crate::error::StreamError;
pub(crate) const BUFFER_SIZE: usize = 65536;
pub(crate) const STREAM_NONCE_SIZE: usize = 19;
const STREAM_CHUNK_COUNT_MAX: u64 = 1u64 << 32;
fn stream_io_error(kind: io::ErrorKind, err: StreamError) -> io::Error {
io::Error::new(kind, err)
}
#[must_use = "EncryptWriter must be finalized via finish() — drop without finish produces an unverifiable stream"]
pub(crate) struct EncryptWriter<W: Write> {
encryptor: Option<stream::EncryptorBE32<XChaCha20Poly1305>>,
chunk: Vec<u8>,
output: Option<W>,
chunk_count: u64,
}
impl<W: Write> EncryptWriter<W> {
pub(crate) fn new(encryptor: stream::EncryptorBE32<XChaCha20Poly1305>, output: W) -> Self {
Self {
encryptor: Some(encryptor),
chunk: Vec::with_capacity(BUFFER_SIZE + TAG_SIZE),
output: Some(output),
chunk_count: 0,
}
}
pub(crate) fn finish(mut self) -> Result<W, CryptoError> {
let encryptor = self.encryptor.take().ok_or(CryptoError::InternalInvariant(
"Internal error: encrypt writer already finished",
))?;
let mut output = self.output.take().ok_or(CryptoError::InternalInvariant(
"Internal error: encrypt writer already finished",
))?;
if self.chunk_count >= STREAM_CHUNK_COUNT_MAX {
return Err(CryptoError::PayloadChunkCountExceeded);
}
encryptor
.encrypt_last_in_place(b"", &mut self.chunk)
.map_err(|_| {
CryptoError::InternalCryptoFailure("Internal error: payload encryption failed")
})?;
self.chunk_count += 1;
output.write_all(&self.chunk)?;
output.flush()?;
self.chunk.zeroize();
Ok(output)
}
}
impl<W: Write> Write for EncryptWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut written = 0;
while written < buf.len() {
if self.chunk.len() == BUFFER_SIZE {
if self.chunk_count >= STREAM_CHUNK_COUNT_MAX {
return Err(stream_io_error(
io::ErrorKind::InvalidData,
StreamError::ChunkCountExceeded,
));
}
let encryptor = self.encryptor.as_mut().ok_or_else(|| {
stream_io_error(io::ErrorKind::Other, StreamError::StateExhausted)
})?;
encryptor
.encrypt_next_in_place(b"", &mut self.chunk)
.map_err(|_| stream_io_error(io::ErrorKind::Other, StreamError::EncryptAead))?;
self.chunk_count += 1;
let output = self.output.as_mut().ok_or_else(|| {
stream_io_error(io::ErrorKind::Other, StreamError::StateExhausted)
})?;
output.write_all(&self.chunk)?;
self.chunk.zeroize();
}
let space = BUFFER_SIZE - self.chunk.len();
let take = cmp::min(space, buf.len() - written);
self.chunk.extend_from_slice(&buf[written..written + take]);
written += take;
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
match self.output.as_mut() {
Some(output) => output.flush(),
None => Ok(()),
}
}
}
impl<W: Write> Drop for EncryptWriter<W> {
fn drop(&mut self) {
self.chunk.zeroize();
}
}
pub(crate) struct DecryptReader<R: Read> {
decryptor: Option<stream::DecryptorBE32<XChaCha20Poly1305>>,
input: R,
chunk: Vec<u8>,
pos: usize,
done: bool,
lookahead: Option<u8>,
chunk_count: u64,
}
impl<R: Read> DecryptReader<R> {
pub(crate) fn new(decryptor: stream::DecryptorBE32<XChaCha20Poly1305>, input: R) -> Self {
Self {
decryptor: Some(decryptor),
input,
chunk: Vec::with_capacity(BUFFER_SIZE + TAG_SIZE),
pos: 0,
done: false,
chunk_count: 0,
lookahead: None,
}
}
fn fill_buffer(&mut self) -> io::Result<()> {
let result = self.fill_buffer_inner();
if result.is_err() {
self.chunk.zeroize();
self.pos = self.chunk.len();
}
result
}
fn fill_buffer_inner(&mut self) -> io::Result<()> {
const ENCRYPTED_CHUNK_SIZE: usize = BUFFER_SIZE + TAG_SIZE;
self.chunk.zeroize();
self.chunk.resize(ENCRYPTED_CHUNK_SIZE, 0);
let mut filled = 0;
if let Some(b) = self.lookahead.take() {
self.chunk[0] = b;
filled = 1;
}
while filled < ENCRYPTED_CHUNK_SIZE {
let n = self.input.read(&mut self.chunk[filled..])?;
if n == 0 {
break;
}
filled += n;
}
self.chunk.truncate(filled);
if filled == 0 {
return Err(stream_io_error(
io::ErrorKind::UnexpectedEof,
StreamError::Truncated,
));
}
let mut probe = [0u8; 1];
let probe_n = if filled == ENCRYPTED_CHUNK_SIZE {
self.input.read(&mut probe)?
} else {
0
};
if self.chunk_count >= STREAM_CHUNK_COUNT_MAX {
return Err(stream_io_error(
io::ErrorKind::InvalidData,
StreamError::ChunkCountExceeded,
));
}
if filled == ENCRYPTED_CHUNK_SIZE && probe_n > 0 {
self.lookahead = Some(probe[0]);
let decryptor = self.decryptor.as_mut().ok_or_else(|| {
stream_io_error(io::ErrorKind::Other, StreamError::StateExhausted)
})?;
decryptor
.decrypt_next_in_place(b"", &mut self.chunk)
.map_err(|_| {
stream_io_error(io::ErrorKind::InvalidData, StreamError::DecryptAead)
})?;
self.chunk_count += 1;
} else {
let decryptor = self.decryptor.take().ok_or_else(|| {
stream_io_error(io::ErrorKind::Other, StreamError::StateExhausted)
})?;
decryptor
.decrypt_last_in_place(b"", &mut self.chunk)
.map_err(|_| {
stream_io_error(io::ErrorKind::InvalidData, StreamError::DecryptAead)
})?;
self.chunk_count += 1;
self.done = true;
let mut probe2 = [0u8; 1];
let n = self.input.read(&mut probe2)?;
if n > 0 {
return Err(stream_io_error(
io::ErrorKind::InvalidData,
StreamError::ExtraData,
));
}
}
self.pos = 0;
Ok(())
}
}
impl<R: Read> Read for DecryptReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
if self.pos >= self.chunk.len() {
if self.done {
return Ok(0);
}
self.fill_buffer()?;
if self.done && self.chunk.is_empty() {
return Ok(0);
}
}
let available = self.chunk.len() - self.pos;
let n = cmp::min(buf.len(), available);
buf[..n].copy_from_slice(&self.chunk[self.pos..self.pos + n]);
self.pos += n;
Ok(n)
}
}
impl<R: Read> Drop for DecryptReader<R> {
fn drop(&mut self) {
self.chunk.zeroize();
}
}
pub(crate) fn payload_encryptor<W: Write>(
payload_key: &PayloadKey,
stream_nonce: &[u8; STREAM_NONCE_SIZE],
writer: W,
) -> EncryptWriter<W> {
let cipher = XChaCha20Poly1305::new(payload_key.expose().into());
let stream_encryptor = stream::EncryptorBE32::from_aead(cipher, stream_nonce.into());
EncryptWriter::new(stream_encryptor, writer)
}
pub(crate) fn payload_decryptor<R: Read>(
payload_key: &PayloadKey,
stream_nonce: &[u8; STREAM_NONCE_SIZE],
reader: R,
) -> DecryptReader<R> {
let cipher = XChaCha20Poly1305::new(payload_key.expose().into());
let stream_decryptor = stream::DecryptorBE32::from_aead(cipher, stream_nonce.into());
DecryptReader::new(stream_decryptor, reader)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::keys::ENCRYPTION_KEY_SIZE;
const TEST_NONCE: [u8; STREAM_NONCE_SIZE] = [0x37; STREAM_NONCE_SIZE];
fn test_key() -> PayloadKey {
PayloadKey::from_bytes_for_tests([0x42; ENCRYPTION_KEY_SIZE])
}
fn encrypt_to_vec(plaintext: &[u8]) -> Vec<u8> {
let mut ciphertext: Vec<u8> = Vec::new();
let mut writer = payload_encryptor(&test_key(), &TEST_NONCE, &mut ciphertext);
writer.write_all(plaintext).unwrap();
let _ = writer.finish().unwrap();
ciphertext
}
fn decrypt_to_vec(ciphertext: &[u8]) -> Vec<u8> {
let mut reader = payload_decryptor(&test_key(), &TEST_NONCE, ciphertext);
let mut out = Vec::new();
reader.read_to_end(&mut out).unwrap();
out
}
#[test]
fn streaming_aead_round_trip_exact_buffer_size() {
let plaintext: Vec<u8> = (0..BUFFER_SIZE).map(|i| (i % 251) as u8).collect();
let ciphertext = encrypt_to_vec(&plaintext);
assert_eq!(
ciphertext.len(),
BUFFER_SIZE + TAG_SIZE,
"expected exactly one full final chunk (FORMAT.md §5: no empty trailer)"
);
let decrypted = decrypt_to_vec(&ciphertext);
assert_eq!(decrypted, plaintext);
}
#[test]
fn streaming_aead_round_trip_byte_at_a_time_writes() {
let plaintext: Vec<u8> = (0..(BUFFER_SIZE * 2 + 50))
.map(|i| (i % 251) as u8)
.collect();
let mut ciphertext: Vec<u8> = Vec::new();
let mut writer = payload_encryptor(&test_key(), &TEST_NONCE, &mut ciphertext);
for byte in &plaintext {
writer.write_all(std::slice::from_ref(byte)).unwrap();
}
let _ = writer.finish().unwrap();
let decrypted = decrypt_to_vec(&ciphertext);
assert_eq!(decrypted, plaintext);
}
#[test]
fn streaming_aead_exact_multiple_no_empty_trailer() {
let plaintext: Vec<u8> = (0..(BUFFER_SIZE * 3)).map(|i| (i % 251) as u8).collect();
let ciphertext = encrypt_to_vec(&plaintext);
assert_eq!(
ciphertext.len(),
3 * (BUFFER_SIZE + TAG_SIZE),
"expected three full chunks (last one is the FINAL chunk; FORMAT.md §5)"
);
let decrypted = decrypt_to_vec(&ciphertext);
assert_eq!(decrypted, plaintext);
}
#[test]
fn streaming_aead_empty_plaintext_is_single_tag_only_chunk() {
let ciphertext = encrypt_to_vec(&[]);
assert_eq!(
ciphertext.len(),
TAG_SIZE,
"empty plaintext must produce exactly one tag-only final chunk"
);
let decrypted = decrypt_to_vec(&ciphertext);
assert_eq!(decrypted, &[] as &[u8]);
}
#[test]
fn streaming_aead_decrypt_with_small_read_buffers() {
let plaintext: Vec<u8> = (0..(BUFFER_SIZE * 2 + 1234))
.map(|i| (i % 251) as u8)
.collect();
let ciphertext = encrypt_to_vec(&plaintext);
let mut reader = payload_decryptor(&test_key(), &TEST_NONCE, ciphertext.as_slice());
let mut decrypted = Vec::with_capacity(plaintext.len());
let mut tiny_buf = [0u8; 7];
loop {
let n = reader.read(&mut tiny_buf).unwrap();
if n == 0 {
break;
}
decrypted.extend_from_slice(&tiny_buf[..n]);
}
assert_eq!(decrypted, plaintext);
}
fn drain_decrypt_reader(reader: &mut DecryptReader<&[u8]>) -> (Vec<u8>, Option<io::Error>) {
let mut out = Vec::new();
let mut scratch = [0u8; 4096];
loop {
match reader.read(&mut scratch) {
Ok(0) => return (out, None),
Ok(n) => out.extend_from_slice(&scratch[..n]),
Err(e) => return (out, Some(e)),
}
}
}
#[test]
fn streaming_aead_empty_input_rejected_as_truncation() {
let mut reader = payload_decryptor(&test_key(), &TEST_NONCE, &[][..]);
let (out, err) = drain_decrypt_reader(&mut reader);
let err = err.expect("expected truncation error, got clean EOF");
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
let marker = err
.get_ref()
.and_then(|inner| inner.downcast_ref::<StreamError>())
.expect("expected StreamError marker");
assert!(
matches!(marker, StreamError::Truncated),
"expected StreamError::Truncated, got {marker:?}"
);
assert!(
out.is_empty(),
"no plaintext should be served on empty input"
);
}
#[test]
fn streaming_aead_zero_len_read_is_noop_on_empty_input() {
let mut reader = payload_decryptor(&test_key(), &TEST_NONCE, &[][..]);
let mut empty: [u8; 0] = [];
assert_eq!(reader.read(&mut empty).unwrap(), 0);
}
#[test]
fn streaming_aead_zero_len_read_does_not_consume_input() {
let plaintext = b"hello, ferrocrypt";
let ciphertext = encrypt_to_vec(plaintext);
let mut reader = payload_decryptor(&test_key(), &TEST_NONCE, ciphertext.as_slice());
let mut empty: [u8; 0] = [];
assert_eq!(reader.read(&mut empty).unwrap(), 0);
let mut recovered = Vec::new();
reader.read_to_end(&mut recovered).unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn streaming_aead_chunk_boundary_truncation_rejected() {
let plaintext: Vec<u8> = (0..(BUFFER_SIZE * 2)).map(|i| (i % 251) as u8).collect();
let mut ciphertext = encrypt_to_vec(&plaintext);
ciphertext.truncate(BUFFER_SIZE + TAG_SIZE);
let mut reader = payload_decryptor(&test_key(), &TEST_NONCE, ciphertext.as_slice());
let (out, err) = drain_decrypt_reader(&mut reader);
let err = err.expect("expected AEAD error on chunk-boundary truncation");
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
let marker = err
.get_ref()
.and_then(|inner| inner.downcast_ref::<StreamError>())
.expect("expected StreamError marker");
assert!(
matches!(marker, StreamError::DecryptAead),
"expected StreamError::DecryptAead, got {marker:?}"
);
assert!(
out.is_empty(),
"no plaintext should leak from a truncated `next` chunk"
);
}
#[test]
fn streaming_aead_late_ciphertext_bit_flip_rejected() {
let plaintext: Vec<u8> = (0..(BUFFER_SIZE * 2 + 1234))
.map(|i| (i % 251) as u8)
.collect();
let mut ciphertext = encrypt_to_vec(&plaintext);
let second_chunk_offset = BUFFER_SIZE + TAG_SIZE;
ciphertext[second_chunk_offset + 100] ^= 0x01;
let mut reader = payload_decryptor(&test_key(), &TEST_NONCE, ciphertext.as_slice());
let (out, err) = drain_decrypt_reader(&mut reader);
let err = err.expect("expected AEAD tamper error, got clean EOF");
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
let marker = err
.get_ref()
.and_then(|inner| inner.downcast_ref::<StreamError>())
.expect("expected StreamError marker");
assert!(
matches!(marker, StreamError::DecryptAead),
"expected StreamError::DecryptAead, got {marker:?}"
);
assert_eq!(out.as_slice(), &plaintext[..BUFFER_SIZE]);
}
#[test]
fn streaming_aead_mid_chunk_truncation_rejected() {
let plaintext: Vec<u8> = (0..(BUFFER_SIZE + 500)).map(|i| (i % 251) as u8).collect();
let mut ciphertext = encrypt_to_vec(&plaintext);
ciphertext.truncate(ciphertext.len() - 10);
let mut reader = payload_decryptor(&test_key(), &TEST_NONCE, ciphertext.as_slice());
let (out, err) = drain_decrypt_reader(&mut reader);
let err = err.expect("expected AEAD error on mid-chunk truncation");
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
let marker = err
.get_ref()
.and_then(|inner| inner.downcast_ref::<StreamError>())
.expect("expected StreamError marker");
assert!(
matches!(marker, StreamError::DecryptAead),
"expected StreamError::DecryptAead, got {marker:?}"
);
assert_eq!(out.as_slice(), &plaintext[..BUFFER_SIZE]);
}
struct LegitThenExtraReader<'a> {
legit: &'a [u8],
extra: &'a [u8],
legit_pos: usize,
extra_pos: usize,
legit_exhausted: bool,
}
impl<'a> LegitThenExtraReader<'a> {
fn new(legit: &'a [u8], extra: &'a [u8]) -> Self {
Self {
legit,
extra,
legit_pos: 0,
extra_pos: 0,
legit_exhausted: false,
}
}
}
impl<'a> Read for LegitThenExtraReader<'a> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if !self.legit_exhausted {
let remaining = self.legit.len() - self.legit_pos;
if remaining == 0 {
self.legit_exhausted = true;
return Ok(0);
}
let n = cmp::min(buf.len(), remaining);
buf[..n].copy_from_slice(&self.legit[self.legit_pos..self.legit_pos + n]);
self.legit_pos += n;
return Ok(n);
}
let remaining = self.extra.len() - self.extra_pos;
if remaining == 0 {
return Ok(0);
}
let n = cmp::min(buf.len(), remaining);
buf[..n].copy_from_slice(&self.extra[self.extra_pos..self.extra_pos + n]);
self.extra_pos += n;
Ok(n)
}
}
#[test]
fn streaming_aead_extra_data_after_final_chunk_rejected() {
let plaintext: Vec<u8> = (0..(BUFFER_SIZE + 500)).map(|i| (i % 251) as u8).collect();
let ciphertext = encrypt_to_vec(&plaintext);
let trailing = b"garbage-appended-by-attacker";
let reader_wrapper = LegitThenExtraReader::new(&ciphertext, trailing);
let mut reader = payload_decryptor(&test_key(), &TEST_NONCE, reader_wrapper);
let mut out = Vec::new();
let mut scratch = [0u8; 4096];
let err = loop {
match reader.read(&mut scratch) {
Ok(0) => panic!("expected ExtraData error, got clean EOF"),
Ok(n) => out.extend_from_slice(&scratch[..n]),
Err(e) => break e,
}
};
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
let marker = err
.get_ref()
.and_then(|inner| inner.downcast_ref::<StreamError>())
.expect("expected StreamError marker");
assert!(
matches!(marker, StreamError::ExtraData),
"expected StreamError::ExtraData, got {marker:?}"
);
assert_eq!(out.as_slice(), &plaintext[..BUFFER_SIZE]);
}
#[test]
fn streaming_aead_no_plaintext_after_err_retry() {
let plaintext: Vec<u8> = (0..500).map(|i| (i % 251) as u8).collect();
let ciphertext = encrypt_to_vec(&plaintext);
let trailing = b"trailing";
let reader_wrapper = LegitThenExtraReader::new(&ciphertext, trailing);
let mut reader = payload_decryptor(&test_key(), &TEST_NONCE, reader_wrapper);
let mut scratch = [0u8; 4096];
let first = reader.read(&mut scratch);
let err = first.expect_err("expected ExtraData error on first read");
let marker = err
.get_ref()
.and_then(|inner| inner.downcast_ref::<StreamError>())
.expect("expected StreamError marker");
assert!(
matches!(marker, StreamError::ExtraData),
"expected StreamError::ExtraData, got {marker:?}"
);
let retry = reader.read(&mut scratch);
assert_eq!(
retry.expect("retry-after-Err must not surface a new I/O error"),
0,
"no plaintext should leak after Err return"
);
}
#[test]
fn streaming_aead_writer_chunk_count_cap_rejects() {
let mut ciphertext: Vec<u8> = Vec::new();
let mut writer = payload_encryptor(&test_key(), &TEST_NONCE, &mut ciphertext);
writer.chunk_count = STREAM_CHUNK_COUNT_MAX;
let plaintext = vec![0u8; BUFFER_SIZE + 1];
let err = writer
.write_all(&plaintext)
.expect_err("expected cap rejection from EncryptWriter::write");
let marker = err
.get_ref()
.and_then(|inner| inner.downcast_ref::<StreamError>())
.expect("expected StreamError marker");
assert!(
matches!(marker, StreamError::ChunkCountExceeded),
"expected StreamError::ChunkCountExceeded, got {marker:?}"
);
}
#[test]
fn streaming_aead_writer_finish_chunk_count_cap_rejects() {
let mut ciphertext: Vec<u8> = Vec::new();
let mut writer = payload_encryptor(&test_key(), &TEST_NONCE, &mut ciphertext);
writer.chunk_count = STREAM_CHUNK_COUNT_MAX;
let err = writer
.finish()
.expect_err("expected cap rejection from EncryptWriter::finish");
assert!(
matches!(err, CryptoError::PayloadChunkCountExceeded),
"expected PayloadChunkCountExceeded, got {err:?}"
);
}
#[test]
fn streaming_aead_reader_chunk_count_cap_rejects() {
let plaintext: Vec<u8> = (0..(BUFFER_SIZE * 2)).map(|i| (i % 251) as u8).collect();
let ciphertext = encrypt_to_vec(&plaintext);
let mut reader = payload_decryptor(&test_key(), &TEST_NONCE, ciphertext.as_slice());
reader.chunk_count = STREAM_CHUNK_COUNT_MAX;
let (out, err) = drain_decrypt_reader(&mut reader);
let err = err.expect("expected cap rejection from DecryptReader");
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
let marker = err
.get_ref()
.and_then(|inner| inner.downcast_ref::<StreamError>())
.expect("expected StreamError marker");
assert!(
matches!(marker, StreamError::ChunkCountExceeded),
"expected StreamError::ChunkCountExceeded, got {marker:?}"
);
assert!(
out.is_empty(),
"no plaintext should leak when the chunk-count cap fires on the first chunk"
);
}
}