use alloc::vec::Vec;
use crate::aead::Algorithm;
use crate::error::{Error, Result};
use super::aead::decrypt_chunk;
use super::frame::{HEADER_LEN, NONCE_PREFIX_LEN, build_nonce, chunk_size_from_log2, parse_header};
#[derive(Debug)]
pub struct StreamDecryptor {
algorithm: Algorithm,
key: [u8; 32],
nonce_prefix: [u8; NONCE_PREFIX_LEN],
aad: [u8; HEADER_LEN],
counter: u32,
chunk_size: usize,
chunk_size_log2: u8,
buffer: Vec<u8>,
}
impl StreamDecryptor {
pub fn new(key: &[u8], header_bytes: &[u8]) -> Result<Self> {
if key.len() != 32 {
return Err(Error::InvalidKey {
expected: 32,
actual: key.len(),
});
}
let parsed = parse_header(header_bytes)?;
let chunk_size = chunk_size_from_log2(parsed.chunk_size_log2);
let mut key_arr = [0u8; 32];
key_arr.copy_from_slice(key);
Ok(Self {
algorithm: parsed.algorithm,
key: key_arr,
nonce_prefix: parsed.nonce_prefix,
aad: parsed.raw,
counter: 0,
chunk_size,
chunk_size_log2: parsed.chunk_size_log2,
buffer: Vec::with_capacity(chunk_size + 16),
})
}
#[must_use]
pub fn chunk_size(&self) -> usize {
self.chunk_size
}
#[must_use]
pub fn chunk_size_log2(&self) -> u8 {
self.chunk_size_log2
}
#[must_use]
pub fn algorithm(&self) -> Algorithm {
self.algorithm
}
pub fn update(&mut self, data: &[u8]) -> Result<Vec<u8>> {
if data.is_empty() {
return Ok(Vec::new());
}
self.buffer.extend_from_slice(data);
let chunk_frame = self.chunk_size + 16;
let mut out = Vec::new();
while self.buffer.len() > chunk_frame {
let chunk_bytes: Vec<u8> = self.buffer.drain(..chunk_frame).collect();
let nonce = build_nonce(&self.nonce_prefix, self.counter, false);
let pt = decrypt_chunk(self.algorithm, &self.key, &nonce, &chunk_bytes, &self.aad)?;
out.extend_from_slice(&pt);
self.counter = self.counter.checked_add(1).ok_or(Error::InvalidCiphertext(
alloc::string::String::from("stream chunk counter overflow"),
))?;
}
Ok(out)
}
pub fn finalize(self) -> Result<Vec<u8>> {
let chunk_frame = self.chunk_size + 16;
if self.buffer.len() > chunk_frame {
return Err(Error::InvalidCiphertext(alloc::format!(
"stream finalize buffer too large ({} bytes, max {chunk_frame})",
self.buffer.len()
)));
}
if self.buffer.len() < 16 {
return Err(Error::InvalidCiphertext(alloc::format!(
"stream finalize buffer too short ({} bytes, need at least 16 for tag)",
self.buffer.len()
)));
}
let nonce = build_nonce(&self.nonce_prefix, self.counter, true);
decrypt_chunk(self.algorithm, &self.key, &nonce, &self.buffer, &self.aad)
}
}