1use alloc::boxed::Box;
2use core::fmt;
3
4use chacha20poly1305::aead::{Aead, AeadCore, KeyInit, OsRng};
5use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce};
6use generic_array::typenum::Unsigned;
7
8use crate::session::key::SessionSharedSecret;
9use crate::versioning::DeserializationError;
10
11#[derive(Debug)]
13pub enum EncryptionError {
14 PlaintextTooLarge,
16}
17
18impl fmt::Display for EncryptionError {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 match self {
21 Self::PlaintextTooLarge => write!(f, "Plaintext is too large to encrypt"),
22 }
23 }
24}
25
26#[derive(Debug)]
27pub enum DecryptionError {
29 CiphertextTooShort,
31 AuthenticationFailed,
37 DeserializationFailed(DeserializationError),
39}
40
41impl fmt::Display for DecryptionError {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 match self {
44 Self::CiphertextTooShort => write!(f, "The ciphertext must include the nonce"),
45 Self::AuthenticationFailed => write!(
46 f,
47 "Decryption of ciphertext failed: \
48 either someone tampered with the ciphertext or \
49 you are using an incorrect decryption key."
50 ),
51 Self::DeserializationFailed(err) => write!(f, "deserialization failed: {err}"),
52 }
53 }
54}
55
56type NonceSize = <ChaCha20Poly1305 as AeadCore>::NonceSize;
57
58pub fn encrypt_with_shared_secret(
59 shared_secret: &SessionSharedSecret,
60 plaintext: &[u8],
61) -> Result<Box<[u8]>, EncryptionError> {
62 let key = Key::from_slice(shared_secret.as_ref());
63 let cipher = ChaCha20Poly1305::new(key);
64 let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng);
65 let mut result = nonce.to_vec();
66 let ciphertext = cipher
67 .encrypt(&nonce, plaintext.as_ref())
68 .map_err(|_err| EncryptionError::PlaintextTooLarge)?;
69 result.extend(ciphertext);
70 Ok(result.into_boxed_slice())
71}
72
73pub fn decrypt_with_shared_secret(
74 shared_secret: &SessionSharedSecret,
75 ciphertext: &[u8],
76) -> Result<Box<[u8]>, DecryptionError> {
77 let nonce_size = <NonceSize as Unsigned>::to_usize();
78 let buf_size = ciphertext.len();
79 if buf_size < nonce_size {
80 return Err(DecryptionError::CiphertextTooShort);
81 }
82 let nonce = Nonce::from_slice(&ciphertext[..nonce_size]);
83 let encrypted_data = &ciphertext[nonce_size..];
84
85 let key = Key::from_slice(shared_secret.as_ref());
86 let cipher = ChaCha20Poly1305::new(key);
87 let plaintext = cipher
88 .decrypt(nonce, encrypted_data)
89 .map_err(|_err| DecryptionError::AuthenticationFailed)?;
90 Ok(plaintext.into_boxed_slice())
91}
92
93pub mod key {
95 use alloc::boxed::Box;
96 use alloc::string::String;
97 use core::fmt;
98
99 use generic_array::{
100 typenum::{Unsigned, U32},
101 GenericArray,
102 };
103 use rand::SeedableRng;
104 use rand_chacha::ChaCha20Rng;
105 use rand_core::{CryptoRng, OsRng, RngCore};
106 use serde::{Deserialize, Deserializer, Serialize, Serializer};
107 use umbral_pre::serde_bytes;
108 use x25519_dalek::{PublicKey, SharedSecret, StaticSecret};
109 use zeroize::ZeroizeOnDrop;
110
111 use crate::secret_box::{kdf, SecretBox};
112 use crate::versioning::{
113 messagepack_deserialize, messagepack_serialize, ProtocolObject, ProtocolObjectInner,
114 };
115
116 #[derive(ZeroizeOnDrop)]
118 pub struct SessionSharedSecret {
119 derived_bytes: [u8; 32],
120 }
121
122 impl SessionSharedSecret {
124 pub fn new(shared_secret: SharedSecret) -> Self {
126 let info = b"SESSION_SHARED_SECRET_DERIVATION/";
127 let derived_key = kdf::<U32>(shared_secret.as_bytes(), Some(info));
128 let derived_bytes = <[u8; 32]>::try_from(derived_key.as_secret().as_slice()).unwrap();
129 Self { derived_bytes }
130 }
131
132 pub fn as_bytes(&self) -> &[u8; 32] {
134 &self.derived_bytes
135 }
136 }
137
138 impl AsRef<[u8]> for SessionSharedSecret {
139 fn as_ref(&self) -> &[u8] {
141 self.as_bytes()
142 }
143 }
144
145 impl fmt::Display for SessionSharedSecret {
146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 write!(f, "SessionSharedSecret...")
149 }
150 }
151
152 #[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
154 pub struct SessionStaticKey(PublicKey);
155
156 impl SessionStaticKey {
158 pub fn to_bytes(&self) -> [u8; 32] {
160 self.0.to_bytes()
161 }
162 }
163
164 impl AsRef<[u8]> for SessionStaticKey {
165 fn as_ref(&self) -> &[u8] {
167 self.0.as_bytes()
168 }
169 }
170
171 impl fmt::Display for SessionStaticKey {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174 write!(f, "SessionStaticKey: {}", hex::encode(&self.as_ref()[..8]))
175 }
176 }
177
178 impl Serialize for SessionStaticKey {
179 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
180 where
181 S: Serializer,
182 {
183 serde_bytes::as_hex::serialize(self.0.as_bytes(), serializer)
184 }
185 }
186
187 impl serde_bytes::TryFromBytes for SessionStaticKey {
188 type Error = core::array::TryFromSliceError;
189 fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
190 let array: [u8; 32] = bytes.try_into()?;
191 Ok(SessionStaticKey(PublicKey::from(array)))
192 }
193 }
194
195 impl<'a> Deserialize<'a> for SessionStaticKey {
196 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
197 where
198 D: Deserializer<'a>,
199 {
200 serde_bytes::as_hex::deserialize(deserializer)
201 }
202 }
203
204 impl<'a> ProtocolObjectInner<'a> for SessionStaticKey {
205 fn version() -> (u16, u16) {
206 (2, 0)
207 }
208
209 fn brand() -> [u8; 4] {
210 *b"TSSk"
211 }
212
213 fn unversioned_to_bytes(&self) -> Box<[u8]> {
214 messagepack_serialize(&self)
215 }
216
217 fn unversioned_from_bytes(
218 minor_version: u16,
219 bytes: &[u8],
220 ) -> Option<Result<Self, String>> {
221 if minor_version == 0 {
222 Some(messagepack_deserialize(bytes))
223 } else {
224 None
225 }
226 }
227 }
228
229 impl<'a> ProtocolObject<'a> for SessionStaticKey {}
230
231 #[derive(ZeroizeOnDrop)]
233 pub struct SessionStaticSecret(pub(crate) StaticSecret);
234
235 impl SessionStaticSecret {
236 pub fn derive_shared_secret(
238 &self,
239 their_public_key: &SessionStaticKey,
240 ) -> SessionSharedSecret {
241 let shared_secret = self.0.diffie_hellman(&their_public_key.0);
242 SessionSharedSecret::new(shared_secret)
243 }
244
245 pub fn random_from_rng(csprng: &mut (impl RngCore + CryptoRng)) -> Self {
247 let secret_key = StaticSecret::random_from_rng(csprng);
248 Self(secret_key)
249 }
250
251 pub fn random() -> Self {
253 Self::random_from_rng(&mut OsRng)
254 }
255
256 pub fn public_key(&self) -> SessionStaticKey {
258 let public_key = PublicKey::from(&self.0);
259 SessionStaticKey(public_key)
260 }
261 }
262
263 impl fmt::Display for SessionStaticSecret {
264 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266 write!(f, "SessionStaticSecret:...")
267 }
268 }
269
270 type SessionSecretFactorySeedSize = U32;
272 type SessionSecretFactoryDerivedKeySize = U32;
274 type SessionSecretFactorySeed = GenericArray<u8, SessionSecretFactorySeedSize>;
275
276 #[derive(Debug)]
278 pub struct InvalidSessionSecretFactorySeedLength;
279
280 impl fmt::Display for InvalidSessionSecretFactorySeedLength {
281 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282 write!(f, "Invalid seed length")
283 }
284 }
285
286 #[derive(Clone, ZeroizeOnDrop, PartialEq)]
289 pub struct SessionSecretFactory(SecretBox<SessionSecretFactorySeed>);
290
291 impl SessionSecretFactory {
292 pub fn random_with_rng(rng: &mut (impl CryptoRng + RngCore)) -> Self {
294 let mut bytes = SecretBox::new(SessionSecretFactorySeed::default());
295 rng.fill_bytes(bytes.as_mut_secret());
296 Self(bytes)
297 }
298
299 pub fn random() -> Self {
301 Self::random_with_rng(&mut OsRng)
302 }
303
304 pub fn seed_size() -> usize {
306 SessionSecretFactorySeedSize::to_usize()
307 }
308
309 pub fn from_secure_randomness(
314 seed: &[u8],
315 ) -> Result<Self, InvalidSessionSecretFactorySeedLength> {
316 if seed.len() != Self::seed_size() {
317 return Err(InvalidSessionSecretFactorySeedLength);
318 }
319 Ok(Self(SecretBox::new(*SessionSecretFactorySeed::from_slice(
320 seed,
321 ))))
322 }
323
324 pub fn make_key(&self, label: &[u8]) -> SessionStaticSecret {
326 let prefix = b"SESSION_KEY_DERIVATION/";
327 let info = [prefix, label].concat();
328 let seed = kdf::<SessionSecretFactoryDerivedKeySize>(self.0.as_secret(), Some(&info));
329 let mut rng =
330 ChaCha20Rng::from_seed(<[u8; 32]>::try_from(seed.as_secret().as_slice()).unwrap());
331 SessionStaticSecret::random_from_rng(&mut rng)
332 }
333 }
334
335 impl fmt::Display for SessionSecretFactory {
336 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337 write!(f, "SessionSecretFactory:...")
338 }
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use generic_array::typenum::Unsigned;
345 use rand_core::RngCore;
346
347 use crate::session::key::SessionStaticSecret;
348 use crate::session::{
349 decrypt_with_shared_secret, encrypt_with_shared_secret, DecryptionError, NonceSize,
350 };
351 use crate::versioning::ProtocolObjectInner;
352 use crate::{SessionSecretFactory, SessionStaticKey};
353
354 #[test]
355 fn decryption_with_shared_secret() {
356 let service_secret = SessionStaticSecret::random();
357
358 let requester_secret = SessionStaticSecret::random();
359 let requester_public_key = requester_secret.public_key();
360
361 let service_shared_secret = service_secret.derive_shared_secret(&requester_public_key);
362
363 let ciphertext = b"1".to_vec().into_boxed_slice(); let nonce_size = <NonceSize as Unsigned>::to_usize();
365 assert!(ciphertext.len() < nonce_size);
366
367 assert!(matches!(
368 decrypt_with_shared_secret(&service_shared_secret, &ciphertext).unwrap_err(),
369 DecryptionError::CiphertextTooShort
370 ));
371 }
372
373 #[test]
374 fn request_key_factory() {
375 let secret_factory = SessionSecretFactory::random();
376
377 let label_1 = b"label_1".to_vec().into_boxed_slice();
379 let service_secret_key = secret_factory.make_key(label_1.as_ref());
380 let service_public_key = service_secret_key.public_key();
381
382 let label_2 = b"label_2".to_vec().into_boxed_slice();
383 let requester_secret_key = secret_factory.make_key(label_2.as_ref());
384 let requester_public_key = requester_secret_key.public_key();
385
386 let service_shared_secret = service_secret_key.derive_shared_secret(&requester_public_key);
387 let requester_shared_secret =
388 requester_secret_key.derive_shared_secret(&service_public_key);
389
390 let data_to_encrypt = b"The Tyranny of Merit".to_vec().into_boxed_slice();
391 let ciphertext =
392 encrypt_with_shared_secret(&requester_shared_secret, data_to_encrypt.as_ref()).unwrap();
393 let decrypted_data =
394 decrypt_with_shared_secret(&service_shared_secret, &ciphertext).unwrap();
395 assert_eq!(decrypted_data, data_to_encrypt);
396
397 let same_requester_secret_key = secret_factory.make_key(label_2.as_ref());
399 let same_requester_public_key = same_requester_secret_key.public_key();
400 assert_eq!(requester_public_key, same_requester_public_key);
401
402 let other_secret_factory = SessionSecretFactory::random();
404 let not_same_requester_secret_key = other_secret_factory.make_key(label_2.as_ref());
405 let not_same_requester_public_key = not_same_requester_secret_key.public_key();
406 assert_ne!(requester_public_key, not_same_requester_public_key);
407
408 let mut secret_factory_seed = [0u8; 32];
410 rand::thread_rng().fill_bytes(&mut secret_factory_seed);
411 let seeded_secret_factory_1 =
412 SessionSecretFactory::from_secure_randomness(&secret_factory_seed).unwrap();
413 let seeded_secret_factory_2 =
414 SessionSecretFactory::from_secure_randomness(&secret_factory_seed).unwrap();
415
416 let key_label = b"seeded_factory_key_label".to_vec().into_boxed_slice();
417 let sk_1 = seeded_secret_factory_1.make_key(&key_label);
418 let pk_1 = sk_1.public_key();
419
420 let sk_2 = seeded_secret_factory_2.make_key(&key_label);
421 let pk_2 = sk_2.public_key();
422
423 assert_eq!(pk_1, pk_2);
424
425 let bytes = [0u8; 32];
427 let factory = SessionSecretFactory::from_secure_randomness(&bytes);
428 assert!(factory.is_ok());
429
430 let bytes = [0u8; 31];
431 let factory = SessionSecretFactory::from_secure_randomness(&bytes);
432 assert!(factory.is_err());
433 }
434
435 #[test]
436 fn session_static_key() {
437 let public_key_1: SessionStaticKey = SessionStaticSecret::random().public_key();
438 let public_key_2: SessionStaticKey = SessionStaticSecret::random().public_key();
439
440 let public_key_1_bytes = public_key_1.unversioned_to_bytes();
441 let public_key_2_bytes = public_key_2.unversioned_to_bytes();
442
443 assert_eq!(public_key_1_bytes.len(), public_key_2_bytes.len());
445
446 let deserialized_public_key_1 =
447 SessionStaticKey::unversioned_from_bytes(0, &public_key_1_bytes)
448 .unwrap()
449 .unwrap();
450 let deserialized_public_key_2 =
451 SessionStaticKey::unversioned_from_bytes(0, &public_key_2_bytes)
452 .unwrap()
453 .unwrap();
454
455 assert_eq!(public_key_1, deserialized_public_key_1);
456 assert_eq!(public_key_2, deserialized_public_key_2);
457 }
458}