1use alloc::{string::ToString, vec::Vec};
11use core::ops::Range;
12
13use miden_crypto_derive::{SilentDebug, SilentDisplay};
14use num::Integer;
15use rand::{
16 Rng,
17 distr::{Distribution, StandardUniform, Uniform},
18};
19use subtle::ConstantTimeEq;
20
21use crate::{
22 Felt, FieldElement, ONE, StarkField, Word, ZERO,
23 aead::{AeadScheme, DataType, EncryptionError},
24 hash::rpo::Rpo256,
25 utils::{
26 ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
27 bytes_to_elements_exact, bytes_to_elements_with_padding, elements_to_bytes,
28 padded_elements_to_bytes,
29 },
30 zeroize::{Zeroize, ZeroizeOnDrop},
31};
32
33#[cfg(test)]
34mod test;
35
36pub const SECRET_KEY_SIZE: usize = 4;
41
42pub const SK_SIZE_BYTES: usize = SECRET_KEY_SIZE * Felt::ELEMENT_BYTES;
44
45pub const NONCE_SIZE: usize = 4;
47
48pub const NONCE_SIZE_BYTES: usize = NONCE_SIZE * Felt::ELEMENT_BYTES;
50
51pub const AUTH_TAG_SIZE: usize = 4;
53
54const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
56
57const CAPACITY_RANGE: Range<usize> = Rpo256::CAPACITY_RANGE;
59
60const RATE_RANGE: Range<usize> = Rpo256::RATE_RANGE;
62
63const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
65
66const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
68
69const RATE_RANGE_FIRST_HALF: Range<usize> =
71 Rpo256::RATE_RANGE.start..Rpo256::RATE_RANGE.start + HALF_RATE_WIDTH;
72
73const RATE_RANGE_SECOND_HALF: Range<usize> =
75 Rpo256::RATE_RANGE.start + HALF_RATE_WIDTH..Rpo256::RATE_RANGE.end;
76
77const RATE_START: usize = Rpo256::RATE_RANGE.start;
79
80const PADDING_BLOCK: [Felt; RATE_WIDTH] = [ONE, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO];
82
83#[derive(Debug, PartialEq, Eq)]
88pub struct EncryptedData {
89 data_type: DataType,
91 ciphertext: Vec<Felt>,
93 auth_tag: AuthTag,
96 nonce: Nonce,
98}
99
100#[derive(Debug, Default, Clone, PartialEq, Eq)]
102pub struct AuthTag([Felt; AUTH_TAG_SIZE]);
103
104#[derive(Clone, SilentDebug, SilentDisplay)]
106pub struct SecretKey([Felt; SECRET_KEY_SIZE]);
107
108impl SecretKey {
109 #[cfg(feature = "std")]
114 #[allow(clippy::new_without_default)]
115 pub fn new() -> Self {
116 use rand::{SeedableRng, rngs::StdRng};
117 let mut rng = StdRng::from_os_rng();
118
119 Self::with_rng(&mut rng)
120 }
121
122 pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
124 rng.sample(StandardUniform)
125 }
126
127 pub fn from_elements(elements: [Felt; SECRET_KEY_SIZE]) -> Self {
135 Self(elements)
136 }
137
138 pub fn to_elements(&self) -> [Felt; SECRET_KEY_SIZE] {
147 self.0
148 }
149
150 #[cfg(feature = "std")]
156 pub fn encrypt_elements(&self, data: &[Felt]) -> Result<EncryptedData, EncryptionError> {
157 self.encrypt_elements_with_associated_data(data, &[])
158 }
159
160 #[cfg(feature = "std")]
163 pub fn encrypt_elements_with_associated_data(
164 &self,
165 data: &[Felt],
166 associated_data: &[Felt],
167 ) -> Result<EncryptedData, EncryptionError> {
168 use rand::{SeedableRng, rngs::StdRng};
169 let mut rng = StdRng::from_os_rng();
170 let nonce = Nonce::with_rng(&mut rng);
171
172 self.encrypt_elements_with_nonce(data, associated_data, nonce)
173 }
174
175 pub fn encrypt_elements_with_nonce(
178 &self,
179 data: &[Felt],
180 associated_data: &[Felt],
181 nonce: Nonce,
182 ) -> Result<EncryptedData, EncryptionError> {
183 let mut sponge = SpongeState::new(self, &nonce);
185
186 let padded_associated_data = pad(associated_data);
188 padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
189 sponge.duplex_overwrite(chunk);
190 });
191
192 let mut ciphertext = Vec::with_capacity(data.len() + RATE_WIDTH);
194 let data = pad(data);
195 let mut data_block_iterator = data.chunks_exact(RATE_WIDTH);
196
197 data_block_iterator.by_ref().for_each(|data_block| {
198 let keystream = sponge.duplex_add(data_block);
199 for (i, &plaintext_felt) in data_block.iter().enumerate() {
200 ciphertext.push(plaintext_felt + keystream[i]);
201 }
202 });
203
204 let auth_tag = sponge.squeeze_tag();
206
207 Ok(EncryptedData {
208 data_type: DataType::Elements,
209 ciphertext,
210 auth_tag,
211 nonce,
212 })
213 }
214
215 #[cfg(feature = "std")]
222 pub fn encrypt_bytes(&self, data: &[u8]) -> Result<EncryptedData, EncryptionError> {
223 self.encrypt_bytes_with_associated_data(data, &[])
224 }
225
226 #[cfg(feature = "std")]
232 pub fn encrypt_bytes_with_associated_data(
233 &self,
234 data: &[u8],
235 associated_data: &[u8],
236 ) -> Result<EncryptedData, EncryptionError> {
237 use rand::{SeedableRng, rngs::StdRng};
238 let mut rng = StdRng::from_os_rng();
239 let nonce = Nonce::with_rng(&mut rng);
240
241 self.encrypt_bytes_with_nonce(data, associated_data, nonce)
242 }
243
244 pub fn encrypt_bytes_with_nonce(
250 &self,
251 data: &[u8],
252 associated_data: &[u8],
253 nonce: Nonce,
254 ) -> Result<EncryptedData, EncryptionError> {
255 let data_felt = bytes_to_elements_with_padding(data);
256 let ad_felt = bytes_to_elements_with_padding(associated_data);
257
258 let mut encrypted_data = self.encrypt_elements_with_nonce(&data_felt, &ad_felt, nonce)?;
259 encrypted_data.data_type = DataType::Bytes;
260 Ok(encrypted_data)
261 }
262
263 pub fn decrypt_elements(
272 &self,
273 encrypted_data: &EncryptedData,
274 ) -> Result<Vec<Felt>, EncryptionError> {
275 self.decrypt_elements_with_associated_data(encrypted_data, &[])
276 }
277
278 pub fn decrypt_elements_with_associated_data(
284 &self,
285 encrypted_data: &EncryptedData,
286 associated_data: &[Felt],
287 ) -> Result<Vec<Felt>, EncryptionError> {
288 if encrypted_data.data_type != DataType::Elements {
289 return Err(EncryptionError::InvalidDataType {
290 expected: DataType::Elements,
291 found: encrypted_data.data_type,
292 });
293 }
294 self.decrypt_elements_with_associated_data_unchecked(encrypted_data, associated_data)
295 }
296
297 fn decrypt_elements_with_associated_data_unchecked(
299 &self,
300 encrypted_data: &EncryptedData,
301 associated_data: &[Felt],
302 ) -> Result<Vec<Felt>, EncryptionError> {
303 if !encrypted_data.ciphertext.len().is_multiple_of(RATE_WIDTH) {
304 return Err(EncryptionError::CiphertextLenNotMultipleRate);
305 }
306
307 let mut sponge = SpongeState::new(self, &encrypted_data.nonce);
309
310 let padded_associated_data = pad(associated_data);
312 padded_associated_data.chunks(RATE_WIDTH).for_each(|chunk| {
313 sponge.duplex_overwrite(chunk);
314 });
315
316 let mut plaintext = Vec::with_capacity(encrypted_data.ciphertext.len());
318 let mut ciphertext_block_iterator = encrypted_data.ciphertext.chunks_exact(RATE_WIDTH);
319 ciphertext_block_iterator.by_ref().for_each(|ciphertext_data_block| {
320 let keystream = sponge.duplex_add(&[]);
321 for (i, &ciphertext_felt) in ciphertext_data_block.iter().enumerate() {
322 let plaintext_felt = ciphertext_felt - keystream[i];
323 plaintext.push(plaintext_felt);
324 }
325 sponge.state[RATE_RANGE].copy_from_slice(ciphertext_data_block);
326 });
327
328 let computed_tag = sponge.squeeze_tag();
330 if computed_tag != encrypted_data.auth_tag {
331 return Err(EncryptionError::InvalidAuthTag);
332 }
333
334 unpad(plaintext)
336 }
337
338 pub fn decrypt_bytes(
348 &self,
349 encrypted_data: &EncryptedData,
350 ) -> Result<Vec<u8>, EncryptionError> {
351 self.decrypt_bytes_with_associated_data(encrypted_data, &[])
352 }
353
354 pub fn decrypt_bytes_with_associated_data(
361 &self,
362 encrypted_data: &EncryptedData,
363 associated_data: &[u8],
364 ) -> Result<Vec<u8>, EncryptionError> {
365 if encrypted_data.data_type != DataType::Bytes {
366 return Err(EncryptionError::InvalidDataType {
367 expected: DataType::Bytes,
368 found: encrypted_data.data_type,
369 });
370 }
371
372 let ad_felt = bytes_to_elements_with_padding(associated_data);
373 let data_felts =
374 self.decrypt_elements_with_associated_data_unchecked(encrypted_data, &ad_felt)?;
375
376 match padded_elements_to_bytes(&data_felts) {
377 Some(bytes) => Ok(bytes),
378 None => Err(EncryptionError::MalformedPadding),
379 }
380 }
381}
382
383impl Distribution<SecretKey> for StandardUniform {
384 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> SecretKey {
385 let mut res = [ZERO; SECRET_KEY_SIZE];
386 let uni_dist =
387 Uniform::new(0, Felt::MODULUS).expect("should not fail given the size of the field");
388 for r in res.iter_mut() {
389 let sampled_integer = uni_dist.sample(rng);
390 *r = Felt::new(sampled_integer);
391 }
392 SecretKey(res)
393 }
394}
395
396impl PartialEq for SecretKey {
397 fn eq(&self, other: &Self) -> bool {
398 let mut result = true;
400 for (a, b) in self.0.iter().zip(other.0.iter()) {
401 result &= bool::from(a.as_int().ct_eq(&b.as_int()));
402 }
403 result
404 }
405}
406
407impl Eq for SecretKey {}
408
409impl Zeroize for SecretKey {
410 fn zeroize(&mut self) {
423 for element in self.0.iter_mut() {
424 unsafe {
425 core::ptr::write_volatile(element, ZERO);
426 }
427 }
428 core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
429 }
430}
431
432impl Drop for SecretKey {
434 fn drop(&mut self) {
435 self.zeroize();
436 }
437}
438
439impl ZeroizeOnDrop for SecretKey {}
440
441struct SpongeState {
446 state: [Felt; STATE_WIDTH],
447}
448
449impl SpongeState {
450 fn new(sk: &SecretKey, nonce: &Nonce) -> Self {
452 let mut state = [ZERO; STATE_WIDTH];
453
454 state[RATE_RANGE_FIRST_HALF].copy_from_slice(&sk.0);
455 state[RATE_RANGE_SECOND_HALF].copy_from_slice(&nonce.0);
456
457 Self { state }
458 }
459
460 fn duplex_overwrite(&mut self, data: &[Felt]) {
465 self.permute();
466
467 self.state[CAPACITY_RANGE.start] += ONE;
469
470 self.state[RATE_RANGE].copy_from_slice(data);
472 }
473
474 fn duplex_add(&mut self, data: &[Felt]) -> [Felt; RATE_WIDTH] {
479 self.permute();
480
481 let squeezed_data = self.squeeze_rate();
482
483 for (idx, &element) in data.iter().enumerate() {
484 self.state[RATE_START + idx] += element;
485 }
486
487 squeezed_data
488 }
489
490 fn squeeze_tag(&mut self) -> AuthTag {
492 self.permute();
493 AuthTag(
494 self.state[RATE_RANGE_FIRST_HALF]
495 .try_into()
496 .expect("rate first half is exactly AUTH_TAG_SIZE elements"),
497 )
498 }
499
500 fn permute(&mut self) {
502 Rpo256::apply_permutation(&mut self.state);
503 }
504
505 fn squeeze_rate(&self) -> [Felt; RATE_WIDTH] {
507 self.state[RATE_RANGE]
508 .try_into()
509 .expect("rate range is exactly RATE_WIDTH elements")
510 }
511}
512
513#[derive(Clone, Debug, PartialEq, Eq)]
518pub struct Nonce([Felt; NONCE_SIZE]);
519
520impl Nonce {
521 pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
523 rng.sample(StandardUniform)
524 }
525}
526
527impl From<Word> for Nonce {
528 fn from(word: Word) -> Self {
529 Nonce(word.into())
530 }
531}
532
533impl From<[Felt; NONCE_SIZE]> for Nonce {
534 fn from(elements: [Felt; NONCE_SIZE]) -> Self {
535 Nonce(elements)
536 }
537}
538
539impl From<Nonce> for Word {
540 fn from(nonce: Nonce) -> Self {
541 nonce.0.into()
542 }
543}
544
545impl From<Nonce> for [Felt; NONCE_SIZE] {
546 fn from(nonce: Nonce) -> Self {
547 nonce.0
548 }
549}
550
551impl Distribution<Nonce> for StandardUniform {
552 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Nonce {
553 let mut res = [ZERO; NONCE_SIZE];
554 let uni_dist =
555 Uniform::new(0, Felt::MODULUS).expect("should not fail given the size of the field");
556 for r in res.iter_mut() {
557 let sampled_integer = uni_dist.sample(rng);
558 *r = Felt::new(sampled_integer);
559 }
560 Nonce(res)
561 }
562}
563
564impl Serializable for SecretKey {
568 fn write_into<W: ByteWriter>(&self, target: &mut W) {
569 let bytes = elements_to_bytes(&self.0);
570 target.write_bytes(&bytes);
571 }
572}
573
574impl Deserializable for SecretKey {
575 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
576 let bytes: [u8; SK_SIZE_BYTES] = source.read_array()?;
577
578 match bytes_to_elements_exact(&bytes) {
579 Some(inner) => {
580 let inner: [Felt; 4] = inner.try_into().map_err(|_| {
581 DeserializationError::InvalidValue("malformed secret key".to_string())
582 })?;
583 Ok(Self(inner))
584 },
585 None => Err(DeserializationError::InvalidValue("malformed secret key".to_string())),
586 }
587 }
588}
589
590impl Serializable for Nonce {
591 fn write_into<W: ByteWriter>(&self, target: &mut W) {
592 let bytes = elements_to_bytes(&self.0);
593 target.write_bytes(&bytes);
594 }
595}
596
597impl Deserializable for Nonce {
598 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
599 let bytes: [u8; NONCE_SIZE_BYTES] = source.read_array()?;
600
601 match bytes_to_elements_exact(&bytes) {
602 Some(inner) => {
603 let inner: [Felt; 4] = inner.try_into().map_err(|_| {
604 DeserializationError::InvalidValue("malformed nonce".to_string())
605 })?;
606 Ok(Self(inner))
607 },
608 None => Err(DeserializationError::InvalidValue("malformed nonce".to_string())),
609 }
610 }
611}
612
613impl Serializable for EncryptedData {
614 fn write_into<W: ByteWriter>(&self, target: &mut W) {
615 target.write_u8(self.data_type as u8);
617 target.write_usize(self.ciphertext.len());
618 target.write_many(self.ciphertext.iter().map(Felt::as_int));
619 target.write_many(self.nonce.0.iter().map(Felt::as_int));
620 target.write_many(self.auth_tag.0.iter().map(Felt::as_int));
621 }
622}
623
624impl Deserializable for EncryptedData {
625 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
626 let data_type_value: u8 = source.read_u8()?;
627 let data_type = data_type_value.try_into().map_err(|_| {
628 DeserializationError::InvalidValue("invalid data type value".to_string())
629 })?;
630
631 let ciphertext_len = source.read_usize()?;
632 let ciphertext_bytes = source.read_many(ciphertext_len)?;
633 let ciphertext =
634 felts_from_u64(ciphertext_bytes).map_err(DeserializationError::InvalidValue)?;
635
636 let nonce = source.read_many(NONCE_SIZE)?;
637 let nonce: [Felt; NONCE_SIZE] = felts_from_u64(nonce)
638 .map_err(DeserializationError::InvalidValue)?
639 .try_into()
640 .expect("deserialization reads exactly NONCE_SIZE elements");
641
642 let tag = source.read_many(AUTH_TAG_SIZE)?;
643 let tag: [Felt; AUTH_TAG_SIZE] = felts_from_u64(tag)
644 .map_err(DeserializationError::InvalidValue)?
645 .try_into()
646 .expect("deserialization reads exactly AUTH_TAG_SIZE elements");
647
648 Ok(Self {
649 ciphertext,
650 nonce: Nonce(nonce),
651 auth_tag: AuthTag(tag),
652 data_type,
653 })
654 }
655}
656
657fn pad(data: &[Felt]) -> Vec<Felt> {
686 let num_elem_final_block = data.len() % RATE_WIDTH;
688 let padding_elements = RATE_WIDTH - num_elem_final_block;
689
690 let mut result = data.to_vec();
691 result.extend_from_slice(&PADDING_BLOCK[..padding_elements]);
692
693 result
694}
695
696fn unpad(mut plaintext: Vec<Felt>) -> Result<Vec<Felt>, EncryptionError> {
698 let (num_blocks, remainder) = plaintext.len().div_rem(&RATE_WIDTH);
699 assert_eq!(remainder, 0);
700
701 let final_block: &[Felt; RATE_WIDTH] = plaintext.last_chunk().expect("plaintext is empty");
702
703 let pos = match final_block.iter().rposition(|entry| *entry == ONE) {
704 Some(pos) => pos,
705 None => return Err(EncryptionError::MalformedPadding),
706 };
707
708 plaintext.truncate((num_blocks - 1) * RATE_WIDTH + pos);
709
710 Ok(plaintext)
711}
712
713fn felts_from_u64(input: Vec<u64>) -> Result<Vec<Felt>, alloc::string::String> {
716 input.into_iter().map(Felt::try_from).collect()
717}
718
719pub struct AeadRpo;
724
725impl AeadScheme for AeadRpo {
726 const KEY_SIZE: usize = SK_SIZE_BYTES;
727
728 type Key = SecretKey;
729
730 fn key_from_bytes(bytes: &[u8]) -> Result<Self::Key, EncryptionError> {
731 SecretKey::read_from_bytes(bytes).map_err(|_| EncryptionError::FailedOperation)
732 }
733
734 fn encrypt_bytes<R: rand::CryptoRng + rand::RngCore>(
735 key: &Self::Key,
736 rng: &mut R,
737 plaintext: &[u8],
738 associated_data: &[u8],
739 ) -> Result<Vec<u8>, EncryptionError> {
740 let nonce = Nonce::with_rng(rng);
741 let encrypted_data = key
742 .encrypt_bytes_with_nonce(plaintext, associated_data, nonce)
743 .map_err(|_| EncryptionError::FailedOperation)?;
744
745 Ok(encrypted_data.to_bytes())
746 }
747
748 fn decrypt_bytes_with_associated_data(
749 key: &Self::Key,
750 ciphertext: &[u8],
751 associated_data: &[u8],
752 ) -> Result<Vec<u8>, EncryptionError> {
753 let encrypted_data = EncryptedData::read_from_bytes(ciphertext)
754 .map_err(|_| EncryptionError::FailedOperation)?;
755
756 key.decrypt_bytes_with_associated_data(&encrypted_data, associated_data)
757 }
758
759 fn encrypt_elements<R: rand::CryptoRng + rand::RngCore>(
763 key: &Self::Key,
764 rng: &mut R,
765 plaintext: &[Felt],
766 associated_data: &[Felt],
767 ) -> Result<Vec<u8>, EncryptionError> {
768 let nonce = Nonce::with_rng(rng);
769 let encrypted_data = key
770 .encrypt_elements_with_nonce(plaintext, associated_data, nonce)
771 .map_err(|_| EncryptionError::FailedOperation)?;
772
773 Ok(encrypted_data.to_bytes())
774 }
775
776 fn decrypt_elements_with_associated_data(
777 key: &Self::Key,
778 ciphertext: &[u8],
779 associated_data: &[Felt],
780 ) -> Result<Vec<Felt>, EncryptionError> {
781 let encrypted_data = EncryptedData::read_from_bytes(ciphertext)
782 .map_err(|_| EncryptionError::FailedOperation)?;
783
784 key.decrypt_elements_with_associated_data(&encrypted_data, associated_data)
785 }
786}