1use std::collections::HashMap;
9
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12
13use crate::crypto::x25519::{PublicKey, SecretKey, X25519Error};
14use crate::crypto::xeddsa::{XEdDSAError, XSignature};
15use crate::crypto::{Rng, RngError};
16use crate::key_bundle::{
17 Lifetime, LongTermKeyBundle, OneTimeKeyBundle, OneTimePreKey, OneTimePreKeyId, PreKey,
18 PreKeyId, latest_prekey,
19};
20use crate::traits::{IdentityManager, PreKeyManager};
21
22#[derive(Clone, Debug)]
25pub struct KeyManager;
26
27#[derive(Clone, Debug, Serialize, Deserialize)]
29pub struct KeyManagerState {
30 identity_secret: SecretKey,
31 identity_key: PublicKey,
32 prekeys: PreKeyBundlesState,
33 onetime_secrets: HashMap<OneTimePreKeyId, (PreKeyId, SecretKey)>,
36 onetime_next_id: OneTimePreKeyId,
37}
38
39impl KeyManagerState {
40 pub fn prekey_bundles(&self) -> &PreKeyBundlesState {
41 &self.prekeys
42 }
43}
44
45#[derive(Clone, Debug, Default, Serialize, Deserialize)]
51pub struct PreKeyBundlesState(HashMap<PreKeyId, PreKeyBundle>);
52
53impl PreKeyBundlesState {
54 fn new() -> Self {
56 Self::default()
57 }
58
59 fn latest(&self) -> Option<PreKeyBundle> {
61 let prekeys = self.0.values().map(|state| &state.prekey).collect();
62 let latest = latest_prekey(prekeys);
63 latest.map(|prekey| {
64 self.0
65 .get(prekey.key())
66 .expect("we know the item exists in the set")
67 .clone()
68 })
69 }
70
71 fn contains(&self, id: &PreKeyId) -> bool {
72 self.0.contains_key(id)
73 }
74
75 #[allow(unused)]
76 fn len(&self) -> usize {
77 self.0.len()
78 }
79
80 fn get(&self, id: &PreKeyId) -> Option<&PreKeyBundle> {
81 self.0.get(id)
82 }
83
84 fn insert(mut self, bundle: PreKeyBundle) -> Self {
86 self.0.insert(bundle.id(), bundle);
87 self
88 }
89
90 #[allow(clippy::manual_retain)]
92 fn remove_expired(self) -> Self {
93 Self(
94 self.0
95 .into_iter()
96 .filter(|(_, prekey)| prekey.prekey.verify_lifetime().is_ok())
97 .collect(),
98 )
99 }
100}
101
102#[derive(Clone, Debug, Serialize, Deserialize)]
105pub struct PreKeyBundle {
106 prekey: PreKey,
107 signature: XSignature,
108 secret: SecretKey,
109}
110
111impl PreKeyBundle {
112 pub fn new(
115 identity_secret: &SecretKey,
116 lifetime: Lifetime,
117 rng: &Rng,
118 ) -> Result<Self, KeyManagerError> {
119 let secret = SecretKey::from_bytes(rng.random_array()?);
120 let prekey = PreKey::new(secret.public_key()?, lifetime);
121 let signature = prekey.sign(identity_secret, rng)?;
122
123 Ok(Self {
124 prekey,
125 signature,
126 secret,
127 })
128 }
129
130 pub fn id(&self) -> PreKeyId {
131 *self.prekey.key()
132 }
133
134 pub fn lifetime(&self) -> &Lifetime {
135 self.prekey.lifetime()
136 }
137}
138
139impl KeyManager {
140 pub fn init(identity_secret: &SecretKey) -> Result<KeyManagerState, KeyManagerError> {
142 Ok(KeyManagerState {
143 identity_key: identity_secret.public_key()?,
144 identity_secret: identity_secret.clone(),
145 prekeys: PreKeyBundlesState::new(),
146 onetime_secrets: HashMap::new(),
147 onetime_next_id: 0,
148 })
149 }
150
151 pub fn init_from_prekey_bundles(
154 identity_secret: &SecretKey,
155 prekeys: PreKeyBundlesState,
156 ) -> Result<KeyManagerState, KeyManagerError> {
157 Ok(KeyManagerState {
158 identity_key: identity_secret.public_key()?,
159 identity_secret: identity_secret.clone(),
160 prekeys,
161 onetime_secrets: HashMap::new(),
162 onetime_next_id: 0,
163 })
164 }
165
166 #[cfg(any(test, feature = "test_utils"))]
169 pub fn init_and_generate_prekey(
170 identity_secret: &SecretKey,
171 lifetime: Lifetime,
172 rng: &Rng,
173 ) -> Result<KeyManagerState, KeyManagerError> {
174 let bundle = PreKeyBundle::new(identity_secret, lifetime, rng)?;
175 let prekeys = PreKeyBundlesState::new().insert(bundle);
176
177 Ok(KeyManagerState {
178 identity_key: identity_secret.public_key()?,
179 identity_secret: identity_secret.clone(),
180 prekeys,
181 onetime_secrets: HashMap::new(),
182 onetime_next_id: 0,
183 })
184 }
185
186 #[allow(clippy::manual_retain)]
188 pub fn remove_expired(mut y: KeyManagerState) -> KeyManagerState {
189 y.prekeys = y.prekeys.remove_expired();
191
192 y.onetime_secrets = y
194 .onetime_secrets
195 .into_iter()
196 .filter(|(_, (prekey_id, _))| y.prekeys.contains(prekey_id))
197 .collect();
198
199 y
200 }
201}
202
203impl IdentityManager<KeyManagerState> for KeyManager {
204 fn identity_secret(y: &KeyManagerState) -> &SecretKey {
206 &y.identity_secret
207 }
208}
209
210impl PreKeyManager for KeyManager {
211 type State = KeyManagerState;
212
213 type Error = KeyManagerError;
214
215 fn prekey_secret<'a>(
219 y: &'a Self::State,
220 id: &'a PreKeyId,
221 ) -> Result<&'a SecretKey, Self::Error> {
222 match y.prekeys.get(id) {
223 Some(prekey) => Ok(&prekey.secret),
224 None => Err(KeyManagerError::UnknownPreKeySecret(*id)),
225 }
226 }
227
228 fn rotate_prekey(
230 mut y: Self::State,
231 lifetime: Lifetime,
232 rng: &Rng,
233 ) -> Result<Self::State, Self::Error> {
234 let prekey = PreKeyBundle::new(&y.identity_secret, lifetime, rng)?;
235 y.prekeys = y.prekeys.insert(prekey);
236 Ok(y)
237 }
238
239 fn prekey_bundle(y: &Self::State) -> Result<LongTermKeyBundle, Self::Error> {
244 y.prekeys
245 .latest()
246 .map(|latest| LongTermKeyBundle::new(y.identity_key, latest.prekey, latest.signature))
247 .ok_or(KeyManagerError::NoPreKeysAvailable)
248 }
249
250 fn generate_onetime_bundle(
252 mut y: Self::State,
253 rng: &Rng,
254 ) -> Result<(Self::State, OneTimeKeyBundle), Self::Error> {
255 let latest = y
256 .prekeys
257 .latest()
258 .ok_or(KeyManagerError::NoPreKeysAvailable)?;
259
260 let onetime_secret = SecretKey::from_bytes(rng.random_array()?);
261 let onetime_key = OneTimePreKey::new(onetime_secret.public_key()?, y.onetime_next_id);
262
263 {
264 let existing_key = y
265 .onetime_secrets
266 .insert(onetime_key.id(), (latest.id(), onetime_secret));
267 assert!(
269 existing_key.is_none(),
270 "should never insert same id more than once"
271 );
272 };
273
274 let bundle = OneTimeKeyBundle::new(
275 y.identity_key,
276 latest.prekey,
277 latest.signature,
278 Some(onetime_key),
279 );
280
281 y.onetime_next_id += 1;
282
283 Ok((y, bundle))
284 }
285
286 fn use_onetime_secret(
294 mut y: Self::State,
295 id: OneTimePreKeyId,
296 ) -> Result<(Self::State, Option<SecretKey>), Self::Error> {
297 match y.onetime_secrets.remove(&id) {
298 Some(secret) => Ok((y, Some(secret.1))),
299 None => Err(KeyManagerError::UnknownOneTimeSecret(id)),
300 }
301 }
302}
303
304#[derive(Debug, Error)]
305pub enum KeyManagerError {
306 #[error(transparent)]
307 Rng(#[from] RngError),
308
309 #[error(transparent)]
310 XEdDSA(#[from] XEdDSAError),
311
312 #[error(transparent)]
313 X25519(#[from] X25519Error),
314
315 #[error("could not find one-time pre-key secret with id {0}")]
316 UnknownOneTimeSecret(OneTimePreKeyId),
317
318 #[error("could not find pre-key secret with id {0}")]
319 UnknownPreKeySecret(PreKeyId),
320
321 #[error("no valid pre-keys available, they are either expired or too early")]
322 NoPreKeysAvailable,
323}
324
325#[cfg(test)]
326mod tests {
327 use std::time::{SystemTime, UNIX_EPOCH};
328
329 use crate::crypto::Rng;
330 use crate::crypto::x25519::SecretKey;
331 use crate::key_bundle::Lifetime;
332 use crate::key_manager::KeyManagerError;
333 use crate::traits::KeyBundle;
334
335 use super::{KeyManager, PreKeyManager};
336
337 #[test]
338 fn generate_onetime_keys() {
339 let rng = Rng::from_seed([1; 32]);
340
341 let identity_secret = SecretKey::from_bytes(rng.random_array().unwrap());
342 let state =
343 KeyManager::init_and_generate_prekey(&identity_secret, Lifetime::default(), &rng)
344 .unwrap();
345
346 let (state, bundle_1) = KeyManager::generate_onetime_bundle(state, &rng).unwrap();
347 let (state, bundle_2) = KeyManager::generate_onetime_bundle(state, &rng).unwrap();
348
349 assert_eq!(
351 bundle_1.signed_prekey(),
352 &KeyManager::prekey_secret(&state, bundle_1.signed_prekey())
353 .expect("non-expired prekey exists")
354 .public_key()
355 .unwrap()
356 );
357 assert_eq!(bundle_1.signed_prekey(), bundle_2.signed_prekey());
358
359 assert_eq!(
361 bundle_1.identity_key(),
362 &identity_secret.public_key().unwrap()
363 );
364 assert_eq!(
365 bundle_2.identity_key(),
366 &identity_secret.public_key().unwrap()
367 );
368
369 assert!(bundle_1.verify().is_ok());
371 assert!(bundle_2.verify().is_ok());
372
373 let (state, onetime_secret_1) =
374 KeyManager::use_onetime_secret(state, bundle_1.onetime_prekey_id().unwrap()).unwrap();
375 let (state, onetime_secret_2) =
376 KeyManager::use_onetime_secret(state, bundle_2.onetime_prekey_id().unwrap()).unwrap();
377
378 assert_eq!(state.onetime_secrets.len(), 0);
380
381 assert!(KeyManager::use_onetime_secret(state.clone(), 42).is_err());
383
384 assert!(
386 KeyManager::use_onetime_secret(state.clone(), bundle_1.onetime_prekey_id().unwrap())
387 .is_err()
388 );
389 assert!(
390 KeyManager::use_onetime_secret(state.clone(), bundle_2.onetime_prekey_id().unwrap())
391 .is_err()
392 );
393
394 assert_eq!(
396 bundle_1.onetime_prekey().unwrap(),
397 &onetime_secret_1.unwrap().public_key().unwrap()
398 );
399 assert_eq!(
400 bundle_2.onetime_prekey().unwrap(),
401 &onetime_secret_2.unwrap().public_key().unwrap()
402 );
403
404 assert_ne!(bundle_1.onetime_prekey(), bundle_2.onetime_prekey());
406 assert_ne!(bundle_1.onetime_prekey_id(), bundle_2.onetime_prekey_id());
407 }
408
409 #[test]
410 fn expired_prekey_bundles() {
411 let rng = Rng::from_seed([1; 32]);
412 let now = SystemTime::now()
413 .duration_since(UNIX_EPOCH)
414 .expect("SystemTime before UNIX EPOCH!")
415 .as_secs();
416
417 let identity_secret = SecretKey::from_bytes(rng.random_array().unwrap());
418
419 let y = KeyManager::init_and_generate_prekey(
420 &identity_secret,
421 Lifetime::from_range(now - 120, now - 60), &rng,
423 )
424 .unwrap();
425
426 assert!(matches!(
428 KeyManager::prekey_bundle(&y),
429 Err(KeyManagerError::NoPreKeysAvailable)
430 ));
431
432 assert!(matches!(
434 KeyManager::generate_onetime_bundle(y.clone(), &rng),
435 Err(KeyManagerError::NoPreKeysAvailable)
436 ));
437
438 let y_i = KeyManager::rotate_prekey(y, Lifetime::default(), &rng).unwrap();
440 assert!(KeyManager::prekey_bundle(&y_i).is_ok());
441 }
442
443 #[test]
444 fn garbage_collection() {
445 let rng = Rng::from_seed([1; 32]);
446 let now = SystemTime::now()
447 .duration_since(UNIX_EPOCH)
448 .expect("SystemTime before UNIX EPOCH!")
449 .as_secs();
450
451 let identity_secret = SecretKey::from_bytes(rng.random_array().unwrap());
452
453 let y = KeyManager::init_and_generate_prekey(
455 &identity_secret,
456 Lifetime::from_range(now - 120, now - 60), &rng,
458 )
459 .unwrap();
460 assert_eq!(y.prekeys.len(), 1);
461
462 let y = KeyManager::rotate_prekey(y, Lifetime::default(), &rng).unwrap();
464 assert_eq!(y.prekeys.len(), 2);
465
466 let y = KeyManager::remove_expired(y);
468 assert_eq!(y.prekeys.len(), 1);
469 }
470}