1use std::collections::HashMap;
8use std::convert::Infallible;
9use std::fmt::Debug;
10use std::marker::PhantomData;
11
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15use crate::crypto::x25519::PublicKey;
16use crate::key_bundle::{KeyBundleError, LongTermKeyBundle, OneTimeKeyBundle, latest_key_bundle};
17use crate::traits::{IdentityHandle, IdentityRegistry, KeyBundle, PreKeyRegistry};
18
19#[derive(Clone, Debug)]
21pub struct KeyRegistry<ID> {
22 _marker: PhantomData<ID>,
23}
24
25#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
27pub struct KeyRegistryState<ID>
28where
29 ID: IdentityHandle,
30{
31 identities: HashMap<ID, PublicKey>,
32 onetime_bundles: HashMap<ID, Vec<OneTimeKeyBundle>>,
33 longterm_bundles: HashMap<ID, Vec<LongTermKeyBundle>>,
34}
35
36impl<ID> KeyRegistry<ID>
37where
38 ID: IdentityHandle + Serialize + for<'a> Deserialize<'a>,
39{
40 pub fn init() -> KeyRegistryState<ID> {
42 KeyRegistryState {
43 identities: HashMap::new(),
44 onetime_bundles: HashMap::new(),
45 longterm_bundles: HashMap::new(),
46 }
47 }
48
49 pub fn remove_expired(mut y: KeyRegistryState<ID>) -> KeyRegistryState<ID> {
51 y.longterm_bundles =
52 y.longterm_bundles
53 .into_iter()
54 .fold(HashMap::new(), |mut acc, (id, bundles)| {
55 let bundles = bundles
56 .into_iter()
57 .filter(|bundle| bundle.verify().is_ok())
58 .collect::<Vec<LongTermKeyBundle>>();
59 acc.insert(id, bundles);
60 acc
61 });
62
63 y.onetime_bundles =
64 y.onetime_bundles
65 .into_iter()
66 .fold(HashMap::new(), |mut acc, (id, bundles)| {
67 let bundles = bundles
68 .into_iter()
69 .filter(|bundle| bundle.verify().is_ok())
70 .collect::<Vec<OneTimeKeyBundle>>();
71 acc.insert(id, bundles);
72 acc
73 });
74 y
75 }
76
77 pub fn add_longterm_bundle(
81 mut y: KeyRegistryState<ID>,
82 id: ID,
83 key_bundle: LongTermKeyBundle,
84 ) -> Result<KeyRegistryState<ID>, KeyRegistryError> {
85 key_bundle.verify()?;
86 let existing = y.identities.insert(id, *key_bundle.identity_key());
87 if let Some(existing) = existing {
88 assert_eq!(&existing, key_bundle.identity_key());
90 }
91 y.longterm_bundles
92 .entry(id)
93 .and_modify(|bundles| bundles.push(key_bundle.clone()))
94 .or_insert(vec![key_bundle]);
95 Ok(y)
96 }
97
98 #[cfg(test)]
99 #[allow(non_snake_case)]
100 fn add_longterm_bundle_UNVERIFIED(
101 mut y: KeyRegistryState<ID>,
102 id: ID,
103 key_bundle: LongTermKeyBundle,
104 ) -> KeyRegistryState<ID> {
105 y.longterm_bundles
106 .entry(id)
107 .and_modify(|bundles| bundles.push(key_bundle.clone()))
108 .or_insert(vec![key_bundle]);
109 y
110 }
111
112 pub fn add_onetime_bundle(
116 mut y: KeyRegistryState<ID>,
117 id: ID,
118 key_bundle: OneTimeKeyBundle,
119 ) -> Result<KeyRegistryState<ID>, KeyRegistryError> {
120 key_bundle.verify()?;
121 let existing = y.identities.insert(id, *key_bundle.identity_key());
122 if let Some(existing) = existing {
123 assert_eq!(&existing, key_bundle.identity_key());
125 }
126 y.onetime_bundles
127 .entry(id)
128 .and_modify(|bundles| bundles.push(key_bundle.clone()))
129 .or_insert(vec![key_bundle]);
130 Ok(y)
131 }
132}
133
134impl<ID> PreKeyRegistry<ID, OneTimeKeyBundle> for KeyRegistry<ID>
135where
136 ID: IdentityHandle + Serialize + for<'a> Deserialize<'a>,
137{
138 type State = KeyRegistryState<ID>;
139
140 type Error = Infallible;
141
142 fn key_bundle(
143 mut y: Self::State,
144 id: &ID,
145 ) -> Result<(Self::State, Option<OneTimeKeyBundle>), Self::Error> {
146 let bundle = y
147 .onetime_bundles
148 .get_mut(id)
149 .and_then(|bundles| bundles.pop());
150 Ok((y, bundle))
151 }
152}
153
154impl<ID> PreKeyRegistry<ID, LongTermKeyBundle> for KeyRegistry<ID>
155where
156 ID: IdentityHandle + Serialize + for<'a> Deserialize<'a>,
157{
158 type State = KeyRegistryState<ID>;
159
160 type Error = KeyRegistryError;
161
162 fn key_bundle(
163 y: Self::State,
164 id: &ID,
165 ) -> Result<(Self::State, Option<LongTermKeyBundle>), Self::Error> {
166 let Some(bundles) = y.longterm_bundles.get(id) else {
167 return Ok((y, None));
168 };
169
170 let valid_bundle = latest_key_bundle(bundles).cloned();
171
172 if !bundles.is_empty() && valid_bundle.is_none() {
174 return Err(KeyRegistryError::KeyBundlesExpired);
175 }
176
177 Ok((y, valid_bundle))
178 }
179}
180
181impl<ID> IdentityRegistry<ID, KeyRegistryState<ID>> for KeyRegistry<ID>
182where
183 ID: IdentityHandle + Serialize + for<'a> Deserialize<'a>,
184{
185 type Error = Infallible;
186
187 fn identity_key(y: &KeyRegistryState<ID>, id: &ID) -> Result<Option<PublicKey>, Self::Error> {
188 let key = y.identities.get(id).cloned();
189 Ok(key)
190 }
191}
192
193#[derive(Debug, Error)]
194pub enum KeyRegistryError {
195 #[error(transparent)]
196 KeyBundle(#[from] KeyBundleError),
197
198 #[error("all available key bundles of this member expired")]
199 KeyBundlesExpired,
200}
201
202#[cfg(test)]
203mod tests {
204 use std::time::{SystemTime, UNIX_EPOCH};
205
206 use crate::Rng;
207 use crate::crypto::x25519::SecretKey;
208 use crate::key_bundle::{Lifetime, LongTermKeyBundle, PreKey};
209 use crate::traits::PreKeyRegistry;
210
211 use super::KeyRegistry;
212
213 #[test]
214 fn latest_key_bundle() {
215 let rng = Rng::from_seed([1; 32]);
216
217 let now = SystemTime::now()
218 .duration_since(UNIX_EPOCH)
219 .expect("SystemTime before UNIX EPOCH!")
220 .as_secs();
221
222 let member_id = 0;
223 let identity_secret = SecretKey::from_bytes(rng.random_array().unwrap());
224
225 let bundle_1 = {
227 let prekey_secret = SecretKey::from_bytes(rng.random_array().unwrap());
228 let prekey = PreKey::new(
229 prekey_secret.public_key().unwrap(),
230 Lifetime::from_range(now - 60, now + 60),
231 );
232 let prekey_signature = prekey.sign(&identity_secret, &rng).unwrap();
233
234 LongTermKeyBundle::new(
235 identity_secret.public_key().unwrap(),
236 prekey,
237 prekey_signature,
238 )
239 };
240
241 let bundle_2 = {
243 let prekey_secret = SecretKey::from_bytes(rng.random_array().unwrap());
244 let prekey = PreKey::new(
245 prekey_secret.public_key().unwrap(),
246 Lifetime::from_range(now - 60, now + 30),
247 );
248 let prekey_signature = prekey.sign(&identity_secret, &rng).unwrap();
249
250 LongTermKeyBundle::new(
251 identity_secret.public_key().unwrap(),
252 prekey,
253 prekey_signature,
254 )
255 };
256
257 let pki = {
259 let y = KeyRegistry::init();
260 let y = KeyRegistry::add_longterm_bundle(y, member_id, bundle_1.clone()).unwrap();
261 let y = KeyRegistry::add_longterm_bundle(y, member_id, bundle_2).unwrap();
262 y
263 };
264
265 assert_eq!(
267 KeyRegistry::key_bundle(pki, &member_id).unwrap().1,
268 Some(bundle_1)
269 );
270 }
271
272 #[test]
273 fn invalid_bundles() {
274 let rng = Rng::from_seed([1; 32]);
275
276 let now = SystemTime::now()
277 .duration_since(UNIX_EPOCH)
278 .expect("SystemTime before UNIX EPOCH!")
279 .as_secs();
280
281 let member_id = 0;
282 let identity_secret = SecretKey::from_bytes(rng.random_array().unwrap());
283
284 let invalid_bundle = {
285 let prekey_secret = SecretKey::from_bytes(rng.random_array().unwrap());
286 let prekey = PreKey::new(
287 prekey_secret.public_key().unwrap(),
288 Lifetime::from_range(now - 60, now - 30),
289 );
290 let prekey_signature = prekey.sign(&identity_secret, &rng).unwrap();
291
292 LongTermKeyBundle::new(
293 identity_secret.public_key().unwrap(),
294 prekey,
295 prekey_signature,
296 )
297 };
298
299 let pki = KeyRegistry::init();
300
301 assert!(
303 KeyRegistry::add_longterm_bundle(pki.clone(), member_id, invalid_bundle.clone())
304 .is_err()
305 );
306
307 let pki =
308 KeyRegistry::add_longterm_bundle_UNVERIFIED(pki, member_id, invalid_bundle.clone());
309
310 assert_eq!(pki.longterm_bundles.get(&member_id).unwrap().len(), 1);
312 assert!(
313 <KeyRegistry<usize> as PreKeyRegistry<usize, LongTermKeyBundle>>::key_bundle(
314 pki.clone(),
315 &member_id
316 )
317 .is_err()
318 );
319 }
320
321 #[test]
322 fn garbage_collection() {
323 let rng = Rng::from_seed([1; 32]);
324
325 let now = SystemTime::now()
326 .duration_since(UNIX_EPOCH)
327 .expect("SystemTime before UNIX EPOCH!")
328 .as_secs();
329
330 let member_id = 0;
331 let identity_secret = SecretKey::from_bytes(rng.random_array().unwrap());
332
333 let invalid_bundle = {
334 let prekey_secret = SecretKey::from_bytes(rng.random_array().unwrap());
335 let prekey = PreKey::new(
336 prekey_secret.public_key().unwrap(),
337 Lifetime::from_range(now - 60, now - 30),
338 );
339 let prekey_signature = prekey.sign(&identity_secret, &rng).unwrap();
340
341 LongTermKeyBundle::new(
342 identity_secret.public_key().unwrap(),
343 prekey,
344 prekey_signature,
345 )
346 };
347
348 let valid_bundle = {
349 let prekey_secret = SecretKey::from_bytes(rng.random_array().unwrap());
350 let prekey = PreKey::new(
351 prekey_secret.public_key().unwrap(),
352 Lifetime::from_range(now - 60, now + 60),
353 );
354 let prekey_signature = prekey.sign(&identity_secret, &rng).unwrap();
355
356 LongTermKeyBundle::new(
357 identity_secret.public_key().unwrap(),
358 prekey,
359 prekey_signature,
360 )
361 };
362
363 let pki = {
364 let y = KeyRegistry::init();
365 let y =
366 KeyRegistry::add_longterm_bundle_UNVERIFIED(y, member_id, invalid_bundle.clone());
367 let y = KeyRegistry::add_longterm_bundle_UNVERIFIED(y, member_id, valid_bundle.clone());
368 y
369 };
370
371 assert_eq!(pki.longterm_bundles.get(&member_id).unwrap().len(), 2);
372
373 let pki = KeyRegistry::remove_expired(pki);
375 assert_eq!(pki.longterm_bundles.get(&member_id).unwrap().len(), 1);
376
377 assert_eq!(
379 KeyRegistry::key_bundle(pki, &member_id).unwrap().1,
380 Some(valid_bundle)
381 );
382 }
383}