use aead::{AeadInPlace, KeyInit};
use aes_gcm::Aes256Gcm as AesGcm256;
use chacha20poly1305::XChaCha20Poly1305;
use oxicrypto_core::{CryptoError, StreamingAead};
use subtle::ConstantTimeEq as _;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum StreamMode {
Encrypting,
Decrypting,
Finished,
}
fn build_nonce<const NONCE_FULL: usize>(
prefix: &[u8],
counter: u32,
is_final: bool,
) -> [u8; NONCE_FULL] {
let mut nonce = [0u8; NONCE_FULL];
let prefix_len = NONCE_FULL - 5;
nonce[..prefix_len].copy_from_slice(prefix);
let counter_bytes = counter.to_be_bytes();
nonce[prefix_len..prefix_len + 4].copy_from_slice(&counter_bytes);
nonce[NONCE_FULL - 1] = if is_final { 0x01 } else { 0x00 };
nonce
}
fn stream_seal_chunk<C, const NONCE_FULL: usize>(
cipher: &C,
nonce: &[u8; NONCE_FULL],
aad: &[u8],
pt: &[u8],
ct_out: &mut [u8],
) -> Result<usize, CryptoError>
where
C: AeadInPlace,
{
let tag_len = <<C as aead::AeadCore>::TagSize as aead::generic_array::typenum::Unsigned>::USIZE;
let required = pt.len().checked_add(tag_len).ok_or(CryptoError::BadInput)?;
if ct_out.len() < required {
return Err(CryptoError::BufferTooSmall);
}
ct_out[..pt.len()].copy_from_slice(pt);
let nonce_ga = aead::generic_array::GenericArray::from_slice(nonce.as_ref());
let tag = cipher
.encrypt_in_place_detached(nonce_ga, aad, &mut ct_out[..pt.len()])
.map_err(|_| CryptoError::Internal("STREAM encrypt chunk failed"))?;
ct_out[pt.len()..required].copy_from_slice(&tag);
Ok(required)
}
fn stream_open_chunk<C, const NONCE_FULL: usize>(
cipher: &C,
nonce: &[u8; NONCE_FULL],
aad: &[u8],
ct_and_tag: &[u8],
pt_out: &mut [u8],
) -> Result<usize, CryptoError>
where
C: AeadInPlace,
{
let tag_len = <<C as aead::AeadCore>::TagSize as aead::generic_array::typenum::Unsigned>::USIZE;
if ct_and_tag.len() < tag_len {
return Err(CryptoError::BadInput);
}
let pt_len = ct_and_tag.len() - tag_len;
if pt_out.len() < pt_len {
return Err(CryptoError::BufferTooSmall);
}
pt_out[..pt_len].copy_from_slice(&ct_and_tag[..pt_len]);
let nonce_ga = aead::generic_array::GenericArray::from_slice(nonce.as_ref());
let tag_bytes = &ct_and_tag[pt_len..];
let tag = aead::Tag::<C>::clone_from_slice(tag_bytes);
cipher
.decrypt_in_place_detached(nonce_ga, aad, &mut pt_out[..pt_len], &tag)
.map_err(|_| CryptoError::InvalidTag)?;
Ok(pt_len)
}
pub struct Aes256GcmStream {
cipher: Option<AesGcm256>,
nonce_prefix: [u8; 7],
counter: u32,
aad: alloc::vec::Vec<u8>,
pending: alloc::vec::Vec<u8>,
mode: StreamMode,
}
impl core::fmt::Debug for Aes256GcmStream {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Aes256GcmStream")
.field("mode", &self.mode)
.field("counter", &self.counter)
.field("pending_len", &self.pending.len())
.finish()
}
}
extern crate alloc;
use alloc::vec::Vec;
impl Aes256GcmStream {
fn current_nonce(&self, is_final: bool) -> [u8; 12] {
build_nonce(&self.nonce_prefix, self.counter, is_final)
}
fn advance_counter(&mut self) -> Result<(), CryptoError> {
self.counter = self
.counter
.checked_add(1)
.ok_or(CryptoError::Internal("STREAM counter overflow"))?;
Ok(())
}
}
impl StreamingAead for Aes256GcmStream {
fn init(key: &[u8], nonce: &[u8], aad: &[u8]) -> Result<Self, CryptoError> {
if key.len() != 32 {
return Err(CryptoError::InvalidKey);
}
if nonce.len() != 7 {
return Err(CryptoError::InvalidNonce);
}
let cipher = AesGcm256::new_from_slice(key).map_err(|_| CryptoError::InvalidKey)?;
let mut nonce_prefix = [0u8; 7];
nonce_prefix.copy_from_slice(nonce);
Ok(Self {
cipher: Some(cipher),
nonce_prefix,
counter: 0,
aad: aad.to_vec(),
pending: Vec::new(),
mode: StreamMode::Encrypting,
})
}
fn encrypt_update(&mut self, chunk: &[u8], out: &mut [u8]) -> Result<usize, CryptoError> {
if self.mode != StreamMode::Encrypting {
return Err(CryptoError::BadInput);
}
let cipher = self.cipher.as_ref().ok_or(CryptoError::BadInput)?;
if self.pending.is_empty() {
self.pending = chunk.to_vec();
return Ok(0);
}
let nonce = self.current_nonce(false);
let prev = core::mem::replace(&mut self.pending, chunk.to_vec());
let written = stream_seal_chunk::<_, 12>(cipher, &nonce, &self.aad, &prev, out)?;
self.advance_counter()?;
Ok(written)
}
fn encrypt_finalize(mut self, out: &mut [u8]) -> Result<[u8; 16], CryptoError> {
if self.mode != StreamMode::Encrypting {
return Err(CryptoError::BadInput);
}
self.mode = StreamMode::Finished;
let cipher = self.cipher.take().ok_or(CryptoError::BadInput)?;
let nonce = self.current_nonce(true);
let last = self.pending.clone();
let written = stream_seal_chunk::<_, 12>(&cipher, &nonce, &self.aad, &last, out)?;
let tag_start = written - 16;
let mut tag = [0u8; 16];
tag.copy_from_slice(&out[tag_start..written]);
Ok(tag)
}
fn decrypt_update(&mut self, chunk: &[u8], out: &mut [u8]) -> Result<usize, CryptoError> {
if self.mode != StreamMode::Decrypting {
if self.mode == StreamMode::Encrypting && self.counter == 0 && self.pending.is_empty() {
self.mode = StreamMode::Decrypting;
} else {
return Err(CryptoError::BadInput);
}
}
let cipher = self.cipher.as_ref().ok_or(CryptoError::BadInput)?;
if self.pending.is_empty() {
self.pending = chunk.to_vec();
return Ok(0);
}
let nonce = self.current_nonce(false);
let prev = core::mem::replace(&mut self.pending, chunk.to_vec());
let written = stream_open_chunk::<_, 12>(cipher, &nonce, &self.aad, &prev, out)?;
self.advance_counter()?;
Ok(written)
}
fn decrypt_finalize(mut self, expected_tag: &[u8]) -> Result<(), CryptoError> {
if self.mode != StreamMode::Decrypting {
return Err(CryptoError::BadInput);
}
self.mode = StreamMode::Finished;
let cipher = self.cipher.take().ok_or(CryptoError::BadInput)?;
let pending = self.pending.clone();
let tag_len = 16usize;
if pending.len() < tag_len {
return Err(CryptoError::BadInput);
}
let embedded_tag = &pending[pending.len() - tag_len..];
if !bool::from(expected_tag.ct_eq(embedded_tag)) {
return Err(CryptoError::InvalidTag);
}
let nonce = self.current_nonce(true);
let mut pt = alloc::vec![0u8; pending.len() - tag_len];
stream_open_chunk::<_, 12>(&cipher, &nonce, &self.aad, &pending, &mut pt).map(|_| ())
}
fn reset(&mut self) {
self.counter = 0;
self.pending.clear();
self.mode = StreamMode::Encrypting;
self.cipher = None;
}
}
pub struct ChaCha20Poly1305Stream {
cipher: Option<XChaCha20Poly1305>,
nonce_prefix: [u8; 19],
counter: u32,
aad: Vec<u8>,
pending: Vec<u8>,
mode: StreamMode,
}
impl core::fmt::Debug for ChaCha20Poly1305Stream {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ChaCha20Poly1305Stream")
.field("mode", &self.mode)
.field("counter", &self.counter)
.field("pending_len", &self.pending.len())
.finish()
}
}
impl ChaCha20Poly1305Stream {
fn current_nonce(&self, is_final: bool) -> [u8; 24] {
build_nonce(&self.nonce_prefix, self.counter, is_final)
}
fn advance_counter(&mut self) -> Result<(), CryptoError> {
self.counter = self
.counter
.checked_add(1)
.ok_or(CryptoError::Internal("STREAM counter overflow"))?;
Ok(())
}
}
impl StreamingAead for ChaCha20Poly1305Stream {
fn init(key: &[u8], nonce: &[u8], aad: &[u8]) -> Result<Self, CryptoError> {
if key.len() != 32 {
return Err(CryptoError::InvalidKey);
}
if nonce.len() != 19 {
return Err(CryptoError::InvalidNonce);
}
let cipher = XChaCha20Poly1305::new_from_slice(key).map_err(|_| CryptoError::InvalidKey)?;
let mut nonce_prefix = [0u8; 19];
nonce_prefix.copy_from_slice(nonce);
Ok(Self {
cipher: Some(cipher),
nonce_prefix,
counter: 0,
aad: aad.to_vec(),
pending: Vec::new(),
mode: StreamMode::Encrypting,
})
}
fn encrypt_update(&mut self, chunk: &[u8], out: &mut [u8]) -> Result<usize, CryptoError> {
if self.mode != StreamMode::Encrypting {
return Err(CryptoError::BadInput);
}
let cipher = self.cipher.as_ref().ok_or(CryptoError::BadInput)?;
if self.pending.is_empty() {
self.pending = chunk.to_vec();
return Ok(0);
}
let nonce = self.current_nonce(false);
let prev = core::mem::replace(&mut self.pending, chunk.to_vec());
let written = stream_seal_chunk::<_, 24>(cipher, &nonce, &self.aad, &prev, out)?;
self.advance_counter()?;
Ok(written)
}
fn encrypt_finalize(mut self, out: &mut [u8]) -> Result<[u8; 16], CryptoError> {
if self.mode != StreamMode::Encrypting {
return Err(CryptoError::BadInput);
}
self.mode = StreamMode::Finished;
let cipher = self.cipher.take().ok_or(CryptoError::BadInput)?;
let nonce = self.current_nonce(true);
let last = self.pending.clone();
let written = stream_seal_chunk::<_, 24>(&cipher, &nonce, &self.aad, &last, out)?;
let tag_start = written - 16;
let mut tag = [0u8; 16];
tag.copy_from_slice(&out[tag_start..written]);
Ok(tag)
}
fn decrypt_update(&mut self, chunk: &[u8], out: &mut [u8]) -> Result<usize, CryptoError> {
if self.mode != StreamMode::Decrypting {
if self.mode == StreamMode::Encrypting && self.counter == 0 && self.pending.is_empty() {
self.mode = StreamMode::Decrypting;
} else {
return Err(CryptoError::BadInput);
}
}
let cipher = self.cipher.as_ref().ok_or(CryptoError::BadInput)?;
if self.pending.is_empty() {
self.pending = chunk.to_vec();
return Ok(0);
}
let nonce = self.current_nonce(false);
let prev = core::mem::replace(&mut self.pending, chunk.to_vec());
let written = stream_open_chunk::<_, 24>(cipher, &nonce, &self.aad, &prev, out)?;
self.advance_counter()?;
Ok(written)
}
fn decrypt_finalize(mut self, expected_tag: &[u8]) -> Result<(), CryptoError> {
if self.mode != StreamMode::Decrypting {
return Err(CryptoError::BadInput);
}
self.mode = StreamMode::Finished;
let cipher = self.cipher.take().ok_or(CryptoError::BadInput)?;
let pending = self.pending.clone();
let tag_len = 16usize;
if pending.len() < tag_len {
return Err(CryptoError::BadInput);
}
let embedded_tag = &pending[pending.len() - tag_len..];
if !bool::from(expected_tag.ct_eq(embedded_tag)) {
return Err(CryptoError::InvalidTag);
}
let nonce = self.current_nonce(true);
let mut pt = alloc::vec![0u8; pending.len() - tag_len];
stream_open_chunk::<_, 24>(&cipher, &nonce, &self.aad, &pending, &mut pt).map(|_| ())
}
fn reset(&mut self) {
self.counter = 0;
self.pending.clear();
self.mode = StreamMode::Encrypting;
self.cipher = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
const KEY_256: [u8; 32] = [0x42u8; 32];
const NONCE_PREFIX_7: [u8; 7] = [0x24u8; 7];
const NONCE_PREFIX_19: [u8; 19] = [0x24u8; 19];
const AAD: &[u8] = b"stream aad";
const TAG_LEN: usize = 16;
fn encrypt_chunks_aes256(chunks: &[&[u8]]) -> (Vec<Vec<u8>>, [u8; 16]) {
assert!(!chunks.is_empty());
let mut enc = Aes256GcmStream::init(&KEY_256, &NONCE_PREFIX_7, AAD).expect("init enc");
let mut ct_chunks: Vec<Vec<u8>> = Vec::new();
let max_chunk_len = chunks.iter().map(|c| c.len()).max().unwrap_or(0);
let buf_cap = max_chunk_len + TAG_LEN;
for chunk in chunks {
let mut buf = alloc::vec![0u8; buf_cap];
let written = enc.encrypt_update(chunk, &mut buf).expect("encrypt_update");
if written > 0 {
ct_chunks.push(buf[..written].to_vec());
}
}
let last = *chunks.last().unwrap();
let mut final_buf = alloc::vec![0u8; last.len() + TAG_LEN];
let tag = enc
.encrypt_finalize(&mut final_buf)
.expect("encrypt_finalize");
ct_chunks.push(final_buf[..last.len() + TAG_LEN].to_vec());
(ct_chunks, tag)
}
fn decrypt_chunks_aes256(ct_chunks: &[Vec<u8>], final_tag: &[u8; 16]) -> Vec<u8> {
let mut dec = Aes256GcmStream::init(&KEY_256, &NONCE_PREFIX_7, AAD).expect("init dec");
dec.mode = StreamMode::Decrypting;
let mut plaintext: Vec<u8> = Vec::new();
for ct in ct_chunks {
let buf_cap = ct.len(); let mut buf = alloc::vec![0u8; buf_cap];
let written = dec.decrypt_update(ct, &mut buf).expect("decrypt_update");
plaintext.extend_from_slice(&buf[..written]);
}
dec.decrypt_finalize(final_tag).expect("decrypt_finalize");
let last_ct = ct_chunks.last().unwrap();
let pt_len = last_ct.len().saturating_sub(TAG_LEN);
let mut dec2 = Aes256GcmStream::init(&KEY_256, &NONCE_PREFIX_7, AAD).expect("init dec2");
dec2.mode = StreamMode::Decrypting;
for ct in ct_chunks {
let mut buf = alloc::vec![0u8; ct.len()];
let written = dec2.decrypt_update(ct, &mut buf).expect("decrypt_update2");
if written > 0 {
let _ = written;
}
}
let nonce_counter = (ct_chunks.len() as u32).wrapping_sub(1);
let nonce: [u8; 12] = build_nonce(&NONCE_PREFIX_7, nonce_counter, true);
let cipher = aes_gcm::Aes256Gcm::new_from_slice(&KEY_256).expect("cipher");
let nonce_ga = aead::generic_array::GenericArray::from_slice(nonce.as_ref());
let mut last_pt = last_ct[..pt_len].to_vec();
let tag_bytes = &last_ct[pt_len..];
let tag_ga = aead::Tag::<aes_gcm::Aes256Gcm>::clone_from_slice(tag_bytes);
cipher
.decrypt_in_place_detached(nonce_ga, AAD, &mut last_pt, &tag_ga)
.expect("last chunk decrypt");
plaintext.extend_from_slice(&last_pt);
plaintext
}
#[test]
fn aes256gcm_stream_three_chunks() {
let chunks: &[&[u8]] = &[b"chunk-one---", b"chunk-two---", b"chunk-three"];
let expected: Vec<u8> = chunks.iter().flat_map(|c| c.iter().copied()).collect();
let (ct_chunks, final_tag) = encrypt_chunks_aes256(chunks);
let recovered = decrypt_chunks_aes256(&ct_chunks, &final_tag);
assert_eq!(recovered, expected, "three-chunk round-trip failed");
}
#[test]
fn aes256gcm_stream_single_chunk() {
let chunk = b"only one chunk";
let mut enc = Aes256GcmStream::init(&KEY_256, &NONCE_PREFIX_7, AAD).expect("init");
let mut buf = alloc::vec![0u8; chunk.len() + TAG_LEN];
let written = enc.encrypt_update(chunk, &mut buf).expect("update");
assert_eq!(written, 0);
let mut final_buf = alloc::vec![0u8; chunk.len() + TAG_LEN];
let tag = enc.encrypt_finalize(&mut final_buf).expect("finalize");
assert_eq!(final_buf.len(), chunk.len() + TAG_LEN);
let mut dec = Aes256GcmStream::init(&KEY_256, &NONCE_PREFIX_7, AAD).expect("init dec");
dec.mode = StreamMode::Decrypting;
let mut pt_buf = alloc::vec![0u8; chunk.len() + TAG_LEN];
let w = dec
.decrypt_update(&final_buf, &mut pt_buf)
.expect("decrypt_update");
assert_eq!(w, 0, "first update must buffer, not emit");
dec.decrypt_finalize(&tag).expect("decrypt_finalize");
}
#[test]
fn aes256gcm_stream_tamper_middle_chunk_fails() {
let chunks: &[&[u8]] = &[b"chunk-A-data---", b"chunk-B-tamper-", b"chunk-C-final--"];
let (mut ct_chunks, final_tag) = encrypt_chunks_aes256(chunks);
ct_chunks[1][0] ^= 0xFF;
let mut dec = Aes256GcmStream::init(&KEY_256, &NONCE_PREFIX_7, AAD).expect("init dec");
dec.mode = StreamMode::Decrypting;
let mut pt_buf = alloc::vec![0u8; ct_chunks[0].len()];
let w0 = dec
.decrypt_update(&ct_chunks[0], &mut pt_buf)
.expect("update0");
assert_eq!(w0, 0);
let mut pt_buf1 = alloc::vec![0u8; ct_chunks[0].len()];
let w1 = dec
.decrypt_update(&ct_chunks[1], &mut pt_buf1)
.expect("update1");
assert!(w1 > 0, "should have emitted decrypted chunk-A");
let mut pt_buf2 = alloc::vec![0u8; ct_chunks[1].len()];
let result = dec.decrypt_update(&ct_chunks[2], &mut pt_buf2);
assert!(
matches!(result, Err(CryptoError::InvalidTag)),
"expected InvalidTag on tampered chunk, got: {:?}",
result
);
let _ = final_tag;
}
#[test]
fn aes256gcm_stream_tamper_final_tag_fails() {
let chunks: &[&[u8]] = &[b"single"];
let (ct_chunks, mut final_tag) = encrypt_chunks_aes256(chunks);
final_tag[0] ^= 0xFF;
let mut dec = Aes256GcmStream::init(&KEY_256, &NONCE_PREFIX_7, AAD).expect("init dec");
dec.mode = StreamMode::Decrypting;
let mut pt_buf = alloc::vec![0u8; ct_chunks[0].len()];
dec.decrypt_update(&ct_chunks[0], &mut pt_buf)
.expect("update");
let result = dec.decrypt_finalize(&final_tag);
assert!(
matches!(result, Err(CryptoError::InvalidTag)),
"expected InvalidTag, got: {:?}",
result
);
}
#[test]
fn aes256gcm_stream_reject_update_after_finalize() {
let chunk = b"data";
let mut enc = Aes256GcmStream::init(&KEY_256, &NONCE_PREFIX_7, AAD).expect("init");
let mut buf = alloc::vec![0u8; chunk.len() + TAG_LEN];
enc.encrypt_update(chunk, &mut buf).expect("update");
let mut final_buf = alloc::vec![0u8; chunk.len() + TAG_LEN];
let _tag = enc.encrypt_finalize(&mut final_buf).expect("finalize");
}
#[test]
fn chacha20poly1305_stream_single_chunk_round_trip() {
let chunk = b"xchacha20 stream chunk";
let mut enc = ChaCha20Poly1305Stream::init(&KEY_256, &NONCE_PREFIX_19, AAD).expect("init");
let mut buf = alloc::vec![0u8; chunk.len() + TAG_LEN];
let w = enc.encrypt_update(chunk, &mut buf).expect("update");
assert_eq!(w, 0);
let mut final_buf = alloc::vec![0u8; chunk.len() + TAG_LEN];
let tag = enc.encrypt_finalize(&mut final_buf).expect("finalize");
let mut dec =
ChaCha20Poly1305Stream::init(&KEY_256, &NONCE_PREFIX_19, AAD).expect("init dec");
dec.mode = StreamMode::Decrypting;
let mut pt_buf = alloc::vec![0u8; chunk.len() + TAG_LEN];
let _w = dec
.decrypt_update(&final_buf, &mut pt_buf)
.expect("decrypt_update");
dec.decrypt_finalize(&tag).expect("decrypt_finalize");
}
#[test]
fn aes256gcm_stream_wrong_nonce_prefix_length() {
let result = Aes256GcmStream::init(&KEY_256, &[0u8; 12], AAD);
assert!(
matches!(result, Err(CryptoError::InvalidNonce)),
"expected InvalidNonce, got: {:?}",
result.as_ref().map(|_| ()).map_err(|e| format!("{e:?}"))
);
}
#[test]
fn chacha20poly1305_stream_wrong_nonce_prefix_length() {
let result = ChaCha20Poly1305Stream::init(&KEY_256, &[0u8; 12], AAD);
assert!(
matches!(result, Err(CryptoError::InvalidNonce)),
"expected InvalidNonce, got: {:?}",
result.as_ref().map(|_| ()).map_err(|e| format!("{e:?}"))
);
}
#[test]
fn aes256gcm_stream_reset_clears_state() {
let chunk = b"some data";
let mut enc = Aes256GcmStream::init(&KEY_256, &NONCE_PREFIX_7, AAD).expect("init");
assert_eq!(enc.counter, 0, "initial counter");
assert!(enc.pending.is_empty(), "initial pending");
assert_eq!(enc.mode, StreamMode::Encrypting, "initial mode");
let mut buf = alloc::vec![0u8; chunk.len() + TAG_LEN];
let w = enc.encrypt_update(chunk, &mut buf).expect("encrypt_update");
assert_eq!(w, 0, "first update buffers; emits nothing");
assert!(!enc.pending.is_empty(), "pending filled after update");
enc.reset();
assert_eq!(enc.counter, 0, "counter after reset");
assert!(enc.pending.is_empty(), "pending after reset");
assert_eq!(enc.mode, StreamMode::Encrypting, "mode after reset");
assert!(enc.cipher.is_none(), "cipher cleared after reset");
}
}