1use alloc::{string::ToString, vec::Vec};
11use core::ops::Range;
12
13use miden_crypto_derive::{SilentDebug, SilentDisplay};
14use num::Integer;
15use p3_field::{PrimeField64, RawDataSerializable};
16use rand::{
17 Rng,
18 distr::{Distribution, StandardUniform, Uniform},
19};
20use subtle::ConstantTimeEq;
21
22use crate::{
23 Felt, ONE, Word, ZERO,
24 aead::{AeadScheme, DataType, EncryptionError},
25 hash::rpo::Rpo256,
26 utils::{
27 ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
28 bytes_to_elements_exact, bytes_to_elements_with_padding, elements_to_bytes,
29 padded_elements_to_bytes,
30 zeroize::{Zeroize, ZeroizeOnDrop},
31 },
32};
33
34#[cfg(test)]
35mod test;
36
37pub const SECRET_KEY_SIZE: usize = 4;
42
43pub const SK_SIZE_BYTES: usize = SECRET_KEY_SIZE * Felt::NUM_BYTES;
45
46pub const NONCE_SIZE: usize = 4;
48
49pub const NONCE_SIZE_BYTES: usize = NONCE_SIZE * Felt::NUM_BYTES;
51
52pub const AUTH_TAG_SIZE: usize = 4;
54
55const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
57
58const CAPACITY_RANGE: Range<usize> = Rpo256::CAPACITY_RANGE;
60
61const RATE_RANGE: Range<usize> = Rpo256::RATE_RANGE;
63
64const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
66
67const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
69
70const RATE_RANGE_FIRST_HALF: Range<usize> =
72 Rpo256::RATE_RANGE.start..Rpo256::RATE_RANGE.start + HALF_RATE_WIDTH;
73
74const RATE_RANGE_SECOND_HALF: Range<usize> =
76 Rpo256::RATE_RANGE.start + HALF_RATE_WIDTH..Rpo256::RATE_RANGE.end;
77
78const RATE_START: usize = Rpo256::RATE_RANGE.start;
80
81const PADDING_BLOCK: [Felt; RATE_WIDTH] = [ONE, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO];
83
84#[derive(Debug, PartialEq, Eq)]
89pub struct EncryptedData {
90 data_type: DataType,
92 ciphertext: Vec<Felt>,
94 auth_tag: AuthTag,
97 nonce: Nonce,
99}
100
101impl EncryptedData {
102 pub fn from_parts(
104 data_type: DataType,
105 ciphertext: Vec<Felt>,
106 auth_tag: AuthTag,
107 nonce: Nonce,
108 ) -> Self {
109 Self { data_type, ciphertext, auth_tag, nonce }
110 }
111
112 pub fn data_type(&self) -> DataType {
114 self.data_type
115 }
116
117 pub fn ciphertext(&self) -> &[Felt] {
119 &self.ciphertext
120 }
121
122 pub fn auth_tag(&self) -> &AuthTag {
124 &self.auth_tag
125 }
126
127 pub fn nonce(&self) -> &Nonce {
129 &self.nonce
130 }
131}
132
133#[derive(Debug, Default, Clone, PartialEq, Eq)]
135pub struct AuthTag([Felt; AUTH_TAG_SIZE]);
136
137impl AuthTag {
138 pub fn new(elements: [Felt; AUTH_TAG_SIZE]) -> Self {
140 Self(elements)
141 }
142
143 pub fn to_elements(&self) -> [Felt; AUTH_TAG_SIZE] {
145 self.0
146 }
147}
148
149#[derive(Clone, SilentDebug, SilentDisplay)]
151pub struct SecretKey([Felt; SECRET_KEY_SIZE]);
152
153impl SecretKey {
154 #[cfg(feature = "std")]
159 #[allow(clippy::new_without_default)]
160 pub fn new() -> Self {
161 let mut rng = rand::rng();
162 Self::with_rng(&mut rng)
163 }
164
165 pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
167 rng.sample(StandardUniform)
168 }
169
170 pub fn from_elements(elements: [Felt; SECRET_KEY_SIZE]) -> Self {
178 Self(elements)
179 }
180
181 pub fn to_elements(&self) -> [Felt; SECRET_KEY_SIZE] {
190 self.0
191 }
192
193 #[cfg(feature = "std")]
199 pub fn encrypt_elements(&self, data: &[Felt]) -> Result<EncryptedData, EncryptionError> {
200 self.encrypt_elements_with_associated_data(data, &[])
201 }
202
203 #[cfg(feature = "std")]
206 pub fn encrypt_elements_with_associated_data(
207 &self,
208 data: &[Felt],
209 associated_data: &[Felt],
210 ) -> Result<EncryptedData, EncryptionError> {
211 let mut rng = rand::rng();
212 let nonce = Nonce::with_rng(&mut rng);
213
214 self.encrypt_elements_with_nonce(data, associated_data, nonce)
215 }
216
217 pub fn encrypt_elements_with_nonce(
220 &self,
221 data: &[Felt],
222 associated_data: &[Felt],
223 nonce: Nonce,
224 ) -> Result<EncryptedData, EncryptionError> {
225 let mut sponge = SpongeState::new(self, &nonce);
227
228 let padded_associated_data = pad(associated_data);
230 padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
231 sponge.duplex_overwrite(chunk);
232 });
233
234 let mut ciphertext = Vec::with_capacity(data.len() + RATE_WIDTH);
236 let data = pad(data);
237 let mut data_block_iterator = data.chunks_exact(RATE_WIDTH);
238
239 data_block_iterator.by_ref().for_each(|data_block| {
240 let keystream = sponge.duplex_add(data_block);
241 for (i, &plaintext_felt) in data_block.iter().enumerate() {
242 ciphertext.push(plaintext_felt + keystream[i]);
243 }
244 });
245
246 let auth_tag = sponge.squeeze_tag();
248
249 Ok(EncryptedData {
250 data_type: DataType::Elements,
251 ciphertext,
252 auth_tag,
253 nonce,
254 })
255 }
256
257 #[cfg(feature = "std")]
264 pub fn encrypt_bytes(&self, data: &[u8]) -> Result<EncryptedData, EncryptionError> {
265 self.encrypt_bytes_with_associated_data(data, &[])
266 }
267
268 #[cfg(feature = "std")]
274 pub fn encrypt_bytes_with_associated_data(
275 &self,
276 data: &[u8],
277 associated_data: &[u8],
278 ) -> Result<EncryptedData, EncryptionError> {
279 let mut rng = rand::rng();
280 let nonce = Nonce::with_rng(&mut rng);
281
282 self.encrypt_bytes_with_nonce(data, associated_data, nonce)
283 }
284
285 pub fn encrypt_bytes_with_nonce(
291 &self,
292 data: &[u8],
293 associated_data: &[u8],
294 nonce: Nonce,
295 ) -> Result<EncryptedData, EncryptionError> {
296 let data_felt = bytes_to_elements_with_padding(data);
297 let ad_felt = bytes_to_elements_with_padding(associated_data);
298
299 let mut encrypted_data = self.encrypt_elements_with_nonce(&data_felt, &ad_felt, nonce)?;
300 encrypted_data.data_type = DataType::Bytes;
301 Ok(encrypted_data)
302 }
303
304 pub fn decrypt_elements(
313 &self,
314 encrypted_data: &EncryptedData,
315 ) -> Result<Vec<Felt>, EncryptionError> {
316 self.decrypt_elements_with_associated_data(encrypted_data, &[])
317 }
318
319 pub fn decrypt_elements_with_associated_data(
325 &self,
326 encrypted_data: &EncryptedData,
327 associated_data: &[Felt],
328 ) -> Result<Vec<Felt>, EncryptionError> {
329 if encrypted_data.data_type != DataType::Elements {
330 return Err(EncryptionError::InvalidDataType {
331 expected: DataType::Elements,
332 found: encrypted_data.data_type,
333 });
334 }
335 self.decrypt_elements_with_associated_data_unchecked(encrypted_data, associated_data)
336 }
337
338 fn decrypt_elements_with_associated_data_unchecked(
340 &self,
341 encrypted_data: &EncryptedData,
342 associated_data: &[Felt],
343 ) -> Result<Vec<Felt>, EncryptionError> {
344 if !encrypted_data.ciphertext.len().is_multiple_of(RATE_WIDTH) {
345 return Err(EncryptionError::CiphertextLenNotMultipleRate);
346 }
347
348 let mut sponge = SpongeState::new(self, &encrypted_data.nonce);
350
351 let padded_associated_data = pad(associated_data);
353 padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
354 sponge.duplex_overwrite(chunk);
355 });
356
357 let mut plaintext = Vec::with_capacity(encrypted_data.ciphertext.len());
359 let mut ciphertext_block_iterator = encrypted_data.ciphertext.chunks_exact(RATE_WIDTH);
360 ciphertext_block_iterator.by_ref().for_each(|ciphertext_data_block| {
361 let keystream = sponge.duplex_add(&[]);
362 for (i, &ciphertext_felt) in ciphertext_data_block.iter().enumerate() {
363 let plaintext_felt = ciphertext_felt - keystream[i];
364 plaintext.push(plaintext_felt);
365 }
366 sponge.state[RATE_RANGE].copy_from_slice(ciphertext_data_block);
367 });
368
369 let computed_tag = sponge.squeeze_tag();
371 if computed_tag != encrypted_data.auth_tag {
372 return Err(EncryptionError::InvalidAuthTag);
373 }
374
375 unpad(plaintext)
377 }
378
379 pub fn decrypt_bytes(
389 &self,
390 encrypted_data: &EncryptedData,
391 ) -> Result<Vec<u8>, EncryptionError> {
392 self.decrypt_bytes_with_associated_data(encrypted_data, &[])
393 }
394
395 pub fn decrypt_bytes_with_associated_data(
402 &self,
403 encrypted_data: &EncryptedData,
404 associated_data: &[u8],
405 ) -> Result<Vec<u8>, EncryptionError> {
406 if encrypted_data.data_type != DataType::Bytes {
407 return Err(EncryptionError::InvalidDataType {
408 expected: DataType::Bytes,
409 found: encrypted_data.data_type,
410 });
411 }
412
413 let ad_felt = bytes_to_elements_with_padding(associated_data);
414 let data_felts =
415 self.decrypt_elements_with_associated_data_unchecked(encrypted_data, &ad_felt)?;
416
417 match padded_elements_to_bytes(&data_felts) {
418 Some(bytes) => Ok(bytes),
419 None => Err(EncryptionError::MalformedPadding),
420 }
421 }
422}
423
424impl Distribution<SecretKey> for StandardUniform {
425 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> SecretKey {
426 let mut res = [ZERO; SECRET_KEY_SIZE];
427 let uni_dist =
428 Uniform::new(0, Felt::ORDER_U64).expect("should not fail given the size of the field");
429 for r in res.iter_mut() {
430 let sampled_integer = uni_dist.sample(rng);
431 *r = Felt::new(sampled_integer);
432 }
433 SecretKey(res)
434 }
435}
436
437impl PartialEq for SecretKey {
438 fn eq(&self, other: &Self) -> bool {
439 let mut result = true;
441 for (a, b) in self.0.iter().zip(other.0.iter()) {
442 result &= bool::from(a.as_canonical_u64().ct_eq(&b.as_canonical_u64()));
443 }
444 result
445 }
446}
447
448impl Eq for SecretKey {}
449
450impl Zeroize for SecretKey {
451 fn zeroize(&mut self) {
464 for element in self.0.iter_mut() {
465 unsafe {
466 core::ptr::write_volatile(element, ZERO);
467 }
468 }
469 core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
470 }
471}
472
473impl Drop for SecretKey {
475 fn drop(&mut self) {
476 self.zeroize();
477 }
478}
479
480impl ZeroizeOnDrop for SecretKey {}
481
482struct SpongeState {
487 state: [Felt; STATE_WIDTH],
488}
489
490impl SpongeState {
491 fn new(sk: &SecretKey, nonce: &Nonce) -> Self {
493 let mut state = [ZERO; STATE_WIDTH];
494
495 state[RATE_RANGE_FIRST_HALF].copy_from_slice(&sk.0);
496 state[RATE_RANGE_SECOND_HALF].copy_from_slice(&nonce.0);
497
498 Self { state }
499 }
500
501 fn duplex_overwrite(&mut self, data: &[Felt]) {
506 self.permute();
507
508 self.state[CAPACITY_RANGE.start] += ONE;
510
511 self.state[RATE_RANGE].copy_from_slice(data);
513 }
514
515 fn duplex_add(&mut self, data: &[Felt]) -> [Felt; RATE_WIDTH] {
520 self.permute();
521
522 let squeezed_data = self.squeeze_rate();
523
524 for (idx, &element) in data.iter().enumerate() {
525 self.state[RATE_START + idx] += element;
526 }
527
528 squeezed_data
529 }
530
531 fn squeeze_tag(&mut self) -> AuthTag {
533 self.permute();
534 AuthTag(
535 self.state[RATE_RANGE_FIRST_HALF]
536 .try_into()
537 .expect("rate first half is exactly AUTH_TAG_SIZE elements"),
538 )
539 }
540
541 fn permute(&mut self) {
543 Rpo256::apply_permutation(&mut self.state);
544 }
545
546 fn squeeze_rate(&self) -> [Felt; RATE_WIDTH] {
548 self.state[RATE_RANGE]
549 .try_into()
550 .expect("rate range is exactly RATE_WIDTH elements")
551 }
552}
553
554#[derive(Clone, Debug, PartialEq, Eq)]
559pub struct Nonce([Felt; NONCE_SIZE]);
560
561impl Nonce {
562 pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
564 rng.sample(StandardUniform)
565 }
566}
567
568impl From<Word> for Nonce {
569 fn from(word: Word) -> Self {
570 Nonce(word.into())
571 }
572}
573
574impl From<[Felt; NONCE_SIZE]> for Nonce {
575 fn from(elements: [Felt; NONCE_SIZE]) -> Self {
576 Nonce(elements)
577 }
578}
579
580impl From<Nonce> for Word {
581 fn from(nonce: Nonce) -> Self {
582 nonce.0.into()
583 }
584}
585
586impl From<Nonce> for [Felt; NONCE_SIZE] {
587 fn from(nonce: Nonce) -> Self {
588 nonce.0
589 }
590}
591
592impl Distribution<Nonce> for StandardUniform {
593 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Nonce {
594 let mut res = [ZERO; NONCE_SIZE];
595 let uni_dist =
596 Uniform::new(0, Felt::ORDER_U64).expect("should not fail given the size of the field");
597 for r in res.iter_mut() {
598 let sampled_integer = uni_dist.sample(rng);
599 *r = Felt::new(sampled_integer);
600 }
601 Nonce(res)
602 }
603}
604
605impl Serializable for SecretKey {
609 fn write_into<W: ByteWriter>(&self, target: &mut W) {
610 let bytes = elements_to_bytes(&self.0);
611 target.write_bytes(&bytes);
612 }
613}
614
615impl Deserializable for SecretKey {
616 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
617 let bytes: [u8; SK_SIZE_BYTES] = source.read_array()?;
618
619 match bytes_to_elements_exact(&bytes) {
620 Some(inner) => {
621 let inner: [Felt; 4] = inner.try_into().map_err(|_| {
622 DeserializationError::InvalidValue("malformed secret key".to_string())
623 })?;
624 Ok(Self(inner))
625 },
626 None => Err(DeserializationError::InvalidValue("malformed secret key".to_string())),
627 }
628 }
629}
630
631impl Serializable for Nonce {
632 fn write_into<W: ByteWriter>(&self, target: &mut W) {
633 let bytes = elements_to_bytes(&self.0);
634 target.write_bytes(&bytes);
635 }
636}
637
638impl Deserializable for Nonce {
639 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
640 let bytes: [u8; NONCE_SIZE_BYTES] = source.read_array()?;
641
642 match bytes_to_elements_exact(&bytes) {
643 Some(inner) => {
644 let inner: [Felt; 4] = inner.try_into().map_err(|_| {
645 DeserializationError::InvalidValue("malformed nonce".to_string())
646 })?;
647 Ok(Self(inner))
648 },
649 None => Err(DeserializationError::InvalidValue("malformed nonce".to_string())),
650 }
651 }
652}
653
654impl Serializable for EncryptedData {
655 fn write_into<W: ByteWriter>(&self, target: &mut W) {
656 target.write_u8(self.data_type as u8);
657 self.ciphertext.write_into(target);
658 target.write_many(self.nonce.0);
659 target.write_many(self.auth_tag.0);
660 }
661}
662
663impl Deserializable for EncryptedData {
664 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
665 let data_type_value: u8 = source.read_u8()?;
666 let data_type = data_type_value.try_into().map_err(|_| {
667 DeserializationError::InvalidValue("invalid data type value".to_string())
668 })?;
669
670 let ciphertext = Vec::<Felt>::read_from(source)?;
671 let nonce: [Felt; NONCE_SIZE] = source.read()?;
672 let auth_tag: [Felt; AUTH_TAG_SIZE] = source.read()?;
673
674 Ok(Self {
675 ciphertext,
676 nonce: Nonce(nonce),
677 auth_tag: AuthTag(auth_tag),
678 data_type,
679 })
680 }
681}
682
683fn pad(data: &[Felt]) -> Vec<Felt> {
712 let num_elem_final_block = data.len() % RATE_WIDTH;
714 let padding_elements = RATE_WIDTH - num_elem_final_block;
715
716 let mut result = data.to_vec();
717 result.extend_from_slice(&PADDING_BLOCK[..padding_elements]);
718
719 result
720}
721
722fn unpad(mut plaintext: Vec<Felt>) -> Result<Vec<Felt>, EncryptionError> {
724 let (num_blocks, remainder) = plaintext.len().div_rem(&RATE_WIDTH);
725 assert_eq!(remainder, 0);
726
727 let final_block: &[Felt; RATE_WIDTH] =
728 plaintext.last_chunk().ok_or(EncryptionError::MalformedPadding)?;
729
730 let pos = match final_block.iter().rposition(|entry| *entry == ONE) {
731 Some(pos) => pos,
732 None => return Err(EncryptionError::MalformedPadding),
733 };
734
735 plaintext.truncate((num_blocks - 1) * RATE_WIDTH + pos);
736
737 Ok(plaintext)
738}
739
740pub struct AeadRpo;
745
746impl AeadScheme for AeadRpo {
747 const KEY_SIZE: usize = SK_SIZE_BYTES;
748
749 type Key = SecretKey;
750
751 fn key_from_bytes(bytes: &[u8]) -> Result<Self::Key, EncryptionError> {
752 SecretKey::read_from_bytes(bytes).map_err(|_| EncryptionError::FailedOperation)
753 }
754
755 fn encrypt_bytes<R: rand::CryptoRng + rand::RngCore>(
756 key: &Self::Key,
757 rng: &mut R,
758 plaintext: &[u8],
759 associated_data: &[u8],
760 ) -> Result<Vec<u8>, EncryptionError> {
761 let nonce = Nonce::with_rng(rng);
762 let encrypted_data = key
763 .encrypt_bytes_with_nonce(plaintext, associated_data, nonce)
764 .map_err(|_| EncryptionError::FailedOperation)?;
765
766 Ok(encrypted_data.to_bytes())
767 }
768
769 fn decrypt_bytes_with_associated_data(
770 key: &Self::Key,
771 ciphertext: &[u8],
772 associated_data: &[u8],
773 ) -> Result<Vec<u8>, EncryptionError> {
774 let encrypted_data = EncryptedData::read_from_bytes(ciphertext)
775 .map_err(|_| EncryptionError::FailedOperation)?;
776
777 key.decrypt_bytes_with_associated_data(&encrypted_data, associated_data)
778 }
779
780 fn encrypt_elements<R: rand::CryptoRng + rand::RngCore>(
784 key: &Self::Key,
785 rng: &mut R,
786 plaintext: &[Felt],
787 associated_data: &[Felt],
788 ) -> Result<Vec<u8>, EncryptionError> {
789 let nonce = Nonce::with_rng(rng);
790 let encrypted_data = key
791 .encrypt_elements_with_nonce(plaintext, associated_data, nonce)
792 .map_err(|_| EncryptionError::FailedOperation)?;
793
794 Ok(encrypted_data.to_bytes())
795 }
796
797 fn decrypt_elements_with_associated_data(
798 key: &Self::Key,
799 ciphertext: &[u8],
800 associated_data: &[Felt],
801 ) -> Result<Vec<Felt>, EncryptionError> {
802 let encrypted_data = EncryptedData::read_from_bytes(ciphertext)
803 .map_err(|_| EncryptionError::FailedOperation)?;
804
805 key.decrypt_elements_with_associated_data(&encrypted_data, associated_data)
806 }
807}