ridl/
symm_encr.rs

1use chacha20poly1305::aead::{Aead, NewAead};
2use litl::{impl_debug_as_litl, impl_single_tagged_data_serde, SingleTaggedData, TaggedDataError};
3use rand07::{rngs::OsRng, RngCore};
4use serde::{de::DeserializeOwned, Serialize};
5use serde_derive::{Deserialize, Serialize};
6use std::{array::TryFromSliceError, borrow::Cow, fmt::Debug, marker::PhantomData, ops::Deref};
7use thiserror::Error;
8use zeroize::Zeroizing;
9
10pub enum KeySecretInner {
11    KeySecretV1(chacha20poly1305::Key),
12}
13
14impl SingleTaggedData for KeySecretInner {
15    const TAG: &'static str = "keySecret";
16
17    fn as_bytes(&self) -> Cow<[u8]> {
18        match self {
19            KeySecretInner::KeySecretV1(key) => {
20                let bytes: &[u8] = key.as_ref();
21                Cow::from(bytes)
22            }
23        }
24    }
25
26    fn from_bytes(data: &[u8]) -> Result<Self, TaggedDataError>
27    where
28        Self: Sized,
29    {
30        let key_bytes: [u8; CHACHA_20_POLY1305_KEY_LEN] = data
31            .try_into()
32            .map_err(|err| TaggedDataError::data_error(Into::<KeySecretError>::into(err)))?;
33        Ok(KeySecretInner::KeySecretV1(chacha20poly1305::Key::from(
34            key_bytes,
35        )))
36    }
37}
38
39impl KeySecretInner {
40    pub fn const_time_eq(&self, other: &Self) -> bool {
41        match (self, other) {
42            (KeySecretInner::KeySecretV1(a), KeySecretInner::KeySecretV1(b)) => {
43                // TODO: vet this
44                constant_time_eq::constant_time_eq(a, b)
45            }
46        }
47    }
48}
49
50#[derive(Debug, Error)]
51pub enum KeySecretError {
52    #[error("Invalid KeySecret length")]
53    InvalidLength(#[from] TryFromSliceError),
54}
55
56impl_single_tagged_data_serde!(KeySecretInner);
57
58pub const KEY_ID_LEN: usize = 16;
59pub const CHACHA_20_POLY1305_KEY_LEN: usize = std::mem::size_of::<chacha20poly1305::Key>();
60pub const CHACHA_20_POLY1305_NONCE_LEN: usize = std::mem::size_of::<chacha20poly1305::Nonce>();
61
62#[derive(Copy, Clone, PartialEq, Eq, Hash)]
63pub struct KeyID([u8; KEY_ID_LEN]);
64
65impl SingleTaggedData for KeyID {
66    const TAG: &'static str = "keyID";
67
68    fn as_bytes(&self) -> Cow<[u8]> {
69        Cow::from(self.0.as_ref())
70    }
71
72    fn from_bytes(data: &[u8]) -> Result<Self, TaggedDataError>
73    where
74        Self: Sized,
75    {
76        let key_id_bytes: [u8; KEY_ID_LEN] = data
77            .try_into()
78            .map_err(|err| TaggedDataError::data_error(Into::<KeyIDError>::into(err)))?;
79        Ok(KeyID(key_id_bytes))
80    }
81}
82
83#[derive(Debug, Error)]
84pub enum KeyIDError {
85    #[error("Invalid KeyID length")]
86    InvalidLength(#[from] TryFromSliceError),
87}
88
89impl_single_tagged_data_serde!(KeyID);
90impl_debug_as_litl!(KeyID);
91
92#[derive(Serialize, Deserialize)]
93pub struct KeySecret {
94    pub id: KeyID,
95    secret: KeySecretInner,
96}
97
98impl KeySecret {
99    pub fn new_random() -> Self {
100        let mut key = [0u8; CHACHA_20_POLY1305_KEY_LEN];
101        OsRng {}.fill_bytes(&mut key);
102        let mut id = [0u8; KEY_ID_LEN];
103        OsRng {}.fill_bytes(&mut id);
104        KeySecret {
105            id: KeyID(id),
106            secret: KeySecretInner::KeySecretV1(*chacha20poly1305::Key::from_slice(&key)),
107        }
108    }
109
110    pub fn encrypt<T: Serialize>(&self, value: &T) -> Encrypted<T> {
111        let cipher = chacha20poly1305::ChaCha20Poly1305::new(self);
112        let mut nonce_bytes = [0; CHACHA_20_POLY1305_NONCE_LEN];
113        OsRng {}.fill_bytes(&mut nonce_bytes);
114        let nonce = chacha20poly1305::Nonce::from(nonce_bytes);
115
116        Encrypted {
117            encrypted: EncryptedInner::EncryptedV1(NonceAndCiphertextChacha20Poly1305::new(
118                &nonce,
119                &cipher
120                    .encrypt(
121                        &nonce,
122                        Zeroizing::new(litl::to_vec(value).unwrap()).as_slice(),
123                    )
124                    .unwrap(),
125            )),
126            for_key: self.id,
127            _marker: PhantomData,
128        }
129    }
130
131    pub fn decrypt<T: DeserializeOwned>(
132        &self,
133        encrypted: &Encrypted<T>,
134    ) -> Result<T, DecryptionError> {
135        // TODO: check for correct key id
136        match &encrypted.encrypted {
137            EncryptedInner::EncryptedV1(nonce_and_ciphertext) => {
138                let cipher = chacha20poly1305::ChaCha20Poly1305::new(&self);
139                let (nonce, ciphertext) = nonce_and_ciphertext.as_nonce_and_ciphertext();
140
141                let plaintext: Zeroizing<Vec<u8>> = cipher
142                    .decrypt(nonce, ciphertext)
143                    .map(Zeroizing::new)
144                    .map_err(DecryptionError::DecryptionError)?;
145                litl::from_slice_owned::<T>(plaintext.as_slice())
146                    .map_err(DecryptionError::DeserializeError)
147            }
148        }
149    }
150
151    pub fn const_time_eq(&self, other: &Self) -> bool {
152        self.secret.const_time_eq(&other.secret)
153    }
154}
155
156impl Deref for KeySecret {
157    type Target = chacha20poly1305::Key;
158
159    fn deref(&self) -> &Self::Target {
160        match self.secret {
161            KeySecretInner::KeySecretV1(ref key) => key,
162        }
163    }
164}
165
166impl Debug for KeySecret {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        f.write_str("(Secret Key)")
169    }
170}
171
172#[derive(Error, Debug)]
173pub enum DecryptionError {
174    #[error("Decryption error.")]
175    DecryptionError(chacha20poly1305::aead::Error),
176    #[error("Error converting from decrypted bytes.")]
177    DeserializeError(litl::Error),
178}
179
180#[derive(Clone, PartialEq, Eq, Hash)]
181pub struct NonceAndCiphertextChacha20Poly1305(Vec<u8>);
182
183impl NonceAndCiphertextChacha20Poly1305 {
184    pub fn new(nonce: &chacha20poly1305::Nonce, ciphertext: &[u8]) -> Self {
185        let mut bytes = Vec::with_capacity(ciphertext.len() + CHACHA_20_POLY1305_NONCE_LEN);
186        bytes.extend_from_slice(nonce);
187        bytes.extend_from_slice(ciphertext);
188        NonceAndCiphertextChacha20Poly1305(bytes)
189    }
190
191    pub fn as_nonce_and_ciphertext(&self) -> (&chacha20poly1305::Nonce, &[u8]) {
192        (
193            chacha20poly1305::Nonce::from_slice(&self.0[0..CHACHA_20_POLY1305_NONCE_LEN]),
194            &self.0[CHACHA_20_POLY1305_NONCE_LEN..],
195        )
196    }
197}
198
199#[derive(Clone, PartialEq, Eq, Hash)]
200pub enum EncryptedInner {
201    EncryptedV1(NonceAndCiphertextChacha20Poly1305),
202}
203
204impl SingleTaggedData for EncryptedInner {
205    const TAG: &'static str = "encryptedV1";
206
207    fn as_bytes(&self) -> Cow<[u8]> {
208        match self {
209            EncryptedInner::EncryptedV1(nonce_and_ciphertext) => {
210                Cow::from(nonce_and_ciphertext.0.as_slice())
211            }
212        }
213    }
214
215    fn from_bytes(data: &[u8]) -> Result<Self, TaggedDataError>
216    where
217        Self: Sized,
218    {
219        Ok(EncryptedInner::EncryptedV1(
220            NonceAndCiphertextChacha20Poly1305(data.to_vec()),
221        ))
222    }
223}
224
225impl_single_tagged_data_serde!(EncryptedInner);
226impl_debug_as_litl!(EncryptedInner);
227
228#[derive(Serialize, Deserialize)]
229pub struct Encrypted<T> {
230    for_key: KeyID,
231    encrypted: EncryptedInner,
232    #[serde(skip)]
233    _marker: PhantomData<T>,
234}
235
236impl<T> PartialEq for Encrypted<T> {
237    fn eq(&self, other: &Self) -> bool {
238        self.for_key == other.for_key && self.encrypted == other.encrypted
239    }
240}
241
242impl<T> Eq for Encrypted<T> {}
243
244impl<T> Clone for Encrypted<T> {
245    fn clone(&self) -> Self {
246        Encrypted {
247            for_key: self.for_key,
248            encrypted: self.encrypted.clone(),
249            _marker: PhantomData,
250        }
251    }
252}
253
254#[allow(clippy::derive_hash_xor_eq)]
255impl<T> std::hash::Hash for Encrypted<T> {
256    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
257        self.for_key.hash(state);
258        self.encrypted.hash(state);
259    }
260}
261
262impl<T> Encrypted<T> {
263    pub fn as_encrypted_value(self) -> Encrypted<litl::Val> {
264        Encrypted {
265            for_key: self.for_key,
266            encrypted: self.encrypted,
267            _marker: PhantomData,
268        }
269    }
270}
271
272impl Encrypted<litl::Val> {
273    pub fn as_encrypted_type<T: DeserializeOwned>(self) -> Encrypted<T> {
274        Encrypted {
275            for_key: self.for_key,
276            encrypted: self.encrypted,
277            _marker: PhantomData,
278        }
279    }
280}
281
282impl<T> Debug for Encrypted<T> {
283    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284        f.write_str(&litl::to_string_pretty(&self).unwrap())
285    }
286}
287
288#[cfg(test)]
289mod test {
290    use serde_derive::{Deserialize, Serialize};
291
292    use crate::symm_encr::KeySecret;
293
294    #[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq)]
295    struct TestData {
296        bla: [u8; 4],
297    }
298
299    #[test]
300    fn encryption_roundtrip_works() {
301        let data = TestData { bla: [0, 1, 2, 4] };
302        let key = KeySecret::new_random();
303        let encrypted = key.encrypt(&data);
304        println!("{:?}", encrypted);
305        let decrypted = key.decrypt(&encrypted).unwrap();
306        assert_eq!(decrypted, data);
307    }
308
309}