#[cfg(feature = "alloc")]
use alloc::{
string::ToString,
vec::Vec,
};
use lib_q_core::{
Aead,
AeadDecryptSemantic,
AeadKey,
DecryptSemanticOutcome,
Error,
Nonce,
Result,
};
use zeroize::{
Zeroize,
Zeroizing,
};
use crate::core::SaturninCore;
#[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
use crate::simd::{
encrypt_blocks8_dispatch,
simd_xor,
};
struct SaturninAeadCores {
d1: SaturninCore,
d2: SaturninCore,
d3: SaturninCore,
d4: SaturninCore,
d5: SaturninCore,
}
impl SaturninAeadCores {
fn new() -> Result<Self> {
Ok(Self {
d1: SaturninCore::new(10, 1)?,
d2: SaturninCore::new(10, 2)?,
d3: SaturninCore::new(10, 3)?,
d4: SaturninCore::new(10, 4)?,
d5: SaturninCore::new(10, 5)?,
})
}
#[inline]
fn domain(&self, d: u8) -> &SaturninCore {
match d {
1 => &self.d1,
2 => &self.d2,
3 => &self.d3,
4 => &self.d4,
5 => &self.d5,
_ => unreachable!("AEAD CTR/cascade only uses domains 1–5"),
}
}
}
pub struct SaturninAead {
cores: SaturninAeadCores,
}
impl SaturninAead {
pub fn new() -> Self {
Self {
cores: SaturninAeadCores::new().expect("Saturnin AEAD uses fixed valid domains"),
}
}
pub const fn key_size() -> usize {
32
}
pub const fn nonce_size() -> usize {
16
}
pub const fn tag_size() -> usize {
32
}
fn cascade_init(&self, key: &[u8], nonce: &[u8]) -> Result<Zeroizing<[u8; 32]>> {
let key32: &[u8; 32] = key.try_into().map_err(|_| Error::InvalidKeySize {
expected: 32,
actual: key.len(),
})?;
let mut r = Zeroizing::new([0u8; 32]);
r[0..16].copy_from_slice(nonce);
r[16] = 0x80;
self.cores.d2.encrypt_block_32(key32, &mut r)?;
for i in 0..16 {
r[i] ^= nonce[i];
}
r[16] ^= 0x80;
Ok(r)
}
fn cascade(&self, r: &mut [u8; 32], d1: u8, d2: u8, data: &[u8]) -> Result<()> {
let core_d1 = self.cores.domain(d1);
let core_d2 = self.cores.domain(d2);
let mut offset = 0;
loop {
let mut t: Zeroizing<[u8; 32]> = Zeroizing::new([0u8; 32]);
let mut m: Zeroizing<[u8; 32]> = Zeroizing::new([0u8; 32]);
let remaining = data.len() - offset;
if remaining >= 32 {
t.copy_from_slice(&data[offset..offset + 32]);
offset += 32;
m.copy_from_slice(&*t);
core_d1.encrypt_block_32(&*r, &mut m)?;
} else {
t[0..remaining].copy_from_slice(&data[offset..]);
t[remaining] = 0x80;
m.copy_from_slice(&*t);
core_d2.encrypt_block_32(&*r, &mut m)?;
}
#[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
{
let mut out: Zeroizing<[u8; 32]> = Zeroizing::new([0u8; 32]);
simd_xor::xor_blocks_32(&m, &t, &mut out);
r.copy_from_slice(&*out);
}
#[cfg(not(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon")))]
{
for i in 0..32 {
r[i] = m[i] ^ t[i];
}
}
if remaining < 32 {
break;
}
}
Ok(())
}
fn ctr_encrypt(&self, key: &[u8], nonce: &[u8], data: &mut [u8]) -> Result<()> {
let key32: &[u8; 32] = key.try_into().map_err(|_| Error::InvalidKeySize {
expected: 32,
actual: key.len(),
})?;
let core = &self.cores.d1;
let mut counter = 1u32; let mut offset = 0;
while offset < data.len() {
#[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
if data.len() - offset >= 32 * 8 {
let mut keystream_blocks = [[0u8; 32]; 8];
for (lane, block) in keystream_blocks.iter_mut().enumerate() {
let c = counter.wrapping_add(lane as u32);
block[0..16].copy_from_slice(nonce);
block[16] = 0x80;
block[28] = (c >> 24) as u8;
block[29] = (c >> 16) as u8;
block[30] = (c >> 8) as u8;
block[31] = c as u8;
}
encrypt_blocks8_dispatch(10, 1, key, &mut keystream_blocks, Some(core))?;
for (lane, ks) in keystream_blocks.iter().enumerate() {
let start = offset + (lane * 32);
let mut input = [0u8; 32];
input.copy_from_slice(&data[start..start + 32]);
let mut out = [0u8; 32];
simd_xor::xor_blocks_32(&input, ks, &mut out);
data[start..start + 32].copy_from_slice(&out);
}
offset += 32 * 8;
let (next_counter, overflowed) = counter.overflowing_add(8);
if overflowed {
return Err(Error::InvalidMessageSize {
max: usize::MAX,
actual: data.len(),
});
}
counter = next_counter;
continue;
}
let mut keystream = [0u8; 32];
keystream[0..16].copy_from_slice(nonce);
keystream[16] = 0x80;
keystream[28] = (counter >> 24) as u8;
keystream[29] = (counter >> 16) as u8;
keystream[30] = (counter >> 8) as u8;
keystream[31] = counter as u8;
core.encrypt_block_32(key32, &mut keystream)?;
let remaining = data.len() - offset;
let block_len = remaining.min(32);
#[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
{
if block_len == 32 {
let mut input = [0u8; 32];
input.copy_from_slice(&data[offset..offset + 32]);
let mut out = [0u8; 32];
simd_xor::xor_blocks_32(&input, &keystream, &mut out);
data[offset..offset + 32].copy_from_slice(&out);
} else {
for i in 0..block_len {
data[offset + i] ^= keystream[i];
}
}
}
#[cfg(not(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon")))]
{
for i in 0..block_len {
data[offset + i] ^= keystream[i];
}
}
offset += block_len;
counter = counter.wrapping_add(1);
}
Ok(())
}
fn decrypt_core(
&self,
key: &AeadKey,
nonce: &Nonce,
ciphertext: &[u8],
associated_data: Option<&[u8]>,
) -> Result<DecryptSemanticOutcome> {
if key.as_bytes().len() != Self::key_size() {
return Err(Error::InvalidKeySize {
expected: Self::key_size(),
actual: key.as_bytes().len(),
});
}
if nonce.as_bytes().len() != Self::nonce_size() {
return Err(Error::InvalidNonceSize {
expected: Self::nonce_size(),
actual: nonce.as_bytes().len(),
});
}
if (ciphertext.len() >> 5) >= 0xFFFFFFFE {
return Err(Error::InvalidMessageSize {
max: 0xFFFFFFFE << 5,
actual: ciphertext.len(),
});
}
if ciphertext.len() < Self::tag_size() {
return Err(Error::aead_ciphertext_shorter_than_tag(
Self::tag_size(),
ciphertext.len(),
));
}
let ad = associated_data.unwrap_or(&[]);
let plaintext_len = ciphertext.len() - 32;
let ciphertext_data = &ciphertext[0..plaintext_len];
let received_tag = &ciphertext[plaintext_len..];
let mut key_staged = Zeroizing::new([0u8; 32]);
key_staged.copy_from_slice(key.as_bytes());
let mut nonce_staged = Zeroizing::new([0u8; 16]);
nonce_staged.copy_from_slice(nonce.as_bytes());
let kb = key_staged.as_slice();
let nb = nonce_staged.as_slice();
let mut tag = self.cascade_init(kb, nb)?;
self.cascade(&mut tag, 2, 3, ad)?;
self.cascade(&mut tag, 4, 5, ciphertext_data)?;
let tag_valid = lib_q_core::Utils::constant_time_compare(&*tag, received_tag);
let mut plaintext = ciphertext_data.to_vec();
if let Err(e) = self.ctr_encrypt(kb, nb, &mut plaintext) {
plaintext.zeroize();
return Err(e);
}
if tag_valid {
Ok(DecryptSemanticOutcome::Success(Zeroizing::new(plaintext)))
} else {
plaintext.zeroize();
Ok(DecryptSemanticOutcome::AuthenticationFailed)
}
}
}
impl Aead for SaturninAead {
fn encrypt(
&self,
key: &AeadKey,
nonce: &Nonce,
plaintext: &[u8],
associated_data: Option<&[u8]>,
) -> Result<Vec<u8>> {
if key.as_bytes().len() != Self::key_size() {
return Err(Error::InvalidKeySize {
expected: Self::key_size(),
actual: key.as_bytes().len(),
});
}
if nonce.as_bytes().len() != Self::nonce_size() {
return Err(Error::InvalidNonceSize {
expected: Self::nonce_size(),
actual: nonce.as_bytes().len(),
});
}
if (plaintext.len() >> 5) >= 0xFFFFFFFD {
return Err(Error::InvalidMessageSize {
max: 0xFFFFFFFD << 5,
actual: plaintext.len(),
});
}
let ad = associated_data.unwrap_or(&[]);
let mut key_staged = Zeroizing::new([0u8; 32]);
key_staged.copy_from_slice(key.as_bytes());
let mut nonce_staged = Zeroizing::new([0u8; 16]);
nonce_staged.copy_from_slice(nonce.as_bytes());
let kb = key_staged.as_slice();
let nb = nonce_staged.as_slice();
let mut tag = self.cascade_init(kb, nb)?;
self.cascade(&mut tag, 2, 3, ad)?;
let mut ciphertext = plaintext.to_vec();
if let Err(e) = self.ctr_encrypt(kb, nb, &mut ciphertext) {
ciphertext.zeroize();
return Err(e);
}
self.cascade(&mut tag, 4, 5, &ciphertext)?;
ciphertext.extend_from_slice(&*tag);
Ok(ciphertext)
}
fn decrypt(
&self,
key: &AeadKey,
nonce: &Nonce,
ciphertext: &[u8],
associated_data: Option<&[u8]>,
) -> Result<Vec<u8>> {
match self.decrypt_core(key, nonce, ciphertext, associated_data) {
Ok(DecryptSemanticOutcome::Success(p)) => Ok(Vec::clone(&*p)),
Ok(DecryptSemanticOutcome::AuthenticationFailed) => Err(Error::VerificationFailed {
operation: "AEAD tag verification".to_string(),
}),
Err(e) => Err(e),
}
}
}
impl AeadDecryptSemantic for SaturninAead {
fn decrypt_semantic(
&self,
key: &AeadKey,
nonce: &Nonce,
ciphertext: &[u8],
associated_data: Option<&[u8]>,
) -> Result<DecryptSemanticOutcome> {
self.decrypt_core(key, nonce, ciphertext, associated_data)
}
}
impl Default for SaturninAead {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "alloc")]
use alloc::vec;
use super::*;
#[test]
fn test_saturnin_creation() {
let _aead = SaturninAead::new();
}
#[test]
fn test_saturnin_constants() {
assert_eq!(SaturninAead::key_size(), 32);
assert_eq!(SaturninAead::nonce_size(), 16);
assert_eq!(SaturninAead::tag_size(), 32);
}
#[test]
fn test_saturnin_encrypt_decrypt_round_trip() -> Result<()> {
let aead = SaturninAead::new();
let key = AeadKey::new(vec![0u8; 32]);
let nonce = Nonce::new(vec![0u8; 16]);
let plaintext = b"test"; let ad: Option<&[u8]> = None;
let ciphertext = aead.encrypt(&key, &nonce, plaintext, ad)?;
assert_eq!(ciphertext.len(), plaintext.len() + 32);
let decrypted = aead.decrypt(&key, &nonce, &ciphertext, ad)?;
assert_eq!(decrypted, plaintext);
Ok(())
}
#[test]
fn test_saturnin_decrypt_semantic_bad_tag() -> Result<()> {
use lib_q_core::AeadDecryptSemantic;
let aead = SaturninAead::new();
let key = AeadKey::new(vec![7u8; 32]);
let nonce = Nonce::new(vec![8u8; 16]);
let ad: Option<&[u8]> = Some(b"ad");
let ct = aead.encrypt(&key, &nonce, b"m", ad)?;
let mut bad = ct.clone();
*bad.last_mut().expect("tag") ^= 0x40;
let out = aead.decrypt_semantic(&key, &nonce, &bad, ad)?;
assert_eq!(out, DecryptSemanticOutcome::AuthenticationFailed);
assert!(matches!(
aead.decrypt(&key, &nonce, &bad, ad),
Err(Error::VerificationFailed { .. })
));
match aead.decrypt_semantic(&key, &nonce, &ct, ad)? {
DecryptSemanticOutcome::Success(pt) => assert_eq!(pt.as_slice(), b"m"),
DecryptSemanticOutcome::AuthenticationFailed => {
panic!("unexpected auth failure on good ciphertext")
}
}
Ok(())
}
}