Skip to main content

nucypher_core/
session.rs

1use alloc::boxed::Box;
2use core::fmt;
3
4use chacha20poly1305::aead::{Aead, AeadCore, KeyInit, OsRng};
5use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce};
6use generic_array::typenum::Unsigned;
7
8use crate::session::key::SessionSharedSecret;
9use crate::versioning::DeserializationError;
10
11/// Errors during encryption.
12#[derive(Debug)]
13pub enum EncryptionError {
14    /// Given plaintext is too large for the backend to handle.
15    PlaintextTooLarge,
16}
17
18impl fmt::Display for EncryptionError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        match self {
21            Self::PlaintextTooLarge => write!(f, "Plaintext is too large to encrypt"),
22        }
23    }
24}
25
26#[derive(Debug)]
27/// Errors during decryption.
28pub enum DecryptionError {
29    /// Ciphertext (which should be prepended by the nonce) is shorter than the nonce length.
30    CiphertextTooShort,
31    /// The ciphertext and the attached authentication data are inconsistent.
32    /// This can happen if:
33    /// - an incorrect key is used,
34    /// - the ciphertext is modified or cut short,
35    /// - an incorrect authentication data is provided on decryption.
36    AuthenticationFailed,
37    /// Unable to create object from decrypted ciphertext
38    DeserializationFailed(DeserializationError),
39}
40
41impl fmt::Display for DecryptionError {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            Self::CiphertextTooShort => write!(f, "The ciphertext must include the nonce"),
45            Self::AuthenticationFailed => write!(
46                f,
47                "Decryption of ciphertext failed: \
48                either someone tampered with the ciphertext or \
49                you are using an incorrect decryption key."
50            ),
51            Self::DeserializationFailed(err) => write!(f, "deserialization failed: {err}"),
52        }
53    }
54}
55
56type NonceSize = <ChaCha20Poly1305 as AeadCore>::NonceSize;
57
58pub fn encrypt_with_shared_secret(
59    shared_secret: &SessionSharedSecret,
60    plaintext: &[u8],
61) -> Result<Box<[u8]>, EncryptionError> {
62    let key = Key::from_slice(shared_secret.as_ref());
63    let cipher = ChaCha20Poly1305::new(key);
64    let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng);
65    let mut result = nonce.to_vec();
66    let ciphertext = cipher
67        .encrypt(&nonce, plaintext.as_ref())
68        .map_err(|_err| EncryptionError::PlaintextTooLarge)?;
69    result.extend(ciphertext);
70    Ok(result.into_boxed_slice())
71}
72
73pub fn decrypt_with_shared_secret(
74    shared_secret: &SessionSharedSecret,
75    ciphertext: &[u8],
76) -> Result<Box<[u8]>, DecryptionError> {
77    let nonce_size = <NonceSize as Unsigned>::to_usize();
78    let buf_size = ciphertext.len();
79    if buf_size < nonce_size {
80        return Err(DecryptionError::CiphertextTooShort);
81    }
82    let nonce = Nonce::from_slice(&ciphertext[..nonce_size]);
83    let encrypted_data = &ciphertext[nonce_size..];
84
85    let key = Key::from_slice(shared_secret.as_ref());
86    let cipher = ChaCha20Poly1305::new(key);
87    let plaintext = cipher
88        .decrypt(nonce, encrypted_data)
89        .map_err(|_err| DecryptionError::AuthenticationFailed)?;
90    Ok(plaintext.into_boxed_slice())
91}
92
93/// Module for session key objects.
94pub mod key {
95    use alloc::boxed::Box;
96    use alloc::string::String;
97    use core::fmt;
98
99    use generic_array::{
100        typenum::{Unsigned, U32},
101        GenericArray,
102    };
103    use rand::SeedableRng;
104    use rand_chacha::ChaCha20Rng;
105    use rand_core::{CryptoRng, OsRng, RngCore};
106    use serde::{Deserialize, Deserializer, Serialize, Serializer};
107    use umbral_pre::serde_bytes;
108    use x25519_dalek::{PublicKey, SharedSecret, StaticSecret};
109    use zeroize::ZeroizeOnDrop;
110
111    use crate::secret_box::{kdf, SecretBox};
112    use crate::versioning::{
113        messagepack_deserialize, messagepack_serialize, ProtocolObject, ProtocolObjectInner,
114    };
115
116    /// A Diffie-Hellman shared secret
117    #[derive(ZeroizeOnDrop)]
118    pub struct SessionSharedSecret {
119        derived_bytes: [u8; 32],
120    }
121
122    /// Implementation of Diffie-Hellman shared secret
123    impl SessionSharedSecret {
124        /// Create new shared secret from underlying library.
125        pub fn new(shared_secret: SharedSecret) -> Self {
126            let info = b"SESSION_SHARED_SECRET_DERIVATION/";
127            let derived_key = kdf::<U32>(shared_secret.as_bytes(), Some(info));
128            let derived_bytes = <[u8; 32]>::try_from(derived_key.as_secret().as_slice()).unwrap();
129            Self { derived_bytes }
130        }
131
132        /// View this shared secret as a byte array.
133        pub fn as_bytes(&self) -> &[u8; 32] {
134            &self.derived_bytes
135        }
136    }
137
138    impl AsRef<[u8]> for SessionSharedSecret {
139        /// View this shared secret as a byte array.
140        fn as_ref(&self) -> &[u8] {
141            self.as_bytes()
142        }
143    }
144
145    impl fmt::Display for SessionSharedSecret {
146        /// Format shared secret information.
147        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148            write!(f, "SessionSharedSecret...")
149        }
150    }
151
152    /// A session public key.
153    #[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
154    pub struct SessionStaticKey(PublicKey);
155
156    /// Implementation of session static key
157    impl SessionStaticKey {
158        /// Convert this public key to a byte array.
159        pub fn to_bytes(&self) -> [u8; 32] {
160            self.0.to_bytes()
161        }
162    }
163
164    impl AsRef<[u8]> for SessionStaticKey {
165        /// View this public key as a byte array.
166        fn as_ref(&self) -> &[u8] {
167            self.0.as_bytes()
168        }
169    }
170
171    impl fmt::Display for SessionStaticKey {
172        /// Format public key information.
173        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174            write!(f, "SessionStaticKey: {}", hex::encode(&self.as_ref()[..8]))
175        }
176    }
177
178    impl Serialize for SessionStaticKey {
179        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
180        where
181            S: Serializer,
182        {
183            serde_bytes::as_hex::serialize(self.0.as_bytes(), serializer)
184        }
185    }
186
187    impl serde_bytes::TryFromBytes for SessionStaticKey {
188        type Error = core::array::TryFromSliceError;
189        fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
190            let array: [u8; 32] = bytes.try_into()?;
191            Ok(SessionStaticKey(PublicKey::from(array)))
192        }
193    }
194
195    impl<'a> Deserialize<'a> for SessionStaticKey {
196        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
197        where
198            D: Deserializer<'a>,
199        {
200            serde_bytes::as_hex::deserialize(deserializer)
201        }
202    }
203
204    impl<'a> ProtocolObjectInner<'a> for SessionStaticKey {
205        fn version() -> (u16, u16) {
206            (2, 0)
207        }
208
209        fn brand() -> [u8; 4] {
210            *b"TSSk"
211        }
212
213        fn unversioned_to_bytes(&self) -> Box<[u8]> {
214            messagepack_serialize(&self)
215        }
216
217        fn unversioned_from_bytes(
218            minor_version: u16,
219            bytes: &[u8],
220        ) -> Option<Result<Self, String>> {
221            if minor_version == 0 {
222                Some(messagepack_deserialize(bytes))
223            } else {
224                None
225            }
226        }
227    }
228
229    impl<'a> ProtocolObject<'a> for SessionStaticKey {}
230
231    /// A session secret key.
232    #[derive(ZeroizeOnDrop)]
233    pub struct SessionStaticSecret(pub(crate) StaticSecret);
234
235    impl SessionStaticSecret {
236        /// Perform diffie-hellman
237        pub fn derive_shared_secret(
238            &self,
239            their_public_key: &SessionStaticKey,
240        ) -> SessionSharedSecret {
241            let shared_secret = self.0.diffie_hellman(&their_public_key.0);
242            SessionSharedSecret::new(shared_secret)
243        }
244
245        /// Create secret key from RNG.
246        pub fn random_from_rng(csprng: &mut (impl RngCore + CryptoRng)) -> Self {
247            let secret_key = StaticSecret::random_from_rng(csprng);
248            Self(secret_key)
249        }
250
251        /// Create random secret key.
252        pub fn random() -> Self {
253            Self::random_from_rng(&mut OsRng)
254        }
255
256        /// Returns a public key corresponding to this secret key.
257        pub fn public_key(&self) -> SessionStaticKey {
258            let public_key = PublicKey::from(&self.0);
259            SessionStaticKey(public_key)
260        }
261    }
262
263    impl fmt::Display for SessionStaticSecret {
264        /// Format information above secret key.
265        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266            write!(f, "SessionStaticSecret:...")
267        }
268    }
269
270    // the size of the seed material for key derivation
271    type SessionSecretFactorySeedSize = U32;
272    // the size of the derived key
273    type SessionSecretFactoryDerivedKeySize = U32;
274    type SessionSecretFactorySeed = GenericArray<u8, SessionSecretFactorySeedSize>;
275
276    /// Error thrown when invalid random seed provided for creating key factory.
277    #[derive(Debug)]
278    pub struct InvalidSessionSecretFactorySeedLength;
279
280    impl fmt::Display for InvalidSessionSecretFactorySeedLength {
281        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282            write!(f, "Invalid seed length")
283        }
284    }
285
286    /// This class handles keyring material for session keys, by allowing deterministic
287    /// derivation of `SessionStaticSecret` objects based on labels.
288    #[derive(Clone, ZeroizeOnDrop, PartialEq)]
289    pub struct SessionSecretFactory(SecretBox<SessionSecretFactorySeed>);
290
291    impl SessionSecretFactory {
292        /// Creates a session secret factory using the given RNG.
293        pub fn random_with_rng(rng: &mut (impl CryptoRng + RngCore)) -> Self {
294            let mut bytes = SecretBox::new(SessionSecretFactorySeed::default());
295            rng.fill_bytes(bytes.as_mut_secret());
296            Self(bytes)
297        }
298
299        /// Creates a session secret factory using the default RNG.
300        pub fn random() -> Self {
301            Self::random_with_rng(&mut OsRng)
302        }
303
304        /// Returns the seed size required by
305        pub fn seed_size() -> usize {
306            SessionSecretFactorySeedSize::to_usize()
307        }
308
309        /// Creates a `SessionSecretFactory` using the given random bytes.
310        ///
311        /// **Warning:** make sure the given seed has been obtained
312        /// from a cryptographically secure source of randomness!
313        pub fn from_secure_randomness(
314            seed: &[u8],
315        ) -> Result<Self, InvalidSessionSecretFactorySeedLength> {
316            if seed.len() != Self::seed_size() {
317                return Err(InvalidSessionSecretFactorySeedLength);
318            }
319            Ok(Self(SecretBox::new(*SessionSecretFactorySeed::from_slice(
320                seed,
321            ))))
322        }
323
324        /// Creates a `SessionStaticSecret` deterministically from the given label.
325        pub fn make_key(&self, label: &[u8]) -> SessionStaticSecret {
326            let prefix = b"SESSION_KEY_DERIVATION/";
327            let info = [prefix, label].concat();
328            let seed = kdf::<SessionSecretFactoryDerivedKeySize>(self.0.as_secret(), Some(&info));
329            let mut rng =
330                ChaCha20Rng::from_seed(<[u8; 32]>::try_from(seed.as_secret().as_slice()).unwrap());
331            SessionStaticSecret::random_from_rng(&mut rng)
332        }
333    }
334
335    impl fmt::Display for SessionSecretFactory {
336        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337            write!(f, "SessionSecretFactory:...")
338        }
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use generic_array::typenum::Unsigned;
345    use rand_core::RngCore;
346
347    use crate::session::key::SessionStaticSecret;
348    use crate::session::{
349        decrypt_with_shared_secret, encrypt_with_shared_secret, DecryptionError, NonceSize,
350    };
351    use crate::versioning::ProtocolObjectInner;
352    use crate::{SessionSecretFactory, SessionStaticKey};
353
354    #[test]
355    fn decryption_with_shared_secret() {
356        let service_secret = SessionStaticSecret::random();
357
358        let requester_secret = SessionStaticSecret::random();
359        let requester_public_key = requester_secret.public_key();
360
361        let service_shared_secret = service_secret.derive_shared_secret(&requester_public_key);
362
363        let ciphertext = b"1".to_vec().into_boxed_slice(); // length less than nonce size
364        let nonce_size = <NonceSize as Unsigned>::to_usize();
365        assert!(ciphertext.len() < nonce_size);
366
367        assert!(matches!(
368            decrypt_with_shared_secret(&service_shared_secret, &ciphertext).unwrap_err(),
369            DecryptionError::CiphertextTooShort
370        ));
371    }
372
373    #[test]
374    fn request_key_factory() {
375        let secret_factory = SessionSecretFactory::random();
376
377        // ensure that shared secret derived from factory can be used correctly
378        let label_1 = b"label_1".to_vec().into_boxed_slice();
379        let service_secret_key = secret_factory.make_key(label_1.as_ref());
380        let service_public_key = service_secret_key.public_key();
381
382        let label_2 = b"label_2".to_vec().into_boxed_slice();
383        let requester_secret_key = secret_factory.make_key(label_2.as_ref());
384        let requester_public_key = requester_secret_key.public_key();
385
386        let service_shared_secret = service_secret_key.derive_shared_secret(&requester_public_key);
387        let requester_shared_secret =
388            requester_secret_key.derive_shared_secret(&service_public_key);
389
390        let data_to_encrypt = b"The Tyranny of Merit".to_vec().into_boxed_slice();
391        let ciphertext =
392            encrypt_with_shared_secret(&requester_shared_secret, data_to_encrypt.as_ref()).unwrap();
393        let decrypted_data =
394            decrypt_with_shared_secret(&service_shared_secret, &ciphertext).unwrap();
395        assert_eq!(decrypted_data, data_to_encrypt);
396
397        // ensure same key can be generated by the same factory using the same seed
398        let same_requester_secret_key = secret_factory.make_key(label_2.as_ref());
399        let same_requester_public_key = same_requester_secret_key.public_key();
400        assert_eq!(requester_public_key, same_requester_public_key);
401
402        // ensure different key generated using same seed but using different factory
403        let other_secret_factory = SessionSecretFactory::random();
404        let not_same_requester_secret_key = other_secret_factory.make_key(label_2.as_ref());
405        let not_same_requester_public_key = not_same_requester_secret_key.public_key();
406        assert_ne!(requester_public_key, not_same_requester_public_key);
407
408        // ensure that two secret factories with the same seed generate the same keys
409        let mut secret_factory_seed = [0u8; 32];
410        rand::thread_rng().fill_bytes(&mut secret_factory_seed);
411        let seeded_secret_factory_1 =
412            SessionSecretFactory::from_secure_randomness(&secret_factory_seed).unwrap();
413        let seeded_secret_factory_2 =
414            SessionSecretFactory::from_secure_randomness(&secret_factory_seed).unwrap();
415
416        let key_label = b"seeded_factory_key_label".to_vec().into_boxed_slice();
417        let sk_1 = seeded_secret_factory_1.make_key(&key_label);
418        let pk_1 = sk_1.public_key();
419
420        let sk_2 = seeded_secret_factory_2.make_key(&key_label);
421        let pk_2 = sk_2.public_key();
422
423        assert_eq!(pk_1, pk_2);
424
425        // test secure randomness
426        let bytes = [0u8; 32];
427        let factory = SessionSecretFactory::from_secure_randomness(&bytes);
428        assert!(factory.is_ok());
429
430        let bytes = [0u8; 31];
431        let factory = SessionSecretFactory::from_secure_randomness(&bytes);
432        assert!(factory.is_err());
433    }
434
435    #[test]
436    fn session_static_key() {
437        let public_key_1: SessionStaticKey = SessionStaticSecret::random().public_key();
438        let public_key_2: SessionStaticKey = SessionStaticSecret::random().public_key();
439
440        let public_key_1_bytes = public_key_1.unversioned_to_bytes();
441        let public_key_2_bytes = public_key_2.unversioned_to_bytes();
442
443        // serialized public keys should always have the same length
444        assert_eq!(public_key_1_bytes.len(), public_key_2_bytes.len());
445
446        let deserialized_public_key_1 =
447            SessionStaticKey::unversioned_from_bytes(0, &public_key_1_bytes)
448                .unwrap()
449                .unwrap();
450        let deserialized_public_key_2 =
451            SessionStaticKey::unversioned_from_bytes(0, &public_key_2_bytes)
452                .unwrap()
453                .unwrap();
454
455        assert_eq!(public_key_1, deserialized_public_key_1);
456        assert_eq!(public_key_2, deserialized_public_key_2);
457    }
458}