use alloc::vec::Vec;
use crate::aead::Algorithm;
use crate::error::{Error, Result};
use super::aead::encrypt_chunk;
use super::frame::{
DEFAULT_CHUNK_SIZE_LOG2, HEADER_LEN, MAX_CHUNK_SIZE_LOG2, MIN_CHUNK_SIZE_LOG2,
NONCE_PREFIX_LEN, build_header, build_nonce, chunk_size_from_log2,
};
#[derive(Debug)]
pub struct StreamEncryptor {
algorithm: Algorithm,
key: [u8; 32],
nonce_prefix: [u8; NONCE_PREFIX_LEN],
aad: [u8; HEADER_LEN],
counter: u32,
chunk_size: usize,
chunk_size_log2: u8,
buffer: Vec<u8>,
}
impl StreamEncryptor {
pub fn new(key: &[u8], algorithm: Algorithm) -> Result<(Self, [u8; HEADER_LEN])> {
Self::new_with_chunk_size(key, algorithm, DEFAULT_CHUNK_SIZE_LOG2)
}
pub fn new_with_chunk_size(
key: &[u8],
algorithm: Algorithm,
chunk_size_log2: u8,
) -> Result<(Self, [u8; HEADER_LEN])> {
check_key(key)?;
if !(MIN_CHUNK_SIZE_LOG2..=MAX_CHUNK_SIZE_LOG2).contains(&chunk_size_log2) {
return Err(Error::InvalidCiphertext(alloc::format!(
"chunk_size_log2 out of range: {chunk_size_log2}"
)));
}
let mut nonce_prefix = [0u8; NONCE_PREFIX_LEN];
mod_rand::tier3::fill_bytes(&mut nonce_prefix)
.map_err(|_| Error::RandomFailure("mod_rand::tier3::fill_bytes"))?;
let header = build_header(algorithm, chunk_size_log2, &nonce_prefix);
let chunk_size = chunk_size_from_log2(chunk_size_log2);
let mut key_arr = [0u8; 32];
key_arr.copy_from_slice(key);
let enc = Self {
algorithm,
key: key_arr,
nonce_prefix,
aad: header,
counter: 0,
chunk_size,
chunk_size_log2,
buffer: Vec::with_capacity(chunk_size),
};
Ok((enc, header))
}
#[must_use]
pub fn chunk_size(&self) -> usize {
self.chunk_size
}
#[must_use]
pub fn chunk_size_log2(&self) -> u8 {
self.chunk_size_log2
}
pub fn update(&mut self, data: &[u8]) -> Result<Vec<u8>> {
if data.is_empty() {
return Ok(Vec::new());
}
let estimated_chunks = data.len() / self.chunk_size + 1;
let mut out = Vec::with_capacity(estimated_chunks * (self.chunk_size + 16));
let mut cursor = 0usize;
while cursor < data.len() {
let needed = self.chunk_size - self.buffer.len();
let take = needed.min(data.len() - cursor);
self.buffer.extend_from_slice(&data[cursor..cursor + take]);
cursor += take;
if self.buffer.len() == self.chunk_size {
let nonce = build_nonce(&self.nonce_prefix, self.counter, false);
let chunk =
encrypt_chunk(self.algorithm, &self.key, &nonce, &self.buffer, &self.aad)?;
out.extend_from_slice(&chunk);
self.counter = self.counter.checked_add(1).ok_or(Error::InvalidCiphertext(
alloc::string::String::from("stream chunk counter overflow"),
))?;
self.buffer.clear();
}
}
Ok(out)
}
pub fn finalize(mut self) -> Result<Vec<u8>> {
let mut out = Vec::with_capacity(self.chunk_size + 16);
if self.buffer.len() == self.chunk_size {
let nonce = build_nonce(&self.nonce_prefix, self.counter, false);
let chunk = encrypt_chunk(self.algorithm, &self.key, &nonce, &self.buffer, &self.aad)?;
out.extend_from_slice(&chunk);
self.counter = self.counter.checked_add(1).ok_or(Error::InvalidCiphertext(
alloc::string::String::from("stream chunk counter overflow"),
))?;
self.buffer.clear();
}
let nonce = build_nonce(&self.nonce_prefix, self.counter, true);
let final_chunk =
encrypt_chunk(self.algorithm, &self.key, &nonce, &self.buffer, &self.aad)?;
out.extend_from_slice(&final_chunk);
Ok(out)
}
}
fn check_key(key: &[u8]) -> Result<()> {
if key.len() == 32 {
Ok(())
} else {
Err(Error::InvalidKey {
expected: 32,
actual: key.len(),
})
}
}