memberlist_proto/
encryption.rs

1use std::borrow::Cow;
2
3use aead::{AeadInPlace, KeyInit};
4use aes_gcm::{
5  Aes128Gcm, Aes256Gcm, AesGcm,
6  aes::{Aes192, cipher::consts::U12},
7};
8use bytes::{Buf, BufMut};
9use generic_array::GenericArray;
10use rand::Rng;
11use varing::decode_u32_varint;
12
13use crate::{WireType, utils::merge};
14
15use super::{Data, DataRef, DecodeError, EncodeError};
16
17const NOPADDING_TAG: u8 = 1;
18const PKCS7_TAG: u8 = 2;
19
20const NONCE_SIZE: usize = 12;
21const TAG_SIZE: usize = 16;
22const BLOCK_SIZE: usize = 16;
23
24type Aes192Gcm = AesGcm<Aes192, U12>;
25
26use std::str::FromStr;
27
28use base64::{Engine as _, engine::general_purpose::STANDARD as b64};
29
30/// An error type when parsing the encryption algorithm from str
31#[derive(Debug, Clone, PartialEq, Eq, Hash, thiserror::Error)]
32#[error("unknown encryption algorithm: {0}")]
33pub struct ParseEncryptionAlgorithmError(String);
34
35impl FromStr for EncryptionAlgorithm {
36  type Err = ParseEncryptionAlgorithmError;
37
38  fn from_str(s: &str) -> Result<Self, Self::Err> {
39    Ok(match s {
40      "aes-gcm-no-padding" | "aes-gcm-nopadding" | "nopadding" | "NOPADDING" | "no-padding"
41      | "NoPadding" | "no_padding" => Self::NoPadding,
42      "aes-gcm-pkcs7" | "PKCS7" | "pkcs7" => Self::Pkcs7,
43      s if s.starts_with("unknown") => {
44        let v = s
45          .trim_start_matches("unknown(")
46          .trim_end_matches(')')
47          .parse()
48          .map_err(|_| ParseEncryptionAlgorithmError(s.to_string()))?;
49        Self::Unknown(v)
50      }
51      e => return Err(ParseEncryptionAlgorithmError(e.to_string())),
52    })
53  }
54}
55
56/// Parse error for [`SecretKey`] from bytes slice
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
58#[error("invalid key length({0}) - must be 16, 24, or 32 bytes for AES-128/192/256")]
59pub struct InvalidKeyLength(pub(crate) usize);
60
61/// Parse error for [`SecretKey`] from a base64 string
62#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
63pub enum ParseSecretKeyError {
64  /// Invalid base64 string
65  #[cfg_attr(feature = "std", error(transparent))]
66  #[cfg_attr(not(feature = "std"), error("{0}"))]
67  Base64(#[cfg_attr(feature = "std", from)] base64::DecodeError),
68  /// Invalid key length
69  #[error(transparent)]
70  InvalidKeyLength(#[from] InvalidKeyLength),
71}
72
73#[cfg(not(feature = "std"))]
74impl From<base64::DecodeError> for ParseSecretKeyError {
75  fn from(e: base64::DecodeError) -> Self {
76    Self::Base64(e)
77  }
78}
79
80/// The key used while attempting to encrypt/decrypt a message
81#[derive(
82  Debug,
83  Copy,
84  Clone,
85  PartialEq,
86  Eq,
87  PartialOrd,
88  Ord,
89  derive_more::IsVariant,
90  derive_more::TryUnwrap,
91  derive_more::Unwrap,
92)]
93#[unwrap(ref, ref_mut)]
94#[try_unwrap(ref, ref_mut)]
95#[cfg_attr(any(feature = "arbitrary", test), derive(arbitrary::Arbitrary))]
96pub enum SecretKey {
97  /// secret key for AES128
98  Aes128([u8; 16]),
99  /// secret key for AES192
100  Aes192([u8; 24]),
101  /// secret key for AES256
102  Aes256([u8; 32]),
103}
104
105impl SecretKey {
106  /// Returns the base64 encoded secret key
107  ///
108  /// It is recommended to use [`SecretKey::encode_base64`] if you want to encode the key to a buffer
109  /// without allocating a new string.
110  pub fn to_base64(&self) -> String {
111    b64.encode(self)
112  }
113
114  /// Returns the base64 encoded length of the secret key
115  ///
116  /// ## Example
117  ///
118  /// ```rust
119  /// use memberlist_proto::encryption::SecretKey;
120  ///
121  /// let key = SecretKey::random_aes128();
122  /// assert_eq!(key.base64_len(), 24);
123  ///
124  /// let key = SecretKey::random_aes192();
125  /// assert_eq!(key.base64_len(), 32);
126  ///
127  /// let key = SecretKey::random_aes256();
128  /// assert_eq!(key.base64_len(), 44);
129  /// ```
130  #[inline]
131  pub const fn base64_len(&self) -> usize {
132    match self {
133      Self::Aes128(_) => 24,
134      Self::Aes192(_) => 32,
135      Self::Aes256(_) => 44,
136    }
137  }
138
139  /// Encodes the secret key to the buffer in base64 format
140  ///
141  /// ## Example
142  ///
143  /// ```rust
144  /// use memberlist_proto::encryption::SecretKey;
145  ///
146  /// let key = SecretKey::random_aes128();
147  /// let mut buf = [0u8; 24];
148  /// key.encode_base64(&mut buf).unwrap();
149  /// assert_eq!(&buf, key.to_base64().as_bytes());
150  ///
151  /// let key = SecretKey::random_aes192();
152  /// let mut buf = [0u8; 32];
153  /// key.encode_base64(&mut buf).unwrap();
154  /// assert_eq!(&buf, key.to_base64().as_bytes());
155  ///
156  /// let key = SecretKey::random_aes256();
157  /// let mut buf = [0u8; 44];
158  /// key.encode_base64(&mut buf).unwrap();
159  /// assert_eq!(&buf, key.to_base64().as_bytes());
160  /// ```
161  #[inline]
162  pub fn encode_base64(&self, buf: &mut [u8]) -> Result<usize, base64::EncodeSliceError> {
163    b64.encode_slice(self.as_ref(), buf)
164  }
165
166  /// Creates a random secret key
167  #[inline]
168  pub fn random_aes128() -> Self {
169    let mut key = [0u8; 16];
170    rand::rng().fill(&mut key);
171    Self::Aes128(key)
172  }
173
174  /// Creates a random secret key
175  #[inline]
176  pub fn random_aes192() -> Self {
177    let mut key = [0u8; 24];
178    rand::rng().fill(&mut key);
179    Self::Aes192(key)
180  }
181
182  /// Creates a random secret key
183  #[inline]
184  pub fn random_aes256() -> Self {
185    let mut key = [0u8; 32];
186    rand::rng().fill(&mut key);
187    Self::Aes256(key)
188  }
189}
190
191impl TryFrom<&str> for SecretKey {
192  type Error = ParseSecretKeyError;
193
194  fn try_from(s: &str) -> Result<Self, Self::Error> {
195    s.parse()
196  }
197}
198
199impl TryFrom<&[u8]> for SecretKey {
200  type Error = InvalidKeyLength;
201
202  fn try_from(k: &[u8]) -> Result<Self, Self::Error> {
203    Ok(match k.len() {
204      16 => Self::Aes128(k.try_into().unwrap()),
205      24 => Self::Aes192(k.try_into().unwrap()),
206      32 => Self::Aes256(k.try_into().unwrap()),
207      v => return Err(InvalidKeyLength(v)),
208    })
209  }
210}
211
212impl FromStr for SecretKey {
213  type Err = ParseSecretKeyError;
214
215  fn from_str(s: &str) -> Result<Self, Self::Err> {
216    let mut buf = [0u8; 44];
217    let readed = b64.decode_slice(s, &mut buf).map_err(|e| match e {
218      base64::DecodeSliceError::DecodeError(decode_error) => decode_error.into(),
219      base64::DecodeSliceError::OutputSliceTooSmall => {
220        ParseSecretKeyError::InvalidKeyLength(InvalidKeyLength(s.len()))
221      }
222    })?;
223
224    let bytes = &buf[..readed];
225    Ok(match readed {
226      16 => SecretKey::Aes128(bytes[..readed].try_into().unwrap()),
227      24 => SecretKey::Aes192(bytes[..readed].try_into().unwrap()),
228      32 => SecretKey::Aes256(bytes[..readed].try_into().unwrap()),
229      v => return Err(ParseSecretKeyError::InvalidKeyLength(InvalidKeyLength(v))),
230    })
231  }
232}
233
234impl core::hash::Hash for SecretKey {
235  fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
236    self.as_ref().hash(state);
237  }
238}
239
240impl core::borrow::Borrow<[u8]> for SecretKey {
241  fn borrow(&self) -> &[u8] {
242    self.as_ref()
243  }
244}
245
246impl PartialEq<[u8]> for SecretKey {
247  fn eq(&self, other: &[u8]) -> bool {
248    self.as_ref() == other
249  }
250}
251
252impl core::ops::Deref for SecretKey {
253  type Target = [u8];
254
255  fn deref(&self) -> &Self::Target {
256    match self {
257      Self::Aes128(k) => k,
258      Self::Aes192(k) => k,
259      Self::Aes256(k) => k,
260    }
261  }
262}
263
264impl core::ops::DerefMut for SecretKey {
265  fn deref_mut(&mut self) -> &mut Self::Target {
266    match self {
267      Self::Aes128(k) => k,
268      Self::Aes192(k) => k,
269      Self::Aes256(k) => k,
270    }
271  }
272}
273
274impl From<[u8; 16]> for SecretKey {
275  fn from(k: [u8; 16]) -> Self {
276    Self::Aes128(k)
277  }
278}
279
280impl From<[u8; 24]> for SecretKey {
281  fn from(k: [u8; 24]) -> Self {
282    Self::Aes192(k)
283  }
284}
285
286impl From<[u8; 32]> for SecretKey {
287  fn from(k: [u8; 32]) -> Self {
288    Self::Aes256(k)
289  }
290}
291
292impl AsRef<[u8]> for SecretKey {
293  fn as_ref(&self) -> &[u8] {
294    match self {
295      Self::Aes128(k) => k,
296      Self::Aes192(k) => k,
297      Self::Aes256(k) => k,
298    }
299  }
300}
301
302impl AsMut<[u8]> for SecretKey {
303  fn as_mut(&mut self) -> &mut [u8] {
304    match self {
305      Self::Aes128(k) => k,
306      Self::Aes192(k) => k,
307      Self::Aes256(k) => k,
308    }
309  }
310}
311
312impl<'a> DataRef<'a, Self> for SecretKey {
313  fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError>
314  where
315    Self: Sized,
316  {
317    let mut offset = 0;
318    let buf_len = buf.len();
319    let mut key = None;
320
321    while offset < buf_len {
322      match buf[offset] {
323        AES128_BYTE => {
324          if key.is_some() {
325            return Err(DecodeError::duplicate_field("SecretKey", "key", 0));
326          }
327          offset += 1;
328
329          let (bytes_read, val) = decode_u32_varint(&buf[offset..])?;
330          offset += bytes_read.get();
331
332          let val: [u8; 16] = buf[offset..offset + val as usize]
333            .try_into()
334            .map_err(|_| DecodeError::buffer_underflow())?;
335          offset += 16;
336          key = Some(SecretKey::Aes128(val));
337        }
338        AES192_BYTE => {
339          if key.is_some() {
340            return Err(DecodeError::duplicate_field("SecretKey", "key", 0));
341          }
342          offset += 1;
343
344          let (bytes_read, val) = decode_u32_varint(&buf[offset..])?;
345          offset += bytes_read.get();
346
347          let val: [u8; 24] = buf[offset..offset + val as usize]
348            .try_into()
349            .map_err(|_| DecodeError::buffer_underflow())?;
350          offset += 24;
351
352          key = Some(SecretKey::Aes192(val));
353        }
354        AES256_BYTE => {
355          if key.is_some() {
356            return Err(DecodeError::duplicate_field("SecretKey", "key", 0));
357          }
358          offset += 1;
359
360          let (bytes_read, val) = decode_u32_varint(&buf[offset..])?;
361          offset += bytes_read.get();
362
363          let val: [u8; 32] = buf[offset..offset + val as usize]
364            .try_into()
365            .map_err(|_| DecodeError::buffer_underflow())?;
366          offset += 32;
367
368          key = Some(SecretKey::Aes256(val));
369        }
370        _ => offset += super::skip("SecretKey", &buf[offset..])?,
371      }
372    }
373
374    let key = key.ok_or_else(|| DecodeError::missing_field("SecretKey", "key"))?;
375    Ok((offset, key))
376  }
377}
378
379const AES128_TAG: u8 = 1;
380const AES192_TAG: u8 = 2;
381const AES256_TAG: u8 = 3;
382
383const AES128_BYTE: u8 = merge(WireType::LengthDelimited, AES128_TAG);
384const AES192_BYTE: u8 = merge(WireType::LengthDelimited, AES192_TAG);
385const AES256_BYTE: u8 = merge(WireType::LengthDelimited, AES256_TAG);
386
387impl Data for SecretKey {
388  type Ref<'a> = Self;
389
390  fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
391  where
392    Self: Sized,
393  {
394    Ok(val)
395  }
396
397  fn encoded_len(&self) -> usize {
398    1 + varing::encoded_u32_varint_len(self.len() as u32).get() + self.len()
399  }
400
401  fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
402    let buf_len = buf.len();
403    let mut offset = 0;
404
405    if buf_len < 1 {
406      return Err(EncodeError::insufficient_buffer(
407        self.encoded_len(),
408        buf_len,
409      ));
410    }
411
412    buf[offset] = match self {
413      Self::Aes128(_) => AES128_BYTE,
414      Self::Aes192(_) => AES192_BYTE,
415      Self::Aes256(_) => AES256_BYTE,
416    };
417    offset += 1;
418
419    let self_len = self.len();
420    let len = varing::encode_u32_varint_to(self_len as u32, &mut buf[offset..])
421      .map_err(|_| EncodeError::insufficient_buffer(self.encoded_len(), buf_len))?;
422    offset += len.get();
423
424    buf[offset..offset + self_len].copy_from_slice(self.as_ref());
425    offset += self_len;
426
427    #[cfg(debug_assertions)]
428    super::debug_assert_write_eq::<Self>(offset, self.encoded_len());
429
430    Ok(offset)
431  }
432}
433
434smallvec_wrapper::smallvec_wrapper!(
435  /// A collection of secret keys, you can just treat it as a `Vec<SecretKey>`.
436  #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
437  #[repr(transparent)]
438  #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
439  #[cfg_attr(feature = "serde", serde(transparent))]
440  pub SecretKeys([SecretKey; 3]);
441);
442
443impl SecretKeys {
444  /// Returns `true` if the collection is empty
445  ///
446  /// # Example
447  ///
448  /// ```
449  /// use memberlist_proto::encryption::SecretKeys;
450  ///
451  /// let keys = SecretKeys::default();
452  /// assert!(keys.is_empty());
453  /// ```
454  #[inline]
455  pub fn is_empty(&self) -> bool {
456    self.0.is_empty()
457  }
458
459  /// Returns the length of the collection
460  ///
461  /// # Example
462  ///
463  /// ```
464  /// use memberlist_proto::encryption::SecretKeys;
465  ///
466  /// let keys = SecretKeys::default();
467  /// assert_eq!(keys.len(), 0);
468  /// ```
469  #[inline]
470  pub fn len(&self) -> usize {
471    self.0.len()
472  }
473}
474
475/// Security errors
476#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)]
477pub enum EncryptionError {
478  /// Unknown encryption algorithm
479  #[error("unknown encryption algorithm: {0}")]
480  UnknownAlgorithm(EncryptionAlgorithm),
481  /// Encryt/Decrypt errors
482  #[error("failed to encrypt/decrypt")]
483  Encryptor,
484}
485
486impl From<aead::Error> for EncryptionError {
487  fn from(_: aead::Error) -> Self {
488    Self::Encryptor
489  }
490}
491
492/// The encryption algorithm used to encrypt the message.
493#[derive(
494  Debug, Default, Clone, Copy, PartialEq, Eq, Hash, derive_more::IsVariant, derive_more::Display,
495)]
496#[non_exhaustive]
497pub enum EncryptionAlgorithm {
498  /// AES-GCM, using no padding
499  #[default]
500  #[display("aes-gcm-nopadding")]
501  NoPadding,
502  /// AES-GCM, using PKCS7 padding
503  #[display("aes-gcm-pkcs7")]
504  Pkcs7,
505  /// Unknwon encryption version
506  #[display("unknown({_0})")]
507  Unknown(u8),
508}
509
510#[cfg(any(feature = "quickcheck", test))]
511const _: () = {
512  use quickcheck::Arbitrary;
513
514  impl EncryptionAlgorithm {
515    const MAX: Self = Self::NoPadding;
516    const MIN: Self = Self::Pkcs7;
517  }
518
519  impl Arbitrary for EncryptionAlgorithm {
520    fn arbitrary(g: &mut quickcheck::Gen) -> Self {
521      let val = (u8::arbitrary(g) % Self::MAX.as_u8()) + Self::MIN.as_u8();
522      match val {
523        NOPADDING_TAG => Self::NoPadding,
524        PKCS7_TAG => Self::Pkcs7,
525        _ => unreachable!(),
526      }
527    }
528  }
529};
530
531impl EncryptionAlgorithm {
532  /// Returns the encryption version as a `u8`.
533  #[inline]
534  pub const fn as_u8(&self) -> u8 {
535    match self {
536      Self::NoPadding => NOPADDING_TAG,
537      Self::Pkcs7 => PKCS7_TAG,
538      Self::Unknown(v) => *v,
539    }
540  }
541
542  /// Returns the encryption version as a `&'static str`.
543  #[inline]
544  pub fn as_str(&self) -> Cow<'static, str> {
545    let val = match self {
546      Self::NoPadding => "aes-gcm-nopadding",
547      Self::Pkcs7 => "aes-gcm-pkcs7",
548      Self::Unknown(e) => return Cow::Owned(format!("unknown({})", e)),
549    };
550    Cow::Borrowed(val)
551  }
552}
553
554impl From<u8> for EncryptionAlgorithm {
555  fn from(value: u8) -> Self {
556    match value {
557      NOPADDING_TAG => Self::NoPadding,
558      PKCS7_TAG => Self::Pkcs7,
559      e => Self::Unknown(e),
560    }
561  }
562}
563
564impl EncryptionAlgorithm {
565  /// Returns the nonce size of the encryption algorithm
566  #[inline]
567  pub const fn nonce_size(&self) -> usize {
568    // only 12 bytes for nonce accepted currently
569    NONCE_SIZE
570  }
571
572  /// Writes the nonce to the buffer, returning the random generated nonce
573  pub fn write_nonce(dst: &mut impl BufMut) -> [u8; NONCE_SIZE] {
574    // Add a random nonce
575    let mut nonce = [0u8; NONCE_SIZE];
576    rand::rng().fill(&mut nonce);
577    dst.put_slice(&nonce);
578
579    nonce
580  }
581
582  /// Generates a random nonce
583  pub fn random_nonce() -> [u8; NONCE_SIZE] {
584    let mut nonce = [0u8; NONCE_SIZE];
585    rand::rng().fill(&mut nonce);
586    nonce
587  }
588
589  /// Reads the nonce from the buffer
590  pub fn read_nonce(src: &mut impl Buf) -> [u8; NONCE_SIZE] {
591    let mut nonce = [0u8; NONCE_SIZE];
592    nonce.copy_from_slice(&src.chunk()[..NONCE_SIZE]);
593    src.advance(NONCE_SIZE);
594    nonce
595  }
596
597  /// Encrypts the data using the provided secret key, nonce, and the authentication data
598  pub fn encrypt<B>(
599    &self,
600    pk: SecretKey,
601    nonce: [u8; NONCE_SIZE],
602    auth_data: &[u8],
603    buf: &mut B,
604  ) -> Result<(), EncryptionError>
605  where
606    B: aead::Buffer,
607  {
608    match self {
609      EncryptionAlgorithm::NoPadding => {}
610      EncryptionAlgorithm::Pkcs7 => {
611        let buf_len = buf.len();
612        pkcs7encode(buf, buf_len, 0)?;
613      }
614      _ => return Err(EncryptionError::UnknownAlgorithm(*self)),
615    }
616
617    match pk {
618      SecretKey::Aes128(pk) => {
619        let gcm = Aes128Gcm::new(GenericArray::from_slice(&pk).as_ref());
620        gcm
621          .encrypt_in_place(GenericArray::from_slice(&nonce).as_ref(), auth_data, buf)
622          .map_err(Into::into)
623      }
624      SecretKey::Aes192(pk) => {
625        let gcm = Aes192Gcm::new(GenericArray::from_slice(&pk).as_ref());
626        gcm
627          .encrypt_in_place(GenericArray::from_slice(&nonce).as_ref(), auth_data, buf)
628          .map_err(Into::into)
629      }
630      SecretKey::Aes256(pk) => {
631        let gcm = Aes256Gcm::new(GenericArray::from_slice(&pk).as_ref());
632        gcm
633          .encrypt_in_place(GenericArray::from_slice(&nonce).as_ref(), auth_data, buf)
634          .map_err(Into::into)
635      }
636    }
637  }
638
639  /// Decrypts the data using the provided secret key, nonce, and the authentication data
640  pub fn decrypt(
641    &self,
642    key: &SecretKey,
643    nonce: &[u8],
644    auth_data: &[u8],
645    dst: &mut impl aead::Buffer,
646  ) -> Result<(), EncryptionError> {
647    if self.is_unknown() {
648      return Err(EncryptionError::UnknownAlgorithm(*self));
649    }
650
651    // Get the AES block cipher
652    match key {
653      SecretKey::Aes128(pk) => {
654        let gcm = Aes128Gcm::new(GenericArray::from_slice(pk).as_ref());
655        gcm
656          .decrypt_in_place(GenericArray::from_slice(nonce).as_ref(), auth_data, dst)
657          .map_err(Into::into)
658      }
659      SecretKey::Aes192(pk) => {
660        let gcm = Aes192Gcm::new(GenericArray::from_slice(pk).as_ref());
661        gcm
662          .decrypt_in_place(GenericArray::from_slice(nonce).as_ref(), auth_data, dst)
663          .map_err(Into::into)
664      }
665      SecretKey::Aes256(pk) => {
666        let gcm = Aes256Gcm::new(GenericArray::from_slice(pk).as_ref());
667        gcm
668          .decrypt_in_place(GenericArray::from_slice(nonce).as_ref(), auth_data, dst)
669          .map_err(Into::into)
670      }
671    }
672    .inspect(|_| {
673      if self.is_pkcs_7() {
674        pkcs7decode(dst);
675      }
676    })
677  }
678
679  /// Returns the overhead of the encryption
680  #[inline]
681  pub(crate) const fn encrypt_overhead(&self) -> usize {
682    match self {
683      EncryptionAlgorithm::Pkcs7 => 44, // IV: 12, Padding: 16, Tag: 16
684      EncryptionAlgorithm::NoPadding => 28, // IV: 12, Tag: 16
685      _ => unreachable!(),
686    }
687  }
688
689  /// Returns the encrypted suffix length of the input size
690  #[inline]
691  pub const fn encrypted_suffix_len(&self, inp: usize) -> usize {
692    match self {
693      EncryptionAlgorithm::Pkcs7 => {
694        // Determine the padding size
695        let padding = BLOCK_SIZE - (inp % BLOCK_SIZE);
696
697        // Sum the extra parts to get total size
698        padding + TAG_SIZE
699      }
700      EncryptionAlgorithm::NoPadding => TAG_SIZE,
701      _ => unreachable!(),
702    }
703  }
704}
705
706/// pkcs7encode is used to pad a byte buffer to a specific block size using
707/// the PKCS7 algorithm. "Ignores" some bytes to compensate for IV
708#[inline]
709fn pkcs7encode(
710  buf: &mut impl aead::Buffer,
711  buf_len: usize,
712  ignore: usize,
713) -> Result<(), aead::Error> {
714  let n = buf_len - ignore;
715  let more = BLOCK_SIZE - (n % BLOCK_SIZE);
716  let mut block_buf = [0u8; BLOCK_SIZE];
717  block_buf
718    .iter_mut()
719    .take(more)
720    .for_each(|b| *b = more as u8);
721  buf.extend_from_slice(&block_buf[..more])
722}
723
724/// pkcs7decode is used to decode a buffer that has been padded
725#[inline]
726fn pkcs7decode(buf: &mut impl aead::Buffer) {
727  if buf.is_empty() {
728    panic!("Cannot decode a PKCS7 buffer of zero length");
729  }
730  let n = buf.len();
731  let last = buf.as_ref()[n - 1];
732  let n = n - (last as usize);
733  buf.truncate(n);
734}
735
736#[cfg(feature = "serde")]
737const _: () = {
738  use serde::{Deserialize, Deserializer, Serialize, Serializer};
739
740  impl Serialize for EncryptionAlgorithm {
741    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
742    where
743      S: Serializer,
744    {
745      if serializer.is_human_readable() {
746        serializer.serialize_str(self.as_str().as_ref())
747      } else {
748        serializer.serialize_u8(self.as_u8())
749      }
750    }
751  }
752
753  impl<'de> Deserialize<'de> for EncryptionAlgorithm {
754    fn deserialize<D>(deserializer: D) -> Result<EncryptionAlgorithm, D::Error>
755    where
756      D: Deserializer<'de>,
757    {
758      if deserializer.is_human_readable() {
759        <&str>::deserialize(deserializer).and_then(|s| s.parse().map_err(serde::de::Error::custom))
760      } else {
761        u8::deserialize(deserializer).map(EncryptionAlgorithm::from)
762      }
763    }
764  }
765};
766
767#[cfg(test)]
768mod tests {
769  use bytes::BytesMut;
770
771  use core::ops::{Deref, DerefMut};
772
773  use arbitrary::Arbitrary;
774
775  use super::*;
776
777  impl super::EncryptionAlgorithm {
778    /// Returns the encrypted length of the input size
779    #[inline]
780    const fn encrypted_len(&self, inp: usize) -> usize {
781      match self {
782        EncryptionAlgorithm::Pkcs7 => {
783          // Determine the padding size
784          let padding = BLOCK_SIZE - (inp % BLOCK_SIZE);
785
786          // Sum the extra parts to get total size
787          4 + 1 + NONCE_SIZE + inp + padding + TAG_SIZE
788        }
789        EncryptionAlgorithm::NoPadding => 4 + 1 + NONCE_SIZE + inp + TAG_SIZE,
790        _ => unreachable!(),
791      }
792    }
793  }
794
795  #[test]
796  fn arbitrary_secret_key() {
797    let key = SecretKey::arbitrary(&mut arbitrary::Unstructured::new(&[0; 128])).unwrap();
798    assert!(matches!(
799      key,
800      SecretKey::Aes128(_) | SecretKey::Aes192(_) | SecretKey::Aes256(_)
801    ));
802  }
803
804  #[test]
805  fn test_secret_key() {
806    let mut key = SecretKey::from([0; 16]);
807    assert_eq!(key.deref(), &[0; 16]);
808    assert_eq!(key.deref_mut(), &mut [0; 16]);
809    assert_eq!(key.as_ref(), &[0; 16]);
810    assert_eq!(key.as_mut(), &mut [0; 16]);
811    assert_eq!(key.len(), 16);
812    assert!(!key.is_empty());
813    assert_eq!(key.to_vec(), vec![0; 16]);
814
815    let mut key = SecretKey::from([0; 24]);
816    assert_eq!(key.deref(), &[0; 24]);
817    assert_eq!(key.deref_mut(), &mut [0; 24]);
818    assert_eq!(key.as_ref(), &[0; 24]);
819    assert_eq!(key.as_mut(), &mut [0; 24]);
820    assert_eq!(key.len(), 24);
821    assert!(!key.is_empty());
822    assert_eq!(key.to_vec(), vec![0; 24]);
823
824    let mut key = SecretKey::from([0; 32]);
825    assert_eq!(key.deref(), &[0; 32]);
826    assert_eq!(key.deref_mut(), &mut [0; 32]);
827    assert_eq!(key.as_ref(), &[0; 32]);
828    assert_eq!(key.as_mut(), &mut [0; 32]);
829    assert_eq!(key.len(), 32);
830    assert!(!key.is_empty());
831    assert_eq!(key.to_vec(), vec![0; 32]);
832
833    let mut key = SecretKey::from([0; 16]);
834    assert_eq!(key.as_ref(), &[0; 16]);
835    assert_eq!(key.as_mut(), &mut [0; 16]);
836
837    let mut key = SecretKey::from([0; 24]);
838    assert_eq!(key.as_ref(), &[0; 24]);
839    assert_eq!(key.as_mut(), &mut [0; 24]);
840
841    let mut key = SecretKey::from([0; 32]);
842    assert_eq!(key.as_ref(), &[0; 32]);
843    assert_eq!(key.as_mut(), &mut [0; 32]);
844
845    let key = SecretKey::Aes128([0; 16]);
846    assert_eq!(key.to_vec(), vec![0; 16]);
847
848    let key = SecretKey::Aes192([0; 24]);
849    assert_eq!(key.to_vec(), vec![0; 24]);
850
851    let key = SecretKey::Aes256([0; 32]);
852    assert_eq!(key.to_vec(), vec![0; 32]);
853  }
854
855  #[test]
856  fn test_try_from() {
857    assert!(SecretKey::try_from([0; 15].as_slice()).is_err());
858    assert!(SecretKey::try_from([0; 16].as_slice()).is_ok());
859    assert!(SecretKey::try_from([0; 23].as_slice()).is_err());
860    assert!(SecretKey::try_from([0; 24].as_slice()).is_ok());
861    assert!(SecretKey::try_from([0; 31].as_slice()).is_err());
862    assert!(SecretKey::try_from([0; 32].as_slice()).is_ok());
863  }
864
865  fn encrypt_decrypt_versioned(vsn: EncryptionAlgorithm) {
866    let k1 = SecretKey::Aes128([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
867    let plain_text = b"this is a plain text message";
868    let extra = b"random data";
869
870    let mut encrypted = BytesMut::new();
871    let nonce = EncryptionAlgorithm::write_nonce(&mut encrypted);
872    let data_offset = encrypted.len();
873    encrypted.put_slice(plain_text);
874
875    let mut dst = encrypted.split_off(data_offset);
876    println!("before encrypted: {} {:?}", dst.len(), dst.as_ref());
877    vsn.encrypt(k1, nonce, extra, &mut dst).unwrap();
878    println!("encrypted: {} {:?}", dst.len(), dst.as_ref());
879    encrypted.unsplit(dst);
880
881    let exp_len = vsn.encrypted_len(plain_text.len());
882    assert_eq!(encrypted.len(), exp_len - 5); // minus 5 for header
883
884    EncryptionAlgorithm::read_nonce(&mut encrypted);
885    vsn.decrypt(&k1, &nonce, extra, &mut encrypted).unwrap();
886    assert_eq!(encrypted.as_ref(), plain_text);
887  }
888
889  fn decrypt_by_other_key(algo: EncryptionAlgorithm) {
890    let k1 = SecretKey::Aes128([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
891    let plain_text = b"this is a plain text message";
892    let extra = b"random data";
893
894    let mut encrypted = BytesMut::new();
895    let nonce = EncryptionAlgorithm::write_nonce(&mut encrypted);
896    let data_offset = encrypted.len();
897    encrypted.put_slice(plain_text);
898
899    let mut dst = encrypted.split_off(data_offset);
900    algo.encrypt(k1, nonce, extra, &mut dst).unwrap();
901    encrypted.unsplit(dst);
902    let exp_len = algo.encrypted_len(plain_text.len());
903    assert_eq!(encrypted.len(), exp_len - 5); // minus 5 for header
904    EncryptionAlgorithm::read_nonce(&mut encrypted);
905
906    for (idx, k) in TEST_KEYS.iter().rev().enumerate() {
907      if idx == TEST_KEYS.len() - 1 {
908        algo.decrypt(k, &nonce, extra, &mut encrypted).unwrap();
909        assert_eq!(encrypted.as_ref(), plain_text);
910        return;
911      }
912      let e = algo.decrypt(k, &nonce, extra, &mut encrypted).unwrap_err();
913      assert_eq!(e.to_string(), "failed to encrypt/decrypt");
914    }
915  }
916
917  #[test]
918  fn test_encrypt_decrypt_v0() {
919    encrypt_decrypt_versioned(EncryptionAlgorithm::Pkcs7);
920  }
921
922  #[test]
923  fn test_encrypt_decrypt_v1() {
924    encrypt_decrypt_versioned(EncryptionAlgorithm::NoPadding);
925  }
926
927  #[test]
928  fn test_decrypt_by_other_key_v0() {
929    let algo = EncryptionAlgorithm::Pkcs7;
930    decrypt_by_other_key(algo);
931  }
932
933  #[test]
934  fn test_decrypt_by_other_key_v1() {
935    let algo = EncryptionAlgorithm::NoPadding;
936    decrypt_by_other_key(algo);
937  }
938
939  #[test]
940  fn test_encrypt_algorithm_from_str() {
941    assert_eq!(
942      "aes-gcm-no-padding".parse::<EncryptionAlgorithm>().unwrap(),
943      EncryptionAlgorithm::NoPadding
944    );
945    assert_eq!(
946      "aes-gcm-nopadding".parse::<EncryptionAlgorithm>().unwrap(),
947      EncryptionAlgorithm::NoPadding
948    );
949    assert_eq!(
950      "aes-gcm-pkcs7".parse::<EncryptionAlgorithm>().unwrap(),
951      EncryptionAlgorithm::Pkcs7
952    );
953    assert_eq!(
954      "NoPadding".parse::<EncryptionAlgorithm>().unwrap(),
955      EncryptionAlgorithm::NoPadding
956    );
957    assert_eq!(
958      "no-padding".parse::<EncryptionAlgorithm>().unwrap(),
959      EncryptionAlgorithm::NoPadding
960    );
961    assert_eq!(
962      "nopadding".parse::<EncryptionAlgorithm>().unwrap(),
963      EncryptionAlgorithm::NoPadding
964    );
965    assert_eq!(
966      "no_padding".parse::<EncryptionAlgorithm>().unwrap(),
967      EncryptionAlgorithm::NoPadding
968    );
969    assert_eq!(
970      "NOPADDING".parse::<EncryptionAlgorithm>().unwrap(),
971      EncryptionAlgorithm::NoPadding
972    );
973    assert_eq!(
974      "unknown(33)".parse::<EncryptionAlgorithm>().unwrap(),
975      EncryptionAlgorithm::Unknown(33)
976    );
977    assert!("unknown".parse::<EncryptionAlgorithm>().is_err());
978  }
979
980  #[cfg(feature = "serde")]
981  #[quickcheck_macros::quickcheck]
982  fn encryption_algorithm_serde(algo: EncryptionAlgorithm) -> bool {
983    use bincode::config::standard;
984
985    let Ok(serialized) = serde_json::to_string(&algo) else {
986      return false;
987    };
988    let Ok(deserialized) = serde_json::from_str(&serialized) else {
989      return false;
990    };
991    if algo != deserialized {
992      return false;
993    }
994
995    let Ok(serialized) = bincode::serde::encode_to_vec(algo, standard()) else {
996      return false;
997    };
998
999    let Ok((deserialized, _)) = bincode::serde::decode_from_slice(&serialized, standard()) else {
1000      return false;
1001    };
1002
1003    algo == deserialized
1004  }
1005
1006  #[test]
1007  fn test_encode_base64() {
1008    for k in [
1009      SecretKey::random_aes128(),
1010      SecretKey::random_aes192(),
1011      SecretKey::random_aes256(),
1012    ] {
1013      let mut buf = vec![0; k.base64_len()];
1014      k.encode_base64(&mut buf).unwrap();
1015      assert_eq!(&buf, k.to_base64().as_bytes());
1016    }
1017  }
1018
1019  #[test]
1020  fn test_try_from_str() {
1021    for k in &[
1022      SecretKey::random_aes128(),
1023      SecretKey::random_aes192(),
1024      SecretKey::random_aes256(),
1025    ] {
1026      let s = k.to_base64();
1027      let key = SecretKey::try_from(s.as_str()).unwrap();
1028      assert_eq!(k, key.as_ref());
1029    }
1030
1031    let buf = "invalid base64 string";
1032    let key = SecretKey::try_from(buf);
1033    assert!(key.is_err());
1034
1035    let mut buf = SecretKey::random_aes256().to_base64();
1036    buf.push_str(SecretKey::random_aes128().to_base64().as_str());
1037    let key = SecretKey::try_from(buf.as_str());
1038    assert!(key.is_err());
1039  }
1040
1041  const TEST_KEYS: &[SecretKey] = &[
1042    SecretKey::Aes128([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
1043    SecretKey::Aes128([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]),
1044    SecretKey::Aes128([8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7]),
1045  ];
1046}