1use std::borrow::Cow;
2
3use aead::{AeadInPlace, KeyInit};
4use aes_gcm::{
5 Aes128Gcm, Aes256Gcm, AesGcm,
6 aes::{Aes192, cipher::consts::U12},
7};
8use bytes::{Buf, BufMut};
9use generic_array::GenericArray;
10use rand::Rng;
11use varing::decode_u32_varint;
12
13use crate::{WireType, utils::merge};
14
15use super::{Data, DataRef, DecodeError, EncodeError};
16
17const NOPADDING_TAG: u8 = 1;
18const PKCS7_TAG: u8 = 2;
19
20const NONCE_SIZE: usize = 12;
21const TAG_SIZE: usize = 16;
22const BLOCK_SIZE: usize = 16;
23
24type Aes192Gcm = AesGcm<Aes192, U12>;
25
26use std::str::FromStr;
27
28use base64::{Engine as _, engine::general_purpose::STANDARD as b64};
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash, thiserror::Error)]
32#[error("unknown encryption algorithm: {0}")]
33pub struct ParseEncryptionAlgorithmError(String);
34
35impl FromStr for EncryptionAlgorithm {
36 type Err = ParseEncryptionAlgorithmError;
37
38 fn from_str(s: &str) -> Result<Self, Self::Err> {
39 Ok(match s {
40 "aes-gcm-no-padding" | "aes-gcm-nopadding" | "nopadding" | "NOPADDING" | "no-padding"
41 | "NoPadding" | "no_padding" => Self::NoPadding,
42 "aes-gcm-pkcs7" | "PKCS7" | "pkcs7" => Self::Pkcs7,
43 s if s.starts_with("unknown") => {
44 let v = s
45 .trim_start_matches("unknown(")
46 .trim_end_matches(')')
47 .parse()
48 .map_err(|_| ParseEncryptionAlgorithmError(s.to_string()))?;
49 Self::Unknown(v)
50 }
51 e => return Err(ParseEncryptionAlgorithmError(e.to_string())),
52 })
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
58#[error("invalid key length({0}) - must be 16, 24, or 32 bytes for AES-128/192/256")]
59pub struct InvalidKeyLength(pub(crate) usize);
60
61#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
63pub enum ParseSecretKeyError {
64 #[cfg_attr(feature = "std", error(transparent))]
66 #[cfg_attr(not(feature = "std"), error("{0}"))]
67 Base64(#[cfg_attr(feature = "std", from)] base64::DecodeError),
68 #[error(transparent)]
70 InvalidKeyLength(#[from] InvalidKeyLength),
71}
72
73#[cfg(not(feature = "std"))]
74impl From<base64::DecodeError> for ParseSecretKeyError {
75 fn from(e: base64::DecodeError) -> Self {
76 Self::Base64(e)
77 }
78}
79
80#[derive(
82 Debug,
83 Copy,
84 Clone,
85 PartialEq,
86 Eq,
87 PartialOrd,
88 Ord,
89 derive_more::IsVariant,
90 derive_more::TryUnwrap,
91 derive_more::Unwrap,
92)]
93#[unwrap(ref, ref_mut)]
94#[try_unwrap(ref, ref_mut)]
95#[cfg_attr(any(feature = "arbitrary", test), derive(arbitrary::Arbitrary))]
96pub enum SecretKey {
97 Aes128([u8; 16]),
99 Aes192([u8; 24]),
101 Aes256([u8; 32]),
103}
104
105impl SecretKey {
106 pub fn to_base64(&self) -> String {
111 b64.encode(self)
112 }
113
114 #[inline]
131 pub const fn base64_len(&self) -> usize {
132 match self {
133 Self::Aes128(_) => 24,
134 Self::Aes192(_) => 32,
135 Self::Aes256(_) => 44,
136 }
137 }
138
139 #[inline]
162 pub fn encode_base64(&self, buf: &mut [u8]) -> Result<usize, base64::EncodeSliceError> {
163 b64.encode_slice(self.as_ref(), buf)
164 }
165
166 #[inline]
168 pub fn random_aes128() -> Self {
169 let mut key = [0u8; 16];
170 rand::rng().fill(&mut key);
171 Self::Aes128(key)
172 }
173
174 #[inline]
176 pub fn random_aes192() -> Self {
177 let mut key = [0u8; 24];
178 rand::rng().fill(&mut key);
179 Self::Aes192(key)
180 }
181
182 #[inline]
184 pub fn random_aes256() -> Self {
185 let mut key = [0u8; 32];
186 rand::rng().fill(&mut key);
187 Self::Aes256(key)
188 }
189}
190
191impl TryFrom<&str> for SecretKey {
192 type Error = ParseSecretKeyError;
193
194 fn try_from(s: &str) -> Result<Self, Self::Error> {
195 s.parse()
196 }
197}
198
199impl TryFrom<&[u8]> for SecretKey {
200 type Error = InvalidKeyLength;
201
202 fn try_from(k: &[u8]) -> Result<Self, Self::Error> {
203 Ok(match k.len() {
204 16 => Self::Aes128(k.try_into().unwrap()),
205 24 => Self::Aes192(k.try_into().unwrap()),
206 32 => Self::Aes256(k.try_into().unwrap()),
207 v => return Err(InvalidKeyLength(v)),
208 })
209 }
210}
211
212impl FromStr for SecretKey {
213 type Err = ParseSecretKeyError;
214
215 fn from_str(s: &str) -> Result<Self, Self::Err> {
216 let mut buf = [0u8; 44];
217 let readed = b64.decode_slice(s, &mut buf).map_err(|e| match e {
218 base64::DecodeSliceError::DecodeError(decode_error) => decode_error.into(),
219 base64::DecodeSliceError::OutputSliceTooSmall => {
220 ParseSecretKeyError::InvalidKeyLength(InvalidKeyLength(s.len()))
221 }
222 })?;
223
224 let bytes = &buf[..readed];
225 Ok(match readed {
226 16 => SecretKey::Aes128(bytes[..readed].try_into().unwrap()),
227 24 => SecretKey::Aes192(bytes[..readed].try_into().unwrap()),
228 32 => SecretKey::Aes256(bytes[..readed].try_into().unwrap()),
229 v => return Err(ParseSecretKeyError::InvalidKeyLength(InvalidKeyLength(v))),
230 })
231 }
232}
233
234impl core::hash::Hash for SecretKey {
235 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
236 self.as_ref().hash(state);
237 }
238}
239
240impl core::borrow::Borrow<[u8]> for SecretKey {
241 fn borrow(&self) -> &[u8] {
242 self.as_ref()
243 }
244}
245
246impl PartialEq<[u8]> for SecretKey {
247 fn eq(&self, other: &[u8]) -> bool {
248 self.as_ref() == other
249 }
250}
251
252impl core::ops::Deref for SecretKey {
253 type Target = [u8];
254
255 fn deref(&self) -> &Self::Target {
256 match self {
257 Self::Aes128(k) => k,
258 Self::Aes192(k) => k,
259 Self::Aes256(k) => k,
260 }
261 }
262}
263
264impl core::ops::DerefMut for SecretKey {
265 fn deref_mut(&mut self) -> &mut Self::Target {
266 match self {
267 Self::Aes128(k) => k,
268 Self::Aes192(k) => k,
269 Self::Aes256(k) => k,
270 }
271 }
272}
273
274impl From<[u8; 16]> for SecretKey {
275 fn from(k: [u8; 16]) -> Self {
276 Self::Aes128(k)
277 }
278}
279
280impl From<[u8; 24]> for SecretKey {
281 fn from(k: [u8; 24]) -> Self {
282 Self::Aes192(k)
283 }
284}
285
286impl From<[u8; 32]> for SecretKey {
287 fn from(k: [u8; 32]) -> Self {
288 Self::Aes256(k)
289 }
290}
291
292impl AsRef<[u8]> for SecretKey {
293 fn as_ref(&self) -> &[u8] {
294 match self {
295 Self::Aes128(k) => k,
296 Self::Aes192(k) => k,
297 Self::Aes256(k) => k,
298 }
299 }
300}
301
302impl AsMut<[u8]> for SecretKey {
303 fn as_mut(&mut self) -> &mut [u8] {
304 match self {
305 Self::Aes128(k) => k,
306 Self::Aes192(k) => k,
307 Self::Aes256(k) => k,
308 }
309 }
310}
311
312impl<'a> DataRef<'a, Self> for SecretKey {
313 fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError>
314 where
315 Self: Sized,
316 {
317 let mut offset = 0;
318 let buf_len = buf.len();
319 let mut key = None;
320
321 while offset < buf_len {
322 match buf[offset] {
323 AES128_BYTE => {
324 if key.is_some() {
325 return Err(DecodeError::duplicate_field("SecretKey", "key", 0));
326 }
327 offset += 1;
328
329 let (bytes_read, val) = decode_u32_varint(&buf[offset..])?;
330 offset += bytes_read.get();
331
332 let val: [u8; 16] = buf[offset..offset + val as usize]
333 .try_into()
334 .map_err(|_| DecodeError::buffer_underflow())?;
335 offset += 16;
336 key = Some(SecretKey::Aes128(val));
337 }
338 AES192_BYTE => {
339 if key.is_some() {
340 return Err(DecodeError::duplicate_field("SecretKey", "key", 0));
341 }
342 offset += 1;
343
344 let (bytes_read, val) = decode_u32_varint(&buf[offset..])?;
345 offset += bytes_read.get();
346
347 let val: [u8; 24] = buf[offset..offset + val as usize]
348 .try_into()
349 .map_err(|_| DecodeError::buffer_underflow())?;
350 offset += 24;
351
352 key = Some(SecretKey::Aes192(val));
353 }
354 AES256_BYTE => {
355 if key.is_some() {
356 return Err(DecodeError::duplicate_field("SecretKey", "key", 0));
357 }
358 offset += 1;
359
360 let (bytes_read, val) = decode_u32_varint(&buf[offset..])?;
361 offset += bytes_read.get();
362
363 let val: [u8; 32] = buf[offset..offset + val as usize]
364 .try_into()
365 .map_err(|_| DecodeError::buffer_underflow())?;
366 offset += 32;
367
368 key = Some(SecretKey::Aes256(val));
369 }
370 _ => offset += super::skip("SecretKey", &buf[offset..])?,
371 }
372 }
373
374 let key = key.ok_or_else(|| DecodeError::missing_field("SecretKey", "key"))?;
375 Ok((offset, key))
376 }
377}
378
379const AES128_TAG: u8 = 1;
380const AES192_TAG: u8 = 2;
381const AES256_TAG: u8 = 3;
382
383const AES128_BYTE: u8 = merge(WireType::LengthDelimited, AES128_TAG);
384const AES192_BYTE: u8 = merge(WireType::LengthDelimited, AES192_TAG);
385const AES256_BYTE: u8 = merge(WireType::LengthDelimited, AES256_TAG);
386
387impl Data for SecretKey {
388 type Ref<'a> = Self;
389
390 fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
391 where
392 Self: Sized,
393 {
394 Ok(val)
395 }
396
397 fn encoded_len(&self) -> usize {
398 1 + varing::encoded_u32_varint_len(self.len() as u32).get() + self.len()
399 }
400
401 fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
402 let buf_len = buf.len();
403 let mut offset = 0;
404
405 if buf_len < 1 {
406 return Err(EncodeError::insufficient_buffer(
407 self.encoded_len(),
408 buf_len,
409 ));
410 }
411
412 buf[offset] = match self {
413 Self::Aes128(_) => AES128_BYTE,
414 Self::Aes192(_) => AES192_BYTE,
415 Self::Aes256(_) => AES256_BYTE,
416 };
417 offset += 1;
418
419 let self_len = self.len();
420 let len = varing::encode_u32_varint_to(self_len as u32, &mut buf[offset..])
421 .map_err(|_| EncodeError::insufficient_buffer(self.encoded_len(), buf_len))?;
422 offset += len.get();
423
424 buf[offset..offset + self_len].copy_from_slice(self.as_ref());
425 offset += self_len;
426
427 #[cfg(debug_assertions)]
428 super::debug_assert_write_eq::<Self>(offset, self.encoded_len());
429
430 Ok(offset)
431 }
432}
433
434smallvec_wrapper::smallvec_wrapper!(
435 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
437 #[repr(transparent)]
438 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
439 #[cfg_attr(feature = "serde", serde(transparent))]
440 pub SecretKeys([SecretKey; 3]);
441);
442
443impl SecretKeys {
444 #[inline]
455 pub fn is_empty(&self) -> bool {
456 self.0.is_empty()
457 }
458
459 #[inline]
470 pub fn len(&self) -> usize {
471 self.0.len()
472 }
473}
474
475#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)]
477pub enum EncryptionError {
478 #[error("unknown encryption algorithm: {0}")]
480 UnknownAlgorithm(EncryptionAlgorithm),
481 #[error("failed to encrypt/decrypt")]
483 Encryptor,
484}
485
486impl From<aead::Error> for EncryptionError {
487 fn from(_: aead::Error) -> Self {
488 Self::Encryptor
489 }
490}
491
492#[derive(
494 Debug, Default, Clone, Copy, PartialEq, Eq, Hash, derive_more::IsVariant, derive_more::Display,
495)]
496#[non_exhaustive]
497pub enum EncryptionAlgorithm {
498 #[default]
500 #[display("aes-gcm-nopadding")]
501 NoPadding,
502 #[display("aes-gcm-pkcs7")]
504 Pkcs7,
505 #[display("unknown({_0})")]
507 Unknown(u8),
508}
509
510#[cfg(any(feature = "quickcheck", test))]
511const _: () = {
512 use quickcheck::Arbitrary;
513
514 impl EncryptionAlgorithm {
515 const MAX: Self = Self::NoPadding;
516 const MIN: Self = Self::Pkcs7;
517 }
518
519 impl Arbitrary for EncryptionAlgorithm {
520 fn arbitrary(g: &mut quickcheck::Gen) -> Self {
521 let val = (u8::arbitrary(g) % Self::MAX.as_u8()) + Self::MIN.as_u8();
522 match val {
523 NOPADDING_TAG => Self::NoPadding,
524 PKCS7_TAG => Self::Pkcs7,
525 _ => unreachable!(),
526 }
527 }
528 }
529};
530
531impl EncryptionAlgorithm {
532 #[inline]
534 pub const fn as_u8(&self) -> u8 {
535 match self {
536 Self::NoPadding => NOPADDING_TAG,
537 Self::Pkcs7 => PKCS7_TAG,
538 Self::Unknown(v) => *v,
539 }
540 }
541
542 #[inline]
544 pub fn as_str(&self) -> Cow<'static, str> {
545 let val = match self {
546 Self::NoPadding => "aes-gcm-nopadding",
547 Self::Pkcs7 => "aes-gcm-pkcs7",
548 Self::Unknown(e) => return Cow::Owned(format!("unknown({})", e)),
549 };
550 Cow::Borrowed(val)
551 }
552}
553
554impl From<u8> for EncryptionAlgorithm {
555 fn from(value: u8) -> Self {
556 match value {
557 NOPADDING_TAG => Self::NoPadding,
558 PKCS7_TAG => Self::Pkcs7,
559 e => Self::Unknown(e),
560 }
561 }
562}
563
564impl EncryptionAlgorithm {
565 #[inline]
567 pub const fn nonce_size(&self) -> usize {
568 NONCE_SIZE
570 }
571
572 pub fn write_nonce(dst: &mut impl BufMut) -> [u8; NONCE_SIZE] {
574 let mut nonce = [0u8; NONCE_SIZE];
576 rand::rng().fill(&mut nonce);
577 dst.put_slice(&nonce);
578
579 nonce
580 }
581
582 pub fn random_nonce() -> [u8; NONCE_SIZE] {
584 let mut nonce = [0u8; NONCE_SIZE];
585 rand::rng().fill(&mut nonce);
586 nonce
587 }
588
589 pub fn read_nonce(src: &mut impl Buf) -> [u8; NONCE_SIZE] {
591 let mut nonce = [0u8; NONCE_SIZE];
592 nonce.copy_from_slice(&src.chunk()[..NONCE_SIZE]);
593 src.advance(NONCE_SIZE);
594 nonce
595 }
596
597 pub fn encrypt<B>(
599 &self,
600 pk: SecretKey,
601 nonce: [u8; NONCE_SIZE],
602 auth_data: &[u8],
603 buf: &mut B,
604 ) -> Result<(), EncryptionError>
605 where
606 B: aead::Buffer,
607 {
608 match self {
609 EncryptionAlgorithm::NoPadding => {}
610 EncryptionAlgorithm::Pkcs7 => {
611 let buf_len = buf.len();
612 pkcs7encode(buf, buf_len, 0)?;
613 }
614 _ => return Err(EncryptionError::UnknownAlgorithm(*self)),
615 }
616
617 match pk {
618 SecretKey::Aes128(pk) => {
619 let gcm = Aes128Gcm::new(GenericArray::from_slice(&pk).as_ref());
620 gcm
621 .encrypt_in_place(GenericArray::from_slice(&nonce).as_ref(), auth_data, buf)
622 .map_err(Into::into)
623 }
624 SecretKey::Aes192(pk) => {
625 let gcm = Aes192Gcm::new(GenericArray::from_slice(&pk).as_ref());
626 gcm
627 .encrypt_in_place(GenericArray::from_slice(&nonce).as_ref(), auth_data, buf)
628 .map_err(Into::into)
629 }
630 SecretKey::Aes256(pk) => {
631 let gcm = Aes256Gcm::new(GenericArray::from_slice(&pk).as_ref());
632 gcm
633 .encrypt_in_place(GenericArray::from_slice(&nonce).as_ref(), auth_data, buf)
634 .map_err(Into::into)
635 }
636 }
637 }
638
639 pub fn decrypt(
641 &self,
642 key: &SecretKey,
643 nonce: &[u8],
644 auth_data: &[u8],
645 dst: &mut impl aead::Buffer,
646 ) -> Result<(), EncryptionError> {
647 if self.is_unknown() {
648 return Err(EncryptionError::UnknownAlgorithm(*self));
649 }
650
651 match key {
653 SecretKey::Aes128(pk) => {
654 let gcm = Aes128Gcm::new(GenericArray::from_slice(pk).as_ref());
655 gcm
656 .decrypt_in_place(GenericArray::from_slice(nonce).as_ref(), auth_data, dst)
657 .map_err(Into::into)
658 }
659 SecretKey::Aes192(pk) => {
660 let gcm = Aes192Gcm::new(GenericArray::from_slice(pk).as_ref());
661 gcm
662 .decrypt_in_place(GenericArray::from_slice(nonce).as_ref(), auth_data, dst)
663 .map_err(Into::into)
664 }
665 SecretKey::Aes256(pk) => {
666 let gcm = Aes256Gcm::new(GenericArray::from_slice(pk).as_ref());
667 gcm
668 .decrypt_in_place(GenericArray::from_slice(nonce).as_ref(), auth_data, dst)
669 .map_err(Into::into)
670 }
671 }
672 .inspect(|_| {
673 if self.is_pkcs_7() {
674 pkcs7decode(dst);
675 }
676 })
677 }
678
679 #[inline]
681 pub(crate) const fn encrypt_overhead(&self) -> usize {
682 match self {
683 EncryptionAlgorithm::Pkcs7 => 44, EncryptionAlgorithm::NoPadding => 28, _ => unreachable!(),
686 }
687 }
688
689 #[inline]
691 pub const fn encrypted_suffix_len(&self, inp: usize) -> usize {
692 match self {
693 EncryptionAlgorithm::Pkcs7 => {
694 let padding = BLOCK_SIZE - (inp % BLOCK_SIZE);
696
697 padding + TAG_SIZE
699 }
700 EncryptionAlgorithm::NoPadding => TAG_SIZE,
701 _ => unreachable!(),
702 }
703 }
704}
705
706#[inline]
709fn pkcs7encode(
710 buf: &mut impl aead::Buffer,
711 buf_len: usize,
712 ignore: usize,
713) -> Result<(), aead::Error> {
714 let n = buf_len - ignore;
715 let more = BLOCK_SIZE - (n % BLOCK_SIZE);
716 let mut block_buf = [0u8; BLOCK_SIZE];
717 block_buf
718 .iter_mut()
719 .take(more)
720 .for_each(|b| *b = more as u8);
721 buf.extend_from_slice(&block_buf[..more])
722}
723
724#[inline]
726fn pkcs7decode(buf: &mut impl aead::Buffer) {
727 if buf.is_empty() {
728 panic!("Cannot decode a PKCS7 buffer of zero length");
729 }
730 let n = buf.len();
731 let last = buf.as_ref()[n - 1];
732 let n = n - (last as usize);
733 buf.truncate(n);
734}
735
736#[cfg(feature = "serde")]
737const _: () = {
738 use serde::{Deserialize, Deserializer, Serialize, Serializer};
739
740 impl Serialize for EncryptionAlgorithm {
741 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
742 where
743 S: Serializer,
744 {
745 if serializer.is_human_readable() {
746 serializer.serialize_str(self.as_str().as_ref())
747 } else {
748 serializer.serialize_u8(self.as_u8())
749 }
750 }
751 }
752
753 impl<'de> Deserialize<'de> for EncryptionAlgorithm {
754 fn deserialize<D>(deserializer: D) -> Result<EncryptionAlgorithm, D::Error>
755 where
756 D: Deserializer<'de>,
757 {
758 if deserializer.is_human_readable() {
759 <&str>::deserialize(deserializer).and_then(|s| s.parse().map_err(serde::de::Error::custom))
760 } else {
761 u8::deserialize(deserializer).map(EncryptionAlgorithm::from)
762 }
763 }
764 }
765};
766
767#[cfg(test)]
768mod tests {
769 use bytes::BytesMut;
770
771 use core::ops::{Deref, DerefMut};
772
773 use arbitrary::Arbitrary;
774
775 use super::*;
776
777 impl super::EncryptionAlgorithm {
778 #[inline]
780 const fn encrypted_len(&self, inp: usize) -> usize {
781 match self {
782 EncryptionAlgorithm::Pkcs7 => {
783 let padding = BLOCK_SIZE - (inp % BLOCK_SIZE);
785
786 4 + 1 + NONCE_SIZE + inp + padding + TAG_SIZE
788 }
789 EncryptionAlgorithm::NoPadding => 4 + 1 + NONCE_SIZE + inp + TAG_SIZE,
790 _ => unreachable!(),
791 }
792 }
793 }
794
795 #[test]
796 fn arbitrary_secret_key() {
797 let key = SecretKey::arbitrary(&mut arbitrary::Unstructured::new(&[0; 128])).unwrap();
798 assert!(matches!(
799 key,
800 SecretKey::Aes128(_) | SecretKey::Aes192(_) | SecretKey::Aes256(_)
801 ));
802 }
803
804 #[test]
805 fn test_secret_key() {
806 let mut key = SecretKey::from([0; 16]);
807 assert_eq!(key.deref(), &[0; 16]);
808 assert_eq!(key.deref_mut(), &mut [0; 16]);
809 assert_eq!(key.as_ref(), &[0; 16]);
810 assert_eq!(key.as_mut(), &mut [0; 16]);
811 assert_eq!(key.len(), 16);
812 assert!(!key.is_empty());
813 assert_eq!(key.to_vec(), vec![0; 16]);
814
815 let mut key = SecretKey::from([0; 24]);
816 assert_eq!(key.deref(), &[0; 24]);
817 assert_eq!(key.deref_mut(), &mut [0; 24]);
818 assert_eq!(key.as_ref(), &[0; 24]);
819 assert_eq!(key.as_mut(), &mut [0; 24]);
820 assert_eq!(key.len(), 24);
821 assert!(!key.is_empty());
822 assert_eq!(key.to_vec(), vec![0; 24]);
823
824 let mut key = SecretKey::from([0; 32]);
825 assert_eq!(key.deref(), &[0; 32]);
826 assert_eq!(key.deref_mut(), &mut [0; 32]);
827 assert_eq!(key.as_ref(), &[0; 32]);
828 assert_eq!(key.as_mut(), &mut [0; 32]);
829 assert_eq!(key.len(), 32);
830 assert!(!key.is_empty());
831 assert_eq!(key.to_vec(), vec![0; 32]);
832
833 let mut key = SecretKey::from([0; 16]);
834 assert_eq!(key.as_ref(), &[0; 16]);
835 assert_eq!(key.as_mut(), &mut [0; 16]);
836
837 let mut key = SecretKey::from([0; 24]);
838 assert_eq!(key.as_ref(), &[0; 24]);
839 assert_eq!(key.as_mut(), &mut [0; 24]);
840
841 let mut key = SecretKey::from([0; 32]);
842 assert_eq!(key.as_ref(), &[0; 32]);
843 assert_eq!(key.as_mut(), &mut [0; 32]);
844
845 let key = SecretKey::Aes128([0; 16]);
846 assert_eq!(key.to_vec(), vec![0; 16]);
847
848 let key = SecretKey::Aes192([0; 24]);
849 assert_eq!(key.to_vec(), vec![0; 24]);
850
851 let key = SecretKey::Aes256([0; 32]);
852 assert_eq!(key.to_vec(), vec![0; 32]);
853 }
854
855 #[test]
856 fn test_try_from() {
857 assert!(SecretKey::try_from([0; 15].as_slice()).is_err());
858 assert!(SecretKey::try_from([0; 16].as_slice()).is_ok());
859 assert!(SecretKey::try_from([0; 23].as_slice()).is_err());
860 assert!(SecretKey::try_from([0; 24].as_slice()).is_ok());
861 assert!(SecretKey::try_from([0; 31].as_slice()).is_err());
862 assert!(SecretKey::try_from([0; 32].as_slice()).is_ok());
863 }
864
865 fn encrypt_decrypt_versioned(vsn: EncryptionAlgorithm) {
866 let k1 = SecretKey::Aes128([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
867 let plain_text = b"this is a plain text message";
868 let extra = b"random data";
869
870 let mut encrypted = BytesMut::new();
871 let nonce = EncryptionAlgorithm::write_nonce(&mut encrypted);
872 let data_offset = encrypted.len();
873 encrypted.put_slice(plain_text);
874
875 let mut dst = encrypted.split_off(data_offset);
876 println!("before encrypted: {} {:?}", dst.len(), dst.as_ref());
877 vsn.encrypt(k1, nonce, extra, &mut dst).unwrap();
878 println!("encrypted: {} {:?}", dst.len(), dst.as_ref());
879 encrypted.unsplit(dst);
880
881 let exp_len = vsn.encrypted_len(plain_text.len());
882 assert_eq!(encrypted.len(), exp_len - 5); EncryptionAlgorithm::read_nonce(&mut encrypted);
885 vsn.decrypt(&k1, &nonce, extra, &mut encrypted).unwrap();
886 assert_eq!(encrypted.as_ref(), plain_text);
887 }
888
889 fn decrypt_by_other_key(algo: EncryptionAlgorithm) {
890 let k1 = SecretKey::Aes128([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
891 let plain_text = b"this is a plain text message";
892 let extra = b"random data";
893
894 let mut encrypted = BytesMut::new();
895 let nonce = EncryptionAlgorithm::write_nonce(&mut encrypted);
896 let data_offset = encrypted.len();
897 encrypted.put_slice(plain_text);
898
899 let mut dst = encrypted.split_off(data_offset);
900 algo.encrypt(k1, nonce, extra, &mut dst).unwrap();
901 encrypted.unsplit(dst);
902 let exp_len = algo.encrypted_len(plain_text.len());
903 assert_eq!(encrypted.len(), exp_len - 5); EncryptionAlgorithm::read_nonce(&mut encrypted);
905
906 for (idx, k) in TEST_KEYS.iter().rev().enumerate() {
907 if idx == TEST_KEYS.len() - 1 {
908 algo.decrypt(k, &nonce, extra, &mut encrypted).unwrap();
909 assert_eq!(encrypted.as_ref(), plain_text);
910 return;
911 }
912 let e = algo.decrypt(k, &nonce, extra, &mut encrypted).unwrap_err();
913 assert_eq!(e.to_string(), "failed to encrypt/decrypt");
914 }
915 }
916
917 #[test]
918 fn test_encrypt_decrypt_v0() {
919 encrypt_decrypt_versioned(EncryptionAlgorithm::Pkcs7);
920 }
921
922 #[test]
923 fn test_encrypt_decrypt_v1() {
924 encrypt_decrypt_versioned(EncryptionAlgorithm::NoPadding);
925 }
926
927 #[test]
928 fn test_decrypt_by_other_key_v0() {
929 let algo = EncryptionAlgorithm::Pkcs7;
930 decrypt_by_other_key(algo);
931 }
932
933 #[test]
934 fn test_decrypt_by_other_key_v1() {
935 let algo = EncryptionAlgorithm::NoPadding;
936 decrypt_by_other_key(algo);
937 }
938
939 #[test]
940 fn test_encrypt_algorithm_from_str() {
941 assert_eq!(
942 "aes-gcm-no-padding".parse::<EncryptionAlgorithm>().unwrap(),
943 EncryptionAlgorithm::NoPadding
944 );
945 assert_eq!(
946 "aes-gcm-nopadding".parse::<EncryptionAlgorithm>().unwrap(),
947 EncryptionAlgorithm::NoPadding
948 );
949 assert_eq!(
950 "aes-gcm-pkcs7".parse::<EncryptionAlgorithm>().unwrap(),
951 EncryptionAlgorithm::Pkcs7
952 );
953 assert_eq!(
954 "NoPadding".parse::<EncryptionAlgorithm>().unwrap(),
955 EncryptionAlgorithm::NoPadding
956 );
957 assert_eq!(
958 "no-padding".parse::<EncryptionAlgorithm>().unwrap(),
959 EncryptionAlgorithm::NoPadding
960 );
961 assert_eq!(
962 "nopadding".parse::<EncryptionAlgorithm>().unwrap(),
963 EncryptionAlgorithm::NoPadding
964 );
965 assert_eq!(
966 "no_padding".parse::<EncryptionAlgorithm>().unwrap(),
967 EncryptionAlgorithm::NoPadding
968 );
969 assert_eq!(
970 "NOPADDING".parse::<EncryptionAlgorithm>().unwrap(),
971 EncryptionAlgorithm::NoPadding
972 );
973 assert_eq!(
974 "unknown(33)".parse::<EncryptionAlgorithm>().unwrap(),
975 EncryptionAlgorithm::Unknown(33)
976 );
977 assert!("unknown".parse::<EncryptionAlgorithm>().is_err());
978 }
979
980 #[cfg(feature = "serde")]
981 #[quickcheck_macros::quickcheck]
982 fn encryption_algorithm_serde(algo: EncryptionAlgorithm) -> bool {
983 use bincode::config::standard;
984
985 let Ok(serialized) = serde_json::to_string(&algo) else {
986 return false;
987 };
988 let Ok(deserialized) = serde_json::from_str(&serialized) else {
989 return false;
990 };
991 if algo != deserialized {
992 return false;
993 }
994
995 let Ok(serialized) = bincode::serde::encode_to_vec(algo, standard()) else {
996 return false;
997 };
998
999 let Ok((deserialized, _)) = bincode::serde::decode_from_slice(&serialized, standard()) else {
1000 return false;
1001 };
1002
1003 algo == deserialized
1004 }
1005
1006 #[test]
1007 fn test_encode_base64() {
1008 for k in [
1009 SecretKey::random_aes128(),
1010 SecretKey::random_aes192(),
1011 SecretKey::random_aes256(),
1012 ] {
1013 let mut buf = vec![0; k.base64_len()];
1014 k.encode_base64(&mut buf).unwrap();
1015 assert_eq!(&buf, k.to_base64().as_bytes());
1016 }
1017 }
1018
1019 #[test]
1020 fn test_try_from_str() {
1021 for k in &[
1022 SecretKey::random_aes128(),
1023 SecretKey::random_aes192(),
1024 SecretKey::random_aes256(),
1025 ] {
1026 let s = k.to_base64();
1027 let key = SecretKey::try_from(s.as_str()).unwrap();
1028 assert_eq!(k, key.as_ref());
1029 }
1030
1031 let buf = "invalid base64 string";
1032 let key = SecretKey::try_from(buf);
1033 assert!(key.is_err());
1034
1035 let mut buf = SecretKey::random_aes256().to_base64();
1036 buf.push_str(SecretKey::random_aes128().to_base64().as_str());
1037 let key = SecretKey::try_from(buf.as_str());
1038 assert!(key.is_err());
1039 }
1040
1041 const TEST_KEYS: &[SecretKey] = &[
1042 SecretKey::Aes128([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
1043 SecretKey::Aes128([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]),
1044 SecretKey::Aes128([8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7]),
1045 ];
1046}