mls_rs/group/
snapshot.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use alloc::boxed::Box;
6use alloc::vec::Vec;
7
8use crate::{
9    client::MlsError,
10    client_config::ClientConfig,
11    group::{
12        cipher_suite_provider, epoch::EpochSecrets, key_schedule::KeySchedule,
13        message_hash::MessageHash, state_repo::GroupStateRepository, ConfirmationTag, Group,
14        GroupContext, GroupState, InterimTranscriptHash, ReInitProposal, TreeKemPublic,
15    },
16    tree_kem::TreeKemPrivate,
17};
18
19#[cfg(feature = "by_ref_proposal")]
20use crate::{
21    crypto::{HpkePublicKey, HpkeSecretKey},
22    group::{
23        proposal_cache::{CachedProposal, ProposalCache},
24        ProposalMessageDescription, ProposalRef,
25    },
26    map::SmallMap,
27};
28
29use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
30use mls_rs_core::crypto::SignatureSecretKey;
31#[cfg(feature = "tree_index")]
32use mls_rs_core::identity::IdentityProvider;
33
34use super::PendingCommit;
35
36pub(crate) use legacy::LegacyPendingCommit;
37
38#[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)]
39#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
40pub(crate) struct Snapshot {
41    version: u16,
42    pub(crate) state: RawGroupState,
43    private_tree: TreeKemPrivate,
44    epoch_secrets: EpochSecrets,
45    key_schedule: KeySchedule,
46    #[cfg(feature = "by_ref_proposal")]
47    pending_updates: SmallMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>,
48    pending_commit_snapshot: PendingCommitSnapshot,
49    signer: SignatureSecretKey,
50}
51
52#[derive(Debug, PartialEq, Clone, Default, MlsSize, MlsEncode, MlsDecode)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54#[repr(u8)]
55pub(crate) enum PendingCommitSnapshot {
56    #[default]
57    None = 0u8,
58    // This must be 1 for backwards compatibility
59    LegacyPendingCommit(Box<LegacyPendingCommit>) = 1u8,
60    PendingCommit(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>) = 2u8,
61}
62
63impl From<Vec<u8>> for PendingCommitSnapshot {
64    fn from(value: Vec<u8>) -> Self {
65        Self::PendingCommit(value)
66    }
67}
68
69impl TryFrom<PendingCommit> for PendingCommitSnapshot {
70    type Error = mls_rs_codec::Error;
71
72    fn try_from(value: PendingCommit) -> Result<Self, Self::Error> {
73        value.mls_encode_to_vec().map(Self::PendingCommit)
74    }
75}
76impl PendingCommitSnapshot {
77    pub fn is_none(&self) -> bool {
78        self == &Self::None
79    }
80
81    pub fn commit_hash(&self) -> Result<Option<MessageHash>, MlsError> {
82        match self {
83            Self::None => Ok(None),
84            Self::PendingCommit(bytes) => Ok(Some(
85                PendingCommit::mls_decode(&mut &bytes[..])?.commit_message_hash,
86            )),
87            Self::LegacyPendingCommit(legacy_pending) => {
88                Ok(Some(legacy_pending.commit_message_hash.clone()))
89            }
90        }
91    }
92}
93
94#[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)]
95#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
96pub(crate) struct RawGroupState {
97    pub(crate) context: GroupContext,
98    #[cfg(feature = "by_ref_proposal")]
99    pub(crate) proposals: SmallMap<ProposalRef, CachedProposal>,
100    #[cfg(feature = "by_ref_proposal")]
101    pub(crate) own_proposals: SmallMap<MessageHash, ProposalMessageDescription>,
102    pub(crate) public_tree: TreeKemPublic,
103    pub(crate) interim_transcript_hash: InterimTranscriptHash,
104    pub(crate) pending_reinit: Option<ReInitProposal>,
105    pub(crate) confirmation_tag: ConfirmationTag,
106}
107
108impl RawGroupState {
109    pub(crate) fn export(state: &GroupState) -> Self {
110        #[cfg(feature = "tree_index")]
111        let public_tree = state.public_tree.clone();
112
113        #[cfg(not(feature = "tree_index"))]
114        let public_tree = {
115            let mut tree = TreeKemPublic::new();
116            tree.nodes = state.public_tree.nodes.clone();
117            tree
118        };
119
120        Self {
121            context: state.context.clone(),
122            #[cfg(feature = "by_ref_proposal")]
123            proposals: state.proposals.proposals.clone(),
124            #[cfg(feature = "by_ref_proposal")]
125            own_proposals: state.proposals.own_proposals.clone(),
126            public_tree,
127            interim_transcript_hash: state.interim_transcript_hash.clone(),
128            pending_reinit: state.pending_reinit.clone(),
129            confirmation_tag: state.confirmation_tag.clone(),
130        }
131    }
132
133    #[cfg(feature = "tree_index")]
134    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
135    pub(crate) async fn import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError>
136    where
137        C: IdentityProvider,
138    {
139        let context = self.context;
140
141        #[cfg(feature = "by_ref_proposal")]
142        let proposals = ProposalCache::import(
143            context.protocol_version,
144            context.group_id.clone(),
145            self.proposals,
146            self.own_proposals.clone(),
147        );
148
149        let mut public_tree = self.public_tree;
150
151        public_tree
152            .initialize_index_if_necessary(identity_provider, &context.extensions)
153            .await?;
154
155        Ok(GroupState {
156            #[cfg(feature = "by_ref_proposal")]
157            proposals,
158            context,
159            public_tree,
160            interim_transcript_hash: self.interim_transcript_hash,
161            pending_reinit: self.pending_reinit,
162            confirmation_tag: self.confirmation_tag,
163        })
164    }
165
166    #[cfg(not(feature = "tree_index"))]
167    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
168    pub(crate) async fn import(self) -> Result<GroupState, MlsError> {
169        let context = self.context;
170
171        #[cfg(feature = "by_ref_proposal")]
172        let proposals = ProposalCache::import(
173            context.protocol_version,
174            context.group_id.clone(),
175            self.proposals,
176            self.own_proposals.clone(),
177        );
178
179        Ok(GroupState {
180            #[cfg(feature = "by_ref_proposal")]
181            proposals,
182            context,
183            public_tree: self.public_tree,
184            interim_transcript_hash: self.interim_transcript_hash,
185            pending_reinit: self.pending_reinit,
186            confirmation_tag: self.confirmation_tag,
187        })
188    }
189}
190
191impl<C> Group<C>
192where
193    C: ClientConfig + Clone,
194{
195    /// Write the current state of the group to the
196    /// [`GroupStorageProvider`](crate::GroupStateStorage)
197    /// that is currently in use by the group.
198    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
199    pub async fn write_to_storage(&mut self) -> Result<(), MlsError> {
200        self.state_repo.write_to_storage(self.snapshot()?).await
201    }
202
203    /// Write the current state of the group to the
204    /// [`GroupStorageProvider`](crate::GroupStateStorage)
205    /// that is currently in use by the group.
206    /// The tree is not included in the state and can be stored
207    /// separately by calling [`Group::export_tree`].
208    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
209    pub async fn write_to_storage_without_ratchet_tree(&mut self) -> Result<(), MlsError> {
210        let mut snapshot = self.snapshot()?;
211        snapshot.state.public_tree.nodes = Default::default();
212
213        self.state_repo.write_to_storage(snapshot).await
214    }
215
216    pub(crate) fn snapshot(&self) -> Result<Snapshot, MlsError> {
217        Ok(Snapshot {
218            state: RawGroupState::export(&self.state),
219            private_tree: self.private_tree.clone(),
220            key_schedule: self.key_schedule.clone(),
221            #[cfg(feature = "by_ref_proposal")]
222            pending_updates: self.pending_updates.clone(),
223            pending_commit_snapshot: self.pending_commit.clone(),
224            epoch_secrets: self.epoch_secrets.clone(),
225            version: 1,
226            signer: self.signer.clone(),
227        })
228    }
229
230    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
231    pub(crate) async fn from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError> {
232        let cipher_suite_provider = cipher_suite_provider(
233            config.crypto_provider(),
234            snapshot.state.context.cipher_suite,
235        )?;
236
237        #[cfg(feature = "tree_index")]
238        let identity_provider = config.identity_provider();
239
240        let state_repo = GroupStateRepository::new(
241            #[cfg(feature = "prior_epoch")]
242            snapshot.state.context.group_id.clone(),
243            config.group_state_storage(),
244            config.key_package_repo(),
245            None,
246        )?;
247
248        Ok(Group {
249            config,
250            state: snapshot
251                .state
252                .import(
253                    #[cfg(feature = "tree_index")]
254                    &identity_provider,
255                )
256                .await?,
257            private_tree: snapshot.private_tree,
258            key_schedule: snapshot.key_schedule,
259            #[cfg(feature = "by_ref_proposal")]
260            pending_updates: snapshot.pending_updates,
261            pending_commit: snapshot.pending_commit_snapshot,
262            #[cfg(test)]
263            commit_modifiers: Default::default(),
264            epoch_secrets: snapshot.epoch_secrets,
265            state_repo,
266            cipher_suite_provider,
267            #[cfg(feature = "psk")]
268            previous_psk: None,
269            signer: snapshot.signer,
270        })
271    }
272}
273
274mod legacy {
275    use crate::{group::AuthenticatedContent, tree_kem::path_secret::PathSecret};
276
277    use super::*;
278
279    #[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
280    #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
281    pub(crate) struct LegacyPendingCommit {
282        pub content: AuthenticatedContent,
283        pub private_tree: TreeKemPrivate,
284        pub commit_secret: PathSecret,
285        pub commit_message_hash: MessageHash,
286    }
287}
288
289#[cfg(test)]
290pub(crate) mod test_utils {
291    use alloc::vec;
292
293    use crate::{
294        cipher_suite::CipherSuite,
295        crypto::test_utils::test_cipher_suite_provider,
296        group::{
297            confirmation_tag::ConfirmationTag, epoch::test_utils::get_test_epoch_secrets,
298            key_schedule::test_utils::get_test_key_schedule, test_utils::get_test_group_context,
299            transcript_hash::InterimTranscriptHash,
300        },
301        tree_kem::{node::LeafIndex, TreeKemPrivate},
302    };
303
304    use super::{RawGroupState, Snapshot};
305
306    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
307    pub(crate) async fn get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot {
308        Snapshot {
309            state: RawGroupState {
310                context: get_test_group_context(epoch_id, cipher_suite).await,
311                #[cfg(feature = "by_ref_proposal")]
312                proposals: Default::default(),
313                #[cfg(feature = "by_ref_proposal")]
314                own_proposals: Default::default(),
315                public_tree: Default::default(),
316                interim_transcript_hash: InterimTranscriptHash::from(vec![]),
317                pending_reinit: None,
318                confirmation_tag: ConfirmationTag::empty(&test_cipher_suite_provider(cipher_suite))
319                    .await,
320            },
321            private_tree: TreeKemPrivate::new(LeafIndex::unchecked(0)),
322            epoch_secrets: get_test_epoch_secrets(cipher_suite),
323            key_schedule: get_test_key_schedule(cipher_suite),
324            #[cfg(feature = "by_ref_proposal")]
325            pending_updates: Default::default(),
326            pending_commit_snapshot: Default::default(),
327            version: 1,
328            signer: vec![].into(),
329        }
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use alloc::vec;
336    use mls_rs_core::group::{GroupState, GroupStateStorage};
337
338    use crate::{
339        client::test_utils::{TestClientBuilder, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
340        group::{
341            test_utils::{test_group, TestGroup},
342            Group,
343        },
344        storage_provider::in_memory::InMemoryGroupStateStorage,
345    };
346
347    #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
348    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
349    async fn legacy_interop() {
350        let mut storage = InMemoryGroupStateStorage::new();
351
352        let legacy_snapshot = include_bytes!(concat!(
353            env!("CARGO_MANIFEST_DIR"),
354            "/test_data/legacy_snapshot.mls"
355        ));
356
357        let group_state = GroupState {
358            id: b"group".into(),
359            data: legacy_snapshot.to_vec(),
360        };
361
362        storage
363            .write(group_state, Default::default(), Default::default())
364            .await
365            .unwrap();
366
367        let client = TestClientBuilder::new_for_test()
368            .group_state_storage(storage)
369            .build();
370
371        let mut group = client.load_group(b"group").await.unwrap();
372
373        group
374            .apply_pending_commit_backwards_compatible()
375            .await
376            .unwrap();
377    }
378
379    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
380    async fn snapshot_restore(group: TestGroup) {
381        let snapshot = group.snapshot().unwrap();
382
383        let group_restored = Group::from_snapshot(group.config.clone(), snapshot)
384            .await
385            .unwrap();
386
387        assert!(Group::equal_group_state(&group, &group_restored));
388
389        #[cfg(feature = "tree_index")]
390        assert!(group_restored
391            .state
392            .public_tree
393            .equal_internals(&group.state.public_tree))
394    }
395
396    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
397    async fn snapshot_with_pending_commit_can_be_serialized_to_json() {
398        let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
399        group.commit(vec![]).await.unwrap();
400
401        snapshot_restore(group).await
402    }
403
404    #[cfg(feature = "by_ref_proposal")]
405    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
406    async fn snapshot_with_pending_updates_can_be_serialized_to_json() {
407        let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
408
409        // Creating the update proposal will add it to pending updates
410        let update_proposal = group.update_proposal().await;
411
412        // This will insert the proposal into the internal proposal cache
413        let _ = group.proposal_message(update_proposal, vec![]).await;
414
415        snapshot_restore(group).await
416    }
417
418    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
419    async fn snapshot_can_be_serialized_to_json_with_internals() {
420        let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
421
422        snapshot_restore(group).await
423    }
424
425    #[cfg(feature = "serde")]
426    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
427    async fn serde() {
428        let snapshot = super::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, 5).await;
429        let json = serde_json::to_string_pretty(&snapshot).unwrap();
430        let recovered = serde_json::from_str(&json).unwrap();
431        assert_eq!(snapshot, recovered);
432    }
433}