p2panda_encryption/
key_registry.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Manager for public key material of other members.
4//!
5//! Peers should actively look for fresh key bundles in the network, check for invalid or expired
6//! ones and automatically choose the latest for groups.
7use std::collections::HashMap;
8use std::convert::Infallible;
9use std::fmt::Debug;
10use std::marker::PhantomData;
11
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15use crate::crypto::x25519::PublicKey;
16use crate::key_bundle::{KeyBundleError, LongTermKeyBundle, OneTimeKeyBundle, latest_key_bundle};
17use crate::traits::{IdentityHandle, IdentityRegistry, KeyBundle, PreKeyRegistry};
18
19/// Key registry to maintain public key material of other members we've collected.
20#[derive(Clone, Debug)]
21pub struct KeyRegistry<ID> {
22    _marker: PhantomData<ID>,
23}
24
25/// Serializable state of key registry (for persistence).
26#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
27pub struct KeyRegistryState<ID>
28where
29    ID: IdentityHandle,
30{
31    identities: HashMap<ID, PublicKey>,
32    onetime_bundles: HashMap<ID, Vec<OneTimeKeyBundle>>,
33    longterm_bundles: HashMap<ID, Vec<LongTermKeyBundle>>,
34}
35
36impl<ID> KeyRegistry<ID>
37where
38    ID: IdentityHandle + Serialize + for<'a> Deserialize<'a>,
39{
40    /// Returns newly initialised key-registry state.
41    pub fn init() -> KeyRegistryState<ID> {
42        KeyRegistryState {
43            identities: HashMap::new(),
44            onetime_bundles: HashMap::new(),
45            longterm_bundles: HashMap::new(),
46        }
47    }
48
49    /// Remove all expired key bundles from registry.
50    pub fn remove_expired(mut y: KeyRegistryState<ID>) -> KeyRegistryState<ID> {
51        y.longterm_bundles =
52            y.longterm_bundles
53                .into_iter()
54                .fold(HashMap::new(), |mut acc, (id, bundles)| {
55                    let bundles = bundles
56                        .into_iter()
57                        .filter(|bundle| bundle.verify().is_ok())
58                        .collect::<Vec<LongTermKeyBundle>>();
59                    acc.insert(id, bundles);
60                    acc
61                });
62
63        y.onetime_bundles =
64            y.onetime_bundles
65                .into_iter()
66                .fold(HashMap::new(), |mut acc, (id, bundles)| {
67                    let bundles = bundles
68                        .into_iter()
69                        .filter(|bundle| bundle.verify().is_ok())
70                        .collect::<Vec<OneTimeKeyBundle>>();
71                    acc.insert(id, bundles);
72                    acc
73                });
74        y
75    }
76
77    /// Adds long-term pre-key bundle to the registry.
78    ///
79    /// This throws an error if an expired or invalid bundle was added.
80    pub fn add_longterm_bundle(
81        mut y: KeyRegistryState<ID>,
82        id: ID,
83        key_bundle: LongTermKeyBundle,
84    ) -> Result<KeyRegistryState<ID>, KeyRegistryError> {
85        key_bundle.verify()?;
86        let existing = y.identities.insert(id, *key_bundle.identity_key());
87        if let Some(existing) = existing {
88            // Sanity check.
89            assert_eq!(&existing, key_bundle.identity_key());
90        }
91        y.longterm_bundles
92            .entry(id)
93            .and_modify(|bundles| bundles.push(key_bundle.clone()))
94            .or_insert(vec![key_bundle]);
95        Ok(y)
96    }
97
98    #[cfg(test)]
99    #[allow(non_snake_case)]
100    fn add_longterm_bundle_UNVERIFIED(
101        mut y: KeyRegistryState<ID>,
102        id: ID,
103        key_bundle: LongTermKeyBundle,
104    ) -> KeyRegistryState<ID> {
105        y.longterm_bundles
106            .entry(id)
107            .and_modify(|bundles| bundles.push(key_bundle.clone()))
108            .or_insert(vec![key_bundle]);
109        y
110    }
111
112    /// Adds one-time pre-key bundle to the registry.
113    ///
114    /// This throws an error if an expired or invalid bundle was added.
115    pub fn add_onetime_bundle(
116        mut y: KeyRegistryState<ID>,
117        id: ID,
118        key_bundle: OneTimeKeyBundle,
119    ) -> Result<KeyRegistryState<ID>, KeyRegistryError> {
120        key_bundle.verify()?;
121        let existing = y.identities.insert(id, *key_bundle.identity_key());
122        if let Some(existing) = existing {
123            // Sanity check.
124            assert_eq!(&existing, key_bundle.identity_key());
125        }
126        y.onetime_bundles
127            .entry(id)
128            .and_modify(|bundles| bundles.push(key_bundle.clone()))
129            .or_insert(vec![key_bundle]);
130        Ok(y)
131    }
132}
133
134impl<ID> PreKeyRegistry<ID, OneTimeKeyBundle> for KeyRegistry<ID>
135where
136    ID: IdentityHandle + Serialize + for<'a> Deserialize<'a>,
137{
138    type State = KeyRegistryState<ID>;
139
140    type Error = Infallible;
141
142    fn key_bundle(
143        mut y: Self::State,
144        id: &ID,
145    ) -> Result<(Self::State, Option<OneTimeKeyBundle>), Self::Error> {
146        let bundle = y
147            .onetime_bundles
148            .get_mut(id)
149            .and_then(|bundles| bundles.pop());
150        Ok((y, bundle))
151    }
152}
153
154impl<ID> PreKeyRegistry<ID, LongTermKeyBundle> for KeyRegistry<ID>
155where
156    ID: IdentityHandle + Serialize + for<'a> Deserialize<'a>,
157{
158    type State = KeyRegistryState<ID>;
159
160    type Error = KeyRegistryError;
161
162    fn key_bundle(
163        y: Self::State,
164        id: &ID,
165    ) -> Result<(Self::State, Option<LongTermKeyBundle>), Self::Error> {
166        let Some(bundles) = y.longterm_bundles.get(id) else {
167            return Ok((y, None));
168        };
169
170        let valid_bundle = latest_key_bundle(bundles).cloned();
171
172        // Even though key bundles are available we couldn't find any non-expired ones.
173        if !bundles.is_empty() && valid_bundle.is_none() {
174            return Err(KeyRegistryError::KeyBundlesExpired);
175        }
176
177        Ok((y, valid_bundle))
178    }
179}
180
181impl<ID> IdentityRegistry<ID, KeyRegistryState<ID>> for KeyRegistry<ID>
182where
183    ID: IdentityHandle + Serialize + for<'a> Deserialize<'a>,
184{
185    type Error = Infallible;
186
187    fn identity_key(y: &KeyRegistryState<ID>, id: &ID) -> Result<Option<PublicKey>, Self::Error> {
188        let key = y.identities.get(id).cloned();
189        Ok(key)
190    }
191}
192
193#[derive(Debug, Error)]
194pub enum KeyRegistryError {
195    #[error(transparent)]
196    KeyBundle(#[from] KeyBundleError),
197
198    #[error("all available key bundles of this member expired")]
199    KeyBundlesExpired,
200}
201
202#[cfg(test)]
203mod tests {
204    use std::time::{SystemTime, UNIX_EPOCH};
205
206    use crate::Rng;
207    use crate::crypto::x25519::SecretKey;
208    use crate::key_bundle::{Lifetime, LongTermKeyBundle, PreKey};
209    use crate::traits::PreKeyRegistry;
210
211    use super::KeyRegistry;
212
213    #[test]
214    fn latest_key_bundle() {
215        let rng = Rng::from_seed([1; 32]);
216
217        let now = SystemTime::now()
218            .duration_since(UNIX_EPOCH)
219            .expect("SystemTime before UNIX EPOCH!")
220            .as_secs();
221
222        let member_id = 0;
223        let identity_secret = SecretKey::from_bytes(rng.random_array().unwrap());
224
225        // Generate first bundle.
226        let bundle_1 = {
227            let prekey_secret = SecretKey::from_bytes(rng.random_array().unwrap());
228            let prekey = PreKey::new(
229                prekey_secret.public_key().unwrap(),
230                Lifetime::from_range(now - 60, now + 60),
231            );
232            let prekey_signature = prekey.sign(&identity_secret, &rng).unwrap();
233
234            LongTermKeyBundle::new(
235                identity_secret.public_key().unwrap(),
236                prekey,
237                prekey_signature,
238            )
239        };
240
241        // Generate second bundle (which expires earlier).
242        let bundle_2 = {
243            let prekey_secret = SecretKey::from_bytes(rng.random_array().unwrap());
244            let prekey = PreKey::new(
245                prekey_secret.public_key().unwrap(),
246                Lifetime::from_range(now - 60, now + 30),
247            );
248            let prekey_signature = prekey.sign(&identity_secret, &rng).unwrap();
249
250            LongTermKeyBundle::new(
251                identity_secret.public_key().unwrap(),
252                prekey,
253                prekey_signature,
254            )
255        };
256
257        // Initialize key registry and register both bundles there.
258        let pki = {
259            let y = KeyRegistry::init();
260            let y = KeyRegistry::add_longterm_bundle(y, member_id, bundle_1.clone()).unwrap();
261            let y = KeyRegistry::add_longterm_bundle(y, member_id, bundle_2).unwrap();
262            y
263        };
264
265        // Registry returns bundle which has the "furthest" expiry date.
266        assert_eq!(
267            KeyRegistry::key_bundle(pki, &member_id).unwrap().1,
268            Some(bundle_1)
269        );
270    }
271
272    #[test]
273    fn invalid_bundles() {
274        let rng = Rng::from_seed([1; 32]);
275
276        let now = SystemTime::now()
277            .duration_since(UNIX_EPOCH)
278            .expect("SystemTime before UNIX EPOCH!")
279            .as_secs();
280
281        let member_id = 0;
282        let identity_secret = SecretKey::from_bytes(rng.random_array().unwrap());
283
284        let invalid_bundle = {
285            let prekey_secret = SecretKey::from_bytes(rng.random_array().unwrap());
286            let prekey = PreKey::new(
287                prekey_secret.public_key().unwrap(),
288                Lifetime::from_range(now - 60, now - 30),
289            );
290            let prekey_signature = prekey.sign(&identity_secret, &rng).unwrap();
291
292            LongTermKeyBundle::new(
293                identity_secret.public_key().unwrap(),
294                prekey,
295                prekey_signature,
296            )
297        };
298
299        let pki = KeyRegistry::init();
300
301        // Registry should throw an error when trying to add an expired bundle.
302        assert!(
303            KeyRegistry::add_longterm_bundle(pki.clone(), member_id, invalid_bundle.clone())
304                .is_err()
305        );
306
307        let pki =
308            KeyRegistry::add_longterm_bundle_UNVERIFIED(pki, member_id, invalid_bundle.clone());
309
310        // Registry should throw an error when we only have expired bundles of that member.
311        assert_eq!(pki.longterm_bundles.get(&member_id).unwrap().len(), 1);
312        assert!(
313            <KeyRegistry<usize> as PreKeyRegistry<usize, LongTermKeyBundle>>::key_bundle(
314                pki.clone(),
315                &member_id
316            )
317            .is_err()
318        );
319    }
320
321    #[test]
322    fn garbage_collection() {
323        let rng = Rng::from_seed([1; 32]);
324
325        let now = SystemTime::now()
326            .duration_since(UNIX_EPOCH)
327            .expect("SystemTime before UNIX EPOCH!")
328            .as_secs();
329
330        let member_id = 0;
331        let identity_secret = SecretKey::from_bytes(rng.random_array().unwrap());
332
333        let invalid_bundle = {
334            let prekey_secret = SecretKey::from_bytes(rng.random_array().unwrap());
335            let prekey = PreKey::new(
336                prekey_secret.public_key().unwrap(),
337                Lifetime::from_range(now - 60, now - 30),
338            );
339            let prekey_signature = prekey.sign(&identity_secret, &rng).unwrap();
340
341            LongTermKeyBundle::new(
342                identity_secret.public_key().unwrap(),
343                prekey,
344                prekey_signature,
345            )
346        };
347
348        let valid_bundle = {
349            let prekey_secret = SecretKey::from_bytes(rng.random_array().unwrap());
350            let prekey = PreKey::new(
351                prekey_secret.public_key().unwrap(),
352                Lifetime::from_range(now - 60, now + 60),
353            );
354            let prekey_signature = prekey.sign(&identity_secret, &rng).unwrap();
355
356            LongTermKeyBundle::new(
357                identity_secret.public_key().unwrap(),
358                prekey,
359                prekey_signature,
360            )
361        };
362
363        let pki = {
364            let y = KeyRegistry::init();
365            let y =
366                KeyRegistry::add_longterm_bundle_UNVERIFIED(y, member_id, invalid_bundle.clone());
367            let y = KeyRegistry::add_longterm_bundle_UNVERIFIED(y, member_id, valid_bundle.clone());
368            y
369        };
370
371        assert_eq!(pki.longterm_bundles.get(&member_id).unwrap().len(), 2);
372
373        // Remove invalid bundles.
374        let pki = KeyRegistry::remove_expired(pki);
375        assert_eq!(pki.longterm_bundles.get(&member_id).unwrap().len(), 1);
376
377        // Registry returns correct and valid bundle.
378        assert_eq!(
379            KeyRegistry::key_bundle(pki, &member_id).unwrap().1,
380            Some(valid_bundle)
381        );
382    }
383}