1use alloc::boxed::Box;
2use core::fmt;
3
4use chacha20poly1305::{
5 aead::{Aead, AeadCore, Payload},
6 Key, KeyInit, KeySizeUser, XChaCha20Poly1305, XNonce,
7};
8use generic_array::{typenum::Unsigned, ArrayLength, GenericArray};
9use hkdf::Hkdf;
10use rand_core::{CryptoRng, RngCore};
11use sha2::Sha256;
12use zeroize::ZeroizeOnDrop;
13
14use crate::secret_box::SecretBox;
15
16#[derive(Debug, PartialEq, Eq)]
18pub enum EncryptionError {
19 PlaintextTooLarge,
21}
22
23impl fmt::Display for EncryptionError {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 match self {
26 Self::PlaintextTooLarge => write!(f, "Plaintext is too large to encrypt"),
27 }
28 }
29}
30
31#[derive(Debug, PartialEq, Eq)]
33pub enum DecryptionError {
34 CiphertextTooShort,
36 AuthenticationFailed,
42}
43
44impl fmt::Display for DecryptionError {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 Self::CiphertextTooShort => write!(f, "The ciphertext must include the nonce"),
48 Self::AuthenticationFailed => write!(
49 f,
50 "Decryption of ciphertext failed: \
51 either someone tampered with the ciphertext or \
52 you are using an incorrect decryption key."
53 ),
54 }
55 }
56}
57
58pub(crate) fn kdf<S: ArrayLength<u8>>(
59 seed: &[u8],
60 salt: Option<&[u8]>,
61 info: Option<&[u8]>,
62) -> SecretBox<GenericArray<u8, S>> {
63 let hk = Hkdf::<Sha256>::new(salt, seed);
64
65 let mut okm = SecretBox::new(GenericArray::<u8, S>::default());
66
67 let def_info = info.unwrap_or(&[]);
68
69 hk.expand(def_info, okm.as_mut_secret()).unwrap();
71
72 okm
73}
74
75type NonceSize = <XChaCha20Poly1305 as AeadCore>::NonceSize;
76
77#[allow(clippy::upper_case_acronyms)]
78#[derive(ZeroizeOnDrop)]
79pub(crate) struct DEM {
80 cipher: XChaCha20Poly1305,
81}
82
83impl DEM {
84 pub fn new(key_seed: &[u8]) -> Self {
85 type KeySize = <XChaCha20Poly1305 as KeySizeUser>::KeySize;
86 let key_bytes = kdf::<KeySize>(key_seed, None, None);
87 let key = SecretBox::new(*Key::from_slice(key_bytes.as_secret()));
90 let cipher = XChaCha20Poly1305::new(key.as_secret());
91 Self { cipher }
92 }
93
94 pub fn encrypt(
95 &self,
96 rng: &mut (impl CryptoRng + RngCore),
97 data: &[u8],
98 authenticated_data: &[u8],
99 ) -> Result<Box<[u8]>, EncryptionError> {
100 let mut nonce = GenericArray::<u8, NonceSize>::default();
101 rng.fill_bytes(&mut nonce);
102 let nonce = XNonce::from_slice(&nonce);
103 let payload = Payload {
104 msg: data,
105 aad: authenticated_data,
106 };
107
108 let mut result = nonce.to_vec();
109 let enc_data = self
110 .cipher
111 .encrypt(nonce, payload)
112 .or(Err(EncryptionError::PlaintextTooLarge))?;
113
114 result.extend(enc_data);
117 Ok(result.into_boxed_slice())
118 }
119
120 pub fn decrypt(
121 &self,
122 ciphertext: impl AsRef<[u8]>,
123 authenticated_data: &[u8],
124 ) -> Result<Box<[u8]>, DecryptionError> {
125 let nonce_size = <NonceSize as Unsigned>::to_usize();
126 let buf_size = ciphertext.as_ref().len();
127
128 if buf_size < nonce_size {
129 return Err(DecryptionError::CiphertextTooShort);
130 }
131
132 let nonce = XNonce::from_slice(&ciphertext.as_ref()[..nonce_size]);
133 let payload = Payload {
134 msg: &ciphertext.as_ref()[nonce_size..],
135 aad: authenticated_data,
136 };
137 self.cipher
138 .decrypt(nonce, payload)
139 .map(|pt| pt.into_boxed_slice())
140 .or(Err(DecryptionError::AuthenticationFailed))
141 }
142}
143
144#[cfg(test)]
145mod tests {
146
147 use generic_array::typenum::U32;
148
149 use super::kdf;
150 use crate::curve::CurvePoint;
151 use crate::secret_box::SecretBox;
152
153 #[test]
154 fn test_kdf() {
155 let p1 = CurvePoint::generator();
156 let salt = b"abcdefg";
157 let info = b"sdasdasd";
158 let seed = SecretBox::new(p1.to_compressed_array());
159 let key = kdf::<U32>(seed.as_secret(), Some(&salt[..]), Some(&info[..]));
160 let key_same = kdf::<U32>(seed.as_secret(), Some(&salt[..]), Some(&info[..]));
161 assert_eq!(key.as_secret(), key_same.as_secret());
162
163 let key_diff = kdf::<U32>(seed.as_secret(), None, Some(&info[..]));
164 assert_ne!(key.as_secret(), key_diff.as_secret());
165 }
166}