1#[cfg(feature = "alloc")]
65use alloc::{
66 string::ToString,
67 vec::Vec,
68};
69
70use lib_q_core::{
71 Aead,
72 AeadDecryptSemantic,
73 AeadKey,
74 DecryptSemanticOutcome,
75 Error,
76 Nonce,
77 Result,
78};
79use zeroize::{
80 Zeroize,
81 Zeroizing,
82};
83
84use crate::core::SaturninCore;
85#[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
86use crate::simd::{
87 encrypt_blocks8_dispatch,
88 simd_xor,
89};
90
91struct SaturninAeadCores {
96 d1: SaturninCore,
97 d2: SaturninCore,
98 d3: SaturninCore,
99 d4: SaturninCore,
100 d5: SaturninCore,
101}
102
103impl SaturninAeadCores {
104 fn new() -> Result<Self> {
105 Ok(Self {
106 d1: SaturninCore::new(10, 1)?,
107 d2: SaturninCore::new(10, 2)?,
108 d3: SaturninCore::new(10, 3)?,
109 d4: SaturninCore::new(10, 4)?,
110 d5: SaturninCore::new(10, 5)?,
111 })
112 }
113
114 #[inline]
115 fn domain(&self, d: u8) -> &SaturninCore {
116 match d {
117 1 => &self.d1,
118 2 => &self.d2,
119 3 => &self.d3,
120 4 => &self.d4,
121 5 => &self.d5,
122 _ => unreachable!("AEAD CTR/cascade only uses domains 1–5"),
123 }
124 }
125}
126
127pub struct SaturninAead {
133 cores: SaturninAeadCores,
134}
135
136impl SaturninAead {
137 pub fn new() -> Self {
139 Self {
140 cores: SaturninAeadCores::new().expect("Saturnin AEAD uses fixed valid domains"),
141 }
142 }
143
144 pub const fn key_size() -> usize {
146 32
147 }
148
149 pub const fn nonce_size() -> usize {
151 16
152 }
153
154 pub const fn tag_size() -> usize {
156 32
157 }
158
159 fn cascade_init(&self, key: &[u8], nonce: &[u8]) -> Result<Zeroizing<[u8; 32]>> {
161 let key32: &[u8; 32] = key.try_into().map_err(|_| Error::InvalidKeySize {
162 expected: 32,
163 actual: key.len(),
164 })?;
165
166 let mut r = Zeroizing::new([0u8; 32]);
167
168 r[0..16].copy_from_slice(nonce);
170 r[16] = 0x80;
171 self.cores.d2.encrypt_block_32(key32, &mut r)?;
175
176 for i in 0..16 {
178 r[i] ^= nonce[i];
179 }
180 r[16] ^= 0x80;
181
182 Ok(r)
183 }
184
185 fn cascade(&self, r: &mut [u8; 32], d1: u8, d2: u8, data: &[u8]) -> Result<()> {
187 let core_d1 = self.cores.domain(d1);
188 let core_d2 = self.cores.domain(d2);
189
190 let mut offset = 0;
191
192 loop {
193 let mut t: Zeroizing<[u8; 32]> = Zeroizing::new([0u8; 32]);
194 let mut m: Zeroizing<[u8; 32]> = Zeroizing::new([0u8; 32]);
195 let remaining = data.len() - offset;
196
197 if remaining >= 32 {
198 t.copy_from_slice(&data[offset..offset + 32]);
199 offset += 32;
200
201 m.copy_from_slice(&*t);
203 core_d1.encrypt_block_32(&*r, &mut m)?;
204 } else {
205 t[0..remaining].copy_from_slice(&data[offset..]);
206 t[remaining] = 0x80;
207 m.copy_from_slice(&*t);
211 core_d2.encrypt_block_32(&*r, &mut m)?;
212 }
213
214 #[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
215 {
216 let mut out: Zeroizing<[u8; 32]> = Zeroizing::new([0u8; 32]);
217 simd_xor::xor_blocks_32(&m, &t, &mut out);
218 r.copy_from_slice(&*out);
219 }
220
221 #[cfg(not(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon")))]
222 {
223 for i in 0..32 {
224 r[i] = m[i] ^ t[i];
225 }
226 }
227
228 if remaining < 32 {
229 break;
230 }
231 }
232
233 Ok(())
234 }
235
236 fn ctr_encrypt(&self, key: &[u8], nonce: &[u8], data: &mut [u8]) -> Result<()> {
238 let key32: &[u8; 32] = key.try_into().map_err(|_| Error::InvalidKeySize {
239 expected: 32,
240 actual: key.len(),
241 })?;
242
243 let core = &self.cores.d1;
244
245 let mut counter = 1u32; let mut offset = 0;
247
248 while offset < data.len() {
249 #[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
250 if data.len() - offset >= 32 * 8 {
251 let mut keystream_blocks = [[0u8; 32]; 8];
252 for (lane, block) in keystream_blocks.iter_mut().enumerate() {
253 let c = counter.wrapping_add(lane as u32);
254 block[0..16].copy_from_slice(nonce);
255 block[16] = 0x80;
256 block[28] = (c >> 24) as u8;
257 block[29] = (c >> 16) as u8;
258 block[30] = (c >> 8) as u8;
259 block[31] = c as u8;
260 }
261
262 encrypt_blocks8_dispatch(10, 1, key, &mut keystream_blocks, Some(core))?;
263
264 for (lane, ks) in keystream_blocks.iter().enumerate() {
265 let start = offset + (lane * 32);
266 let mut input = [0u8; 32];
267 input.copy_from_slice(&data[start..start + 32]);
268 let mut out = [0u8; 32];
269 simd_xor::xor_blocks_32(&input, ks, &mut out);
270 data[start..start + 32].copy_from_slice(&out);
271 }
272
273 offset += 32 * 8;
274 let (next_counter, overflowed) = counter.overflowing_add(8);
275 if overflowed {
276 return Err(Error::InvalidMessageSize {
277 max: usize::MAX,
278 actual: data.len(),
279 });
280 }
281 counter = next_counter;
282 continue;
283 }
284
285 let mut keystream = [0u8; 32];
286
287 keystream[0..16].copy_from_slice(nonce);
289 keystream[16] = 0x80;
290 keystream[28] = (counter >> 24) as u8;
292 keystream[29] = (counter >> 16) as u8;
293 keystream[30] = (counter >> 8) as u8;
294 keystream[31] = counter as u8;
295
296 core.encrypt_block_32(key32, &mut keystream)?;
298
299 let remaining = data.len() - offset;
300 let block_len = remaining.min(32);
301 #[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
302 {
303 if block_len == 32 {
304 let mut input = [0u8; 32];
305 input.copy_from_slice(&data[offset..offset + 32]);
306 let mut out = [0u8; 32];
307 simd_xor::xor_blocks_32(&input, &keystream, &mut out);
308 data[offset..offset + 32].copy_from_slice(&out);
309 } else {
310 for i in 0..block_len {
311 data[offset + i] ^= keystream[i];
312 }
313 }
314 }
315
316 #[cfg(not(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon")))]
317 {
318 for i in 0..block_len {
319 data[offset + i] ^= keystream[i];
320 }
321 }
322
323 offset += block_len;
324 counter = counter.wrapping_add(1);
325 }
326
327 Ok(())
328 }
329
330 fn decrypt_core(
333 &self,
334 key: &AeadKey,
335 nonce: &Nonce,
336 ciphertext: &[u8],
337 associated_data: Option<&[u8]>,
338 ) -> Result<DecryptSemanticOutcome> {
339 if key.as_bytes().len() != Self::key_size() {
340 return Err(Error::InvalidKeySize {
341 expected: Self::key_size(),
342 actual: key.as_bytes().len(),
343 });
344 }
345
346 if nonce.as_bytes().len() != Self::nonce_size() {
347 return Err(Error::InvalidNonceSize {
348 expected: Self::nonce_size(),
349 actual: nonce.as_bytes().len(),
350 });
351 }
352
353 if (ciphertext.len() >> 5) >= 0xFFFFFFFE {
354 return Err(Error::InvalidMessageSize {
355 max: 0xFFFFFFFE << 5,
356 actual: ciphertext.len(),
357 });
358 }
359
360 if ciphertext.len() < Self::tag_size() {
361 return Err(Error::aead_ciphertext_shorter_than_tag(
362 Self::tag_size(),
363 ciphertext.len(),
364 ));
365 }
366
367 let ad = associated_data.unwrap_or(&[]);
368 let plaintext_len = ciphertext.len() - 32;
369 let ciphertext_data = &ciphertext[0..plaintext_len];
370 let received_tag = &ciphertext[plaintext_len..];
371
372 let mut key_staged = Zeroizing::new([0u8; 32]);
373 key_staged.copy_from_slice(key.as_bytes());
374 let mut nonce_staged = Zeroizing::new([0u8; 16]);
375 nonce_staged.copy_from_slice(nonce.as_bytes());
376 let kb = key_staged.as_slice();
377 let nb = nonce_staged.as_slice();
378
379 let mut tag = self.cascade_init(kb, nb)?;
380 self.cascade(&mut tag, 2, 3, ad)?;
381 self.cascade(&mut tag, 4, 5, ciphertext_data)?;
382
383 let tag_valid = lib_q_core::Utils::constant_time_compare(&*tag, received_tag);
384
385 let mut plaintext = ciphertext_data.to_vec();
386 if let Err(e) = self.ctr_encrypt(kb, nb, &mut plaintext) {
387 plaintext.zeroize();
388 return Err(e);
389 }
390
391 if tag_valid {
392 Ok(DecryptSemanticOutcome::Success(Zeroizing::new(plaintext)))
393 } else {
394 plaintext.zeroize();
395 Ok(DecryptSemanticOutcome::AuthenticationFailed)
396 }
397 }
398}
399
400impl Aead for SaturninAead {
401 fn encrypt(
412 &self,
413 key: &AeadKey,
414 nonce: &Nonce,
415 plaintext: &[u8],
416 associated_data: Option<&[u8]>,
417 ) -> Result<Vec<u8>> {
418 if key.as_bytes().len() != Self::key_size() {
419 return Err(Error::InvalidKeySize {
420 expected: Self::key_size(),
421 actual: key.as_bytes().len(),
422 });
423 }
424
425 if nonce.as_bytes().len() != Self::nonce_size() {
426 return Err(Error::InvalidNonceSize {
427 expected: Self::nonce_size(),
428 actual: nonce.as_bytes().len(),
429 });
430 }
431
432 if (plaintext.len() >> 5) >= 0xFFFFFFFD {
434 return Err(Error::InvalidMessageSize {
435 max: 0xFFFFFFFD << 5,
436 actual: plaintext.len(),
437 });
438 }
439
440 let ad = associated_data.unwrap_or(&[]);
441
442 let mut key_staged = Zeroizing::new([0u8; 32]);
443 key_staged.copy_from_slice(key.as_bytes());
444 let mut nonce_staged = Zeroizing::new([0u8; 16]);
445 nonce_staged.copy_from_slice(nonce.as_bytes());
446 let kb = key_staged.as_slice();
447 let nb = nonce_staged.as_slice();
448
449 let mut tag = self.cascade_init(kb, nb)?;
451
452 self.cascade(&mut tag, 2, 3, ad)?;
454
455 let mut ciphertext = plaintext.to_vec();
457 if let Err(e) = self.ctr_encrypt(kb, nb, &mut ciphertext) {
458 ciphertext.zeroize();
459 return Err(e);
460 }
461
462 self.cascade(&mut tag, 4, 5, &ciphertext)?;
464
465 ciphertext.extend_from_slice(&*tag);
467
468 Ok(ciphertext)
469 }
470
471 fn decrypt(
473 &self,
474 key: &AeadKey,
475 nonce: &Nonce,
476 ciphertext: &[u8],
477 associated_data: Option<&[u8]>,
478 ) -> Result<Vec<u8>> {
479 match self.decrypt_core(key, nonce, ciphertext, associated_data) {
480 Ok(DecryptSemanticOutcome::Success(p)) => Ok(Vec::clone(&*p)),
481 Ok(DecryptSemanticOutcome::AuthenticationFailed) => Err(Error::VerificationFailed {
482 operation: "AEAD tag verification".to_string(),
483 }),
484 Err(e) => Err(e),
485 }
486 }
487}
488
489impl AeadDecryptSemantic for SaturninAead {
490 fn decrypt_semantic(
492 &self,
493 key: &AeadKey,
494 nonce: &Nonce,
495 ciphertext: &[u8],
496 associated_data: Option<&[u8]>,
497 ) -> Result<DecryptSemanticOutcome> {
498 self.decrypt_core(key, nonce, ciphertext, associated_data)
499 }
500}
501
502impl Default for SaturninAead {
503 fn default() -> Self {
504 Self::new()
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 #[cfg(feature = "alloc")]
511 use alloc::vec;
512
513 use super::*;
514
515 #[test]
516 fn test_saturnin_creation() {
517 let _aead = SaturninAead::new();
518 }
521
522 #[test]
523 fn test_saturnin_constants() {
524 assert_eq!(SaturninAead::key_size(), 32);
525 assert_eq!(SaturninAead::nonce_size(), 16);
526 assert_eq!(SaturninAead::tag_size(), 32);
527 }
528
529 #[test]
530 fn test_saturnin_encrypt_decrypt_round_trip() -> Result<()> {
531 let aead = SaturninAead::new();
532 let key = AeadKey::new(vec![0u8; 32]);
533 let nonce = Nonce::new(vec![0u8; 16]);
534 let plaintext = b"test"; let ad: Option<&[u8]> = None;
536
537 let ciphertext = aead.encrypt(&key, &nonce, plaintext, ad)?;
539 assert_eq!(ciphertext.len(), plaintext.len() + 32); let decrypted = aead.decrypt(&key, &nonce, &ciphertext, ad)?;
543 assert_eq!(decrypted, plaintext);
544
545 Ok(())
546 }
547
548 #[test]
549 fn test_saturnin_decrypt_semantic_bad_tag() -> Result<()> {
550 use lib_q_core::AeadDecryptSemantic;
551
552 let aead = SaturninAead::new();
553 let key = AeadKey::new(vec![7u8; 32]);
554 let nonce = Nonce::new(vec![8u8; 16]);
555 let ad: Option<&[u8]> = Some(b"ad");
556 let ct = aead.encrypt(&key, &nonce, b"m", ad)?;
557 let mut bad = ct.clone();
558 *bad.last_mut().expect("tag") ^= 0x40;
559 let out = aead.decrypt_semantic(&key, &nonce, &bad, ad)?;
560 assert_eq!(out, DecryptSemanticOutcome::AuthenticationFailed);
561 assert!(matches!(
562 aead.decrypt(&key, &nonce, &bad, ad),
563 Err(Error::VerificationFailed { .. })
564 ));
565 match aead.decrypt_semantic(&key, &nonce, &ct, ad)? {
566 DecryptSemanticOutcome::Success(pt) => assert_eq!(pt.as_slice(), b"m"),
567 DecryptSemanticOutcome::AuthenticationFailed => {
568 panic!("unexpected auth failure on good ciphertext")
569 }
570 }
571 Ok(())
572 }
573}