1use alloc::{string::ToString, vec::Vec};
11use core::ops::Range;
12
13use miden_crypto_derive::{SilentDebug, SilentDisplay};
14use num::Integer;
15use rand::{
16 Rng,
17 distr::{Distribution, StandardUniform, Uniform},
18};
19use subtle::ConstantTimeEq;
20
21use crate::{
22 Felt, FieldElement, ONE, StarkField, Word, ZERO,
23 aead::{AeadScheme, DataType, EncryptionError},
24 hash::rpo::Rpo256,
25 utils::{
26 ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
27 bytes_to_elements_exact, bytes_to_elements_with_padding, elements_to_bytes,
28 padded_elements_to_bytes,
29 },
30 zeroize::{Zeroize, ZeroizeOnDrop},
31};
32
33#[cfg(test)]
34mod test;
35
36pub const SECRET_KEY_SIZE: usize = 4;
41
42pub const SK_SIZE_BYTES: usize = SECRET_KEY_SIZE * Felt::ELEMENT_BYTES;
44
45pub const NONCE_SIZE: usize = 4;
47
48pub const NONCE_SIZE_BYTES: usize = NONCE_SIZE * Felt::ELEMENT_BYTES;
50
51pub const AUTH_TAG_SIZE: usize = 4;
53
54const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
56
57const CAPACITY_RANGE: Range<usize> = Rpo256::CAPACITY_RANGE;
59
60const RATE_RANGE: Range<usize> = Rpo256::RATE_RANGE;
62
63const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
65
66const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
68
69const RATE_RANGE_FIRST_HALF: Range<usize> =
71 Rpo256::RATE_RANGE.start..Rpo256::RATE_RANGE.start + HALF_RATE_WIDTH;
72
73const RATE_RANGE_SECOND_HALF: Range<usize> =
75 Rpo256::RATE_RANGE.start + HALF_RATE_WIDTH..Rpo256::RATE_RANGE.end;
76
77const RATE_START: usize = Rpo256::RATE_RANGE.start;
79
80const PADDING_BLOCK: [Felt; RATE_WIDTH] = [ONE, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO];
82
83#[derive(Debug, PartialEq, Eq)]
88pub struct EncryptedData {
89 data_type: DataType,
91 ciphertext: Vec<Felt>,
93 auth_tag: AuthTag,
96 nonce: Nonce,
98}
99
100impl EncryptedData {
101 pub fn from_parts(
103 data_type: DataType,
104 ciphertext: Vec<Felt>,
105 auth_tag: AuthTag,
106 nonce: Nonce,
107 ) -> Self {
108 Self { data_type, ciphertext, auth_tag, nonce }
109 }
110
111 pub fn data_type(&self) -> DataType {
113 self.data_type
114 }
115
116 pub fn ciphertext(&self) -> &[Felt] {
118 &self.ciphertext
119 }
120
121 pub fn auth_tag(&self) -> &AuthTag {
123 &self.auth_tag
124 }
125
126 pub fn nonce(&self) -> &Nonce {
128 &self.nonce
129 }
130}
131
132#[derive(Debug, Default, Clone, PartialEq, Eq)]
134pub struct AuthTag([Felt; AUTH_TAG_SIZE]);
135
136impl AuthTag {
137 pub fn new(elements: [Felt; AUTH_TAG_SIZE]) -> Self {
139 Self(elements)
140 }
141
142 pub fn to_elements(&self) -> [Felt; AUTH_TAG_SIZE] {
144 self.0
145 }
146}
147
148#[derive(Clone, SilentDebug, SilentDisplay)]
150pub struct SecretKey([Felt; SECRET_KEY_SIZE]);
151
152impl SecretKey {
153 #[cfg(feature = "std")]
158 #[allow(clippy::new_without_default)]
159 pub fn new() -> Self {
160 let mut rng = rand::rng();
161 Self::with_rng(&mut rng)
162 }
163
164 pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
166 rng.sample(StandardUniform)
167 }
168
169 pub fn from_elements(elements: [Felt; SECRET_KEY_SIZE]) -> Self {
177 Self(elements)
178 }
179
180 pub fn to_elements(&self) -> [Felt; SECRET_KEY_SIZE] {
189 self.0
190 }
191
192 #[cfg(feature = "std")]
198 pub fn encrypt_elements(&self, data: &[Felt]) -> Result<EncryptedData, EncryptionError> {
199 self.encrypt_elements_with_associated_data(data, &[])
200 }
201
202 #[cfg(feature = "std")]
205 pub fn encrypt_elements_with_associated_data(
206 &self,
207 data: &[Felt],
208 associated_data: &[Felt],
209 ) -> Result<EncryptedData, EncryptionError> {
210 let mut rng = rand::rng();
211 let nonce = Nonce::with_rng(&mut rng);
212
213 self.encrypt_elements_with_nonce(data, associated_data, nonce)
214 }
215
216 pub fn encrypt_elements_with_nonce(
219 &self,
220 data: &[Felt],
221 associated_data: &[Felt],
222 nonce: Nonce,
223 ) -> Result<EncryptedData, EncryptionError> {
224 let mut sponge = SpongeState::new(self, &nonce);
226
227 let padded_associated_data = pad(associated_data);
229 padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
230 sponge.duplex_overwrite(chunk);
231 });
232
233 let mut ciphertext = Vec::with_capacity(data.len() + RATE_WIDTH);
235 let data = pad(data);
236 let mut data_block_iterator = data.chunks_exact(RATE_WIDTH);
237
238 data_block_iterator.by_ref().for_each(|data_block| {
239 let keystream = sponge.duplex_add(data_block);
240 for (i, &plaintext_felt) in data_block.iter().enumerate() {
241 ciphertext.push(plaintext_felt + keystream[i]);
242 }
243 });
244
245 let auth_tag = sponge.squeeze_tag();
247
248 Ok(EncryptedData {
249 data_type: DataType::Elements,
250 ciphertext,
251 auth_tag,
252 nonce,
253 })
254 }
255
256 #[cfg(feature = "std")]
263 pub fn encrypt_bytes(&self, data: &[u8]) -> Result<EncryptedData, EncryptionError> {
264 self.encrypt_bytes_with_associated_data(data, &[])
265 }
266
267 #[cfg(feature = "std")]
273 pub fn encrypt_bytes_with_associated_data(
274 &self,
275 data: &[u8],
276 associated_data: &[u8],
277 ) -> Result<EncryptedData, EncryptionError> {
278 let mut rng = rand::rng();
279 let nonce = Nonce::with_rng(&mut rng);
280
281 self.encrypt_bytes_with_nonce(data, associated_data, nonce)
282 }
283
284 pub fn encrypt_bytes_with_nonce(
290 &self,
291 data: &[u8],
292 associated_data: &[u8],
293 nonce: Nonce,
294 ) -> Result<EncryptedData, EncryptionError> {
295 let data_felt = bytes_to_elements_with_padding(data);
296 let ad_felt = bytes_to_elements_with_padding(associated_data);
297
298 let mut encrypted_data = self.encrypt_elements_with_nonce(&data_felt, &ad_felt, nonce)?;
299 encrypted_data.data_type = DataType::Bytes;
300 Ok(encrypted_data)
301 }
302
303 pub fn decrypt_elements(
312 &self,
313 encrypted_data: &EncryptedData,
314 ) -> Result<Vec<Felt>, EncryptionError> {
315 self.decrypt_elements_with_associated_data(encrypted_data, &[])
316 }
317
318 pub fn decrypt_elements_with_associated_data(
324 &self,
325 encrypted_data: &EncryptedData,
326 associated_data: &[Felt],
327 ) -> Result<Vec<Felt>, EncryptionError> {
328 if encrypted_data.data_type != DataType::Elements {
329 return Err(EncryptionError::InvalidDataType {
330 expected: DataType::Elements,
331 found: encrypted_data.data_type,
332 });
333 }
334 self.decrypt_elements_with_associated_data_unchecked(encrypted_data, associated_data)
335 }
336
337 fn decrypt_elements_with_associated_data_unchecked(
339 &self,
340 encrypted_data: &EncryptedData,
341 associated_data: &[Felt],
342 ) -> Result<Vec<Felt>, EncryptionError> {
343 if !encrypted_data.ciphertext.len().is_multiple_of(RATE_WIDTH) {
344 return Err(EncryptionError::CiphertextLenNotMultipleRate);
345 }
346
347 let mut sponge = SpongeState::new(self, &encrypted_data.nonce);
349
350 let padded_associated_data = pad(associated_data);
352 padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
353 sponge.duplex_overwrite(chunk);
354 });
355
356 let mut plaintext = Vec::with_capacity(encrypted_data.ciphertext.len());
358 let mut ciphertext_block_iterator = encrypted_data.ciphertext.chunks_exact(RATE_WIDTH);
359 ciphertext_block_iterator.by_ref().for_each(|ciphertext_data_block| {
360 let keystream = sponge.duplex_add(&[]);
361 for (i, &ciphertext_felt) in ciphertext_data_block.iter().enumerate() {
362 let plaintext_felt = ciphertext_felt - keystream[i];
363 plaintext.push(plaintext_felt);
364 }
365 sponge.state[RATE_RANGE].copy_from_slice(ciphertext_data_block);
366 });
367
368 let computed_tag = sponge.squeeze_tag();
370 if computed_tag != encrypted_data.auth_tag {
371 return Err(EncryptionError::InvalidAuthTag);
372 }
373
374 unpad(plaintext)
376 }
377
378 pub fn decrypt_bytes(
388 &self,
389 encrypted_data: &EncryptedData,
390 ) -> Result<Vec<u8>, EncryptionError> {
391 self.decrypt_bytes_with_associated_data(encrypted_data, &[])
392 }
393
394 pub fn decrypt_bytes_with_associated_data(
401 &self,
402 encrypted_data: &EncryptedData,
403 associated_data: &[u8],
404 ) -> Result<Vec<u8>, EncryptionError> {
405 if encrypted_data.data_type != DataType::Bytes {
406 return Err(EncryptionError::InvalidDataType {
407 expected: DataType::Bytes,
408 found: encrypted_data.data_type,
409 });
410 }
411
412 let ad_felt = bytes_to_elements_with_padding(associated_data);
413 let data_felts =
414 self.decrypt_elements_with_associated_data_unchecked(encrypted_data, &ad_felt)?;
415
416 match padded_elements_to_bytes(&data_felts) {
417 Some(bytes) => Ok(bytes),
418 None => Err(EncryptionError::MalformedPadding),
419 }
420 }
421}
422
423impl Distribution<SecretKey> for StandardUniform {
424 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> SecretKey {
425 let mut res = [ZERO; SECRET_KEY_SIZE];
426 let uni_dist =
427 Uniform::new(0, Felt::MODULUS).expect("should not fail given the size of the field");
428 for r in res.iter_mut() {
429 let sampled_integer = uni_dist.sample(rng);
430 *r = Felt::new(sampled_integer);
431 }
432 SecretKey(res)
433 }
434}
435
436impl PartialEq for SecretKey {
437 fn eq(&self, other: &Self) -> bool {
438 let mut result = true;
440 for (a, b) in self.0.iter().zip(other.0.iter()) {
441 result &= bool::from(a.as_int().ct_eq(&b.as_int()));
442 }
443 result
444 }
445}
446
447impl Eq for SecretKey {}
448
449impl Zeroize for SecretKey {
450 fn zeroize(&mut self) {
463 for element in self.0.iter_mut() {
464 unsafe {
465 core::ptr::write_volatile(element, ZERO);
466 }
467 }
468 core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
469 }
470}
471
472impl Drop for SecretKey {
474 fn drop(&mut self) {
475 self.zeroize();
476 }
477}
478
479impl ZeroizeOnDrop for SecretKey {}
480
481struct SpongeState {
486 state: [Felt; STATE_WIDTH],
487}
488
489impl SpongeState {
490 fn new(sk: &SecretKey, nonce: &Nonce) -> Self {
492 let mut state = [ZERO; STATE_WIDTH];
493
494 state[RATE_RANGE_FIRST_HALF].copy_from_slice(&sk.0);
495 state[RATE_RANGE_SECOND_HALF].copy_from_slice(&nonce.0);
496
497 Self { state }
498 }
499
500 fn duplex_overwrite(&mut self, data: &[Felt]) {
505 self.permute();
506
507 self.state[CAPACITY_RANGE.start] += ONE;
509
510 self.state[RATE_RANGE].copy_from_slice(data);
512 }
513
514 fn duplex_add(&mut self, data: &[Felt]) -> [Felt; RATE_WIDTH] {
519 self.permute();
520
521 let squeezed_data = self.squeeze_rate();
522
523 for (idx, &element) in data.iter().enumerate() {
524 self.state[RATE_START + idx] += element;
525 }
526
527 squeezed_data
528 }
529
530 fn squeeze_tag(&mut self) -> AuthTag {
532 self.permute();
533 AuthTag(
534 self.state[RATE_RANGE_FIRST_HALF]
535 .try_into()
536 .expect("rate first half is exactly AUTH_TAG_SIZE elements"),
537 )
538 }
539
540 fn permute(&mut self) {
542 Rpo256::apply_permutation(&mut self.state);
543 }
544
545 fn squeeze_rate(&self) -> [Felt; RATE_WIDTH] {
547 self.state[RATE_RANGE]
548 .try_into()
549 .expect("rate range is exactly RATE_WIDTH elements")
550 }
551}
552
553#[derive(Clone, Debug, PartialEq, Eq)]
558pub struct Nonce([Felt; NONCE_SIZE]);
559
560impl Nonce {
561 pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
563 rng.sample(StandardUniform)
564 }
565}
566
567impl From<Word> for Nonce {
568 fn from(word: Word) -> Self {
569 Nonce(word.into())
570 }
571}
572
573impl From<[Felt; NONCE_SIZE]> for Nonce {
574 fn from(elements: [Felt; NONCE_SIZE]) -> Self {
575 Nonce(elements)
576 }
577}
578
579impl From<Nonce> for Word {
580 fn from(nonce: Nonce) -> Self {
581 nonce.0.into()
582 }
583}
584
585impl From<Nonce> for [Felt; NONCE_SIZE] {
586 fn from(nonce: Nonce) -> Self {
587 nonce.0
588 }
589}
590
591impl Distribution<Nonce> for StandardUniform {
592 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Nonce {
593 let mut res = [ZERO; NONCE_SIZE];
594 let uni_dist =
595 Uniform::new(0, Felt::MODULUS).expect("should not fail given the size of the field");
596 for r in res.iter_mut() {
597 let sampled_integer = uni_dist.sample(rng);
598 *r = Felt::new(sampled_integer);
599 }
600 Nonce(res)
601 }
602}
603
604impl Serializable for SecretKey {
608 fn write_into<W: ByteWriter>(&self, target: &mut W) {
609 let bytes = elements_to_bytes(&self.0);
610 target.write_bytes(&bytes);
611 }
612}
613
614impl Deserializable for SecretKey {
615 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
616 let bytes: [u8; SK_SIZE_BYTES] = source.read_array()?;
617
618 match bytes_to_elements_exact(&bytes) {
619 Some(inner) => {
620 let inner: [Felt; 4] = inner.try_into().map_err(|_| {
621 DeserializationError::InvalidValue("malformed secret key".to_string())
622 })?;
623 Ok(Self(inner))
624 },
625 None => Err(DeserializationError::InvalidValue("malformed secret key".to_string())),
626 }
627 }
628}
629
630impl Serializable for Nonce {
631 fn write_into<W: ByteWriter>(&self, target: &mut W) {
632 let bytes = elements_to_bytes(&self.0);
633 target.write_bytes(&bytes);
634 }
635}
636
637impl Deserializable for Nonce {
638 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
639 let bytes: [u8; NONCE_SIZE_BYTES] = source.read_array()?;
640
641 match bytes_to_elements_exact(&bytes) {
642 Some(inner) => {
643 let inner: [Felt; 4] = inner.try_into().map_err(|_| {
644 DeserializationError::InvalidValue("malformed nonce".to_string())
645 })?;
646 Ok(Self(inner))
647 },
648 None => Err(DeserializationError::InvalidValue("malformed nonce".to_string())),
649 }
650 }
651}
652
653impl Serializable for EncryptedData {
654 fn write_into<W: ByteWriter>(&self, target: &mut W) {
655 target.write_u8(self.data_type as u8);
657 target.write_usize(self.ciphertext.len());
658 target.write_many(self.ciphertext.iter().map(Felt::as_int));
659 target.write_many(self.nonce.0.iter().map(Felt::as_int));
660 target.write_many(self.auth_tag.0.iter().map(Felt::as_int));
661 }
662}
663
664impl Deserializable for EncryptedData {
665 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
666 let data_type_value: u8 = source.read_u8()?;
667 let data_type = data_type_value.try_into().map_err(|_| {
668 DeserializationError::InvalidValue("invalid data type value".to_string())
669 })?;
670
671 let ciphertext_len = source.read_usize()?;
672 let ciphertext_bytes = source.read_many(ciphertext_len)?;
673 let ciphertext =
674 felts_from_u64(ciphertext_bytes).map_err(DeserializationError::InvalidValue)?;
675
676 let nonce = source.read_many(NONCE_SIZE)?;
677 let nonce: [Felt; NONCE_SIZE] = felts_from_u64(nonce)
678 .map_err(DeserializationError::InvalidValue)?
679 .try_into()
680 .map_err(|_| {
681 DeserializationError::InvalidValue("nonce conversion failed".to_string())
682 })?;
683
684 let tag = source.read_many(AUTH_TAG_SIZE)?;
685 let tag: [Felt; AUTH_TAG_SIZE] = felts_from_u64(tag)
686 .map_err(DeserializationError::InvalidValue)?
687 .try_into()
688 .expect("deserialization reads exactly AUTH_TAG_SIZE elements");
689
690 Ok(Self {
691 ciphertext,
692 nonce: Nonce(nonce),
693 auth_tag: AuthTag(tag),
694 data_type,
695 })
696 }
697}
698
699fn pad(data: &[Felt]) -> Vec<Felt> {
728 let num_elem_final_block = data.len() % RATE_WIDTH;
730 let padding_elements = RATE_WIDTH - num_elem_final_block;
731
732 let mut result = data.to_vec();
733 result.extend_from_slice(&PADDING_BLOCK[..padding_elements]);
734
735 result
736}
737
738fn unpad(mut plaintext: Vec<Felt>) -> Result<Vec<Felt>, EncryptionError> {
740 let (num_blocks, remainder) = plaintext.len().div_rem(&RATE_WIDTH);
741 assert_eq!(remainder, 0);
742
743 let final_block: &[Felt; RATE_WIDTH] =
744 plaintext.last_chunk().ok_or(EncryptionError::MalformedPadding)?;
745
746 let pos = match final_block.iter().rposition(|entry| *entry == ONE) {
747 Some(pos) => pos,
748 None => return Err(EncryptionError::MalformedPadding),
749 };
750
751 plaintext.truncate((num_blocks - 1) * RATE_WIDTH + pos);
752
753 Ok(plaintext)
754}
755
756fn felts_from_u64(input: Vec<u64>) -> Result<Vec<Felt>, alloc::string::String> {
759 input.into_iter().map(Felt::try_from).collect()
760}
761
762pub struct AeadRpo;
767
768impl AeadScheme for AeadRpo {
769 const KEY_SIZE: usize = SK_SIZE_BYTES;
770
771 type Key = SecretKey;
772
773 fn key_from_bytes(bytes: &[u8]) -> Result<Self::Key, EncryptionError> {
774 SecretKey::read_from_bytes(bytes).map_err(|_| EncryptionError::FailedOperation)
775 }
776
777 fn encrypt_bytes<R: rand::CryptoRng + rand::RngCore>(
778 key: &Self::Key,
779 rng: &mut R,
780 plaintext: &[u8],
781 associated_data: &[u8],
782 ) -> Result<Vec<u8>, EncryptionError> {
783 let nonce = Nonce::with_rng(rng);
784 let encrypted_data = key
785 .encrypt_bytes_with_nonce(plaintext, associated_data, nonce)
786 .map_err(|_| EncryptionError::FailedOperation)?;
787
788 Ok(encrypted_data.to_bytes())
789 }
790
791 fn decrypt_bytes_with_associated_data(
792 key: &Self::Key,
793 ciphertext: &[u8],
794 associated_data: &[u8],
795 ) -> Result<Vec<u8>, EncryptionError> {
796 let encrypted_data = EncryptedData::read_from_bytes(ciphertext)
797 .map_err(|_| EncryptionError::FailedOperation)?;
798
799 key.decrypt_bytes_with_associated_data(&encrypted_data, associated_data)
800 }
801
802 fn encrypt_elements<R: rand::CryptoRng + rand::RngCore>(
806 key: &Self::Key,
807 rng: &mut R,
808 plaintext: &[Felt],
809 associated_data: &[Felt],
810 ) -> Result<Vec<u8>, EncryptionError> {
811 let nonce = Nonce::with_rng(rng);
812 let encrypted_data = key
813 .encrypt_elements_with_nonce(plaintext, associated_data, nonce)
814 .map_err(|_| EncryptionError::FailedOperation)?;
815
816 Ok(encrypted_data.to_bytes())
817 }
818
819 fn decrypt_elements_with_associated_data(
820 key: &Self::Key,
821 ciphertext: &[u8],
822 associated_data: &[Felt],
823 ) -> Result<Vec<Felt>, EncryptionError> {
824 let encrypted_data = EncryptedData::read_from_bytes(ciphertext)
825 .map_err(|_| EncryptionError::FailedOperation)?;
826
827 key.decrypt_elements_with_associated_data(&encrypted_data, associated_data)
828 }
829}