1use crate::client::MlsError;
6use crate::{group::PriorEpoch, key_package::KeyPackageRef};
7
8use alloc::collections::VecDeque;
9use alloc::vec::Vec;
10use core::fmt::{self, Debug};
11use mls_rs_codec::{MlsDecode, MlsEncode};
12use mls_rs_core::group::{EpochRecord, GroupState};
13use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage, key_package::KeyPackageStorage};
14
15use super::snapshot::Snapshot;
16
17#[cfg(feature = "psk")]
18use crate::group::ResumptionPsk;
19
20#[cfg(feature = "psk")]
21use mls_rs_core::psk::PreSharedKey;
22
23#[derive(Default, Clone, Debug)]
26struct EpochStorageCommit {
27 pub(crate) inserts: VecDeque<PriorEpoch>,
28 pub(crate) updates: Vec<PriorEpoch>,
29}
30
31#[derive(Clone)]
32pub(crate) struct GroupStateRepository<S, K>
33where
34 S: GroupStateStorage,
35 K: KeyPackageStorage,
36{
37 pending_commit: EpochStorageCommit,
38 pending_key_package_removal: Option<KeyPackageRef>,
39 group_id: Vec<u8>,
40 storage: S,
41 key_package_repo: K,
42}
43
44impl<S, K> Debug for GroupStateRepository<S, K>
45where
46 S: GroupStateStorage + Debug,
47 K: KeyPackageStorage + Debug,
48{
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_struct("GroupStateRepository")
51 .field("pending_commit", &self.pending_commit)
52 .field(
53 "pending_key_package_removal",
54 &self.pending_key_package_removal,
55 )
56 .field(
57 "group_id",
58 &mls_rs_core::debug::pretty_group_id(&self.group_id),
59 )
60 .field("storage", &self.storage)
61 .field("key_package_repo", &self.key_package_repo)
62 .finish()
63 }
64}
65
66impl<S, K> GroupStateRepository<S, K>
67where
68 S: GroupStateStorage,
69 K: KeyPackageStorage,
70{
71 pub fn new(
72 group_id: Vec<u8>,
73 storage: S,
74 key_package_repo: K,
75 key_package_to_remove: Option<KeyPackageRef>,
77 ) -> Result<GroupStateRepository<S, K>, MlsError> {
78 Ok(GroupStateRepository {
79 group_id,
80 storage,
81 pending_key_package_removal: key_package_to_remove,
82 pending_commit: Default::default(),
83 key_package_repo,
84 })
85 }
86
87 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
88 async fn find_max_id(&self) -> Result<Option<u64>, MlsError> {
89 if let Some(max) = self.pending_commit.inserts.back().map(|e| e.epoch_id()) {
90 Ok(Some(max))
91 } else {
92 self.storage
93 .max_epoch_id(&self.group_id)
94 .await
95 .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))
96 }
97 }
98
99 #[cfg(feature = "psk")]
100 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
101 pub async fn resumption_secret(
102 &self,
103 psk_id: &ResumptionPsk,
104 ) -> Result<Option<PreSharedKey>, MlsError> {
105 if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
107 if psk_id.psk_epoch >= min {
108 return Ok(self
109 .pending_commit
110 .inserts
111 .get((psk_id.psk_epoch - min) as usize)
112 .map(|e| e.secrets.resumption_secret.clone()));
113 }
114 }
115
116 let maybe_pending = self.find_pending(psk_id.psk_epoch);
118
119 if let Some(pending) = maybe_pending {
120 return Ok(Some(
121 self.pending_commit.updates[pending]
122 .secrets
123 .resumption_secret
124 .clone(),
125 ));
126 }
127
128 self.storage
130 .epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch)
131 .await
132 .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
133 .map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret))
134 .transpose()
135 }
136
137 #[cfg(feature = "private_message")]
138 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
139 pub async fn get_epoch_mut(
140 &mut self,
141 epoch_id: u64,
142 ) -> Result<Option<&mut PriorEpoch>, MlsError> {
143 if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
145 if epoch_id >= min {
146 return Ok(self
147 .pending_commit
148 .inserts
149 .get_mut((epoch_id - min) as usize));
150 }
151 }
152
153 match self.find_pending(epoch_id) {
156 Some(i) => self.pending_commit.updates.get_mut(i).map(Ok),
157 None => self
158 .storage
159 .epoch(&self.group_id, epoch_id)
160 .await
161 .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
162 .and_then(|epoch| {
163 PriorEpoch::mls_decode(&mut &*epoch)
164 .map(|epoch| {
165 self.pending_commit.updates.push(epoch);
166 self.pending_commit.updates.last_mut()
167 })
168 .transpose()
169 }),
170 }
171 .transpose()
172 .map_err(Into::into)
173 }
174
175 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
176 pub async fn insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError> {
177 if epoch.group_id() != self.group_id {
178 return Err(MlsError::GroupIdMismatch);
179 }
180
181 let epoch_id = epoch.epoch_id();
182
183 if let Some(expected_id) = self.find_max_id().await?.map(|id| id + 1) {
184 if epoch_id != expected_id {
185 return Err(MlsError::InvalidEpoch);
186 }
187 }
188
189 self.pending_commit.inserts.push_back(epoch);
190
191 Ok(())
192 }
193
194 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
195 pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
196 let inserts = self
197 .pending_commit
198 .inserts
199 .iter()
200 .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
201 .collect::<Result<_, MlsError>>()?;
202
203 let updates = self
204 .pending_commit
205 .updates
206 .iter()
207 .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
208 .collect::<Result<_, MlsError>>()?;
209
210 let group_state = GroupState {
211 data: group_snapshot.mls_encode_to_vec()?,
212 id: group_snapshot.state.context.group_id,
213 };
214
215 self.storage
216 .write(group_state, inserts, updates)
217 .await
218 .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
219
220 if let Some(ref key_package_ref) = self.pending_key_package_removal {
221 self.key_package_repo
222 .delete(key_package_ref)
223 .await
224 .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
225 }
226
227 self.pending_commit.inserts.clear();
228 self.pending_commit.updates.clear();
229
230 Ok(())
231 }
232
233 #[cfg(any(feature = "psk", feature = "private_message"))]
234 fn find_pending(&self, epoch_id: u64) -> Option<usize> {
235 self.pending_commit
236 .updates
237 .iter()
238 .position(|ep| ep.context.epoch == epoch_id)
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use alloc::vec;
245 use mls_rs_codec::MlsEncode;
246
247 use crate::{
248 client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
249 group::{
250 epoch::{test_utils::get_test_epoch_with_id, SenderDataSecret},
251 test_utils::{random_bytes, test_member, TEST_GROUP},
252 PskGroupId, ResumptionPSKUsage,
253 },
254 storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
255 };
256
257 use super::*;
258
259 fn test_group_state_repo(
260 retention_limit: usize,
261 ) -> GroupStateRepository<InMemoryGroupStateStorage, InMemoryKeyPackageStorage> {
262 GroupStateRepository::new(
263 TEST_GROUP.to_vec(),
264 InMemoryGroupStateStorage::new()
265 .with_max_epoch_retention(retention_limit)
266 .unwrap(),
267 InMemoryKeyPackageStorage::default(),
268 None,
269 )
270 .unwrap()
271 }
272
273 fn test_epoch(epoch_id: u64) -> PriorEpoch {
274 get_test_epoch_with_id(TEST_GROUP.to_vec(), TEST_CIPHER_SUITE, epoch_id)
275 }
276
277 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
278 async fn test_snapshot(epoch_id: u64) -> Snapshot {
279 crate::group::snapshot::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
280 }
281
282 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
283 async fn test_epoch_inserts() {
284 let mut test_repo = test_group_state_repo(1);
285 let test_epoch = test_epoch(0);
286
287 test_repo.insert(test_epoch.clone()).await.unwrap();
288
289 assert_eq!(
291 test_repo.pending_commit.inserts.back().unwrap(),
292 &test_epoch
293 );
294
295 assert!(test_repo.pending_commit.updates.is_empty());
296
297 #[cfg(feature = "std")]
298 assert!(test_repo.storage.inner.lock().unwrap().is_empty());
299 #[cfg(not(feature = "std"))]
300 assert!(test_repo.storage.inner.lock().is_empty());
301
302 let psk_id = ResumptionPsk {
303 psk_epoch: 0,
304 psk_group_id: PskGroupId(test_repo.group_id.clone()),
305 usage: ResumptionPSKUsage::Application,
306 };
307
308 let resumption = test_repo.resumption_secret(&psk_id).await.unwrap();
310 let prior_epoch = test_repo.get_epoch_mut(0).await.unwrap().cloned();
311
312 assert_eq!(
313 prior_epoch.clone().unwrap().secrets.resumption_secret,
314 resumption.unwrap()
315 );
316
317 assert_eq!(prior_epoch.unwrap(), test_epoch);
318
319 let snapshot = test_snapshot(test_epoch.epoch_id()).await;
321 test_repo.write_to_storage(snapshot.clone()).await.unwrap();
322
323 assert!(test_repo.pending_commit.inserts.is_empty());
325 assert!(test_repo.pending_commit.updates.is_empty());
326
327 #[cfg(feature = "std")]
329 let storage = test_repo.storage.inner.lock().unwrap();
330 #[cfg(not(feature = "std"))]
331 let storage = test_repo.storage.inner.lock();
332
333 assert_eq!(storage.len(), 1);
334
335 let stored = storage.get(TEST_GROUP).unwrap();
336
337 assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
338
339 assert_eq!(stored.epoch_data.len(), 1);
340
341 assert_eq!(
342 stored.epoch_data.back().unwrap(),
343 &EpochRecord::new(
344 test_epoch.epoch_id(),
345 test_epoch.mls_encode_to_vec().unwrap()
346 )
347 );
348 }
349
350 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
351 async fn test_updates() {
352 let mut test_repo = test_group_state_repo(2);
353 let test_epoch_0 = test_epoch(0);
354
355 test_repo.insert(test_epoch_0.clone()).await.unwrap();
356
357 test_repo
358 .write_to_storage(test_snapshot(0).await)
359 .await
360 .unwrap();
361
362 let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
364 assert_eq!(to_update, &test_epoch_0);
365
366 let new_sender_secret = random_bytes(32);
367 to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
368 let to_update = to_update.clone();
369
370 assert_eq!(test_repo.pending_commit.updates.len(), 1);
371 assert!(test_repo.pending_commit.inserts.is_empty());
372
373 assert_eq!(
374 test_repo.pending_commit.updates.first().unwrap(),
375 &to_update
376 );
377
378 let psk_id = ResumptionPsk {
380 psk_epoch: 0,
381 psk_group_id: PskGroupId(test_repo.group_id.clone()),
382 usage: ResumptionPSKUsage::Application,
383 };
384
385 let owned = test_repo.resumption_secret(&psk_id).await.unwrap();
386 assert_eq!(owned.as_ref(), Some(&to_update.secrets.resumption_secret));
387
388 let snapshot = test_snapshot(1).await;
390 test_repo.write_to_storage(snapshot.clone()).await.unwrap();
391
392 assert!(test_repo.pending_commit.updates.is_empty());
393 assert!(test_repo.pending_commit.inserts.is_empty());
394
395 #[cfg(feature = "std")]
397 let storage = test_repo.storage.inner.lock().unwrap();
398 #[cfg(not(feature = "std"))]
399 let storage = test_repo.storage.inner.lock();
400
401 assert_eq!(storage.len(), 1);
402
403 let stored = storage.get(TEST_GROUP).unwrap();
404
405 assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
406
407 assert_eq!(stored.epoch_data.len(), 1);
408
409 assert_eq!(
410 stored.epoch_data.back().unwrap(),
411 &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
412 );
413 }
414
415 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
416 async fn test_insert_and_update() {
417 let mut test_repo = test_group_state_repo(2);
418 let test_epoch_0 = test_epoch(0);
419
420 test_repo.insert(test_epoch_0).await.unwrap();
421
422 test_repo
423 .write_to_storage(test_snapshot(0).await)
424 .await
425 .unwrap();
426
427 let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
429 let new_sender_secret = random_bytes(32);
430 to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
431 let to_update = to_update.clone();
432
433 let test_epoch_1 = test_epoch(1);
435 test_repo.insert(test_epoch_1.clone()).await.unwrap();
436
437 test_repo
438 .write_to_storage(test_snapshot(1).await)
439 .await
440 .unwrap();
441
442 assert!(test_repo.pending_commit.inserts.is_empty());
443 assert!(test_repo.pending_commit.updates.is_empty());
444
445 #[cfg(feature = "std")]
447 let storage = test_repo.storage.inner.lock().unwrap();
448 #[cfg(not(feature = "std"))]
449 let storage = test_repo.storage.inner.lock();
450
451 assert_eq!(storage.len(), 1);
452
453 let stored = storage.get(TEST_GROUP).unwrap();
454
455 assert_eq!(stored.epoch_data.len(), 2);
456
457 assert_eq!(
458 stored.epoch_data.front().unwrap(),
459 &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
460 );
461
462 assert_eq!(
463 stored.epoch_data.back().unwrap(),
464 &EpochRecord::new(
465 test_epoch_1.epoch_id(),
466 test_epoch_1.mls_encode_to_vec().unwrap()
467 )
468 );
469 }
470
471 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
472 async fn test_many_epochs_in_storage() {
473 let epochs = (0..10).map(test_epoch).collect::<Vec<_>>();
474
475 let mut test_repo = test_group_state_repo(10);
476
477 for epoch in epochs.iter().cloned() {
478 test_repo.insert(epoch).await.unwrap()
479 }
480
481 test_repo
482 .write_to_storage(test_snapshot(9).await)
483 .await
484 .unwrap();
485
486 for mut epoch in epochs {
487 let res = test_repo.get_epoch_mut(epoch.epoch_id()).await.unwrap();
488
489 assert_eq!(res, Some(&mut epoch));
490 }
491 }
492
493 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
494 async fn test_stored_groups_list() {
495 let mut test_repo = test_group_state_repo(2);
496 let test_epoch_0 = test_epoch(0);
497
498 test_repo.insert(test_epoch_0.clone()).await.unwrap();
499
500 test_repo
501 .write_to_storage(test_snapshot(0).await)
502 .await
503 .unwrap();
504
505 assert_eq!(
506 test_repo.storage.stored_groups(),
507 vec![test_epoch_0.context.group_id]
508 )
509 }
510
511 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
512 async fn reducing_retention_limit_takes_effect_on_epoch_access() {
513 let mut repo = test_group_state_repo(1);
514
515 repo.insert(test_epoch(0)).await.unwrap();
516 repo.insert(test_epoch(1)).await.unwrap();
517
518 repo.write_to_storage(test_snapshot(0).await).await.unwrap();
519
520 let mut repo = GroupStateRepository {
521 storage: repo.storage,
522 ..test_group_state_repo(1)
523 };
524
525 let res = repo.get_epoch_mut(0).await.unwrap();
526
527 assert!(res.is_none());
528 }
529
530 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
531 async fn in_memory_storage_obeys_retention_limit_after_saving() {
532 let mut repo = test_group_state_repo(1);
533
534 repo.insert(test_epoch(0)).await.unwrap();
535 repo.write_to_storage(test_snapshot(0).await).await.unwrap();
536 repo.insert(test_epoch(1)).await.unwrap();
537 repo.write_to_storage(test_snapshot(1).await).await.unwrap();
538
539 #[cfg(feature = "std")]
540 let lock = repo.storage.inner.lock().unwrap();
541 #[cfg(not(feature = "std"))]
542 let lock = repo.storage.inner.lock();
543
544 assert_eq!(lock.get(TEST_GROUP).unwrap().epoch_data.len(), 1);
545 }
546
547 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
548 async fn used_key_package_is_deleted() {
549 let key_package_repo = InMemoryKeyPackageStorage::default();
550
551 let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
552 .await
553 .0;
554
555 let (id, data) = key_package.to_storage().unwrap();
556
557 key_package_repo.insert(id, data);
558
559 let mut repo = GroupStateRepository::new(
560 TEST_GROUP.to_vec(),
561 InMemoryGroupStateStorage::new(),
562 key_package_repo,
563 Some(key_package.reference.clone()),
564 )
565 .unwrap();
566
567 repo.key_package_repo.get(&key_package.reference).unwrap();
568
569 repo.write_to_storage(test_snapshot(4).await).await.unwrap();
570
571 assert!(repo.key_package_repo.get(&key_package.reference).is_none());
572 }
573}