Skip to main content

lib_q_saturnin/
aead.rs

1//! Saturnin AEAD implementation
2//!
3//! Saturnin is a lightweight post-quantum symmetric algorithm suite designed
4//! for IoT and constrained devices, providing authenticated encryption and
5//! hashing modes with superior post-quantum security.
6//!
7//! ## Usage Example
8//!
9//! ```rust
10//! use lib_q_saturnin::{
11//!     Aead,
12//!     AeadKey,
13//!     Nonce,
14//!     SaturninAead,
15//! };
16//!
17//! // Create AEAD instance
18//! let aead = SaturninAead::new();
19//!
20//! // Generate key and nonce (in practice, use secure random generation)
21//! let key = AeadKey::new(vec![0u8; 32]);
22//! let nonce = Nonce::new(vec![0u8; 16]);
23//!
24//! let plaintext = b"Secret message";
25//! let associated_data = b"metadata";
26//!
27//! // Encrypt with associated data
28//! let ciphertext = aead
29//!     .encrypt(&key, &nonce, plaintext, Some(associated_data))
30//!     .unwrap();
31//!
32//! // Decrypt and verify authenticity
33//! let decrypted = aead
34//!     .decrypt(&key, &nonce, &ciphertext, Some(associated_data))
35//!     .unwrap();
36//! assert_eq!(decrypted, plaintext);
37//! ```
38//!
39//! ## Performance Notes
40//!
41//! - **Key size**: 256 bits (32 bytes)
42//! - **Nonce size**: 128 bits (16 bytes)  
43//! - **Tag size**: 256 bits (32 bytes)
44//! - **Throughput**: ~100-500 MB/s on modern hardware
45//! - **Memory usage**: Small fixed state (pre-built cipher cores for domains 1–5); per-message
46//!   key/nonce are staged in zeroizing buffers at the `Aead` boundary, and the cascade running tag
47//!   plus per-iteration cascade blocks (`t`, `m`, and SIMD xor staging) are held in `Zeroizing`
48//!   buffers so they are cleared on drop.
49//!
50//! ## Verification timing
51//!
52//! Decrypt computes the expected tag over AAD and ciphertext (cascade), compares it to the
53//! appended tag with [`lib_q_core::Utils::constant_time_compare`](lib_q_core::Utils::constant_time_compare),
54//! then **always** runs full CTR on the ciphertext body. Only after that does the API return
55//! `Ok(plaintext)` versus `Err(Error::VerificationFailed)` (Layer A) for a failed tag after that
56//! schedule, or `Ok(DecryptSemanticOutcome::AuthenticationFailed)` (Layer B). Ciphertext shorter
57//! than the tag is rejected up front as `Err(Error::InvalidCiphertextSize)` (operational). Failed
58//! plaintext buffers are zeroized. This matches the [`lib_q_core::Aead`] contract in
59//! `lib-q-core`: bulk symmetric work is not skipped on auth failure; the public `Result` / outcome
60//! still discriminates at the boundary. For semantic decrypt without plaintext on authentication
61//! failure, see [`lib_q_core::AeadDecryptSemantic`]. See this crate’s
62//! `SECURITY.md` for Saturnin-Short specifics.
63
64#[cfg(feature = "alloc")]
65use alloc::{
66    string::ToString,
67    vec::Vec,
68};
69
70use lib_q_core::{
71    Aead,
72    AeadDecryptSemantic,
73    AeadKey,
74    DecryptSemanticOutcome,
75    Error,
76    Nonce,
77    Result,
78};
79use zeroize::{
80    Zeroize,
81    Zeroizing,
82};
83
84use crate::core::SaturninCore;
85#[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
86use crate::simd::{
87    encrypt_blocks8_dispatch,
88    simd_xor,
89};
90
91/// Pre-built Saturnin cores for CTR-Cascade AEAD (10 super-rounds, domains 1–5).
92///
93/// Building these once per [`SaturninAead`] avoids repeated `Vec` allocation of round constants
94/// on every encrypt/decrypt (domains 1–5 cover CTR and all cascade steps).
95struct SaturninAeadCores {
96    d1: SaturninCore,
97    d2: SaturninCore,
98    d3: SaturninCore,
99    d4: SaturninCore,
100    d5: SaturninCore,
101}
102
103impl SaturninAeadCores {
104    fn new() -> Result<Self> {
105        Ok(Self {
106            d1: SaturninCore::new(10, 1)?,
107            d2: SaturninCore::new(10, 2)?,
108            d3: SaturninCore::new(10, 3)?,
109            d4: SaturninCore::new(10, 4)?,
110            d5: SaturninCore::new(10, 5)?,
111        })
112    }
113
114    #[inline]
115    fn domain(&self, d: u8) -> &SaturninCore {
116        match d {
117            1 => &self.d1,
118            2 => &self.d2,
119            3 => &self.d3,
120            4 => &self.d4,
121            5 => &self.d5,
122            _ => unreachable!("AEAD CTR/cascade only uses domains 1–5"),
123        }
124    }
125}
126
127/// Saturnin AEAD implementation
128///
129/// Provides authenticated encryption using the Saturnin CTR-Cascade mode.
130/// This is the full AEAD mode that supports associated data and arbitrary
131/// length plaintexts.
132pub struct SaturninAead {
133    cores: SaturninAeadCores,
134}
135
136impl SaturninAead {
137    /// Create a new Saturnin AEAD instance
138    pub fn new() -> Self {
139        Self {
140            cores: SaturninAeadCores::new().expect("Saturnin AEAD uses fixed valid domains"),
141        }
142    }
143
144    /// Get the key size in bytes (256 bits = 32 bytes)
145    pub const fn key_size() -> usize {
146        32
147    }
148
149    /// Get the nonce size in bytes (128 bits = 16 bytes)
150    pub const fn nonce_size() -> usize {
151        16
152    }
153
154    /// Get the tag size in bytes (256 bits = 32 bytes)
155    pub const fn tag_size() -> usize {
156        32
157    }
158
159    /// Initialize the cascade state
160    fn cascade_init(&self, key: &[u8], nonce: &[u8]) -> Result<Zeroizing<[u8; 32]>> {
161        let key32: &[u8; 32] = key.try_into().map_err(|_| Error::InvalidKeySize {
162            expected: 32,
163            actual: key.len(),
164        })?;
165
166        let mut r = Zeroizing::new([0u8; 32]);
167
168        // Copy nonce to first 16 bytes
169        r[0..16].copy_from_slice(nonce);
170        r[16] = 0x80;
171        // Remaining bytes are already zero
172
173        // Encrypt with cascade parameters: 10 super-rounds, domain 2 (AAD1)
174        self.cores.d2.encrypt_block_32(key32, &mut r)?;
175
176        // XOR with nonce
177        for i in 0..16 {
178            r[i] ^= nonce[i];
179        }
180        r[16] ^= 0x80;
181
182        Ok(r)
183    }
184
185    /// Apply cascade construction to data (optimized)
186    fn cascade(&self, r: &mut [u8; 32], d1: u8, d2: u8, data: &[u8]) -> Result<()> {
187        let core_d1 = self.cores.domain(d1);
188        let core_d2 = self.cores.domain(d2);
189
190        let mut offset = 0;
191
192        loop {
193            let mut t: Zeroizing<[u8; 32]> = Zeroizing::new([0u8; 32]);
194            let mut m: Zeroizing<[u8; 32]> = Zeroizing::new([0u8; 32]);
195            let remaining = data.len() - offset;
196
197            if remaining >= 32 {
198                t.copy_from_slice(&data[offset..offset + 32]);
199                offset += 32;
200
201                // Use pre-allocated core for d1
202                m.copy_from_slice(&*t);
203                core_d1.encrypt_block_32(&*r, &mut m)?;
204            } else {
205                t[0..remaining].copy_from_slice(&data[offset..]);
206                t[remaining] = 0x80;
207                // Remaining bytes are already zero
208
209                // Use pre-allocated core for d2
210                m.copy_from_slice(&*t);
211                core_d2.encrypt_block_32(&*r, &mut m)?;
212            }
213
214            #[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
215            {
216                let mut out: Zeroizing<[u8; 32]> = Zeroizing::new([0u8; 32]);
217                simd_xor::xor_blocks_32(&m, &t, &mut out);
218                r.copy_from_slice(&*out);
219            }
220
221            #[cfg(not(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon")))]
222            {
223                for i in 0..32 {
224                    r[i] = m[i] ^ t[i];
225                }
226            }
227
228            if remaining < 32 {
229                break;
230            }
231        }
232
233        Ok(())
234    }
235
236    /// CTR encryption/decryption (optimized)
237    fn ctr_encrypt(&self, key: &[u8], nonce: &[u8], data: &mut [u8]) -> Result<()> {
238        let key32: &[u8; 32] = key.try_into().map_err(|_| Error::InvalidKeySize {
239            expected: 32,
240            actual: key.len(),
241        })?;
242
243        let core = &self.cores.d1;
244
245        let mut counter = 1u32; // Counter starts at 1
246        let mut offset = 0;
247
248        while offset < data.len() {
249            #[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
250            if data.len() - offset >= 32 * 8 {
251                let mut keystream_blocks = [[0u8; 32]; 8];
252                for (lane, block) in keystream_blocks.iter_mut().enumerate() {
253                    let c = counter.wrapping_add(lane as u32);
254                    block[0..16].copy_from_slice(nonce);
255                    block[16] = 0x80;
256                    block[28] = (c >> 24) as u8;
257                    block[29] = (c >> 16) as u8;
258                    block[30] = (c >> 8) as u8;
259                    block[31] = c as u8;
260                }
261
262                encrypt_blocks8_dispatch(10, 1, key, &mut keystream_blocks, Some(core))?;
263
264                for (lane, ks) in keystream_blocks.iter().enumerate() {
265                    let start = offset + (lane * 32);
266                    let mut input = [0u8; 32];
267                    input.copy_from_slice(&data[start..start + 32]);
268                    let mut out = [0u8; 32];
269                    simd_xor::xor_blocks_32(&input, ks, &mut out);
270                    data[start..start + 32].copy_from_slice(&out);
271                }
272
273                offset += 32 * 8;
274                let (next_counter, overflowed) = counter.overflowing_add(8);
275                if overflowed {
276                    return Err(Error::InvalidMessageSize {
277                        max: usize::MAX,
278                        actual: data.len(),
279                    });
280                }
281                counter = next_counter;
282                continue;
283            }
284
285            let mut keystream = [0u8; 32];
286
287            // Build counter block efficiently
288            keystream[0..16].copy_from_slice(nonce);
289            keystream[16] = 0x80;
290            // Bytes 17-27 are zero
291            keystream[28] = (counter >> 24) as u8;
292            keystream[29] = (counter >> 16) as u8;
293            keystream[30] = (counter >> 8) as u8;
294            keystream[31] = counter as u8;
295
296            // Encrypt to get keystream
297            core.encrypt_block_32(key32, &mut keystream)?;
298
299            let remaining = data.len() - offset;
300            let block_len = remaining.min(32);
301            #[cfg(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon"))]
302            {
303                if block_len == 32 {
304                    let mut input = [0u8; 32];
305                    input.copy_from_slice(&data[offset..offset + 32]);
306                    let mut out = [0u8; 32];
307                    simd_xor::xor_blocks_32(&input, &keystream, &mut out);
308                    data[offset..offset + 32].copy_from_slice(&out);
309                } else {
310                    for i in 0..block_len {
311                        data[offset + i] ^= keystream[i];
312                    }
313                }
314            }
315
316            #[cfg(not(any(feature = "simd", feature = "simd-avx2", feature = "simd-neon")))]
317            {
318                for i in 0..block_len {
319                    data[offset + i] ^= keystream[i];
320                }
321            }
322
323            offset += block_len;
324            counter = counter.wrapping_add(1);
325        }
326
327        Ok(())
328    }
329
330    /// Shared decrypt core for Layer A ([`Aead::decrypt`](lib_q_core::Aead::decrypt)) and Layer B
331    /// ([`AeadDecryptSemantic::decrypt_semantic`](lib_q_core::AeadDecryptSemantic::decrypt_semantic)).
332    fn decrypt_core(
333        &self,
334        key: &AeadKey,
335        nonce: &Nonce,
336        ciphertext: &[u8],
337        associated_data: Option<&[u8]>,
338    ) -> Result<DecryptSemanticOutcome> {
339        if key.as_bytes().len() != Self::key_size() {
340            return Err(Error::InvalidKeySize {
341                expected: Self::key_size(),
342                actual: key.as_bytes().len(),
343            });
344        }
345
346        if nonce.as_bytes().len() != Self::nonce_size() {
347            return Err(Error::InvalidNonceSize {
348                expected: Self::nonce_size(),
349                actual: nonce.as_bytes().len(),
350            });
351        }
352
353        if (ciphertext.len() >> 5) >= 0xFFFFFFFE {
354            return Err(Error::InvalidMessageSize {
355                max: 0xFFFFFFFE << 5,
356                actual: ciphertext.len(),
357            });
358        }
359
360        if ciphertext.len() < Self::tag_size() {
361            return Err(Error::aead_ciphertext_shorter_than_tag(
362                Self::tag_size(),
363                ciphertext.len(),
364            ));
365        }
366
367        let ad = associated_data.unwrap_or(&[]);
368        let plaintext_len = ciphertext.len() - 32;
369        let ciphertext_data = &ciphertext[0..plaintext_len];
370        let received_tag = &ciphertext[plaintext_len..];
371
372        let mut key_staged = Zeroizing::new([0u8; 32]);
373        key_staged.copy_from_slice(key.as_bytes());
374        let mut nonce_staged = Zeroizing::new([0u8; 16]);
375        nonce_staged.copy_from_slice(nonce.as_bytes());
376        let kb = key_staged.as_slice();
377        let nb = nonce_staged.as_slice();
378
379        let mut tag = self.cascade_init(kb, nb)?;
380        self.cascade(&mut tag, 2, 3, ad)?;
381        self.cascade(&mut tag, 4, 5, ciphertext_data)?;
382
383        let tag_valid = lib_q_core::Utils::constant_time_compare(&*tag, received_tag);
384
385        let mut plaintext = ciphertext_data.to_vec();
386        if let Err(e) = self.ctr_encrypt(kb, nb, &mut plaintext) {
387            plaintext.zeroize();
388            return Err(e);
389        }
390
391        if tag_valid {
392            Ok(DecryptSemanticOutcome::Success(Zeroizing::new(plaintext)))
393        } else {
394            plaintext.zeroize();
395            Ok(DecryptSemanticOutcome::AuthenticationFailed)
396        }
397    }
398}
399
400impl Aead for SaturninAead {
401    /// Encrypt data with authentication
402    ///
403    /// # Arguments
404    /// * `key` - 256-bit encryption key
405    /// * `nonce` - 128-bit nonce
406    /// * `plaintext` - Data to encrypt
407    /// * `associated_data` - Additional authenticated data
408    ///
409    /// # Returns
410    /// Encrypted data with authentication tag appended
411    fn encrypt(
412        &self,
413        key: &AeadKey,
414        nonce: &Nonce,
415        plaintext: &[u8],
416        associated_data: Option<&[u8]>,
417    ) -> Result<Vec<u8>> {
418        if key.as_bytes().len() != Self::key_size() {
419            return Err(Error::InvalidKeySize {
420                expected: Self::key_size(),
421                actual: key.as_bytes().len(),
422            });
423        }
424
425        if nonce.as_bytes().len() != Self::nonce_size() {
426            return Err(Error::InvalidNonceSize {
427                expected: Self::nonce_size(),
428                actual: nonce.as_bytes().len(),
429            });
430        }
431
432        // Check length limits (about 137.4 GB)
433        if (plaintext.len() >> 5) >= 0xFFFFFFFD {
434            return Err(Error::InvalidMessageSize {
435                max: 0xFFFFFFFD << 5,
436                actual: plaintext.len(),
437            });
438        }
439
440        let ad = associated_data.unwrap_or(&[]);
441
442        let mut key_staged = Zeroizing::new([0u8; 32]);
443        key_staged.copy_from_slice(key.as_bytes());
444        let mut nonce_staged = Zeroizing::new([0u8; 16]);
445        nonce_staged.copy_from_slice(nonce.as_bytes());
446        let kb = key_staged.as_slice();
447        let nb = nonce_staged.as_slice();
448
449        // Initialize cascade state
450        let mut tag = self.cascade_init(kb, nb)?;
451
452        // Process associated data
453        self.cascade(&mut tag, 2, 3, ad)?;
454
455        // Encrypt plaintext with CTR
456        let mut ciphertext = plaintext.to_vec();
457        if let Err(e) = self.ctr_encrypt(kb, nb, &mut ciphertext) {
458            ciphertext.zeroize();
459            return Err(e);
460        }
461
462        // Continue cascade on ciphertext
463        self.cascade(&mut tag, 4, 5, &ciphertext)?;
464
465        // Append tag
466        ciphertext.extend_from_slice(&*tag);
467
468        Ok(ciphertext)
469    }
470
471    /// Decrypt and verify data (Layer A); shares one decrypt core with [`lib_q_core::AeadDecryptSemantic`].
472    fn decrypt(
473        &self,
474        key: &AeadKey,
475        nonce: &Nonce,
476        ciphertext: &[u8],
477        associated_data: Option<&[u8]>,
478    ) -> Result<Vec<u8>> {
479        match self.decrypt_core(key, nonce, ciphertext, associated_data) {
480            Ok(DecryptSemanticOutcome::Success(p)) => Ok(Vec::clone(&*p)),
481            Ok(DecryptSemanticOutcome::AuthenticationFailed) => Err(Error::VerificationFailed {
482                operation: "AEAD tag verification".to_string(),
483            }),
484            Err(e) => Err(e),
485        }
486    }
487}
488
489impl AeadDecryptSemantic for SaturninAead {
490    /// Layer B semantic decrypt; see `docs/adr/003-aead-decrypt-layers.md`.
491    fn decrypt_semantic(
492        &self,
493        key: &AeadKey,
494        nonce: &Nonce,
495        ciphertext: &[u8],
496        associated_data: Option<&[u8]>,
497    ) -> Result<DecryptSemanticOutcome> {
498        self.decrypt_core(key, nonce, ciphertext, associated_data)
499    }
500}
501
502impl Default for SaturninAead {
503    fn default() -> Self {
504        Self::new()
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    #[cfg(feature = "alloc")]
511    use alloc::vec;
512
513    use super::*;
514
515    #[test]
516    fn test_saturnin_creation() {
517        let _aead = SaturninAead::new();
518        // Saturnin implementation created successfully
519        // Test passes if we reach this point without panicking
520    }
521
522    #[test]
523    fn test_saturnin_constants() {
524        assert_eq!(SaturninAead::key_size(), 32);
525        assert_eq!(SaturninAead::nonce_size(), 16);
526        assert_eq!(SaturninAead::tag_size(), 32);
527    }
528
529    #[test]
530    fn test_saturnin_encrypt_decrypt_round_trip() -> Result<()> {
531        let aead = SaturninAead::new();
532        let key = AeadKey::new(vec![0u8; 32]);
533        let nonce = Nonce::new(vec![0u8; 16]);
534        let plaintext = b"test"; // 4 bytes
535        let ad: Option<&[u8]> = None;
536
537        // Test encryption
538        let ciphertext = aead.encrypt(&key, &nonce, plaintext, ad)?;
539        assert_eq!(ciphertext.len(), plaintext.len() + 32); // plaintext + 32-byte tag
540
541        // Test decryption
542        let decrypted = aead.decrypt(&key, &nonce, &ciphertext, ad)?;
543        assert_eq!(decrypted, plaintext);
544
545        Ok(())
546    }
547
548    #[test]
549    fn test_saturnin_decrypt_semantic_bad_tag() -> Result<()> {
550        use lib_q_core::AeadDecryptSemantic;
551
552        let aead = SaturninAead::new();
553        let key = AeadKey::new(vec![7u8; 32]);
554        let nonce = Nonce::new(vec![8u8; 16]);
555        let ad: Option<&[u8]> = Some(b"ad");
556        let ct = aead.encrypt(&key, &nonce, b"m", ad)?;
557        let mut bad = ct.clone();
558        *bad.last_mut().expect("tag") ^= 0x40;
559        let out = aead.decrypt_semantic(&key, &nonce, &bad, ad)?;
560        assert_eq!(out, DecryptSemanticOutcome::AuthenticationFailed);
561        assert!(matches!(
562            aead.decrypt(&key, &nonce, &bad, ad),
563            Err(Error::VerificationFailed { .. })
564        ));
565        match aead.decrypt_semantic(&key, &nonce, &ct, ad)? {
566            DecryptSemanticOutcome::Success(pt) => assert_eq!(pt.as_slice(), b"m"),
567            DecryptSemanticOutcome::AuthenticationFailed => {
568                panic!("unexpected auth failure on good ciphertext")
569            }
570        }
571        Ok(())
572    }
573}