1use core::ops::Deref;
6
7use super::*;
8use crate::hash_reference::HashReference;
9
10#[cfg_attr(
11 all(feature = "ffi", not(test)),
12 safer_ffi_gen::ffi_type(clone, opaque)
13)]
14#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
15#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17pub struct ProposalRef(HashReference);
19
20impl Deref for ProposalRef {
21 type Target = [u8];
22
23 fn deref(&self) -> &Self::Target {
24 &self.0
25 }
26}
27
28#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
29impl ProposalRef {
30 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
31 pub(crate) async fn from_content<CS: CipherSuiteProvider>(
32 cipher_suite_provider: &CS,
33 content: &AuthenticatedContent,
34 ) -> Result<Self, MlsError> {
35 let bytes = &content.mls_encode_to_vec()?;
36
37 Ok(ProposalRef(
38 HashReference::compute(bytes, b"MLS 1.0 Proposal Reference", cipher_suite_provider)
39 .await?,
40 ))
41 }
42
43 pub fn as_slice(&self) -> &[u8] {
44 &self.0
45 }
46}
47
48#[cfg(test)]
49pub(crate) mod test_utils {
50 use super::*;
51 use crate::group::test_utils::{random_bytes, TEST_GROUP};
52 use alloc::boxed::Box;
53
54 impl ProposalRef {
55 pub fn new_fake(bytes: Vec<u8>) -> Self {
56 Self(bytes.into())
57 }
58 }
59
60 pub fn auth_content_from_proposal<S>(proposal: Proposal, sender: S) -> AuthenticatedContent
61 where
62 S: Into<Sender>,
63 {
64 AuthenticatedContent {
65 wire_format: WireFormat::PublicMessage,
66 content: FramedContent {
67 group_id: TEST_GROUP.to_vec(),
68 epoch: 0,
69 sender: sender.into(),
70 authenticated_data: vec![],
71 content: Content::Proposal(Box::new(proposal)),
72 },
73 auth: FramedContentAuthData {
74 signature: MessageSignature::from(random_bytes(128)),
75 confirmation_tag: None,
76 },
77 }
78 }
79}
80
81#[cfg(test)]
82mod test {
83 use super::test_utils::auth_content_from_proposal;
84 use super::*;
85 use crate::{
86 crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
87 key_package::test_utils::test_key_package,
88 tree_kem::leaf_node::test_utils::get_basic_test_node,
89 };
90 use alloc::boxed::Box;
91
92 use crate::extension::RequiredCapabilitiesExt;
93
94 #[cfg_attr(coverage_nightly, coverage(off))]
95 fn get_test_extension_list() -> ExtensionList {
96 let test_extension = RequiredCapabilitiesExt {
97 extensions: vec![42.into()],
98 proposals: Default::default(),
99 credentials: vec![],
100 };
101
102 let mut extension_list = ExtensionList::new();
103 extension_list.set_from(test_extension).unwrap();
104
105 extension_list
106 }
107
108 #[derive(serde::Serialize, serde::Deserialize)]
109 struct TestCase {
110 cipher_suite: u16,
111 #[serde(with = "hex::serde")]
112 input: Vec<u8>,
113 #[serde(with = "hex::serde")]
114 output: Vec<u8>,
115 }
116
117 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
118 #[cfg_attr(coverage_nightly, coverage(off))]
119 async fn generate_proposal_test_cases() -> Vec<TestCase> {
120 let mut test_cases = Vec::new();
121
122 for (protocol_version, cipher_suite) in
123 ProtocolVersion::all().flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
124 {
125 let sender = LeafIndex::unchecked(0);
126
127 let add = auth_content_from_proposal(
128 Proposal::Add(Box::new(AddProposal {
129 key_package: test_key_package(protocol_version, cipher_suite, "alice").await,
130 })),
131 sender,
132 );
133
134 let update = auth_content_from_proposal(
135 Proposal::Update(UpdateProposal {
136 leaf_node: get_basic_test_node(cipher_suite, "foo").await,
137 }),
138 sender,
139 );
140
141 let remove = auth_content_from_proposal(
142 Proposal::Remove(RemoveProposal {
143 to_remove: LeafIndex::unchecked(1),
144 }),
145 sender,
146 );
147
148 let group_context_ext = auth_content_from_proposal(
149 Proposal::GroupContextExtensions(get_test_extension_list()),
150 sender,
151 );
152
153 let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
154
155 test_cases.push(TestCase {
156 cipher_suite: cipher_suite.into(),
157 input: add.mls_encode_to_vec().unwrap(),
158 output: ProposalRef::from_content(&cipher_suite_provider, &add)
159 .await
160 .unwrap()
161 .to_vec(),
162 });
163
164 test_cases.push(TestCase {
165 cipher_suite: cipher_suite.into(),
166 input: update.mls_encode_to_vec().unwrap(),
167 output: ProposalRef::from_content(&cipher_suite_provider, &update)
168 .await
169 .unwrap()
170 .to_vec(),
171 });
172
173 test_cases.push(TestCase {
174 cipher_suite: cipher_suite.into(),
175 input: remove.mls_encode_to_vec().unwrap(),
176 output: ProposalRef::from_content(&cipher_suite_provider, &remove)
177 .await
178 .unwrap()
179 .to_vec(),
180 });
181
182 test_cases.push(TestCase {
183 cipher_suite: cipher_suite.into(),
184 input: group_context_ext.mls_encode_to_vec().unwrap(),
185 output: ProposalRef::from_content(&cipher_suite_provider, &group_context_ext)
186 .await
187 .unwrap()
188 .to_vec(),
189 });
190 }
191
192 test_cases
193 }
194
195 #[cfg(mls_build_async)]
196 async fn load_test_cases() -> Vec<TestCase> {
197 load_test_case_json!(proposal_ref, generate_proposal_test_cases().await)
198 }
199
200 #[cfg(not(mls_build_async))]
201 fn load_test_cases() -> Vec<TestCase> {
202 load_test_case_json!(proposal_ref, generate_proposal_test_cases())
203 }
204
205 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
206 async fn test_proposal_ref() {
207 let test_cases = load_test_cases().await;
208
209 for one_case in test_cases {
210 let Some(cs_provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
211 continue;
212 };
213
214 let proposal_content =
215 AuthenticatedContent::mls_decode(&mut one_case.input.as_slice()).unwrap();
216
217 let proposal_ref = ProposalRef::from_content(&cs_provider, &proposal_content)
218 .await
219 .unwrap();
220
221 let expected_out = ProposalRef(HashReference::from(one_case.output));
222
223 assert_eq!(expected_out, proposal_ref);
224 }
225 }
226}