mls_rs/group/
state_repo.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use crate::client::MlsError;
6use crate::{group::PriorEpoch, key_package::KeyPackageRef};
7
8use alloc::collections::VecDeque;
9use alloc::vec::Vec;
10use core::fmt::{self, Debug};
11use mls_rs_codec::{MlsDecode, MlsEncode};
12use mls_rs_core::group::{EpochRecord, GroupState};
13use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage, key_package::KeyPackageStorage};
14
15use super::snapshot::Snapshot;
16
17#[cfg(feature = "psk")]
18use crate::group::ResumptionPsk;
19
20#[cfg(feature = "psk")]
21use mls_rs_core::psk::PreSharedKey;
22
23/// A set of changes to apply to a GroupStateStorage implementation. These changes MUST
24/// be made in a single transaction to avoid creating invalid states.
25#[derive(Default, Clone, Debug)]
26struct EpochStorageCommit {
27    pub(crate) inserts: VecDeque<PriorEpoch>,
28    pub(crate) updates: Vec<PriorEpoch>,
29}
30
31#[derive(Clone)]
32pub(crate) struct GroupStateRepository<S, K>
33where
34    S: GroupStateStorage,
35    K: KeyPackageStorage,
36{
37    pending_commit: EpochStorageCommit,
38    pending_key_package_removal: Option<KeyPackageRef>,
39    group_id: Vec<u8>,
40    storage: S,
41    key_package_repo: K,
42}
43
44impl<S, K> Debug for GroupStateRepository<S, K>
45where
46    S: GroupStateStorage + Debug,
47    K: KeyPackageStorage + Debug,
48{
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("GroupStateRepository")
51            .field("pending_commit", &self.pending_commit)
52            .field(
53                "pending_key_package_removal",
54                &self.pending_key_package_removal,
55            )
56            .field(
57                "group_id",
58                &mls_rs_core::debug::pretty_group_id(&self.group_id),
59            )
60            .field("storage", &self.storage)
61            .field("key_package_repo", &self.key_package_repo)
62            .finish()
63    }
64}
65
66impl<S, K> GroupStateRepository<S, K>
67where
68    S: GroupStateStorage,
69    K: KeyPackageStorage,
70{
71    pub fn new(
72        group_id: Vec<u8>,
73        storage: S,
74        key_package_repo: K,
75        // Set to `None` if restoring from snapshot; set to `Some` when joining a group.
76        key_package_to_remove: Option<KeyPackageRef>,
77    ) -> Result<GroupStateRepository<S, K>, MlsError> {
78        Ok(GroupStateRepository {
79            group_id,
80            storage,
81            pending_key_package_removal: key_package_to_remove,
82            pending_commit: Default::default(),
83            key_package_repo,
84        })
85    }
86
87    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
88    async fn find_max_id(&self) -> Result<Option<u64>, MlsError> {
89        if let Some(max) = self.pending_commit.inserts.back().map(|e| e.epoch_id()) {
90            Ok(Some(max))
91        } else {
92            self.storage
93                .max_epoch_id(&self.group_id)
94                .await
95                .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))
96        }
97    }
98
99    #[cfg(feature = "psk")]
100    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
101    pub async fn resumption_secret(
102        &self,
103        psk_id: &ResumptionPsk,
104    ) -> Result<Option<PreSharedKey>, MlsError> {
105        // Search the local inserts cache
106        if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
107            if psk_id.psk_epoch >= min {
108                return Ok(self
109                    .pending_commit
110                    .inserts
111                    .get((psk_id.psk_epoch - min) as usize)
112                    .map(|e| e.secrets.resumption_secret.clone()));
113            }
114        }
115
116        // Search the local updates cache
117        let maybe_pending = self.find_pending(psk_id.psk_epoch);
118
119        if let Some(pending) = maybe_pending {
120            return Ok(Some(
121                self.pending_commit.updates[pending]
122                    .secrets
123                    .resumption_secret
124                    .clone(),
125            ));
126        }
127
128        // Search the stored cache
129        self.storage
130            .epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch)
131            .await
132            .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
133            .map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret))
134            .transpose()
135    }
136
137    #[cfg(feature = "private_message")]
138    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
139    pub async fn get_epoch_mut(
140        &mut self,
141        epoch_id: u64,
142    ) -> Result<Option<&mut PriorEpoch>, MlsError> {
143        // Search the local inserts cache
144        if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
145            if epoch_id >= min {
146                return Ok(self
147                    .pending_commit
148                    .inserts
149                    .get_mut((epoch_id - min) as usize));
150            }
151        }
152
153        // Look in the cached updates map, and if not found look in disk storage
154        // and insert into the updates map for future caching
155        match self.find_pending(epoch_id) {
156            Some(i) => self.pending_commit.updates.get_mut(i).map(Ok),
157            None => self
158                .storage
159                .epoch(&self.group_id, epoch_id)
160                .await
161                .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
162                .and_then(|epoch| {
163                    PriorEpoch::mls_decode(&mut &*epoch)
164                        .map(|epoch| {
165                            self.pending_commit.updates.push(epoch);
166                            self.pending_commit.updates.last_mut()
167                        })
168                        .transpose()
169                }),
170        }
171        .transpose()
172        .map_err(Into::into)
173    }
174
175    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
176    pub async fn insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError> {
177        if epoch.group_id() != self.group_id {
178            return Err(MlsError::GroupIdMismatch);
179        }
180
181        let epoch_id = epoch.epoch_id();
182
183        if let Some(expected_id) = self.find_max_id().await?.map(|id| id + 1) {
184            if epoch_id != expected_id {
185                return Err(MlsError::InvalidEpoch);
186            }
187        }
188
189        self.pending_commit.inserts.push_back(epoch);
190
191        Ok(())
192    }
193
194    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
195    pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
196        let inserts = self
197            .pending_commit
198            .inserts
199            .iter()
200            .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
201            .collect::<Result<_, MlsError>>()?;
202
203        let updates = self
204            .pending_commit
205            .updates
206            .iter()
207            .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
208            .collect::<Result<_, MlsError>>()?;
209
210        let group_state = GroupState {
211            data: group_snapshot.mls_encode_to_vec()?,
212            id: group_snapshot.state.context.group_id,
213        };
214
215        self.storage
216            .write(group_state, inserts, updates)
217            .await
218            .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
219
220        if let Some(ref key_package_ref) = self.pending_key_package_removal {
221            self.key_package_repo
222                .delete(key_package_ref)
223                .await
224                .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
225        }
226
227        self.pending_commit.inserts.clear();
228        self.pending_commit.updates.clear();
229
230        Ok(())
231    }
232
233    #[cfg(any(feature = "psk", feature = "private_message"))]
234    fn find_pending(&self, epoch_id: u64) -> Option<usize> {
235        self.pending_commit
236            .updates
237            .iter()
238            .position(|ep| ep.context.epoch == epoch_id)
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use alloc::vec;
245    use mls_rs_codec::MlsEncode;
246
247    use crate::{
248        client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
249        group::{
250            epoch::{test_utils::get_test_epoch_with_id, SenderDataSecret},
251            test_utils::{random_bytes, test_member, TEST_GROUP},
252            PskGroupId, ResumptionPSKUsage,
253        },
254        storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
255    };
256
257    use super::*;
258
259    fn test_group_state_repo(
260        retention_limit: usize,
261    ) -> GroupStateRepository<InMemoryGroupStateStorage, InMemoryKeyPackageStorage> {
262        GroupStateRepository::new(
263            TEST_GROUP.to_vec(),
264            InMemoryGroupStateStorage::new()
265                .with_max_epoch_retention(retention_limit)
266                .unwrap(),
267            InMemoryKeyPackageStorage::default(),
268            None,
269        )
270        .unwrap()
271    }
272
273    fn test_epoch(epoch_id: u64) -> PriorEpoch {
274        get_test_epoch_with_id(TEST_GROUP.to_vec(), TEST_CIPHER_SUITE, epoch_id)
275    }
276
277    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
278    async fn test_snapshot(epoch_id: u64) -> Snapshot {
279        crate::group::snapshot::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
280    }
281
282    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
283    async fn test_epoch_inserts() {
284        let mut test_repo = test_group_state_repo(1);
285        let test_epoch = test_epoch(0);
286
287        test_repo.insert(test_epoch.clone()).await.unwrap();
288
289        // Check the in-memory state
290        assert_eq!(
291            test_repo.pending_commit.inserts.back().unwrap(),
292            &test_epoch
293        );
294
295        assert!(test_repo.pending_commit.updates.is_empty());
296
297        #[cfg(feature = "std")]
298        assert!(test_repo.storage.inner.lock().unwrap().is_empty());
299        #[cfg(not(feature = "std"))]
300        assert!(test_repo.storage.inner.lock().is_empty());
301
302        let psk_id = ResumptionPsk {
303            psk_epoch: 0,
304            psk_group_id: PskGroupId(test_repo.group_id.clone()),
305            usage: ResumptionPSKUsage::Application,
306        };
307
308        // Make sure you can recall an epoch sitting as a pending insert
309        let resumption = test_repo.resumption_secret(&psk_id).await.unwrap();
310        let prior_epoch = test_repo.get_epoch_mut(0).await.unwrap().cloned();
311
312        assert_eq!(
313            prior_epoch.clone().unwrap().secrets.resumption_secret,
314            resumption.unwrap()
315        );
316
317        assert_eq!(prior_epoch.unwrap(), test_epoch);
318
319        // Write to the storage
320        let snapshot = test_snapshot(test_epoch.epoch_id()).await;
321        test_repo.write_to_storage(snapshot.clone()).await.unwrap();
322
323        // Make sure the memory cache cleared
324        assert!(test_repo.pending_commit.inserts.is_empty());
325        assert!(test_repo.pending_commit.updates.is_empty());
326
327        // Make sure the storage was written
328        #[cfg(feature = "std")]
329        let storage = test_repo.storage.inner.lock().unwrap();
330        #[cfg(not(feature = "std"))]
331        let storage = test_repo.storage.inner.lock();
332
333        assert_eq!(storage.len(), 1);
334
335        let stored = storage.get(TEST_GROUP).unwrap();
336
337        assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
338
339        assert_eq!(stored.epoch_data.len(), 1);
340
341        assert_eq!(
342            stored.epoch_data.back().unwrap(),
343            &EpochRecord::new(
344                test_epoch.epoch_id(),
345                test_epoch.mls_encode_to_vec().unwrap()
346            )
347        );
348    }
349
350    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
351    async fn test_updates() {
352        let mut test_repo = test_group_state_repo(2);
353        let test_epoch_0 = test_epoch(0);
354
355        test_repo.insert(test_epoch_0.clone()).await.unwrap();
356
357        test_repo
358            .write_to_storage(test_snapshot(0).await)
359            .await
360            .unwrap();
361
362        // Update the stored epoch
363        let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
364        assert_eq!(to_update, &test_epoch_0);
365
366        let new_sender_secret = random_bytes(32);
367        to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
368        let to_update = to_update.clone();
369
370        assert_eq!(test_repo.pending_commit.updates.len(), 1);
371        assert!(test_repo.pending_commit.inserts.is_empty());
372
373        assert_eq!(
374            test_repo.pending_commit.updates.first().unwrap(),
375            &to_update
376        );
377
378        // Make sure you can access an epoch pending update
379        let psk_id = ResumptionPsk {
380            psk_epoch: 0,
381            psk_group_id: PskGroupId(test_repo.group_id.clone()),
382            usage: ResumptionPSKUsage::Application,
383        };
384
385        let owned = test_repo.resumption_secret(&psk_id).await.unwrap();
386        assert_eq!(owned.as_ref(), Some(&to_update.secrets.resumption_secret));
387
388        // Write the update to storage
389        let snapshot = test_snapshot(1).await;
390        test_repo.write_to_storage(snapshot.clone()).await.unwrap();
391
392        assert!(test_repo.pending_commit.updates.is_empty());
393        assert!(test_repo.pending_commit.inserts.is_empty());
394
395        // Make sure the storage was written
396        #[cfg(feature = "std")]
397        let storage = test_repo.storage.inner.lock().unwrap();
398        #[cfg(not(feature = "std"))]
399        let storage = test_repo.storage.inner.lock();
400
401        assert_eq!(storage.len(), 1);
402
403        let stored = storage.get(TEST_GROUP).unwrap();
404
405        assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
406
407        assert_eq!(stored.epoch_data.len(), 1);
408
409        assert_eq!(
410            stored.epoch_data.back().unwrap(),
411            &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
412        );
413    }
414
415    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
416    async fn test_insert_and_update() {
417        let mut test_repo = test_group_state_repo(2);
418        let test_epoch_0 = test_epoch(0);
419
420        test_repo.insert(test_epoch_0).await.unwrap();
421
422        test_repo
423            .write_to_storage(test_snapshot(0).await)
424            .await
425            .unwrap();
426
427        // Update the stored epoch
428        let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
429        let new_sender_secret = random_bytes(32);
430        to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
431        let to_update = to_update.clone();
432
433        // Insert another epoch
434        let test_epoch_1 = test_epoch(1);
435        test_repo.insert(test_epoch_1.clone()).await.unwrap();
436
437        test_repo
438            .write_to_storage(test_snapshot(1).await)
439            .await
440            .unwrap();
441
442        assert!(test_repo.pending_commit.inserts.is_empty());
443        assert!(test_repo.pending_commit.updates.is_empty());
444
445        // Make sure the storage was written
446        #[cfg(feature = "std")]
447        let storage = test_repo.storage.inner.lock().unwrap();
448        #[cfg(not(feature = "std"))]
449        let storage = test_repo.storage.inner.lock();
450
451        assert_eq!(storage.len(), 1);
452
453        let stored = storage.get(TEST_GROUP).unwrap();
454
455        assert_eq!(stored.epoch_data.len(), 2);
456
457        assert_eq!(
458            stored.epoch_data.front().unwrap(),
459            &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
460        );
461
462        assert_eq!(
463            stored.epoch_data.back().unwrap(),
464            &EpochRecord::new(
465                test_epoch_1.epoch_id(),
466                test_epoch_1.mls_encode_to_vec().unwrap()
467            )
468        );
469    }
470
471    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
472    async fn test_many_epochs_in_storage() {
473        let epochs = (0..10).map(test_epoch).collect::<Vec<_>>();
474
475        let mut test_repo = test_group_state_repo(10);
476
477        for epoch in epochs.iter().cloned() {
478            test_repo.insert(epoch).await.unwrap()
479        }
480
481        test_repo
482            .write_to_storage(test_snapshot(9).await)
483            .await
484            .unwrap();
485
486        for mut epoch in epochs {
487            let res = test_repo.get_epoch_mut(epoch.epoch_id()).await.unwrap();
488
489            assert_eq!(res, Some(&mut epoch));
490        }
491    }
492
493    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
494    async fn test_stored_groups_list() {
495        let mut test_repo = test_group_state_repo(2);
496        let test_epoch_0 = test_epoch(0);
497
498        test_repo.insert(test_epoch_0.clone()).await.unwrap();
499
500        test_repo
501            .write_to_storage(test_snapshot(0).await)
502            .await
503            .unwrap();
504
505        assert_eq!(
506            test_repo.storage.stored_groups(),
507            vec![test_epoch_0.context.group_id]
508        )
509    }
510
511    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
512    async fn reducing_retention_limit_takes_effect_on_epoch_access() {
513        let mut repo = test_group_state_repo(1);
514
515        repo.insert(test_epoch(0)).await.unwrap();
516        repo.insert(test_epoch(1)).await.unwrap();
517
518        repo.write_to_storage(test_snapshot(0).await).await.unwrap();
519
520        let mut repo = GroupStateRepository {
521            storage: repo.storage,
522            ..test_group_state_repo(1)
523        };
524
525        let res = repo.get_epoch_mut(0).await.unwrap();
526
527        assert!(res.is_none());
528    }
529
530    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
531    async fn in_memory_storage_obeys_retention_limit_after_saving() {
532        let mut repo = test_group_state_repo(1);
533
534        repo.insert(test_epoch(0)).await.unwrap();
535        repo.write_to_storage(test_snapshot(0).await).await.unwrap();
536        repo.insert(test_epoch(1)).await.unwrap();
537        repo.write_to_storage(test_snapshot(1).await).await.unwrap();
538
539        #[cfg(feature = "std")]
540        let lock = repo.storage.inner.lock().unwrap();
541        #[cfg(not(feature = "std"))]
542        let lock = repo.storage.inner.lock();
543
544        assert_eq!(lock.get(TEST_GROUP).unwrap().epoch_data.len(), 1);
545    }
546
547    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
548    async fn used_key_package_is_deleted() {
549        let key_package_repo = InMemoryKeyPackageStorage::default();
550
551        let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
552            .await
553            .0;
554
555        let (id, data) = key_package.to_storage().unwrap();
556
557        key_package_repo.insert(id, data);
558
559        let mut repo = GroupStateRepository::new(
560            TEST_GROUP.to_vec(),
561            InMemoryGroupStateStorage::new(),
562            key_package_repo,
563            Some(key_package.reference.clone()),
564        )
565        .unwrap();
566
567        repo.key_package_repo.get(&key_package.reference).unwrap();
568
569        repo.write_to_storage(test_snapshot(4).await).await.unwrap();
570
571        assert!(repo.key_package_repo.get(&key_package.reference).is_none());
572    }
573}