1use chacha20poly1305::aead::{Aead, NewAead};
2use litl::{impl_debug_as_litl, impl_single_tagged_data_serde, SingleTaggedData, TaggedDataError};
3use rand07::{rngs::OsRng, RngCore};
4use serde::{de::DeserializeOwned, Serialize};
5use serde_derive::{Deserialize, Serialize};
6use std::{array::TryFromSliceError, borrow::Cow, fmt::Debug, marker::PhantomData, ops::Deref};
7use thiserror::Error;
8use zeroize::Zeroizing;
9
10pub enum KeySecretInner {
11 KeySecretV1(chacha20poly1305::Key),
12}
13
14impl SingleTaggedData for KeySecretInner {
15 const TAG: &'static str = "keySecret";
16
17 fn as_bytes(&self) -> Cow<[u8]> {
18 match self {
19 KeySecretInner::KeySecretV1(key) => {
20 let bytes: &[u8] = key.as_ref();
21 Cow::from(bytes)
22 }
23 }
24 }
25
26 fn from_bytes(data: &[u8]) -> Result<Self, TaggedDataError>
27 where
28 Self: Sized,
29 {
30 let key_bytes: [u8; CHACHA_20_POLY1305_KEY_LEN] = data
31 .try_into()
32 .map_err(|err| TaggedDataError::data_error(Into::<KeySecretError>::into(err)))?;
33 Ok(KeySecretInner::KeySecretV1(chacha20poly1305::Key::from(
34 key_bytes,
35 )))
36 }
37}
38
39impl KeySecretInner {
40 pub fn const_time_eq(&self, other: &Self) -> bool {
41 match (self, other) {
42 (KeySecretInner::KeySecretV1(a), KeySecretInner::KeySecretV1(b)) => {
43 constant_time_eq::constant_time_eq(a, b)
45 }
46 }
47 }
48}
49
50#[derive(Debug, Error)]
51pub enum KeySecretError {
52 #[error("Invalid KeySecret length")]
53 InvalidLength(#[from] TryFromSliceError),
54}
55
56impl_single_tagged_data_serde!(KeySecretInner);
57
58pub const KEY_ID_LEN: usize = 16;
59pub const CHACHA_20_POLY1305_KEY_LEN: usize = std::mem::size_of::<chacha20poly1305::Key>();
60pub const CHACHA_20_POLY1305_NONCE_LEN: usize = std::mem::size_of::<chacha20poly1305::Nonce>();
61
62#[derive(Copy, Clone, PartialEq, Eq, Hash)]
63pub struct KeyID([u8; KEY_ID_LEN]);
64
65impl SingleTaggedData for KeyID {
66 const TAG: &'static str = "keyID";
67
68 fn as_bytes(&self) -> Cow<[u8]> {
69 Cow::from(self.0.as_ref())
70 }
71
72 fn from_bytes(data: &[u8]) -> Result<Self, TaggedDataError>
73 where
74 Self: Sized,
75 {
76 let key_id_bytes: [u8; KEY_ID_LEN] = data
77 .try_into()
78 .map_err(|err| TaggedDataError::data_error(Into::<KeyIDError>::into(err)))?;
79 Ok(KeyID(key_id_bytes))
80 }
81}
82
83#[derive(Debug, Error)]
84pub enum KeyIDError {
85 #[error("Invalid KeyID length")]
86 InvalidLength(#[from] TryFromSliceError),
87}
88
89impl_single_tagged_data_serde!(KeyID);
90impl_debug_as_litl!(KeyID);
91
92#[derive(Serialize, Deserialize)]
93pub struct KeySecret {
94 pub id: KeyID,
95 secret: KeySecretInner,
96}
97
98impl KeySecret {
99 pub fn new_random() -> Self {
100 let mut key = [0u8; CHACHA_20_POLY1305_KEY_LEN];
101 OsRng {}.fill_bytes(&mut key);
102 let mut id = [0u8; KEY_ID_LEN];
103 OsRng {}.fill_bytes(&mut id);
104 KeySecret {
105 id: KeyID(id),
106 secret: KeySecretInner::KeySecretV1(*chacha20poly1305::Key::from_slice(&key)),
107 }
108 }
109
110 pub fn encrypt<T: Serialize>(&self, value: &T) -> Encrypted<T> {
111 let cipher = chacha20poly1305::ChaCha20Poly1305::new(self);
112 let mut nonce_bytes = [0; CHACHA_20_POLY1305_NONCE_LEN];
113 OsRng {}.fill_bytes(&mut nonce_bytes);
114 let nonce = chacha20poly1305::Nonce::from(nonce_bytes);
115
116 Encrypted {
117 encrypted: EncryptedInner::EncryptedV1(NonceAndCiphertextChacha20Poly1305::new(
118 &nonce,
119 &cipher
120 .encrypt(
121 &nonce,
122 Zeroizing::new(litl::to_vec(value).unwrap()).as_slice(),
123 )
124 .unwrap(),
125 )),
126 for_key: self.id,
127 _marker: PhantomData,
128 }
129 }
130
131 pub fn decrypt<T: DeserializeOwned>(
132 &self,
133 encrypted: &Encrypted<T>,
134 ) -> Result<T, DecryptionError> {
135 match &encrypted.encrypted {
137 EncryptedInner::EncryptedV1(nonce_and_ciphertext) => {
138 let cipher = chacha20poly1305::ChaCha20Poly1305::new(&self);
139 let (nonce, ciphertext) = nonce_and_ciphertext.as_nonce_and_ciphertext();
140
141 let plaintext: Zeroizing<Vec<u8>> = cipher
142 .decrypt(nonce, ciphertext)
143 .map(Zeroizing::new)
144 .map_err(DecryptionError::DecryptionError)?;
145 litl::from_slice_owned::<T>(plaintext.as_slice())
146 .map_err(DecryptionError::DeserializeError)
147 }
148 }
149 }
150
151 pub fn const_time_eq(&self, other: &Self) -> bool {
152 self.secret.const_time_eq(&other.secret)
153 }
154}
155
156impl Deref for KeySecret {
157 type Target = chacha20poly1305::Key;
158
159 fn deref(&self) -> &Self::Target {
160 match self.secret {
161 KeySecretInner::KeySecretV1(ref key) => key,
162 }
163 }
164}
165
166impl Debug for KeySecret {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 f.write_str("(Secret Key)")
169 }
170}
171
172#[derive(Error, Debug)]
173pub enum DecryptionError {
174 #[error("Decryption error.")]
175 DecryptionError(chacha20poly1305::aead::Error),
176 #[error("Error converting from decrypted bytes.")]
177 DeserializeError(litl::Error),
178}
179
180#[derive(Clone, PartialEq, Eq, Hash)]
181pub struct NonceAndCiphertextChacha20Poly1305(Vec<u8>);
182
183impl NonceAndCiphertextChacha20Poly1305 {
184 pub fn new(nonce: &chacha20poly1305::Nonce, ciphertext: &[u8]) -> Self {
185 let mut bytes = Vec::with_capacity(ciphertext.len() + CHACHA_20_POLY1305_NONCE_LEN);
186 bytes.extend_from_slice(nonce);
187 bytes.extend_from_slice(ciphertext);
188 NonceAndCiphertextChacha20Poly1305(bytes)
189 }
190
191 pub fn as_nonce_and_ciphertext(&self) -> (&chacha20poly1305::Nonce, &[u8]) {
192 (
193 chacha20poly1305::Nonce::from_slice(&self.0[0..CHACHA_20_POLY1305_NONCE_LEN]),
194 &self.0[CHACHA_20_POLY1305_NONCE_LEN..],
195 )
196 }
197}
198
199#[derive(Clone, PartialEq, Eq, Hash)]
200pub enum EncryptedInner {
201 EncryptedV1(NonceAndCiphertextChacha20Poly1305),
202}
203
204impl SingleTaggedData for EncryptedInner {
205 const TAG: &'static str = "encryptedV1";
206
207 fn as_bytes(&self) -> Cow<[u8]> {
208 match self {
209 EncryptedInner::EncryptedV1(nonce_and_ciphertext) => {
210 Cow::from(nonce_and_ciphertext.0.as_slice())
211 }
212 }
213 }
214
215 fn from_bytes(data: &[u8]) -> Result<Self, TaggedDataError>
216 where
217 Self: Sized,
218 {
219 Ok(EncryptedInner::EncryptedV1(
220 NonceAndCiphertextChacha20Poly1305(data.to_vec()),
221 ))
222 }
223}
224
225impl_single_tagged_data_serde!(EncryptedInner);
226impl_debug_as_litl!(EncryptedInner);
227
228#[derive(Serialize, Deserialize)]
229pub struct Encrypted<T> {
230 for_key: KeyID,
231 encrypted: EncryptedInner,
232 #[serde(skip)]
233 _marker: PhantomData<T>,
234}
235
236impl<T> PartialEq for Encrypted<T> {
237 fn eq(&self, other: &Self) -> bool {
238 self.for_key == other.for_key && self.encrypted == other.encrypted
239 }
240}
241
242impl<T> Eq for Encrypted<T> {}
243
244impl<T> Clone for Encrypted<T> {
245 fn clone(&self) -> Self {
246 Encrypted {
247 for_key: self.for_key,
248 encrypted: self.encrypted.clone(),
249 _marker: PhantomData,
250 }
251 }
252}
253
254#[allow(clippy::derive_hash_xor_eq)]
255impl<T> std::hash::Hash for Encrypted<T> {
256 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
257 self.for_key.hash(state);
258 self.encrypted.hash(state);
259 }
260}
261
262impl<T> Encrypted<T> {
263 pub fn as_encrypted_value(self) -> Encrypted<litl::Val> {
264 Encrypted {
265 for_key: self.for_key,
266 encrypted: self.encrypted,
267 _marker: PhantomData,
268 }
269 }
270}
271
272impl Encrypted<litl::Val> {
273 pub fn as_encrypted_type<T: DeserializeOwned>(self) -> Encrypted<T> {
274 Encrypted {
275 for_key: self.for_key,
276 encrypted: self.encrypted,
277 _marker: PhantomData,
278 }
279 }
280}
281
282impl<T> Debug for Encrypted<T> {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 f.write_str(&litl::to_string_pretty(&self).unwrap())
285 }
286}
287
288#[cfg(test)]
289mod test {
290 use serde_derive::{Deserialize, Serialize};
291
292 use crate::symm_encr::KeySecret;
293
294 #[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
295 struct TestData {
296 bla: [u8; 4],
297 }
298
299 #[test]
300 fn encryption_roundtrip_works() {
301 let data = TestData { bla: [0, 1, 2, 4] };
302 let key = KeySecret::new_random();
303 let encrypted = key.encrypt(&data);
304 println!("{:?}", encrypted);
305 let decrypted = key.decrypt(&encrypted).unwrap();
306 assert_eq!(decrypted, data);
307 }
308
309}