mls_rs/group/
proposal_ref.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use core::ops::Deref;
6
7use super::*;
8use crate::hash_reference::HashReference;
9
10#[cfg_attr(
11    all(feature = "ffi", not(test)),
12    safer_ffi_gen::ffi_type(clone, opaque)
13)]
14#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
15#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17/// Unique identifier for a proposal message.
18pub struct ProposalRef(HashReference);
19
20impl Deref for ProposalRef {
21    type Target = [u8];
22
23    fn deref(&self) -> &Self::Target {
24        &self.0
25    }
26}
27
28#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
29impl ProposalRef {
30    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
31    pub(crate) async fn from_content<CS: CipherSuiteProvider>(
32        cipher_suite_provider: &CS,
33        content: &AuthenticatedContent,
34    ) -> Result<Self, MlsError> {
35        let bytes = &content.mls_encode_to_vec()?;
36
37        Ok(ProposalRef(
38            HashReference::compute(bytes, b"MLS 1.0 Proposal Reference", cipher_suite_provider)
39                .await?,
40        ))
41    }
42
43    pub fn as_slice(&self) -> &[u8] {
44        &self.0
45    }
46}
47
48#[cfg(test)]
49pub(crate) mod test_utils {
50    use super::*;
51    use crate::group::test_utils::{random_bytes, TEST_GROUP};
52    use alloc::boxed::Box;
53
54    impl ProposalRef {
55        pub fn new_fake(bytes: Vec<u8>) -> Self {
56            Self(bytes.into())
57        }
58    }
59
60    pub fn auth_content_from_proposal<S>(proposal: Proposal, sender: S) -> AuthenticatedContent
61    where
62        S: Into<Sender>,
63    {
64        AuthenticatedContent {
65            wire_format: WireFormat::PublicMessage,
66            content: FramedContent {
67                group_id: TEST_GROUP.to_vec(),
68                epoch: 0,
69                sender: sender.into(),
70                authenticated_data: vec![],
71                content: Content::Proposal(Box::new(proposal)),
72            },
73            auth: FramedContentAuthData {
74                signature: MessageSignature::from(random_bytes(128)),
75                confirmation_tag: None,
76            },
77        }
78    }
79}
80
81#[cfg(test)]
82mod test {
83    use super::test_utils::auth_content_from_proposal;
84    use super::*;
85    use crate::{
86        crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
87        key_package::test_utils::test_key_package,
88        tree_kem::leaf_node::test_utils::get_basic_test_node,
89    };
90    use alloc::boxed::Box;
91
92    use crate::extension::RequiredCapabilitiesExt;
93
94    #[cfg_attr(coverage_nightly, coverage(off))]
95    fn get_test_extension_list() -> ExtensionList {
96        let test_extension = RequiredCapabilitiesExt {
97            extensions: vec![42.into()],
98            proposals: Default::default(),
99            credentials: vec![],
100        };
101
102        let mut extension_list = ExtensionList::new();
103        extension_list.set_from(test_extension).unwrap();
104
105        extension_list
106    }
107
108    #[derive(serde::Serialize, serde::Deserialize)]
109    struct TestCase {
110        cipher_suite: u16,
111        #[serde(with = "hex::serde")]
112        input: Vec<u8>,
113        #[serde(with = "hex::serde")]
114        output: Vec<u8>,
115    }
116
117    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
118    #[cfg_attr(coverage_nightly, coverage(off))]
119    async fn generate_proposal_test_cases() -> Vec<TestCase> {
120        let mut test_cases = Vec::new();
121
122        for (protocol_version, cipher_suite) in
123            ProtocolVersion::all().flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
124        {
125            let sender = LeafIndex::unchecked(0);
126
127            let add = auth_content_from_proposal(
128                Proposal::Add(Box::new(AddProposal {
129                    key_package: test_key_package(protocol_version, cipher_suite, "alice").await,
130                })),
131                sender,
132            );
133
134            let update = auth_content_from_proposal(
135                Proposal::Update(UpdateProposal {
136                    leaf_node: get_basic_test_node(cipher_suite, "foo").await,
137                }),
138                sender,
139            );
140
141            let remove = auth_content_from_proposal(
142                Proposal::Remove(RemoveProposal {
143                    to_remove: LeafIndex::unchecked(1),
144                }),
145                sender,
146            );
147
148            let group_context_ext = auth_content_from_proposal(
149                Proposal::GroupContextExtensions(get_test_extension_list()),
150                sender,
151            );
152
153            let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
154
155            test_cases.push(TestCase {
156                cipher_suite: cipher_suite.into(),
157                input: add.mls_encode_to_vec().unwrap(),
158                output: ProposalRef::from_content(&cipher_suite_provider, &add)
159                    .await
160                    .unwrap()
161                    .to_vec(),
162            });
163
164            test_cases.push(TestCase {
165                cipher_suite: cipher_suite.into(),
166                input: update.mls_encode_to_vec().unwrap(),
167                output: ProposalRef::from_content(&cipher_suite_provider, &update)
168                    .await
169                    .unwrap()
170                    .to_vec(),
171            });
172
173            test_cases.push(TestCase {
174                cipher_suite: cipher_suite.into(),
175                input: remove.mls_encode_to_vec().unwrap(),
176                output: ProposalRef::from_content(&cipher_suite_provider, &remove)
177                    .await
178                    .unwrap()
179                    .to_vec(),
180            });
181
182            test_cases.push(TestCase {
183                cipher_suite: cipher_suite.into(),
184                input: group_context_ext.mls_encode_to_vec().unwrap(),
185                output: ProposalRef::from_content(&cipher_suite_provider, &group_context_ext)
186                    .await
187                    .unwrap()
188                    .to_vec(),
189            });
190        }
191
192        test_cases
193    }
194
195    #[cfg(mls_build_async)]
196    async fn load_test_cases() -> Vec<TestCase> {
197        load_test_case_json!(proposal_ref, generate_proposal_test_cases().await)
198    }
199
200    #[cfg(not(mls_build_async))]
201    fn load_test_cases() -> Vec<TestCase> {
202        load_test_case_json!(proposal_ref, generate_proposal_test_cases())
203    }
204
205    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
206    async fn test_proposal_ref() {
207        let test_cases = load_test_cases().await;
208
209        for one_case in test_cases {
210            let Some(cs_provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
211                continue;
212            };
213
214            let proposal_content =
215                AuthenticatedContent::mls_decode(&mut one_case.input.as_slice()).unwrap();
216
217            let proposal_ref = ProposalRef::from_content(&cs_provider, &proposal_content)
218                .await
219                .unwrap();
220
221            let expected_out = ProposalRef(HashReference::from(one_case.output));
222
223            assert_eq!(expected_out, proposal_ref);
224        }
225    }
226}