1use alloc::boxed::Box;
6use alloc::vec::Vec;
7
8use crate::{
9 client::MlsError,
10 client_config::ClientConfig,
11 group::{
12 cipher_suite_provider, epoch::EpochSecrets, key_schedule::KeySchedule,
13 message_hash::MessageHash, state_repo::GroupStateRepository, ConfirmationTag, Group,
14 GroupContext, GroupState, InterimTranscriptHash, ReInitProposal, TreeKemPublic,
15 },
16 tree_kem::TreeKemPrivate,
17};
18
19#[cfg(feature = "by_ref_proposal")]
20use crate::{
21 crypto::{HpkePublicKey, HpkeSecretKey},
22 group::{
23 proposal_cache::{CachedProposal, ProposalCache},
24 ProposalMessageDescription, ProposalRef,
25 },
26 map::SmallMap,
27};
28
29use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
30use mls_rs_core::crypto::SignatureSecretKey;
31#[cfg(feature = "tree_index")]
32use mls_rs_core::identity::IdentityProvider;
33
34use super::PendingCommit;
35
36pub(crate) use legacy::LegacyPendingCommit;
37
38#[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)]
39#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
40pub(crate) struct Snapshot {
41 version: u16,
42 pub(crate) state: RawGroupState,
43 private_tree: TreeKemPrivate,
44 epoch_secrets: EpochSecrets,
45 key_schedule: KeySchedule,
46 #[cfg(feature = "by_ref_proposal")]
47 pending_updates: SmallMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>,
48 pending_commit_snapshot: PendingCommitSnapshot,
49 signer: SignatureSecretKey,
50}
51
52#[derive(Debug, PartialEq, Clone, Default, MlsSize, MlsEncode, MlsDecode)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54#[repr(u8)]
55pub(crate) enum PendingCommitSnapshot {
56 #[default]
57 None = 0u8,
58 LegacyPendingCommit(Box<LegacyPendingCommit>) = 1u8,
60 PendingCommit(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>) = 2u8,
61}
62
63impl From<Vec<u8>> for PendingCommitSnapshot {
64 fn from(value: Vec<u8>) -> Self {
65 Self::PendingCommit(value)
66 }
67}
68
69impl TryFrom<PendingCommit> for PendingCommitSnapshot {
70 type Error = mls_rs_codec::Error;
71
72 fn try_from(value: PendingCommit) -> Result<Self, Self::Error> {
73 value.mls_encode_to_vec().map(Self::PendingCommit)
74 }
75}
76impl PendingCommitSnapshot {
77 pub fn is_none(&self) -> bool {
78 self == &Self::None
79 }
80
81 pub fn commit_hash(&self) -> Result<Option<MessageHash>, MlsError> {
82 match self {
83 Self::None => Ok(None),
84 Self::PendingCommit(bytes) => Ok(Some(
85 PendingCommit::mls_decode(&mut &bytes[..])?.commit_message_hash,
86 )),
87 Self::LegacyPendingCommit(legacy_pending) => {
88 Ok(Some(legacy_pending.commit_message_hash.clone()))
89 }
90 }
91 }
92}
93
94#[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)]
95#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
96pub(crate) struct RawGroupState {
97 pub(crate) context: GroupContext,
98 #[cfg(feature = "by_ref_proposal")]
99 pub(crate) proposals: SmallMap<ProposalRef, CachedProposal>,
100 #[cfg(feature = "by_ref_proposal")]
101 pub(crate) own_proposals: SmallMap<MessageHash, ProposalMessageDescription>,
102 pub(crate) public_tree: TreeKemPublic,
103 pub(crate) interim_transcript_hash: InterimTranscriptHash,
104 pub(crate) pending_reinit: Option<ReInitProposal>,
105 pub(crate) confirmation_tag: ConfirmationTag,
106}
107
108impl RawGroupState {
109 pub(crate) fn export(state: &GroupState) -> Self {
110 #[cfg(feature = "tree_index")]
111 let public_tree = state.public_tree.clone();
112
113 #[cfg(not(feature = "tree_index"))]
114 let public_tree = {
115 let mut tree = TreeKemPublic::new();
116 tree.nodes = state.public_tree.nodes.clone();
117 tree
118 };
119
120 Self {
121 context: state.context.clone(),
122 #[cfg(feature = "by_ref_proposal")]
123 proposals: state.proposals.proposals.clone(),
124 #[cfg(feature = "by_ref_proposal")]
125 own_proposals: state.proposals.own_proposals.clone(),
126 public_tree,
127 interim_transcript_hash: state.interim_transcript_hash.clone(),
128 pending_reinit: state.pending_reinit.clone(),
129 confirmation_tag: state.confirmation_tag.clone(),
130 }
131 }
132
133 #[cfg(feature = "tree_index")]
134 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
135 pub(crate) async fn import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError>
136 where
137 C: IdentityProvider,
138 {
139 let context = self.context;
140
141 #[cfg(feature = "by_ref_proposal")]
142 let proposals = ProposalCache::import(
143 context.protocol_version,
144 context.group_id.clone(),
145 self.proposals,
146 self.own_proposals.clone(),
147 );
148
149 let mut public_tree = self.public_tree;
150
151 public_tree
152 .initialize_index_if_necessary(identity_provider, &context.extensions)
153 .await?;
154
155 Ok(GroupState {
156 #[cfg(feature = "by_ref_proposal")]
157 proposals,
158 context,
159 public_tree,
160 interim_transcript_hash: self.interim_transcript_hash,
161 pending_reinit: self.pending_reinit,
162 confirmation_tag: self.confirmation_tag,
163 })
164 }
165
166 #[cfg(not(feature = "tree_index"))]
167 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
168 pub(crate) async fn import(self) -> Result<GroupState, MlsError> {
169 let context = self.context;
170
171 #[cfg(feature = "by_ref_proposal")]
172 let proposals = ProposalCache::import(
173 context.protocol_version,
174 context.group_id.clone(),
175 self.proposals,
176 self.own_proposals.clone(),
177 );
178
179 Ok(GroupState {
180 #[cfg(feature = "by_ref_proposal")]
181 proposals,
182 context,
183 public_tree: self.public_tree,
184 interim_transcript_hash: self.interim_transcript_hash,
185 pending_reinit: self.pending_reinit,
186 confirmation_tag: self.confirmation_tag,
187 })
188 }
189}
190
191impl<C> Group<C>
192where
193 C: ClientConfig + Clone,
194{
195 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
199 pub async fn write_to_storage(&mut self) -> Result<(), MlsError> {
200 self.state_repo.write_to_storage(self.snapshot()?).await
201 }
202
203 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
209 pub async fn write_to_storage_without_ratchet_tree(&mut self) -> Result<(), MlsError> {
210 let mut snapshot = self.snapshot()?;
211 snapshot.state.public_tree.nodes = Default::default();
212
213 self.state_repo.write_to_storage(snapshot).await
214 }
215
216 pub(crate) fn snapshot(&self) -> Result<Snapshot, MlsError> {
217 Ok(Snapshot {
218 state: RawGroupState::export(&self.state),
219 private_tree: self.private_tree.clone(),
220 key_schedule: self.key_schedule.clone(),
221 #[cfg(feature = "by_ref_proposal")]
222 pending_updates: self.pending_updates.clone(),
223 pending_commit_snapshot: self.pending_commit.clone(),
224 epoch_secrets: self.epoch_secrets.clone(),
225 version: 1,
226 signer: self.signer.clone(),
227 })
228 }
229
230 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
231 pub(crate) async fn from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError> {
232 let cipher_suite_provider = cipher_suite_provider(
233 config.crypto_provider(),
234 snapshot.state.context.cipher_suite,
235 )?;
236
237 #[cfg(feature = "tree_index")]
238 let identity_provider = config.identity_provider();
239
240 let state_repo = GroupStateRepository::new(
241 #[cfg(feature = "prior_epoch")]
242 snapshot.state.context.group_id.clone(),
243 config.group_state_storage(),
244 config.key_package_repo(),
245 None,
246 )?;
247
248 Ok(Group {
249 config,
250 state: snapshot
251 .state
252 .import(
253 #[cfg(feature = "tree_index")]
254 &identity_provider,
255 )
256 .await?,
257 private_tree: snapshot.private_tree,
258 key_schedule: snapshot.key_schedule,
259 #[cfg(feature = "by_ref_proposal")]
260 pending_updates: snapshot.pending_updates,
261 pending_commit: snapshot.pending_commit_snapshot,
262 #[cfg(test)]
263 commit_modifiers: Default::default(),
264 epoch_secrets: snapshot.epoch_secrets,
265 state_repo,
266 cipher_suite_provider,
267 #[cfg(feature = "psk")]
268 previous_psk: None,
269 signer: snapshot.signer,
270 })
271 }
272}
273
274mod legacy {
275 use crate::{group::AuthenticatedContent, tree_kem::path_secret::PathSecret};
276
277 use super::*;
278
279 #[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
280 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
281 pub(crate) struct LegacyPendingCommit {
282 pub content: AuthenticatedContent,
283 pub private_tree: TreeKemPrivate,
284 pub commit_secret: PathSecret,
285 pub commit_message_hash: MessageHash,
286 }
287}
288
289#[cfg(test)]
290pub(crate) mod test_utils {
291 use alloc::vec;
292
293 use crate::{
294 cipher_suite::CipherSuite,
295 crypto::test_utils::test_cipher_suite_provider,
296 group::{
297 confirmation_tag::ConfirmationTag, epoch::test_utils::get_test_epoch_secrets,
298 key_schedule::test_utils::get_test_key_schedule, test_utils::get_test_group_context,
299 transcript_hash::InterimTranscriptHash,
300 },
301 tree_kem::{node::LeafIndex, TreeKemPrivate},
302 };
303
304 use super::{RawGroupState, Snapshot};
305
306 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
307 pub(crate) async fn get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot {
308 Snapshot {
309 state: RawGroupState {
310 context: get_test_group_context(epoch_id, cipher_suite).await,
311 #[cfg(feature = "by_ref_proposal")]
312 proposals: Default::default(),
313 #[cfg(feature = "by_ref_proposal")]
314 own_proposals: Default::default(),
315 public_tree: Default::default(),
316 interim_transcript_hash: InterimTranscriptHash::from(vec![]),
317 pending_reinit: None,
318 confirmation_tag: ConfirmationTag::empty(&test_cipher_suite_provider(cipher_suite))
319 .await,
320 },
321 private_tree: TreeKemPrivate::new(LeafIndex::unchecked(0)),
322 epoch_secrets: get_test_epoch_secrets(cipher_suite),
323 key_schedule: get_test_key_schedule(cipher_suite),
324 #[cfg(feature = "by_ref_proposal")]
325 pending_updates: Default::default(),
326 pending_commit_snapshot: Default::default(),
327 version: 1,
328 signer: vec![].into(),
329 }
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use alloc::vec;
336 use mls_rs_core::group::{GroupState, GroupStateStorage};
337
338 use crate::{
339 client::test_utils::{TestClientBuilder, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
340 group::{
341 test_utils::{test_group, TestGroup},
342 Group,
343 },
344 storage_provider::in_memory::InMemoryGroupStateStorage,
345 };
346
347 #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
348 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
349 async fn legacy_interop() {
350 let mut storage = InMemoryGroupStateStorage::new();
351
352 let legacy_snapshot = include_bytes!(concat!(
353 env!("CARGO_MANIFEST_DIR"),
354 "/test_data/legacy_snapshot.mls"
355 ));
356
357 let group_state = GroupState {
358 id: b"group".into(),
359 data: legacy_snapshot.to_vec(),
360 };
361
362 storage
363 .write(group_state, Default::default(), Default::default())
364 .await
365 .unwrap();
366
367 let client = TestClientBuilder::new_for_test()
368 .group_state_storage(storage)
369 .build();
370
371 let mut group = client.load_group(b"group").await.unwrap();
372
373 group
374 .apply_pending_commit_backwards_compatible()
375 .await
376 .unwrap();
377 }
378
379 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
380 async fn snapshot_restore(group: TestGroup) {
381 let snapshot = group.snapshot().unwrap();
382
383 let group_restored = Group::from_snapshot(group.config.clone(), snapshot)
384 .await
385 .unwrap();
386
387 assert!(Group::equal_group_state(&group, &group_restored));
388
389 #[cfg(feature = "tree_index")]
390 assert!(group_restored
391 .state
392 .public_tree
393 .equal_internals(&group.state.public_tree))
394 }
395
396 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
397 async fn snapshot_with_pending_commit_can_be_serialized_to_json() {
398 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
399 group.commit(vec![]).await.unwrap();
400
401 snapshot_restore(group).await
402 }
403
404 #[cfg(feature = "by_ref_proposal")]
405 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
406 async fn snapshot_with_pending_updates_can_be_serialized_to_json() {
407 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
408
409 let update_proposal = group.update_proposal().await;
411
412 let _ = group.proposal_message(update_proposal, vec![]).await;
414
415 snapshot_restore(group).await
416 }
417
418 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
419 async fn snapshot_can_be_serialized_to_json_with_internals() {
420 let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
421
422 snapshot_restore(group).await
423 }
424
425 #[cfg(feature = "serde")]
426 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
427 async fn serde() {
428 let snapshot = super::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, 5).await;
429 let json = serde_json::to_string_pretty(&snapshot).unwrap();
430 let recovered = serde_json::from_str(&json).unwrap();
431 assert_eq!(snapshot, recovered);
432 }
433}