1use alloc::{string::ToString, vec::Vec};
11use core::ops::Range;
12
13use miden_crypto_derive::{SilentDebug, SilentDisplay};
14use num::Integer;
15use rand::{
16 Rng,
17 distr::{Distribution, StandardUniform, Uniform},
18};
19use subtle::ConstantTimeEq;
20
21use crate::{
22 Felt, ONE, Word, ZERO,
23 aead::{AeadScheme, DataType, EncryptionError},
24 hash::poseidon2::Poseidon2,
25 utils::{
26 ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
27 bytes_to_elements_exact, bytes_to_elements_with_padding, elements_to_bytes,
28 padded_elements_to_bytes,
29 zeroize::{Zeroize, ZeroizeOnDrop},
30 },
31};
32
33#[cfg(test)]
34mod test;
35
36pub const SECRET_KEY_SIZE: usize = 4;
41
42pub const SK_SIZE_BYTES: usize = SECRET_KEY_SIZE * Felt::NUM_BYTES;
44
45pub const NONCE_SIZE: usize = 4;
47
48pub const NONCE_SIZE_BYTES: usize = NONCE_SIZE * Felt::NUM_BYTES;
50
51pub const AUTH_TAG_SIZE: usize = 4;
53
54const STATE_WIDTH: usize = Poseidon2::STATE_WIDTH;
56
57const CAPACITY_RANGE: Range<usize> = Poseidon2::CAPACITY_RANGE;
59
60const RATE_RANGE: Range<usize> = Poseidon2::RATE_RANGE;
62
63const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
65
66const HALF_RATE_WIDTH: usize = (Poseidon2::RATE_RANGE.end - Poseidon2::RATE_RANGE.start) / 2;
68
69const RATE_RANGE_FIRST_HALF: Range<usize> =
71 Poseidon2::RATE_RANGE.start..Poseidon2::RATE_RANGE.start + HALF_RATE_WIDTH;
72
73const RATE_RANGE_SECOND_HALF: Range<usize> =
75 Poseidon2::RATE_RANGE.start + HALF_RATE_WIDTH..Poseidon2::RATE_RANGE.end;
76
77const RATE_START: usize = Poseidon2::RATE_RANGE.start;
79
80const PADDING_BLOCK: [Felt; RATE_WIDTH] = [ONE, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO];
82
83#[derive(Debug, PartialEq, Eq)]
88pub struct EncryptedData {
89 data_type: DataType,
91 ciphertext: Vec<Felt>,
93 auth_tag: AuthTag,
96 nonce: Nonce,
98}
99
100impl EncryptedData {
101 pub fn from_parts(
103 data_type: DataType,
104 ciphertext: Vec<Felt>,
105 auth_tag: AuthTag,
106 nonce: Nonce,
107 ) -> Self {
108 Self { data_type, ciphertext, auth_tag, nonce }
109 }
110
111 pub fn data_type(&self) -> DataType {
113 self.data_type
114 }
115
116 pub fn ciphertext(&self) -> &[Felt] {
118 &self.ciphertext
119 }
120
121 pub fn auth_tag(&self) -> &AuthTag {
123 &self.auth_tag
124 }
125
126 pub fn nonce(&self) -> &Nonce {
128 &self.nonce
129 }
130}
131
132#[derive(Debug, Default, Clone)]
134pub struct AuthTag([Felt; AUTH_TAG_SIZE]);
135
136impl AuthTag {
137 pub fn new(elements: [Felt; AUTH_TAG_SIZE]) -> Self {
139 Self(elements)
140 }
141
142 pub fn to_elements(&self) -> [Felt; AUTH_TAG_SIZE] {
144 self.0
145 }
146}
147
148impl PartialEq for AuthTag {
149 fn eq(&self, other: &Self) -> bool {
150 let mut result = true;
153 for (a, b) in self.0.iter().zip(other.0.iter()) {
154 result &= bool::from(a.as_canonical_u64_ct().ct_eq(&b.as_canonical_u64_ct()));
155 }
156 result
157 }
158}
159
160impl Eq for AuthTag {}
161
162#[derive(Clone, SilentDebug, SilentDisplay)]
164pub struct SecretKey([Felt; SECRET_KEY_SIZE]);
165
166impl SecretKey {
167 #[cfg(feature = "std")]
172 #[allow(clippy::new_without_default)]
173 pub fn new() -> Self {
174 let mut rng = rand::rng();
175 Self::with_rng(&mut rng)
176 }
177
178 pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
180 rng.sample(StandardUniform)
181 }
182
183 pub fn from_elements(elements: [Felt; SECRET_KEY_SIZE]) -> Self {
191 Self(elements)
192 }
193
194 pub fn to_elements(&self) -> [Felt; SECRET_KEY_SIZE] {
203 self.0
204 }
205
206 #[cfg(feature = "std")]
212 pub fn encrypt_elements(&self, data: &[Felt]) -> Result<EncryptedData, EncryptionError> {
213 self.encrypt_elements_with_associated_data(data, &[])
214 }
215
216 #[cfg(feature = "std")]
219 pub fn encrypt_elements_with_associated_data(
220 &self,
221 data: &[Felt],
222 associated_data: &[Felt],
223 ) -> Result<EncryptedData, EncryptionError> {
224 let mut rng = rand::rng();
225 let nonce = Nonce::with_rng(&mut rng);
226
227 self.encrypt_elements_with_nonce(data, associated_data, nonce)
228 }
229
230 pub fn encrypt_elements_with_nonce(
233 &self,
234 data: &[Felt],
235 associated_data: &[Felt],
236 nonce: Nonce,
237 ) -> Result<EncryptedData, EncryptionError> {
238 let mut sponge = SpongeState::new(self, &nonce);
240
241 let padded_associated_data = pad(associated_data);
243 padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
244 sponge.duplex_overwrite(chunk);
245 });
246
247 let mut ciphertext = Vec::with_capacity(data.len() + RATE_WIDTH);
249 let data = pad(data);
250 let mut data_block_iterator = data.chunks_exact(RATE_WIDTH);
251
252 data_block_iterator.by_ref().for_each(|data_block| {
253 let keystream = sponge.duplex_add(data_block);
254 for (i, &plaintext_felt) in data_block.iter().enumerate() {
255 ciphertext.push(plaintext_felt + keystream[i]);
256 }
257 });
258
259 let auth_tag = sponge.squeeze_tag();
261
262 Ok(EncryptedData {
263 data_type: DataType::Elements,
264 ciphertext,
265 auth_tag,
266 nonce,
267 })
268 }
269
270 #[cfg(feature = "std")]
277 pub fn encrypt_bytes(&self, data: &[u8]) -> Result<EncryptedData, EncryptionError> {
278 self.encrypt_bytes_with_associated_data(data, &[])
279 }
280
281 #[cfg(feature = "std")]
287 pub fn encrypt_bytes_with_associated_data(
288 &self,
289 data: &[u8],
290 associated_data: &[u8],
291 ) -> Result<EncryptedData, EncryptionError> {
292 let mut rng = rand::rng();
293 let nonce = Nonce::with_rng(&mut rng);
294
295 self.encrypt_bytes_with_nonce(data, associated_data, nonce)
296 }
297
298 pub fn encrypt_bytes_with_nonce(
304 &self,
305 data: &[u8],
306 associated_data: &[u8],
307 nonce: Nonce,
308 ) -> Result<EncryptedData, EncryptionError> {
309 let data_felt = bytes_to_elements_with_padding(data);
310 let ad_felt = bytes_to_elements_with_padding(associated_data);
311
312 let mut encrypted_data = self.encrypt_elements_with_nonce(&data_felt, &ad_felt, nonce)?;
313 encrypted_data.data_type = DataType::Bytes;
314 Ok(encrypted_data)
315 }
316
317 pub fn decrypt_elements(
326 &self,
327 encrypted_data: &EncryptedData,
328 ) -> Result<Vec<Felt>, EncryptionError> {
329 self.decrypt_elements_with_associated_data(encrypted_data, &[])
330 }
331
332 pub fn decrypt_elements_with_associated_data(
338 &self,
339 encrypted_data: &EncryptedData,
340 associated_data: &[Felt],
341 ) -> Result<Vec<Felt>, EncryptionError> {
342 if encrypted_data.data_type != DataType::Elements {
343 return Err(EncryptionError::InvalidDataType {
344 expected: DataType::Elements,
345 found: encrypted_data.data_type,
346 });
347 }
348 self.decrypt_elements_with_associated_data_unchecked(encrypted_data, associated_data)
349 }
350
351 fn decrypt_elements_with_associated_data_unchecked(
353 &self,
354 encrypted_data: &EncryptedData,
355 associated_data: &[Felt],
356 ) -> Result<Vec<Felt>, EncryptionError> {
357 if !encrypted_data.ciphertext.len().is_multiple_of(RATE_WIDTH) {
358 return Err(EncryptionError::CiphertextLenNotMultipleRate);
359 }
360
361 let mut sponge = SpongeState::new(self, &encrypted_data.nonce);
363
364 let padded_associated_data = pad(associated_data);
366 padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
367 sponge.duplex_overwrite(chunk);
368 });
369
370 let mut plaintext = Vec::with_capacity(encrypted_data.ciphertext.len());
372 let mut ciphertext_block_iterator = encrypted_data.ciphertext.chunks_exact(RATE_WIDTH);
373 ciphertext_block_iterator.by_ref().for_each(|ciphertext_data_block| {
374 let keystream = sponge.duplex_add(&[]);
375 for (i, &ciphertext_felt) in ciphertext_data_block.iter().enumerate() {
376 let plaintext_felt = ciphertext_felt - keystream[i];
377 plaintext.push(plaintext_felt);
378 }
379 sponge.state[RATE_RANGE].copy_from_slice(ciphertext_data_block);
380 });
381
382 let computed_tag = sponge.squeeze_tag();
384 if computed_tag != encrypted_data.auth_tag {
385 return Err(EncryptionError::InvalidAuthTag);
386 }
387
388 unpad(plaintext)
390 }
391
392 pub fn decrypt_bytes(
402 &self,
403 encrypted_data: &EncryptedData,
404 ) -> Result<Vec<u8>, EncryptionError> {
405 self.decrypt_bytes_with_associated_data(encrypted_data, &[])
406 }
407
408 pub fn decrypt_bytes_with_associated_data(
415 &self,
416 encrypted_data: &EncryptedData,
417 associated_data: &[u8],
418 ) -> Result<Vec<u8>, EncryptionError> {
419 if encrypted_data.data_type != DataType::Bytes {
420 return Err(EncryptionError::InvalidDataType {
421 expected: DataType::Bytes,
422 found: encrypted_data.data_type,
423 });
424 }
425
426 let ad_felt = bytes_to_elements_with_padding(associated_data);
427 let data_felts =
428 self.decrypt_elements_with_associated_data_unchecked(encrypted_data, &ad_felt)?;
429
430 match padded_elements_to_bytes(&data_felts) {
431 Some(bytes) => Ok(bytes),
432 None => Err(EncryptionError::MalformedPadding),
433 }
434 }
435}
436
437impl Distribution<SecretKey> for StandardUniform {
438 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> SecretKey {
439 let mut res = [ZERO; SECRET_KEY_SIZE];
440 let uni_dist =
441 Uniform::new(0, Felt::ORDER).expect("should not fail given the size of the field");
442 for r in res.iter_mut() {
443 let sampled_integer = uni_dist.sample(rng);
444 *r = Felt::new(sampled_integer);
445 }
446 SecretKey(res)
447 }
448}
449
450#[cfg(any(test, feature = "testing"))]
451impl PartialEq for SecretKey {
452 fn eq(&self, other: &Self) -> bool {
453 let mut result = true;
455 for (a, b) in self.0.iter().zip(other.0.iter()) {
456 result &= bool::from(a.as_canonical_u64_ct().ct_eq(&b.as_canonical_u64_ct()));
457 }
458 result
459 }
460}
461
462#[cfg(any(test, feature = "testing"))]
463impl Eq for SecretKey {}
464
465impl Zeroize for SecretKey {
466 fn zeroize(&mut self) {
479 for element in self.0.iter_mut() {
480 unsafe {
481 core::ptr::write_volatile(element, ZERO);
482 }
483 }
484 core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
485 }
486}
487
488impl Drop for SecretKey {
490 fn drop(&mut self) {
491 self.zeroize();
492 }
493}
494
495impl ZeroizeOnDrop for SecretKey {}
496
497struct SpongeState {
502 state: [Felt; STATE_WIDTH],
503}
504
505impl SpongeState {
506 fn new(sk: &SecretKey, nonce: &Nonce) -> Self {
508 let mut state = [ZERO; STATE_WIDTH];
509
510 state[RATE_RANGE_FIRST_HALF].copy_from_slice(&sk.0);
511 state[RATE_RANGE_SECOND_HALF].copy_from_slice(&nonce.0);
512
513 Self { state }
514 }
515
516 fn duplex_overwrite(&mut self, data: &[Felt]) {
521 self.permute();
522
523 self.state[CAPACITY_RANGE.start] += ONE;
525
526 self.state[RATE_RANGE].copy_from_slice(data);
528 }
529
530 fn duplex_add(&mut self, data: &[Felt]) -> [Felt; RATE_WIDTH] {
535 self.permute();
536
537 let squeezed_data = self.squeeze_rate();
538
539 for (idx, &element) in data.iter().enumerate() {
540 self.state[RATE_START + idx] += element;
541 }
542
543 squeezed_data
544 }
545
546 fn squeeze_tag(&mut self) -> AuthTag {
548 self.permute();
549 AuthTag(
550 self.state[RATE_RANGE_FIRST_HALF]
551 .try_into()
552 .expect("rate first half is exactly AUTH_TAG_SIZE elements"),
553 )
554 }
555
556 fn permute(&mut self) {
558 Poseidon2::apply_permutation(&mut self.state);
559 }
560
561 fn squeeze_rate(&self) -> [Felt; RATE_WIDTH] {
563 self.state[RATE_RANGE]
564 .try_into()
565 .expect("rate range is exactly RATE_WIDTH elements")
566 }
567}
568
569#[derive(Clone, Debug, PartialEq, Eq)]
574pub struct Nonce([Felt; NONCE_SIZE]);
575
576impl Nonce {
577 pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
579 rng.sample(StandardUniform)
580 }
581}
582
583impl From<Word> for Nonce {
584 fn from(word: Word) -> Self {
585 Nonce(word.into())
586 }
587}
588
589impl From<[Felt; NONCE_SIZE]> for Nonce {
590 fn from(elements: [Felt; NONCE_SIZE]) -> Self {
591 Nonce(elements)
592 }
593}
594
595impl From<Nonce> for Word {
596 fn from(nonce: Nonce) -> Self {
597 nonce.0.into()
598 }
599}
600
601impl From<Nonce> for [Felt; NONCE_SIZE] {
602 fn from(nonce: Nonce) -> Self {
603 nonce.0
604 }
605}
606
607impl Distribution<Nonce> for StandardUniform {
608 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Nonce {
609 let mut res = [ZERO; NONCE_SIZE];
610 let uni_dist =
611 Uniform::new(0, Felt::ORDER).expect("should not fail given the size of the field");
612 for r in res.iter_mut() {
613 let sampled_integer = uni_dist.sample(rng);
614 *r = Felt::new(sampled_integer);
615 }
616 Nonce(res)
617 }
618}
619
620impl Serializable for SecretKey {
624 fn write_into<W: ByteWriter>(&self, target: &mut W) {
625 let bytes = elements_to_bytes(&self.0);
626 target.write_bytes(&bytes);
627 }
628}
629
630impl Deserializable for SecretKey {
631 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
632 let bytes: [u8; SK_SIZE_BYTES] = source.read_array()?;
633
634 match bytes_to_elements_exact(&bytes) {
635 Some(inner) => {
636 let inner: [Felt; 4] = inner.try_into().map_err(|_| {
637 DeserializationError::InvalidValue("malformed secret key".to_string())
638 })?;
639 Ok(Self(inner))
640 },
641 None => Err(DeserializationError::InvalidValue("malformed secret key".to_string())),
642 }
643 }
644}
645
646impl Serializable for Nonce {
647 fn write_into<W: ByteWriter>(&self, target: &mut W) {
648 let bytes = elements_to_bytes(&self.0);
649 target.write_bytes(&bytes);
650 }
651}
652
653impl Deserializable for Nonce {
654 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
655 let bytes: [u8; NONCE_SIZE_BYTES] = source.read_array()?;
656
657 match bytes_to_elements_exact(&bytes) {
658 Some(inner) => {
659 let inner: [Felt; 4] = inner.try_into().map_err(|_| {
660 DeserializationError::InvalidValue("malformed nonce".to_string())
661 })?;
662 Ok(Self(inner))
663 },
664 None => Err(DeserializationError::InvalidValue("malformed nonce".to_string())),
665 }
666 }
667}
668
669impl Serializable for EncryptedData {
670 fn write_into<W: ByteWriter>(&self, target: &mut W) {
671 target.write_u8(self.data_type as u8);
672 self.ciphertext.write_into(target);
673 target.write_many(self.nonce.0);
674 target.write_many(self.auth_tag.0);
675 }
676}
677
678impl Deserializable for EncryptedData {
679 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
680 let data_type_value: u8 = source.read_u8()?;
681 let data_type = data_type_value.try_into().map_err(|_| {
682 DeserializationError::InvalidValue("invalid data type value".to_string())
683 })?;
684
685 let ciphertext = Vec::<Felt>::read_from(source)?;
686 let nonce: [Felt; NONCE_SIZE] = source.read()?;
687 let auth_tag: [Felt; AUTH_TAG_SIZE] = source.read()?;
688
689 Ok(Self {
690 ciphertext,
691 nonce: Nonce(nonce),
692 auth_tag: AuthTag(auth_tag),
693 data_type,
694 })
695 }
696}
697
698fn pad(data: &[Felt]) -> Vec<Felt> {
727 let num_elem_final_block = data.len() % RATE_WIDTH;
729 let padding_elements = RATE_WIDTH - num_elem_final_block;
730
731 let mut result = data.to_vec();
732 result.extend_from_slice(&PADDING_BLOCK[..padding_elements]);
733
734 result
735}
736
737fn unpad(mut plaintext: Vec<Felt>) -> Result<Vec<Felt>, EncryptionError> {
739 let (num_blocks, remainder) = plaintext.len().div_rem(&RATE_WIDTH);
740 assert_eq!(remainder, 0);
741
742 let final_block: &[Felt; RATE_WIDTH] =
743 plaintext.last_chunk().ok_or(EncryptionError::MalformedPadding)?;
744
745 let pos = match final_block.iter().rposition(|entry| *entry == ONE) {
746 Some(pos) => pos,
747 None => return Err(EncryptionError::MalformedPadding),
748 };
749
750 plaintext.truncate((num_blocks - 1) * RATE_WIDTH + pos);
751
752 Ok(plaintext)
753}
754
755pub struct AeadPoseidon2;
760
761impl AeadScheme for AeadPoseidon2 {
762 const KEY_SIZE: usize = SK_SIZE_BYTES;
763
764 type Key = SecretKey;
765
766 fn key_from_bytes(bytes: &[u8]) -> Result<Self::Key, EncryptionError> {
767 if bytes.len() != SK_SIZE_BYTES {
768 return Err(EncryptionError::FailedOperation);
769 }
770
771 SecretKey::read_from_bytes_with_budget(bytes, SK_SIZE_BYTES)
772 .map_err(|_| EncryptionError::FailedOperation)
773 }
774
775 fn encrypt_bytes<R: rand::CryptoRng + rand::RngCore>(
776 key: &Self::Key,
777 rng: &mut R,
778 plaintext: &[u8],
779 associated_data: &[u8],
780 ) -> Result<Vec<u8>, EncryptionError> {
781 let nonce = Nonce::with_rng(rng);
782 let encrypted_data = key
783 .encrypt_bytes_with_nonce(plaintext, associated_data, nonce)
784 .map_err(|_| EncryptionError::FailedOperation)?;
785
786 Ok(encrypted_data.to_bytes())
787 }
788
789 fn decrypt_bytes_with_associated_data(
790 key: &Self::Key,
791 ciphertext: &[u8],
792 associated_data: &[u8],
793 ) -> Result<Vec<u8>, EncryptionError> {
794 let encrypted_data =
795 EncryptedData::read_from_bytes_with_budget(ciphertext, ciphertext.len())
796 .map_err(|_| EncryptionError::FailedOperation)?;
797
798 key.decrypt_bytes_with_associated_data(&encrypted_data, associated_data)
799 }
800
801 fn encrypt_elements<R: rand::CryptoRng + rand::RngCore>(
805 key: &Self::Key,
806 rng: &mut R,
807 plaintext: &[Felt],
808 associated_data: &[Felt],
809 ) -> Result<Vec<u8>, EncryptionError> {
810 let nonce = Nonce::with_rng(rng);
811 let encrypted_data = key
812 .encrypt_elements_with_nonce(plaintext, associated_data, nonce)
813 .map_err(|_| EncryptionError::FailedOperation)?;
814
815 Ok(encrypted_data.to_bytes())
816 }
817
818 fn decrypt_elements_with_associated_data(
819 key: &Self::Key,
820 ciphertext: &[u8],
821 associated_data: &[Felt],
822 ) -> Result<Vec<Felt>, EncryptionError> {
823 let encrypted_data =
824 EncryptedData::read_from_bytes_with_budget(ciphertext, ciphertext.len())
825 .map_err(|_| EncryptionError::FailedOperation)?;
826
827 key.decrypt_elements_with_associated_data(&encrypted_data, associated_data)
828 }
829}