Skip to main content

miden_crypto/aead/aead_poseidon2/
mod.rs

1//! # Arithmetization Oriented AEAD
2//!
3//! This module implements an AEAD scheme optimized for speed within SNARKs/STARKs.
4//! The design is described in \[1\] and is based on the MonkeySpongeWrap construction and uses
5//! the Poseidon2 permutation, creating an encryption scheme that is highly efficient when executed
6//! within zero-knowledge proof systems.
7//!
8//! \[1\] <https://eprint.iacr.org/2023/1668>
9
10use alloc::{string::ToString, vec::Vec};
11use core::ops::Range;
12
13use miden_crypto_derive::{SilentDebug, SilentDisplay};
14use num::Integer;
15use rand::{
16    Rng,
17    distr::{Distribution, StandardUniform, Uniform},
18};
19use subtle::ConstantTimeEq;
20
21use crate::{
22    Felt, ONE, Word, ZERO,
23    aead::{AeadScheme, DataType, EncryptionError},
24    hash::poseidon2::Poseidon2,
25    utils::{
26        ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
27        bytes_to_elements_exact, bytes_to_elements_with_padding, elements_to_bytes,
28        padded_elements_to_bytes,
29        zeroize::{Zeroize, ZeroizeOnDrop},
30    },
31};
32
33#[cfg(test)]
34mod test;
35
36// CONSTANTS
37// ================================================================================================
38
39/// Size of a secret key in field elements
40pub const SECRET_KEY_SIZE: usize = 4;
41
42/// Size of a secret key in bytes
43pub const SK_SIZE_BYTES: usize = SECRET_KEY_SIZE * Felt::NUM_BYTES;
44
45/// Size of a nonce in field elements
46pub const NONCE_SIZE: usize = 4;
47
48/// Size of a nonce in bytes
49pub const NONCE_SIZE_BYTES: usize = NONCE_SIZE * Felt::NUM_BYTES;
50
51/// Size of an authentication tag in field elements
52pub const AUTH_TAG_SIZE: usize = 4;
53
54/// Size of the sponge state field elements
55const STATE_WIDTH: usize = Poseidon2::STATE_WIDTH;
56
57/// Capacity portion of the sponge state.
58const CAPACITY_RANGE: Range<usize> = Poseidon2::CAPACITY_RANGE;
59
60/// Rate portion of the sponge state
61const RATE_RANGE: Range<usize> = Poseidon2::RATE_RANGE;
62
63/// Size of the rate portion of the sponge state in field elements
64const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
65
66/// Size of either the 1st or 2nd half of the rate portion of the sponge state in field elements
67const HALF_RATE_WIDTH: usize = (Poseidon2::RATE_RANGE.end - Poseidon2::RATE_RANGE.start) / 2;
68
69/// First half of the rate portion of the sponge state
70const RATE_RANGE_FIRST_HALF: Range<usize> =
71    Poseidon2::RATE_RANGE.start..Poseidon2::RATE_RANGE.start + HALF_RATE_WIDTH;
72
73/// Second half of the rate portion of the sponge state
74const RATE_RANGE_SECOND_HALF: Range<usize> =
75    Poseidon2::RATE_RANGE.start + HALF_RATE_WIDTH..Poseidon2::RATE_RANGE.end;
76
77/// Index of the first element of the rate portion of the sponge state
78const RATE_START: usize = Poseidon2::RATE_RANGE.start;
79
80/// Padding block used when the length of the data to encrypt is a multiple of `RATE_WIDTH`
81const PADDING_BLOCK: [Felt; RATE_WIDTH] = [ONE, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO];
82
83// TYPES AND STRUCTURES
84// ================================================================================================
85
86/// Encrypted data with its authentication tag
87#[derive(Debug, PartialEq, Eq)]
88pub struct EncryptedData {
89    /// Indicates the original format of the data before encryption
90    data_type: DataType,
91    /// The encrypted ciphertext
92    ciphertext: Vec<Felt>,
93    /// The authentication tag attesting to the integrity of the ciphertext, and the associated
94    /// data if it exists
95    auth_tag: AuthTag,
96    /// The nonce used during encryption
97    nonce: Nonce,
98}
99
100impl EncryptedData {
101    /// Constructs an EncryptedData from its component parts.
102    pub fn from_parts(
103        data_type: DataType,
104        ciphertext: Vec<Felt>,
105        auth_tag: AuthTag,
106        nonce: Nonce,
107    ) -> Self {
108        Self { data_type, ciphertext, auth_tag, nonce }
109    }
110
111    /// Returns the data type of the encrypted data
112    pub fn data_type(&self) -> DataType {
113        self.data_type
114    }
115
116    /// Returns a reference to the ciphertext
117    pub fn ciphertext(&self) -> &[Felt] {
118        &self.ciphertext
119    }
120
121    /// Returns a reference to the authentication tag
122    pub fn auth_tag(&self) -> &AuthTag {
123        &self.auth_tag
124    }
125
126    /// Returns a reference to the nonce
127    pub fn nonce(&self) -> &Nonce {
128        &self.nonce
129    }
130}
131
132/// An authentication tag represented as 4 field elements
133#[derive(Debug, Default, Clone)]
134pub struct AuthTag([Felt; AUTH_TAG_SIZE]);
135
136impl AuthTag {
137    /// Constructs an AuthTag from an array of field elements.
138    pub fn new(elements: [Felt; AUTH_TAG_SIZE]) -> Self {
139        Self(elements)
140    }
141
142    /// Returns the authentication tag as an array of field elements
143    pub fn to_elements(&self) -> [Felt; AUTH_TAG_SIZE] {
144        self.0
145    }
146}
147
148impl PartialEq for AuthTag {
149    fn eq(&self, other: &Self) -> bool {
150        // Use constant-time comparison on non-wasm targets to reduce timing leaks. On wasm, the
151        // canonicalization helper is a host call and not guaranteed constant-time.
152        let mut result = true;
153        for (a, b) in self.0.iter().zip(other.0.iter()) {
154            result &= bool::from(a.as_canonical_u64_ct().ct_eq(&b.as_canonical_u64_ct()));
155        }
156        result
157    }
158}
159
160impl Eq for AuthTag {}
161
162/// A 256-bit secret key represented as 4 field elements
163#[derive(Clone, SilentDebug, SilentDisplay)]
164pub struct SecretKey([Felt; SECRET_KEY_SIZE]);
165
166impl SecretKey {
167    // CONSTRUCTORS
168    // --------------------------------------------------------------------------------------------
169
170    /// Creates a new random secret key using the default random number generator.
171    #[cfg(feature = "std")]
172    #[allow(clippy::new_without_default)]
173    pub fn new() -> Self {
174        let mut rng = rand::rng();
175        Self::with_rng(&mut rng)
176    }
177
178    /// Creates a new random secret key using the provided random number generator.
179    pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
180        rng.sample(StandardUniform)
181    }
182
183    /// Creates a secret key from the provided array of field elements.
184    ///
185    /// # Security Warning
186    /// This method should be used with caution. Secret keys must be derived from a
187    /// cryptographically secure source of entropy. Do not use predictable or low-entropy
188    /// values as secret key material. Prefer using `new()` or `with_rng()` with a
189    /// cryptographically secure random number generator.
190    pub fn from_elements(elements: [Felt; SECRET_KEY_SIZE]) -> Self {
191        Self(elements)
192    }
193
194    // ACCESSORS
195    // --------------------------------------------------------------------------------------------
196
197    /// Returns the secret key as an array of field elements.
198    ///
199    /// # Security Warning
200    /// This method exposes the raw secret key material. Use with caution and ensure
201    /// proper zeroization of the returned array when no longer needed.
202    pub fn to_elements(&self) -> [Felt; SECRET_KEY_SIZE] {
203        self.0
204    }
205
206    // ELEMENT ENCRYPTION
207    // --------------------------------------------------------------------------------------------
208
209    /// Encrypts and authenticates the provided sequence of field elements using this secret key
210    /// and a random nonce.
211    #[cfg(feature = "std")]
212    pub fn encrypt_elements(&self, data: &[Felt]) -> Result<EncryptedData, EncryptionError> {
213        self.encrypt_elements_with_associated_data(data, &[])
214    }
215
216    /// Encrypts the provided sequence of field elements and authenticates both the ciphertext as
217    /// well as the provided associated data using this secret key and a random nonce.
218    #[cfg(feature = "std")]
219    pub fn encrypt_elements_with_associated_data(
220        &self,
221        data: &[Felt],
222        associated_data: &[Felt],
223    ) -> Result<EncryptedData, EncryptionError> {
224        let mut rng = rand::rng();
225        let nonce = Nonce::with_rng(&mut rng);
226
227        self.encrypt_elements_with_nonce(data, associated_data, nonce)
228    }
229
230    /// Encrypts the provided sequence of field elements and authenticates both the ciphertext as
231    /// well as the provided associated data using this secret key and the specified nonce.
232    pub fn encrypt_elements_with_nonce(
233        &self,
234        data: &[Felt],
235        associated_data: &[Felt],
236        nonce: Nonce,
237    ) -> Result<EncryptedData, EncryptionError> {
238        // Initialize as sponge state with key and nonce
239        let mut sponge = SpongeState::new(self, &nonce);
240
241        // Process the associated data
242        let padded_associated_data = pad(associated_data);
243        padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
244            sponge.duplex_overwrite(chunk);
245        });
246
247        // Encrypt the data
248        let mut ciphertext = Vec::with_capacity(data.len() + RATE_WIDTH);
249        let data = pad(data);
250        let mut data_block_iterator = data.chunks_exact(RATE_WIDTH);
251
252        data_block_iterator.by_ref().for_each(|data_block| {
253            let keystream = sponge.duplex_add(data_block);
254            for (i, &plaintext_felt) in data_block.iter().enumerate() {
255                ciphertext.push(plaintext_felt + keystream[i]);
256            }
257        });
258
259        // Generate authentication tag
260        let auth_tag = sponge.squeeze_tag();
261
262        Ok(EncryptedData {
263            data_type: DataType::Elements,
264            ciphertext,
265            auth_tag,
266            nonce,
267        })
268    }
269
270    // BYTE ENCRYPTION
271    // --------------------------------------------------------------------------------------------
272
273    /// Encrypts and authenticates the provided data using this secret key and a random nonce.
274    ///
275    /// Before encryption, the bytestring is converted to a sequence of field elements.
276    #[cfg(feature = "std")]
277    pub fn encrypt_bytes(&self, data: &[u8]) -> Result<EncryptedData, EncryptionError> {
278        self.encrypt_bytes_with_associated_data(data, &[])
279    }
280
281    /// Encrypts the provided data and authenticates both the ciphertext as well as the provided
282    /// associated data using this secret key and a random nonce.
283    ///
284    /// Before encryption, both the data and the associated data are converted to sequences of
285    /// field elements.
286    #[cfg(feature = "std")]
287    pub fn encrypt_bytes_with_associated_data(
288        &self,
289        data: &[u8],
290        associated_data: &[u8],
291    ) -> Result<EncryptedData, EncryptionError> {
292        let mut rng = rand::rng();
293        let nonce = Nonce::with_rng(&mut rng);
294
295        self.encrypt_bytes_with_nonce(data, associated_data, nonce)
296    }
297
298    /// Encrypts the provided data and authenticates both the ciphertext as well as the provided
299    /// associated data using this secret key and the specified nonce.
300    ///
301    /// Before encryption, both the data and the associated data are converted to sequences of
302    /// field elements.
303    pub fn encrypt_bytes_with_nonce(
304        &self,
305        data: &[u8],
306        associated_data: &[u8],
307        nonce: Nonce,
308    ) -> Result<EncryptedData, EncryptionError> {
309        let data_felt = bytes_to_elements_with_padding(data);
310        let ad_felt = bytes_to_elements_with_padding(associated_data);
311
312        let mut encrypted_data = self.encrypt_elements_with_nonce(&data_felt, &ad_felt, nonce)?;
313        encrypted_data.data_type = DataType::Bytes;
314        Ok(encrypted_data)
315    }
316
317    // ELEMENT DECRYPTION
318    // --------------------------------------------------------------------------------------------
319
320    /// Decrypts the provided encrypted data using this secret key.
321    ///
322    /// # Errors
323    /// Returns an error if decryption fails or if the underlying data was encrypted as bytes
324    /// rather than as field elements.
325    pub fn decrypt_elements(
326        &self,
327        encrypted_data: &EncryptedData,
328    ) -> Result<Vec<Felt>, EncryptionError> {
329        self.decrypt_elements_with_associated_data(encrypted_data, &[])
330    }
331
332    /// Decrypts the provided encrypted data, given some associated data, using this secret key.
333    ///
334    /// # Errors
335    /// Returns an error if decryption fails or if the underlying data was encrypted as bytes
336    /// rather than as field elements.
337    pub fn decrypt_elements_with_associated_data(
338        &self,
339        encrypted_data: &EncryptedData,
340        associated_data: &[Felt],
341    ) -> Result<Vec<Felt>, EncryptionError> {
342        if encrypted_data.data_type != DataType::Elements {
343            return Err(EncryptionError::InvalidDataType {
344                expected: DataType::Elements,
345                found: encrypted_data.data_type,
346            });
347        }
348        self.decrypt_elements_with_associated_data_unchecked(encrypted_data, associated_data)
349    }
350
351    /// Decrypts the provided encrypted data, given some associated data, using this secret key.
352    fn decrypt_elements_with_associated_data_unchecked(
353        &self,
354        encrypted_data: &EncryptedData,
355        associated_data: &[Felt],
356    ) -> Result<Vec<Felt>, EncryptionError> {
357        if !encrypted_data.ciphertext.len().is_multiple_of(RATE_WIDTH) {
358            return Err(EncryptionError::CiphertextLenNotMultipleRate);
359        }
360
361        // Initialize as sponge state with key and nonce
362        let mut sponge = SpongeState::new(self, &encrypted_data.nonce);
363
364        // Process the associated data
365        let padded_associated_data = pad(associated_data);
366        padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
367            sponge.duplex_overwrite(chunk);
368        });
369
370        // Decrypt the data
371        let mut plaintext = Vec::with_capacity(encrypted_data.ciphertext.len());
372        let mut ciphertext_block_iterator = encrypted_data.ciphertext.chunks_exact(RATE_WIDTH);
373        ciphertext_block_iterator.by_ref().for_each(|ciphertext_data_block| {
374            let keystream = sponge.duplex_add(&[]);
375            for (i, &ciphertext_felt) in ciphertext_data_block.iter().enumerate() {
376                let plaintext_felt = ciphertext_felt - keystream[i];
377                plaintext.push(plaintext_felt);
378            }
379            sponge.state[RATE_RANGE].copy_from_slice(ciphertext_data_block);
380        });
381
382        // Verify authentication tag
383        let computed_tag = sponge.squeeze_tag();
384        if computed_tag != encrypted_data.auth_tag {
385            return Err(EncryptionError::InvalidAuthTag);
386        }
387
388        // Remove padding and return
389        unpad(plaintext)
390    }
391
392    // BYTE DECRYPTION
393    // --------------------------------------------------------------------------------------------
394
395    /// Decrypts the provided encrypted data, as bytes, using this secret key.
396    ///
397    ///
398    /// # Errors
399    /// Returns an error if decryption fails or if the underlying data was encrypted as elements
400    /// rather than as bytes.
401    pub fn decrypt_bytes(
402        &self,
403        encrypted_data: &EncryptedData,
404    ) -> Result<Vec<u8>, EncryptionError> {
405        self.decrypt_bytes_with_associated_data(encrypted_data, &[])
406    }
407
408    /// Decrypts the provided encrypted data, as bytes, given some associated data using this
409    /// secret key.
410    ///
411    /// # Errors
412    /// Returns an error if decryption fails or if the underlying data was encrypted as elements
413    /// rather than as bytes.
414    pub fn decrypt_bytes_with_associated_data(
415        &self,
416        encrypted_data: &EncryptedData,
417        associated_data: &[u8],
418    ) -> Result<Vec<u8>, EncryptionError> {
419        if encrypted_data.data_type != DataType::Bytes {
420            return Err(EncryptionError::InvalidDataType {
421                expected: DataType::Bytes,
422                found: encrypted_data.data_type,
423            });
424        }
425
426        let ad_felt = bytes_to_elements_with_padding(associated_data);
427        let data_felts =
428            self.decrypt_elements_with_associated_data_unchecked(encrypted_data, &ad_felt)?;
429
430        match padded_elements_to_bytes(&data_felts) {
431            Some(bytes) => Ok(bytes),
432            None => Err(EncryptionError::MalformedPadding),
433        }
434    }
435}
436
437impl Distribution<SecretKey> for StandardUniform {
438    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> SecretKey {
439        let mut res = [ZERO; SECRET_KEY_SIZE];
440        let uni_dist =
441            Uniform::new(0, Felt::ORDER).expect("should not fail given the size of the field");
442        for r in res.iter_mut() {
443            let sampled_integer = uni_dist.sample(rng);
444            *r = Felt::new(sampled_integer);
445        }
446        SecretKey(res)
447    }
448}
449
450#[cfg(any(test, feature = "testing"))]
451impl PartialEq for SecretKey {
452    fn eq(&self, other: &Self) -> bool {
453        // Use constant-time comparison to prevent timing attacks
454        let mut result = true;
455        for (a, b) in self.0.iter().zip(other.0.iter()) {
456            result &= bool::from(a.as_canonical_u64_ct().ct_eq(&b.as_canonical_u64_ct()));
457        }
458        result
459    }
460}
461
462#[cfg(any(test, feature = "testing"))]
463impl Eq for SecretKey {}
464
465impl Zeroize for SecretKey {
466    /// Securely clears the shared secret from memory.
467    ///
468    /// # Security
469    ///
470    /// This implementation follows the same security methodology as the `zeroize` crate to ensure
471    /// that sensitive cryptographic material is reliably cleared from memory:
472    ///
473    /// - **Volatile writes**: Uses `ptr::write_volatile` to prevent dead store elimination and
474    ///   other compiler optimizations that might remove the zeroing operation.
475    /// - **Memory ordering**: Includes a sequentially consistent compiler fence (`SeqCst`) to
476    ///   prevent instruction reordering that could expose the secret data after this function
477    ///   returns.
478    fn zeroize(&mut self) {
479        for element in self.0.iter_mut() {
480            unsafe {
481                core::ptr::write_volatile(element, ZERO);
482            }
483        }
484        core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
485    }
486}
487
488// Manual Drop implementation to ensure zeroization on drop.
489impl Drop for SecretKey {
490    fn drop(&mut self) {
491        self.zeroize();
492    }
493}
494
495impl ZeroizeOnDrop for SecretKey {}
496
497// SPONGE STATE
498// ================================================================================================
499
500/// Internal sponge state
501struct SpongeState {
502    state: [Felt; STATE_WIDTH],
503}
504
505impl SpongeState {
506    /// Creates a new sponge state
507    fn new(sk: &SecretKey, nonce: &Nonce) -> Self {
508        let mut state = [ZERO; STATE_WIDTH];
509
510        state[RATE_RANGE_FIRST_HALF].copy_from_slice(&sk.0);
511        state[RATE_RANGE_SECOND_HALF].copy_from_slice(&nonce.0);
512
513        Self { state }
514    }
515
516    /// Duplex interface as described in Algorithm 2 in [1] with `d = 0`
517    ///
518    ///
519    /// [1]: <https://eprint.iacr.org/2023/1668>
520    fn duplex_overwrite(&mut self, data: &[Felt]) {
521        self.permute();
522
523        // add 1 to the first capacity element
524        self.state[CAPACITY_RANGE.start] += ONE;
525
526        // overwrite the rate portion with `data`
527        self.state[RATE_RANGE].copy_from_slice(data);
528    }
529
530    /// Duplex interface as described in Algorithm 2 in [1] with `d = 1`
531    ///
532    ///
533    /// [1]: <https://eprint.iacr.org/2023/1668>
534    fn duplex_add(&mut self, data: &[Felt]) -> [Felt; RATE_WIDTH] {
535        self.permute();
536
537        let squeezed_data = self.squeeze_rate();
538
539        for (idx, &element) in data.iter().enumerate() {
540            self.state[RATE_START + idx] += element;
541        }
542
543        squeezed_data
544    }
545
546    /// Squeezes an authentication tag (from RATE0, the first rate word)
547    fn squeeze_tag(&mut self) -> AuthTag {
548        self.permute();
549        AuthTag(
550            self.state[RATE_RANGE_FIRST_HALF]
551                .try_into()
552                .expect("rate first half is exactly AUTH_TAG_SIZE elements"),
553        )
554    }
555
556    /// Applies the Poseidon2 permutation to the sponge state
557    fn permute(&mut self) {
558        Poseidon2::apply_permutation(&mut self.state);
559    }
560
561    /// Squeeze the rate portion of the state
562    fn squeeze_rate(&self) -> [Felt; RATE_WIDTH] {
563        self.state[RATE_RANGE]
564            .try_into()
565            .expect("rate range is exactly RATE_WIDTH elements")
566    }
567}
568
569// NONCE
570// ================================================================================================
571
572/// A 256-bit nonce represented as 4 field elements
573#[derive(Clone, Debug, PartialEq, Eq)]
574pub struct Nonce([Felt; NONCE_SIZE]);
575
576impl Nonce {
577    /// Creates a new random nonce using the provided random number generator
578    pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
579        rng.sample(StandardUniform)
580    }
581}
582
583impl From<Word> for Nonce {
584    fn from(word: Word) -> Self {
585        Nonce(word.into())
586    }
587}
588
589impl From<[Felt; NONCE_SIZE]> for Nonce {
590    fn from(elements: [Felt; NONCE_SIZE]) -> Self {
591        Nonce(elements)
592    }
593}
594
595impl From<Nonce> for Word {
596    fn from(nonce: Nonce) -> Self {
597        nonce.0.into()
598    }
599}
600
601impl From<Nonce> for [Felt; NONCE_SIZE] {
602    fn from(nonce: Nonce) -> Self {
603        nonce.0
604    }
605}
606
607impl Distribution<Nonce> for StandardUniform {
608    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Nonce {
609        let mut res = [ZERO; NONCE_SIZE];
610        let uni_dist =
611            Uniform::new(0, Felt::ORDER).expect("should not fail given the size of the field");
612        for r in res.iter_mut() {
613            let sampled_integer = uni_dist.sample(rng);
614            *r = Felt::new(sampled_integer);
615        }
616        Nonce(res)
617    }
618}
619
620// SERIALIZATION / DESERIALIZATION
621// ================================================================================================
622
623impl Serializable for SecretKey {
624    fn write_into<W: ByteWriter>(&self, target: &mut W) {
625        let bytes = elements_to_bytes(&self.0);
626        target.write_bytes(&bytes);
627    }
628}
629
630impl Deserializable for SecretKey {
631    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
632        let bytes: [u8; SK_SIZE_BYTES] = source.read_array()?;
633
634        match bytes_to_elements_exact(&bytes) {
635            Some(inner) => {
636                let inner: [Felt; 4] = inner.try_into().map_err(|_| {
637                    DeserializationError::InvalidValue("malformed secret key".to_string())
638                })?;
639                Ok(Self(inner))
640            },
641            None => Err(DeserializationError::InvalidValue("malformed secret key".to_string())),
642        }
643    }
644}
645
646impl Serializable for Nonce {
647    fn write_into<W: ByteWriter>(&self, target: &mut W) {
648        let bytes = elements_to_bytes(&self.0);
649        target.write_bytes(&bytes);
650    }
651}
652
653impl Deserializable for Nonce {
654    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
655        let bytes: [u8; NONCE_SIZE_BYTES] = source.read_array()?;
656
657        match bytes_to_elements_exact(&bytes) {
658            Some(inner) => {
659                let inner: [Felt; 4] = inner.try_into().map_err(|_| {
660                    DeserializationError::InvalidValue("malformed nonce".to_string())
661                })?;
662                Ok(Self(inner))
663            },
664            None => Err(DeserializationError::InvalidValue("malformed nonce".to_string())),
665        }
666    }
667}
668
669impl Serializable for EncryptedData {
670    fn write_into<W: ByteWriter>(&self, target: &mut W) {
671        target.write_u8(self.data_type as u8);
672        self.ciphertext.write_into(target);
673        target.write_many(self.nonce.0);
674        target.write_many(self.auth_tag.0);
675    }
676}
677
678impl Deserializable for EncryptedData {
679    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
680        let data_type_value: u8 = source.read_u8()?;
681        let data_type = data_type_value.try_into().map_err(|_| {
682            DeserializationError::InvalidValue("invalid data type value".to_string())
683        })?;
684
685        let ciphertext = Vec::<Felt>::read_from(source)?;
686        let nonce: [Felt; NONCE_SIZE] = source.read()?;
687        let auth_tag: [Felt; AUTH_TAG_SIZE] = source.read()?;
688
689        Ok(Self {
690            ciphertext,
691            nonce: Nonce(nonce),
692            auth_tag: AuthTag(auth_tag),
693            data_type,
694        })
695    }
696}
697
698//  HELPERS
699// ================================================================================================
700
701/// Performs padding on either the plaintext or associated data.
702///
703/// # Padding Scheme
704///
705/// This AEAD implementation uses an injective padding scheme to ensure that different plaintexts
706/// always produce different ciphertexts, preventing ambiguity during decryption.
707///
708/// ## Data Padding
709///
710/// Plaintext data is padded using a 10* padding scheme:
711///
712/// - A padding separator (field element `ONE`) is appended to the message.
713/// - The message is then zero-padded to reach the next rate boundary.
714/// - **Security guarantee**: `[ONE]` and `[ONE, ZERO]` will produce different ciphertexts because
715///   after padding they become `[ONE, ONE, 0, 0, ...]` and `[ONE, ZERO, ONE, 0, ...]` respectively,
716///   ensuring injectivity.
717///
718/// ## Associated Data Padding
719///
720/// Associated data follows the same injective padding scheme:
721///
722/// - Padding separator (`ONE`) is appended.
723/// - Zero-padded to rate boundary.
724/// - **Security guarantee**: Different associated data inputs (like `[ONE]` vs `[ONE, ZERO]`)
725///   produce different authentication tags due to the injective padding.
726fn pad(data: &[Felt]) -> Vec<Felt> {
727    // if data length is a multiple of 8, padding_elements will be 8
728    let num_elem_final_block = data.len() % RATE_WIDTH;
729    let padding_elements = RATE_WIDTH - num_elem_final_block;
730
731    let mut result = data.to_vec();
732    result.extend_from_slice(&PADDING_BLOCK[..padding_elements]);
733
734    result
735}
736
737/// Removes the padding from the decoded ciphertext.
738fn unpad(mut plaintext: Vec<Felt>) -> Result<Vec<Felt>, EncryptionError> {
739    let (num_blocks, remainder) = plaintext.len().div_rem(&RATE_WIDTH);
740    assert_eq!(remainder, 0);
741
742    let final_block: &[Felt; RATE_WIDTH] =
743        plaintext.last_chunk().ok_or(EncryptionError::MalformedPadding)?;
744
745    let pos = match final_block.iter().rposition(|entry| *entry == ONE) {
746        Some(pos) => pos,
747        None => return Err(EncryptionError::MalformedPadding),
748    };
749
750    plaintext.truncate((num_blocks - 1) * RATE_WIDTH + pos);
751
752    Ok(plaintext)
753}
754
755// AEAD SCHEME IMPLEMENTATION
756// ================================================================================================
757
758/// Poseidon2-based AEAD scheme implementation
759pub struct AeadPoseidon2;
760
761impl AeadScheme for AeadPoseidon2 {
762    const KEY_SIZE: usize = SK_SIZE_BYTES;
763
764    type Key = SecretKey;
765
766    fn key_from_bytes(bytes: &[u8]) -> Result<Self::Key, EncryptionError> {
767        if bytes.len() != SK_SIZE_BYTES {
768            return Err(EncryptionError::FailedOperation);
769        }
770
771        SecretKey::read_from_bytes_with_budget(bytes, SK_SIZE_BYTES)
772            .map_err(|_| EncryptionError::FailedOperation)
773    }
774
775    fn encrypt_bytes<R: rand::CryptoRng + rand::RngCore>(
776        key: &Self::Key,
777        rng: &mut R,
778        plaintext: &[u8],
779        associated_data: &[u8],
780    ) -> Result<Vec<u8>, EncryptionError> {
781        let nonce = Nonce::with_rng(rng);
782        let encrypted_data = key
783            .encrypt_bytes_with_nonce(plaintext, associated_data, nonce)
784            .map_err(|_| EncryptionError::FailedOperation)?;
785
786        Ok(encrypted_data.to_bytes())
787    }
788
789    fn decrypt_bytes_with_associated_data(
790        key: &Self::Key,
791        ciphertext: &[u8],
792        associated_data: &[u8],
793    ) -> Result<Vec<u8>, EncryptionError> {
794        let encrypted_data =
795            EncryptedData::read_from_bytes_with_budget(ciphertext, ciphertext.len())
796                .map_err(|_| EncryptionError::FailedOperation)?;
797
798        key.decrypt_bytes_with_associated_data(&encrypted_data, associated_data)
799    }
800
801    // OPTIMIZED FELT METHODS
802    // --------------------------------------------------------------------------------------------
803
804    fn encrypt_elements<R: rand::CryptoRng + rand::RngCore>(
805        key: &Self::Key,
806        rng: &mut R,
807        plaintext: &[Felt],
808        associated_data: &[Felt],
809    ) -> Result<Vec<u8>, EncryptionError> {
810        let nonce = Nonce::with_rng(rng);
811        let encrypted_data = key
812            .encrypt_elements_with_nonce(plaintext, associated_data, nonce)
813            .map_err(|_| EncryptionError::FailedOperation)?;
814
815        Ok(encrypted_data.to_bytes())
816    }
817
818    fn decrypt_elements_with_associated_data(
819        key: &Self::Key,
820        ciphertext: &[u8],
821        associated_data: &[Felt],
822    ) -> Result<Vec<Felt>, EncryptionError> {
823        let encrypted_data =
824            EncryptedData::read_from_bytes_with_budget(ciphertext, ciphertext.len())
825                .map_err(|_| EncryptionError::FailedOperation)?;
826
827        key.decrypt_elements_with_associated_data(&encrypted_data, associated_data)
828    }
829}