1use alloc::vec::Vec;
6
7use mls_rs_core::{
8 crypto::{CipherSuite, SignatureSecretKey},
9 extension::ExtensionList,
10 identity::SigningIdentity,
11 protocol_version::ProtocolVersion,
12};
13
14use crate::time::MlsTime;
15use crate::{client::MlsError, Client, Group, MlsMessage};
16
17use super::{
18 proposal::ReInitProposal, ClientConfig, ExportedTree, JustPreSharedKeyID, MessageProcessor,
19 NewMemberInfo, PreSharedKeyID, PskGroupId, PskSecretInput, ResumptionPSKUsage, ResumptionPsk,
20};
21
22struct ResumptionGroupParameters<'a> {
23 group_id: &'a [u8],
24 cipher_suite: CipherSuite,
25 version: ProtocolVersion,
26 extensions: &'a ExtensionList,
27}
28
29pub struct ReinitClient<C: ClientConfig + Clone> {
30 client: Client<C>,
31 reinit: ReInitProposal,
32 psk_input: PskSecretInput,
33}
34
35impl<C> Group<C>
36where
37 C: ClientConfig + Clone,
38{
39 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
49 pub async fn branch(
50 &self,
51 sub_group_id: Vec<u8>,
52 new_key_packages: Vec<MlsMessage>,
53 timestamp: Option<MlsTime>,
54 ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
55 let new_group_params = ResumptionGroupParameters {
56 group_id: &sub_group_id,
57 cipher_suite: self.cipher_suite(),
58 version: self.protocol_version(),
59 extensions: &self.group_state().context.extensions,
60 };
61
62 let current_leaf_node_extensions = &self.current_user_leaf_node()?.ungreased_extensions();
63 resumption_create_group(
64 self.config.clone(),
65 new_key_packages,
66 &new_group_params,
67 self.current_member_signing_identity()?.clone(),
69 self.signer.clone(),
70 current_leaf_node_extensions,
71 #[cfg(any(feature = "private_message", feature = "psk"))]
72 self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
73 timestamp,
74 )
75 .await
76 }
77
78 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
80 pub async fn join_subgroup(
81 &self,
82 welcome: &MlsMessage,
83 tree_data: Option<ExportedTree<'_>>,
84 maybe_time: Option<MlsTime>,
85 ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
86 let expected_new_group_prams = ResumptionGroupParameters {
87 group_id: &[],
88 cipher_suite: self.cipher_suite(),
89 version: self.protocol_version(),
90 extensions: &self.group_state().context.extensions,
91 };
92
93 resumption_join_group(
94 self.config.clone(),
95 self.signer.clone(),
96 welcome,
97 tree_data,
98 expected_new_group_prams,
99 false,
100 self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
101 maybe_time,
102 )
103 .await
104 }
105
106 pub fn get_reinit_client(
118 self,
119 new_signer: Option<SignatureSecretKey>,
120 new_signing_identity: Option<SigningIdentity>,
121 ) -> Result<ReinitClient<C>, MlsError> {
122 let psk_input = self.resumption_psk_input(ResumptionPSKUsage::Reinit)?;
123
124 let new_signing_identity = new_signing_identity
125 .map(Ok)
126 .unwrap_or_else(|| self.current_member_signing_identity().cloned())?;
127
128 let reinit = self
129 .state
130 .pending_reinit
131 .ok_or(MlsError::PendingReInitNotFound)?;
132
133 let new_signer = match new_signer {
134 Some(signer) => signer,
135 None => self.signer,
136 };
137
138 let client = Client::new(
139 self.config,
140 Some(new_signer),
141 Some((new_signing_identity, reinit.new_cipher_suite())),
142 reinit.new_version(),
143 );
144
145 Ok(ReinitClient {
146 client,
147 reinit,
148 psk_input,
149 })
150 }
151
152 fn resumption_psk_input(&self, usage: ResumptionPSKUsage) -> Result<PskSecretInput, MlsError> {
153 let psk = self.epoch_secrets.resumption_secret.clone();
154
155 let id = JustPreSharedKeyID::Resumption(ResumptionPsk {
156 usage,
157 psk_group_id: PskGroupId(self.group_id().to_vec()),
158 psk_epoch: self.current_epoch(),
159 });
160
161 let id = PreSharedKeyID::new(id, self.cipher_suite_provider())?;
162 Ok(PskSecretInput { id, psk })
163 }
164}
165
166impl<C: ClientConfig + Clone> ReinitClient<C> {
170 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
173 pub async fn generate_key_package(
174 &self,
175 timestamp: Option<MlsTime>,
176 ) -> Result<MlsMessage, MlsError> {
177 self.client
178 .generate_key_package_message(Default::default(), Default::default(), timestamp)
179 .await
180 }
181
182 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
190 pub async fn commit(
191 self,
192 new_key_packages: Vec<MlsMessage>,
193 new_leaf_node_extensions: ExtensionList,
194 timestamp: Option<MlsTime>,
195 ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
196 let new_group_params = ResumptionGroupParameters {
197 group_id: self.reinit.group_id(),
198 cipher_suite: self.reinit.new_cipher_suite(),
199 version: self.reinit.new_version(),
200 extensions: self.reinit.new_group_context_extensions(),
201 };
202
203 resumption_create_group(
204 self.client.config.clone(),
205 new_key_packages,
206 &new_group_params,
207 self.client.signing_identity.unwrap().0,
209 self.client.signer.unwrap(),
210 &new_leaf_node_extensions,
211 #[cfg(any(feature = "private_message", feature = "psk"))]
212 self.psk_input,
213 timestamp,
214 )
215 .await
216 }
217
218 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
220 pub async fn join(
221 self,
222 welcome: &MlsMessage,
223 tree_data: Option<ExportedTree<'_>>,
224 maybe_time: Option<MlsTime>,
225 ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
226 let reinit = self.reinit;
227
228 let expected_group_params = ResumptionGroupParameters {
229 group_id: reinit.group_id(),
230 cipher_suite: reinit.new_cipher_suite(),
231 version: reinit.new_version(),
232 extensions: reinit.new_group_context_extensions(),
233 };
234
235 resumption_join_group(
236 self.client.config,
237 self.client.signer.unwrap(),
239 welcome,
240 tree_data,
241 expected_group_params,
242 true,
243 self.psk_input,
244 maybe_time,
245 )
246 .await
247 }
248}
249
250#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
251#[allow(clippy::too_many_arguments)]
252async fn resumption_create_group<C: ClientConfig + Clone>(
253 config: C,
254 new_key_packages: Vec<MlsMessage>,
255 new_group_params: &ResumptionGroupParameters<'_>,
256 signing_identity: SigningIdentity,
257 signer: SignatureSecretKey,
258 leaf_node_extensions: &ExtensionList,
259 psk_input: PskSecretInput,
260 timestamp: Option<MlsTime>,
261) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
262 let mut group = Group::new(
264 config,
265 Some(new_group_params.group_id.to_vec()),
266 new_group_params.cipher_suite,
267 new_group_params.version,
268 signing_identity,
269 new_group_params.extensions.clone(),
270 leaf_node_extensions.clone(),
271 signer,
272 timestamp,
273 )
274 .await?;
275
276 group.previous_psk = Some(psk_input);
278
279 let mut commit = group.commit_builder();
281
282 for kp in new_key_packages.into_iter() {
283 commit = commit.add_member(kp)?;
284 }
285
286 let commit = commit.build().await?;
287 group.apply_pending_commit().await?;
288
289 group.previous_psk = None;
291
292 Ok((group, commit.welcome_messages))
293}
294
295#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
296#[allow(clippy::too_many_arguments)]
297async fn resumption_join_group<C: ClientConfig + Clone>(
298 config: C,
299 signer: SignatureSecretKey,
300 welcome: &MlsMessage,
301 tree_data: Option<ExportedTree<'_>>,
302 expected_new_group_params: ResumptionGroupParameters<'_>,
303 verify_group_id: bool,
304 psk_input: PskSecretInput,
305 maybe_time: Option<MlsTime>,
306) -> Result<(Group<C>, NewMemberInfo), MlsError> {
307 let psk_input = Some(psk_input);
308
309 let (group, new_member_info) =
310 Group::<C>::from_welcome_message(welcome, tree_data, config, signer, psk_input, maybe_time)
311 .await?;
312
313 if group.protocol_version() != expected_new_group_params.version {
314 Err(MlsError::ProtocolVersionMismatch)
315 } else if group.cipher_suite() != expected_new_group_params.cipher_suite {
316 Err(MlsError::CipherSuiteMismatch)
317 } else if verify_group_id && group.group_id() != expected_new_group_params.group_id {
318 Err(MlsError::GroupIdMismatch)
319 } else if &group.group_state().context.extensions != expected_new_group_params.extensions {
320 Err(MlsError::ReInitExtensionsMismatch)
321 } else {
322 Ok((group, new_member_info))
323 }
324}