miden_crypto/aead/aead_rpo/
mod.rs

1//! # Arithmetization Oriented AEAD
2//!
3//! This module implements an AEAD scheme optimized for speed within SNARKs/STARKs.
4//! The design is described in \[1\] and is based on the MonkeySpongeWrap construction and uses
5//! the RPO (Rescue Prime Optimized) permutation, creating an encryption scheme that is highly
6//! efficient when executed within zero-knowledge proof systems.
7//!
8//! \[1\] <https://eprint.iacr.org/2023/1668>
9
10use alloc::{string::ToString, vec::Vec};
11use core::ops::Range;
12
13use miden_crypto_derive::{SilentDebug, SilentDisplay};
14use num::Integer;
15use rand::{
16    Rng,
17    distr::{Distribution, StandardUniform, Uniform},
18};
19use subtle::ConstantTimeEq;
20
21use crate::{
22    Felt, FieldElement, ONE, StarkField, Word, ZERO,
23    aead::{AeadScheme, DataType, EncryptionError},
24    hash::rpo::Rpo256,
25    utils::{
26        ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
27        bytes_to_elements_exact, bytes_to_elements_with_padding, elements_to_bytes,
28        padded_elements_to_bytes,
29    },
30    zeroize::{Zeroize, ZeroizeOnDrop},
31};
32
33#[cfg(test)]
34mod test;
35
36// CONSTANTS
37// ================================================================================================
38
39/// Size of a secret key in field elements
40pub const SECRET_KEY_SIZE: usize = 4;
41
42/// Size of a secret key in bytes
43pub const SK_SIZE_BYTES: usize = SECRET_KEY_SIZE * Felt::ELEMENT_BYTES;
44
45/// Size of a nonce in field elements
46pub const NONCE_SIZE: usize = 4;
47
48/// Size of a nonce in bytes
49pub const NONCE_SIZE_BYTES: usize = NONCE_SIZE * Felt::ELEMENT_BYTES;
50
51/// Size of an authentication tag in field elements
52pub const AUTH_TAG_SIZE: usize = 4;
53
54/// Size of the sponge state field elements
55const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
56
57/// Capacity portion of the sponge state.
58const CAPACITY_RANGE: Range<usize> = Rpo256::CAPACITY_RANGE;
59
60/// Rate portion of the sponge state
61const RATE_RANGE: Range<usize> = Rpo256::RATE_RANGE;
62
63/// Size of the rate portion of the sponge state in field elements
64const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
65
66/// Size of either the 1st or 2nd half of the rate portion of the sponge state in field elements
67const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
68
69/// First half of the rate portion of the sponge state
70const RATE_RANGE_FIRST_HALF: Range<usize> =
71    Rpo256::RATE_RANGE.start..Rpo256::RATE_RANGE.start + HALF_RATE_WIDTH;
72
73/// Second half of the rate portion of the sponge state
74const RATE_RANGE_SECOND_HALF: Range<usize> =
75    Rpo256::RATE_RANGE.start + HALF_RATE_WIDTH..Rpo256::RATE_RANGE.end;
76
77/// Index of the first element of the rate portion of the sponge state
78const RATE_START: usize = Rpo256::RATE_RANGE.start;
79
80/// Padding block used when the length of the data to encrypt is a multiple of `RATE_WIDTH`
81const PADDING_BLOCK: [Felt; RATE_WIDTH] = [ONE, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO];
82
83// TYPES AND STRUCTURES
84// ================================================================================================
85
86/// Encrypted data with its authentication tag
87#[derive(Debug, PartialEq, Eq)]
88pub struct EncryptedData {
89    /// Indicates the original format of the data before encryption
90    data_type: DataType,
91    /// The encrypted ciphertext
92    ciphertext: Vec<Felt>,
93    /// The authentication tag attesting to the integrity of the ciphertext, and the associated
94    /// data if it exists
95    auth_tag: AuthTag,
96    /// The nonce used during encryption
97    nonce: Nonce,
98}
99
100/// An authentication tag represented as 4 field elements
101#[derive(Debug, Default, Clone, PartialEq, Eq)]
102pub struct AuthTag([Felt; AUTH_TAG_SIZE]);
103
104/// A 256-bit secret key represented as 4 field elements
105#[derive(Clone, SilentDebug, SilentDisplay)]
106pub struct SecretKey([Felt; SECRET_KEY_SIZE]);
107
108impl SecretKey {
109    // CONSTRUCTORS
110    // --------------------------------------------------------------------------------------------
111
112    /// Creates a new random secret key using the default random number generator.
113    #[cfg(feature = "std")]
114    #[allow(clippy::new_without_default)]
115    pub fn new() -> Self {
116        use rand::{SeedableRng, rngs::StdRng};
117        let mut rng = StdRng::from_os_rng();
118
119        Self::with_rng(&mut rng)
120    }
121
122    /// Creates a new random secret key using the provided random number generator.
123    pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
124        rng.sample(StandardUniform)
125    }
126
127    /// Creates a secret key from the provided array of field elements.
128    ///
129    /// # Security Warning
130    /// This method should be used with caution. Secret keys must be derived from a
131    /// cryptographically secure source of entropy. Do not use predictable or low-entropy
132    /// values as secret key material. Prefer using `new()` or `with_rng()` with a
133    /// cryptographically secure random number generator.
134    pub fn from_elements(elements: [Felt; SECRET_KEY_SIZE]) -> Self {
135        Self(elements)
136    }
137
138    // ACCESSORS
139    // --------------------------------------------------------------------------------------------
140
141    /// Returns the secret key as an array of field elements.
142    ///
143    /// # Security Warning
144    /// This method exposes the raw secret key material. Use with caution and ensure
145    /// proper zeroization of the returned array when no longer needed.
146    pub fn to_elements(&self) -> [Felt; SECRET_KEY_SIZE] {
147        self.0
148    }
149
150    // ELEMENT ENCRYPTION
151    // --------------------------------------------------------------------------------------------
152
153    /// Encrypts and authenticates the provided sequence of field elements using this secret key
154    /// and a random nonce.
155    #[cfg(feature = "std")]
156    pub fn encrypt_elements(&self, data: &[Felt]) -> Result<EncryptedData, EncryptionError> {
157        self.encrypt_elements_with_associated_data(data, &[])
158    }
159
160    /// Encrypts the provided sequence of field elements and authenticates both the ciphertext as
161    /// well as the provided associated data using this secret key and a random nonce.
162    #[cfg(feature = "std")]
163    pub fn encrypt_elements_with_associated_data(
164        &self,
165        data: &[Felt],
166        associated_data: &[Felt],
167    ) -> Result<EncryptedData, EncryptionError> {
168        use rand::{SeedableRng, rngs::StdRng};
169        let mut rng = StdRng::from_os_rng();
170        let nonce = Nonce::with_rng(&mut rng);
171
172        self.encrypt_elements_with_nonce(data, associated_data, nonce)
173    }
174
175    /// Encrypts the provided sequence of field elements and authenticates both the ciphertext as
176    /// well as the provided associated data using this secret key and the specified nonce.
177    pub fn encrypt_elements_with_nonce(
178        &self,
179        data: &[Felt],
180        associated_data: &[Felt],
181        nonce: Nonce,
182    ) -> Result<EncryptedData, EncryptionError> {
183        // Initialize as sponge state with key and nonce
184        let mut sponge = SpongeState::new(self, &nonce);
185
186        // Process the associated data
187        let padded_associated_data = pad(associated_data);
188        padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
189            sponge.duplex_overwrite(chunk);
190        });
191
192        // Encrypt the data
193        let mut ciphertext = Vec::with_capacity(data.len() + RATE_WIDTH);
194        let data = pad(data);
195        let mut data_block_iterator = data.chunks_exact(RATE_WIDTH);
196
197        data_block_iterator.by_ref().for_each(|data_block| {
198            let keystream = sponge.duplex_add(data_block);
199            for (i, &plaintext_felt) in data_block.iter().enumerate() {
200                ciphertext.push(plaintext_felt + keystream[i]);
201            }
202        });
203
204        // Generate authentication tag
205        let auth_tag = sponge.squeeze_tag();
206
207        Ok(EncryptedData {
208            data_type: DataType::Elements,
209            ciphertext,
210            auth_tag,
211            nonce,
212        })
213    }
214
215    // BYTE ENCRYPTION
216    // --------------------------------------------------------------------------------------------
217
218    /// Encrypts and authenticates the provided data using this secret key and a random nonce.
219    ///
220    /// Before encryption, the bytestring is converted to a sequence of field elements.
221    #[cfg(feature = "std")]
222    pub fn encrypt_bytes(&self, data: &[u8]) -> Result<EncryptedData, EncryptionError> {
223        self.encrypt_bytes_with_associated_data(data, &[])
224    }
225
226    /// Encrypts the provided data and authenticates both the ciphertext as well as the provided
227    /// associated data using this secret key and a random nonce.
228    ///
229    /// Before encryption, both the data and the associated data are converted to sequences of
230    /// field elements.
231    #[cfg(feature = "std")]
232    pub fn encrypt_bytes_with_associated_data(
233        &self,
234        data: &[u8],
235        associated_data: &[u8],
236    ) -> Result<EncryptedData, EncryptionError> {
237        use rand::{SeedableRng, rngs::StdRng};
238        let mut rng = StdRng::from_os_rng();
239        let nonce = Nonce::with_rng(&mut rng);
240
241        self.encrypt_bytes_with_nonce(data, associated_data, nonce)
242    }
243
244    /// Encrypts the provided data and authenticates both the ciphertext as well as the provided
245    /// associated data using this secret key and the specified nonce.
246    ///
247    /// Before encryption, both the data and the associated data are converted to sequences of
248    /// field elements.
249    pub fn encrypt_bytes_with_nonce(
250        &self,
251        data: &[u8],
252        associated_data: &[u8],
253        nonce: Nonce,
254    ) -> Result<EncryptedData, EncryptionError> {
255        let data_felt = bytes_to_elements_with_padding(data);
256        let ad_felt = bytes_to_elements_with_padding(associated_data);
257
258        let mut encrypted_data = self.encrypt_elements_with_nonce(&data_felt, &ad_felt, nonce)?;
259        encrypted_data.data_type = DataType::Bytes;
260        Ok(encrypted_data)
261    }
262
263    // ELEMENT DECRYPTION
264    // --------------------------------------------------------------------------------------------
265
266    /// Decrypts the provided encrypted data using this secret key.
267    ///
268    /// # Errors
269    /// Returns an error if decryption fails or if the underlying data was encrypted as bytes
270    /// rather than as field elements.
271    pub fn decrypt_elements(
272        &self,
273        encrypted_data: &EncryptedData,
274    ) -> Result<Vec<Felt>, EncryptionError> {
275        self.decrypt_elements_with_associated_data(encrypted_data, &[])
276    }
277
278    /// Decrypts the provided encrypted data, given some associated data, using this secret key.
279    ///
280    /// # Errors
281    /// Returns an error if decryption fails or if the underlying data was encrypted as bytes
282    /// rather than as field elements.
283    pub fn decrypt_elements_with_associated_data(
284        &self,
285        encrypted_data: &EncryptedData,
286        associated_data: &[Felt],
287    ) -> Result<Vec<Felt>, EncryptionError> {
288        if encrypted_data.data_type != DataType::Elements {
289            return Err(EncryptionError::InvalidDataType {
290                expected: DataType::Elements,
291                found: encrypted_data.data_type,
292            });
293        }
294        self.decrypt_elements_with_associated_data_unchecked(encrypted_data, associated_data)
295    }
296
297    /// Decrypts the provided encrypted data, given some associated data, using this secret key.
298    fn decrypt_elements_with_associated_data_unchecked(
299        &self,
300        encrypted_data: &EncryptedData,
301        associated_data: &[Felt],
302    ) -> Result<Vec<Felt>, EncryptionError> {
303        if !encrypted_data.ciphertext.len().is_multiple_of(RATE_WIDTH) {
304            return Err(EncryptionError::CiphertextLenNotMultipleRate);
305        }
306
307        // Initialize as sponge state with key and nonce
308        let mut sponge = SpongeState::new(self, &encrypted_data.nonce);
309
310        // Process the associated data
311        let padded_associated_data = pad(associated_data);
312        padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
313            sponge.duplex_overwrite(chunk);
314        });
315
316        // Decrypt the data
317        let mut plaintext = Vec::with_capacity(encrypted_data.ciphertext.len());
318        let mut ciphertext_block_iterator = encrypted_data.ciphertext.chunks_exact(RATE_WIDTH);
319        ciphertext_block_iterator.by_ref().for_each(|ciphertext_data_block| {
320            let keystream = sponge.duplex_add(&[]);
321            for (i, &ciphertext_felt) in ciphertext_data_block.iter().enumerate() {
322                let plaintext_felt = ciphertext_felt - keystream[i];
323                plaintext.push(plaintext_felt);
324            }
325            sponge.state[RATE_RANGE].copy_from_slice(ciphertext_data_block);
326        });
327
328        // Verify authentication tag
329        let computed_tag = sponge.squeeze_tag();
330        if computed_tag != encrypted_data.auth_tag {
331            return Err(EncryptionError::InvalidAuthTag);
332        }
333
334        // Remove padding and return
335        unpad(plaintext)
336    }
337
338    // BYTE DECRYPTION
339    // --------------------------------------------------------------------------------------------
340
341    /// Decrypts the provided encrypted data, as bytes, using this secret key.
342    ///
343    ///
344    /// # Errors
345    /// Returns an error if decryption fails or if the underlying data was encrypted as elements
346    /// rather than as bytes.
347    pub fn decrypt_bytes(
348        &self,
349        encrypted_data: &EncryptedData,
350    ) -> Result<Vec<u8>, EncryptionError> {
351        self.decrypt_bytes_with_associated_data(encrypted_data, &[])
352    }
353
354    /// Decrypts the provided encrypted data, as bytes, given some associated data using this
355    /// secret key.
356    ///
357    /// # Errors
358    /// Returns an error if decryption fails or if the underlying data was encrypted as elements
359    /// rather than as bytes.
360    pub fn decrypt_bytes_with_associated_data(
361        &self,
362        encrypted_data: &EncryptedData,
363        associated_data: &[u8],
364    ) -> Result<Vec<u8>, EncryptionError> {
365        if encrypted_data.data_type != DataType::Bytes {
366            return Err(EncryptionError::InvalidDataType {
367                expected: DataType::Bytes,
368                found: encrypted_data.data_type,
369            });
370        }
371
372        let ad_felt = bytes_to_elements_with_padding(associated_data);
373        let data_felts =
374            self.decrypt_elements_with_associated_data_unchecked(encrypted_data, &ad_felt)?;
375
376        match padded_elements_to_bytes(&data_felts) {
377            Some(bytes) => Ok(bytes),
378            None => Err(EncryptionError::MalformedPadding),
379        }
380    }
381}
382
383impl Distribution<SecretKey> for StandardUniform {
384    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> SecretKey {
385        let mut res = [ZERO; SECRET_KEY_SIZE];
386        let uni_dist =
387            Uniform::new(0, Felt::MODULUS).expect("should not fail given the size of the field");
388        for r in res.iter_mut() {
389            let sampled_integer = uni_dist.sample(rng);
390            *r = Felt::new(sampled_integer);
391        }
392        SecretKey(res)
393    }
394}
395
396impl PartialEq for SecretKey {
397    fn eq(&self, other: &Self) -> bool {
398        // Use constant-time comparison to prevent timing attacks
399        let mut result = true;
400        for (a, b) in self.0.iter().zip(other.0.iter()) {
401            result &= bool::from(a.as_int().ct_eq(&b.as_int()));
402        }
403        result
404    }
405}
406
407impl Eq for SecretKey {}
408
409impl Zeroize for SecretKey {
410    /// Securely clears the shared secret from memory.
411    ///
412    /// # Security
413    ///
414    /// This implementation follows the same security methodology as the `zeroize` crate to ensure
415    /// that sensitive cryptographic material is reliably cleared from memory:
416    ///
417    /// - **Volatile writes**: Uses `ptr::write_volatile` to prevent dead store elimination and
418    ///   other compiler optimizations that might remove the zeroing operation.
419    /// - **Memory ordering**: Includes a sequentially consistent compiler fence (`SeqCst`) to
420    ///   prevent instruction reordering that could expose the secret data after this function
421    ///   returns.
422    fn zeroize(&mut self) {
423        for element in self.0.iter_mut() {
424            unsafe {
425                core::ptr::write_volatile(element, ZERO);
426            }
427        }
428        core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
429    }
430}
431
432// Manual Drop implementation to ensure zeroization on drop.
433impl Drop for SecretKey {
434    fn drop(&mut self) {
435        self.zeroize();
436    }
437}
438
439impl ZeroizeOnDrop for SecretKey {}
440
441// SPONGE STATE
442// ================================================================================================
443
444/// Internal sponge state
445struct SpongeState {
446    state: [Felt; STATE_WIDTH],
447}
448
449impl SpongeState {
450    /// Creates a new sponge state
451    fn new(sk: &SecretKey, nonce: &Nonce) -> Self {
452        let mut state = [ZERO; STATE_WIDTH];
453
454        state[RATE_RANGE_FIRST_HALF].copy_from_slice(&sk.0);
455        state[RATE_RANGE_SECOND_HALF].copy_from_slice(&nonce.0);
456
457        Self { state }
458    }
459
460    /// Duplex interface as described in Algorithm 2 in [1] with `d = 0`
461    ///
462    ///
463    /// [1]: https://eprint.iacr.org/2023/1668
464    fn duplex_overwrite(&mut self, data: &[Felt]) {
465        self.permute();
466
467        // add 1 to the first capacity element
468        self.state[CAPACITY_RANGE.start] += ONE;
469
470        // overwrite the rate portion with `data`
471        self.state[RATE_RANGE].copy_from_slice(data);
472    }
473
474    /// Duplex interface as described in Algorithm 2 in [1] with `d = 1`
475    ///
476    ///
477    /// [1]: https://eprint.iacr.org/2023/1668
478    fn duplex_add(&mut self, data: &[Felt]) -> [Felt; RATE_WIDTH] {
479        self.permute();
480
481        let squeezed_data = self.squeeze_rate();
482
483        for (idx, &element) in data.iter().enumerate() {
484            self.state[RATE_START + idx] += element;
485        }
486
487        squeezed_data
488    }
489
490    /// Squeezes an authentication tag
491    fn squeeze_tag(&mut self) -> AuthTag {
492        self.permute();
493        AuthTag(
494            self.state[RATE_RANGE_FIRST_HALF]
495                .try_into()
496                .expect("rate first half is exactly AUTH_TAG_SIZE elements"),
497        )
498    }
499
500    /// Applies the RPO permutation to the sponge state
501    fn permute(&mut self) {
502        Rpo256::apply_permutation(&mut self.state);
503    }
504
505    /// Squeeze the rate portion of the state
506    fn squeeze_rate(&self) -> [Felt; RATE_WIDTH] {
507        self.state[RATE_RANGE]
508            .try_into()
509            .expect("rate range is exactly RATE_WIDTH elements")
510    }
511}
512
513// NONCE
514// ================================================================================================
515
516/// A 256-bit nonce represented as 4 field elements
517#[derive(Clone, Debug, PartialEq, Eq)]
518pub struct Nonce([Felt; NONCE_SIZE]);
519
520impl Nonce {
521    /// Creates a new random nonce using the provided random number generator
522    pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
523        rng.sample(StandardUniform)
524    }
525}
526
527impl From<Word> for Nonce {
528    fn from(word: Word) -> Self {
529        Nonce(word.into())
530    }
531}
532
533impl From<[Felt; NONCE_SIZE]> for Nonce {
534    fn from(elements: [Felt; NONCE_SIZE]) -> Self {
535        Nonce(elements)
536    }
537}
538
539impl From<Nonce> for Word {
540    fn from(nonce: Nonce) -> Self {
541        nonce.0.into()
542    }
543}
544
545impl From<Nonce> for [Felt; NONCE_SIZE] {
546    fn from(nonce: Nonce) -> Self {
547        nonce.0
548    }
549}
550
551impl Distribution<Nonce> for StandardUniform {
552    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Nonce {
553        let mut res = [ZERO; NONCE_SIZE];
554        let uni_dist =
555            Uniform::new(0, Felt::MODULUS).expect("should not fail given the size of the field");
556        for r in res.iter_mut() {
557            let sampled_integer = uni_dist.sample(rng);
558            *r = Felt::new(sampled_integer);
559        }
560        Nonce(res)
561    }
562}
563
564// SERIALIZATION / DESERIALIZATION
565// ================================================================================================
566
567impl Serializable for SecretKey {
568    fn write_into<W: ByteWriter>(&self, target: &mut W) {
569        let bytes = elements_to_bytes(&self.0);
570        target.write_bytes(&bytes);
571    }
572}
573
574impl Deserializable for SecretKey {
575    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
576        let bytes: [u8; SK_SIZE_BYTES] = source.read_array()?;
577
578        match bytes_to_elements_exact(&bytes) {
579            Some(inner) => {
580                let inner: [Felt; 4] = inner.try_into().map_err(|_| {
581                    DeserializationError::InvalidValue("malformed secret key".to_string())
582                })?;
583                Ok(Self(inner))
584            },
585            None => Err(DeserializationError::InvalidValue("malformed secret key".to_string())),
586        }
587    }
588}
589
590impl Serializable for Nonce {
591    fn write_into<W: ByteWriter>(&self, target: &mut W) {
592        let bytes = elements_to_bytes(&self.0);
593        target.write_bytes(&bytes);
594    }
595}
596
597impl Deserializable for Nonce {
598    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
599        let bytes: [u8; NONCE_SIZE_BYTES] = source.read_array()?;
600
601        match bytes_to_elements_exact(&bytes) {
602            Some(inner) => {
603                let inner: [Felt; 4] = inner.try_into().map_err(|_| {
604                    DeserializationError::InvalidValue("malformed nonce".to_string())
605                })?;
606                Ok(Self(inner))
607            },
608            None => Err(DeserializationError::InvalidValue("malformed nonce".to_string())),
609        }
610    }
611}
612
613impl Serializable for EncryptedData {
614    fn write_into<W: ByteWriter>(&self, target: &mut W) {
615        // we serialize field elements in their canonical form
616        target.write_u8(self.data_type as u8);
617        target.write_usize(self.ciphertext.len());
618        target.write_many(self.ciphertext.iter().map(Felt::as_int));
619        target.write_many(self.nonce.0.iter().map(Felt::as_int));
620        target.write_many(self.auth_tag.0.iter().map(Felt::as_int));
621    }
622}
623
624impl Deserializable for EncryptedData {
625    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
626        let data_type_value: u8 = source.read_u8()?;
627        let data_type = data_type_value.try_into().map_err(|_| {
628            DeserializationError::InvalidValue("invalid data type value".to_string())
629        })?;
630
631        let ciphertext_len = source.read_usize()?;
632        let ciphertext_bytes = source.read_many(ciphertext_len)?;
633        let ciphertext =
634            felts_from_u64(ciphertext_bytes).map_err(DeserializationError::InvalidValue)?;
635
636        let nonce = source.read_many(NONCE_SIZE)?;
637        let nonce: [Felt; NONCE_SIZE] = felts_from_u64(nonce)
638            .map_err(DeserializationError::InvalidValue)?
639            .try_into()
640            .expect("deserialization reads exactly NONCE_SIZE elements");
641
642        let tag = source.read_many(AUTH_TAG_SIZE)?;
643        let tag: [Felt; AUTH_TAG_SIZE] = felts_from_u64(tag)
644            .map_err(DeserializationError::InvalidValue)?
645            .try_into()
646            .expect("deserialization reads exactly AUTH_TAG_SIZE elements");
647
648        Ok(Self {
649            ciphertext,
650            nonce: Nonce(nonce),
651            auth_tag: AuthTag(tag),
652            data_type,
653        })
654    }
655}
656
657//  HELPERS
658// ================================================================================================
659
660/// Performs padding on either the plaintext or associated data.
661///
662/// # Padding Scheme
663///
664/// This AEAD implementation uses an injective padding scheme to ensure that different plaintexts
665/// always produce different ciphertexts, preventing ambiguity during decryption.
666///
667/// ## Data Padding
668///
669/// Plaintext data is padded using a 10* padding scheme:
670///
671/// - A padding separator (field element `ONE`) is appended to the message.
672/// - The message is then zero-padded to reach the next rate boundary.
673/// - **Security guarantee**: `[ONE]` and `[ONE, ZERO]` will produce different ciphertexts because
674///   after padding they become `[ONE, ONE, 0, 0, ...]` and `[ONE, ZERO, ONE, 0, ...]` respectively,
675///   ensuring injectivity.
676///
677/// ## Associated Data Padding
678///
679/// Associated data follows the same injective padding scheme:
680///
681/// - Padding separator (`ONE`) is appended.
682/// - Zero-padded to rate boundary.
683/// - **Security guarantee**: Different associated data inputs (like `[ONE]` vs `[ONE, ZERO]`)
684///   produce different authentication tags due to the injective padding.
685fn pad(data: &[Felt]) -> Vec<Felt> {
686    // if data length is a multiple of 8, padding_elements will be 8
687    let num_elem_final_block = data.len() % RATE_WIDTH;
688    let padding_elements = RATE_WIDTH - num_elem_final_block;
689
690    let mut result = data.to_vec();
691    result.extend_from_slice(&PADDING_BLOCK[..padding_elements]);
692
693    result
694}
695
696/// Removes the padding from the decoded ciphertext.
697fn unpad(mut plaintext: Vec<Felt>) -> Result<Vec<Felt>, EncryptionError> {
698    let (num_blocks, remainder) = plaintext.len().div_rem(&RATE_WIDTH);
699    assert_eq!(remainder, 0);
700
701    let final_block: &[Felt; RATE_WIDTH] = plaintext.last_chunk().expect("plaintext is empty");
702
703    let pos = match final_block.iter().rposition(|entry| *entry == ONE) {
704        Some(pos) => pos,
705        None => return Err(EncryptionError::MalformedPadding),
706    };
707
708    plaintext.truncate((num_blocks - 1) * RATE_WIDTH + pos);
709
710    Ok(plaintext)
711}
712
713/// Converts a vector of u64 values into a vector of field elements, returning an error if any of
714/// the u64 values is not a valid field element.
715fn felts_from_u64(input: Vec<u64>) -> Result<Vec<Felt>, alloc::string::String> {
716    input.into_iter().map(Felt::try_from).collect()
717}
718
719// AEAD SCHEME IMPLEMENTATION
720// ================================================================================================
721
722/// RPO256-based AEAD scheme implementation
723pub struct AeadRpo;
724
725impl AeadScheme for AeadRpo {
726    const KEY_SIZE: usize = SK_SIZE_BYTES;
727
728    type Key = SecretKey;
729
730    fn key_from_bytes(bytes: &[u8]) -> Result<Self::Key, EncryptionError> {
731        SecretKey::read_from_bytes(bytes).map_err(|_| EncryptionError::FailedOperation)
732    }
733
734    fn encrypt_bytes<R: rand::CryptoRng + rand::RngCore>(
735        key: &Self::Key,
736        rng: &mut R,
737        plaintext: &[u8],
738        associated_data: &[u8],
739    ) -> Result<Vec<u8>, EncryptionError> {
740        let nonce = Nonce::with_rng(rng);
741        let encrypted_data = key
742            .encrypt_bytes_with_nonce(plaintext, associated_data, nonce)
743            .map_err(|_| EncryptionError::FailedOperation)?;
744
745        Ok(encrypted_data.to_bytes())
746    }
747
748    fn decrypt_bytes_with_associated_data(
749        key: &Self::Key,
750        ciphertext: &[u8],
751        associated_data: &[u8],
752    ) -> Result<Vec<u8>, EncryptionError> {
753        let encrypted_data = EncryptedData::read_from_bytes(ciphertext)
754            .map_err(|_| EncryptionError::FailedOperation)?;
755
756        key.decrypt_bytes_with_associated_data(&encrypted_data, associated_data)
757    }
758
759    // OPTIMIZED FELT METHODS
760    // --------------------------------------------------------------------------------------------
761
762    fn encrypt_elements<R: rand::CryptoRng + rand::RngCore>(
763        key: &Self::Key,
764        rng: &mut R,
765        plaintext: &[Felt],
766        associated_data: &[Felt],
767    ) -> Result<Vec<u8>, EncryptionError> {
768        let nonce = Nonce::with_rng(rng);
769        let encrypted_data = key
770            .encrypt_elements_with_nonce(plaintext, associated_data, nonce)
771            .map_err(|_| EncryptionError::FailedOperation)?;
772
773        Ok(encrypted_data.to_bytes())
774    }
775
776    fn decrypt_elements_with_associated_data(
777        key: &Self::Key,
778        ciphertext: &[u8],
779        associated_data: &[Felt],
780    ) -> Result<Vec<Felt>, EncryptionError> {
781        let encrypted_data = EncryptedData::read_from_bytes(ciphertext)
782            .map_err(|_| EncryptionError::FailedOperation)?;
783
784        key.decrypt_elements_with_associated_data(&encrypted_data, associated_data)
785    }
786}