de_mls/mls_crypto/service/
backend.rs1use openmls::{
9 group::MlsGroup,
10 prelude::{
11 ContentType, DeserializeBytes, MlsMessageBodyIn, MlsMessageIn, ProcessedMessageContent,
12 Proposal, ProtocolMessage,
13 },
14};
15use openmls_rust_crypto::RustCrypto;
16use openmls_traits::{OpenMlsProvider, storage::StorageProvider};
17use prost::Message;
18
19use crate::{
20 ds::{APP_MSG_SUBTOPIC, OutboundPacket},
21 mls_crypto::{
22 CommitCandidate, DeMlsStorage, DecryptResult, MlsCommitInput, MlsError, MlsMessageKind,
23 MlsProposalOutput, OpenMlsService, StagedCandidateResult, service::api::MlsService,
24 },
25 protos::de_mls::messages::v1::AppMessage,
26};
27
28pub(super) struct MlsProvider<'a, T: StorageProvider<1>> {
30 crypto: &'a RustCrypto,
31 storage: &'a T,
32}
33
34impl<'a, T: StorageProvider<1>> MlsProvider<'a, T> {
35 pub(super) fn new(crypto: &'a RustCrypto, storage: &'a T) -> Self {
36 Self { crypto, storage }
37 }
38}
39
40impl<'a, T: StorageProvider<1>> OpenMlsProvider for MlsProvider<'a, T> {
41 type CryptoProvider = RustCrypto;
42 type RandProvider = RustCrypto;
43 type StorageProvider = T;
44
45 fn crypto(&self) -> &Self::CryptoProvider {
46 self.crypto
47 }
48
49 fn rand(&self) -> &Self::RandProvider {
50 self.crypto
51 }
52
53 fn storage(&self) -> &Self::StorageProvider {
54 self.storage
55 }
56}
57
58impl<S> OpenMlsService<S>
59where
60 S: DeMlsStorage,
61{
62 fn extract_proposal_action(
63 group: &MlsGroup,
64 proposal: &Proposal,
65 ) -> Result<MlsProposalOutput, MlsError> {
66 match proposal {
67 Proposal::Add(add) => {
68 let id = add
69 .key_package()
70 .leaf_node()
71 .credential()
72 .serialized_content()
73 .to_vec();
74 Ok(MlsProposalOutput::Add(id))
75 }
76 Proposal::Remove(remove) => {
77 let removed = remove.removed();
78 let id = group
79 .member(removed)
80 .map(|c| c.serialized_content().to_vec())
81 .ok_or(MlsError::UnknownLeafIndex(removed.u32()))?;
82 Ok(MlsProposalOutput::Remove(id))
83 }
84 other => Ok(MlsProposalOutput::Other(format!("{other:?}"))),
85 }
86 }
87}
88
89impl<S> MlsService for OpenMlsService<S>
90where
91 S: DeMlsStorage,
92{
93 fn conversation_id(&self) -> &str {
94 &self.conversation_id
95 }
96
97 fn delete(&mut self) -> Result<(), MlsError> {
102 self.group
103 .delete(self.storage.mls_storage())
104 .map_err(MlsError::storage)
105 }
106
107 fn members(&self) -> Result<Vec<Vec<u8>>, MlsError> {
112 Ok(self
113 .group
114 .members()
115 .map(|m| m.credential.serialized_content().to_vec())
116 .collect())
117 }
118
119 fn is_member(&self, identity: &[u8]) -> bool {
120 self.members()
121 .map(|members| members.iter().any(|m| m.as_slice() == identity))
122 .unwrap_or(false)
123 }
124
125 fn current_epoch(&self) -> Result<u64, MlsError> {
126 Ok(self.group.epoch().as_u64())
127 }
128
129 fn create_commit_candidate(
134 &mut self,
135 updates: &[MlsCommitInput],
136 ) -> Result<CommitCandidate, MlsError> {
137 let crypto = &self.crypto;
138 let mls_storage = self.storage.mls_storage();
139 let provider = MlsProvider::new(crypto, mls_storage);
140 let signer = self.credentials.signer();
141 let group = &mut self.group;
142 let mut mls_proposals = Vec::new();
143
144 for update in updates {
145 match update {
146 MlsCommitInput::Add(key_package) => {
147 let kp: openmls::key_packages::KeyPackage =
148 serde_json::from_slice(key_package.as_bytes())
149 .map_err(MlsError::KeyPackageJson)?;
150 let (mls_message_out, _proposal_ref) =
151 group.propose_add_member(&provider, signer, &kp)?;
152 mls_proposals.push(mls_message_out.to_bytes()?);
153 }
154 MlsCommitInput::Remove(identity) => {
155 let member_index = group.members().find_map(|m| {
156 if m.credential.serialized_content() == identity {
157 Some(m.index)
158 } else {
159 None
160 }
161 });
162 if let Some(index) = member_index {
163 let (mls_message_out, _proposal_ref) =
164 group.propose_remove_member(&provider, signer, index)?;
165 mls_proposals.push(mls_message_out.to_bytes()?);
166 }
167 }
168 }
169 }
170
171 let (commit_msg, welcome, _group_info) =
172 group.commit_to_pending_proposals(&provider, signer)?;
173
174 let welcome_bytes = match welcome {
175 Some(w) => Some(w.to_bytes()?),
176 None => None,
177 };
178
179 Ok(CommitCandidate {
180 proposals: mls_proposals,
181 commit: commit_msg.to_bytes()?,
182 welcome: welcome_bytes,
183 })
184 }
185
186 fn merge_own_commit(&mut self) -> Result<(), MlsError> {
187 let crypto = &self.crypto;
188 let mls_storage = self.storage.mls_storage();
189 let provider = MlsProvider::new(crypto, mls_storage);
190 self.group.merge_pending_commit(&provider)?;
191 Ok(())
192 }
193
194 fn discard_own_commit(&mut self) -> Result<(), MlsError> {
195 let crypto = &self.crypto;
196 let mls_storage = self.storage.mls_storage();
197 let provider = MlsProvider::new(crypto, mls_storage);
198 self.group
199 .clear_pending_commit(provider.storage())
200 .map_err(MlsError::storage)?;
201 self.group
202 .clear_pending_proposals(provider.storage())
203 .map_err(MlsError::storage)?;
204 Ok(())
205 }
206
207 fn stage_remote_commit(
212 &mut self,
213 proposals: &[Vec<u8>],
214 commit_bytes: &[u8],
215 ) -> Result<StagedCandidateResult, MlsError> {
216 let provider = MlsProvider::new(&self.crypto, self.storage.mls_storage());
217 let group = &mut self.group;
218 let conversation_id = &self.conversation_id;
219
220 let mut proposal_senders: Vec<Vec<u8>> = Vec::with_capacity(proposals.len());
222 for (i, proposal_bytes) in proposals.iter().enumerate() {
223 let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(proposal_bytes)?;
224 let protocol_message: ProtocolMessage = mls_message.try_into_protocol_message()?;
225 let processed = group.process_message(&provider, protocol_message)?;
226 let sender = processed.credential().serialized_content().to_vec();
227 match processed.into_content() {
228 ProcessedMessageContent::ProposalMessage(proposal) => {
229 group
230 .store_pending_proposal(provider.storage(), proposal.as_ref().clone())
231 .map_err(MlsError::storage)?;
232 proposal_senders.push(sender);
233 }
234 _ => {
235 tracing::debug!(
236 group = %conversation_id,
237 index = i,
238 "stage_remote_commit: non-proposal in proposal slot",
239 );
240 return Ok(StagedCandidateResult::Aborted);
241 }
242 }
243 }
244
245 let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(commit_bytes)?;
247 let protocol_message: ProtocolMessage = mls_message.try_into_protocol_message()?;
248
249 if protocol_message.group_id().as_slice() != group.group_id().as_slice() {
250 tracing::debug!(
251 "stage_remote_commit: ignoring commit for wrong group ID (expected {})",
252 conversation_id,
253 );
254 return Ok(StagedCandidateResult::Aborted);
255 }
256 if protocol_message.epoch() < group.epoch() {
257 tracing::debug!(
258 "stage_remote_commit: ignoring stale commit from epoch {} (current: {})",
259 protocol_message.epoch().as_u64(),
260 group.epoch().as_u64(),
261 );
262 return Ok(StagedCandidateResult::Aborted);
263 }
264
265 let processed = group.process_message(&provider, protocol_message)?;
266 let commit_sender = processed.credential().serialized_content().to_vec();
267
268 let outcome = match processed.into_content() {
269 ProcessedMessageContent::StagedCommitMessage(staged) => {
270 let self_removed = staged.self_removed();
271 let mut actions = Vec::new();
272 for add in staged.add_proposals() {
273 let id = add
274 .add_proposal()
275 .key_package()
276 .leaf_node()
277 .credential()
278 .serialized_content()
279 .to_vec();
280 actions.push(MlsProposalOutput::Add(id));
281 }
282 for remove in staged.remove_proposals() {
283 let removed_index = remove.remove_proposal().removed();
284 let id = group
285 .member(removed_index)
286 .map(|c| c.serialized_content().to_vec())
287 .ok_or(MlsError::UnknownLeafIndex(removed_index.u32()))?;
288 actions.push(MlsProposalOutput::Remove(id));
289 }
290 Some((commit_sender, self_removed, actions, *staged))
291 }
292 _ => {
293 tracing::debug!(
294 "stage_remote_commit: ignoring non-commit message for group {}",
295 conversation_id,
296 );
297 None
298 }
299 };
300
301 match outcome {
302 Some((commit_sender, self_removed, actions, staged)) => {
303 self.pending_staged_commit = Some(staged);
304 Ok(StagedCandidateResult::Staged {
305 commit_sender,
306 proposal_senders,
307 self_removed,
308 actions,
309 })
310 }
311 None => Ok(StagedCandidateResult::Aborted),
312 }
313 }
314
315 fn merge_staged_commit(&mut self) -> Result<(), MlsError> {
316 let provider = MlsProvider::new(&self.crypto, self.storage.mls_storage());
317 let staged = self
318 .pending_staged_commit
319 .take()
320 .ok_or_else(|| MlsError::NoPendingStagedCommit(self.conversation_id.clone()))?;
321 self.group.merge_staged_commit(&provider, staged)?;
322 Ok(())
323 }
324
325 fn discard_staged_commit(&mut self) -> Result<(), MlsError> {
326 self.pending_staged_commit = None;
327 self.group
328 .clear_pending_proposals(self.storage.mls_storage())
329 .map_err(MlsError::storage)?;
330 Ok(())
331 }
332
333 fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, MlsError> {
338 let provider = MlsProvider::new(&self.crypto, self.storage.mls_storage());
339 let signer = self.credentials.signer();
340 let message = self.group.create_message(&provider, signer, plaintext)?;
341 Ok(message.to_bytes()?)
342 }
343
344 fn build_message(
345 &mut self,
346 app_msg: &AppMessage,
347 app_id: &[u8],
348 ) -> Result<OutboundPacket, MlsError> {
349 let bytes = self.encrypt(&app_msg.encode_to_vec())?;
350 Ok(OutboundPacket::new(
351 bytes,
352 APP_MSG_SUBTOPIC,
353 &self.conversation_id,
354 app_id,
355 ))
356 }
357
358 fn decrypt_application_only(&mut self, ciphertext: &[u8]) -> Result<DecryptResult, MlsError> {
359 let provider = MlsProvider::new(&self.crypto, self.storage.mls_storage());
360 let group = &mut self.group;
361
362 let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(ciphertext)?;
363 let protocol_message: ProtocolMessage = mls_message.try_into_protocol_message()?;
364
365 if protocol_message.group_id().as_slice() != group.group_id().as_slice() {
366 return Ok(DecryptResult::Ignored);
367 }
368
369 if protocol_message.epoch() != group.epoch() {
372 return Ok(DecryptResult::Ignored);
373 }
374
375 match protocol_message.content_type() {
378 ContentType::Commit | ContentType::Proposal => {
379 return Ok(DecryptResult::Ignored);
380 }
381 ContentType::Application => {}
382 }
383
384 let processed = group.process_message(&provider, protocol_message)?;
385 let sender_identity = processed.credential().serialized_content().to_vec();
386
387 match processed.into_content() {
388 ProcessedMessageContent::ApplicationMessage(app) => Ok(DecryptResult::Application(
389 app.into_bytes(),
390 sender_identity,
391 )),
392 _ => Ok(DecryptResult::Ignored),
393 }
394 }
395
396 fn decrypt(&mut self, ciphertext: &[u8]) -> Result<DecryptResult, MlsError> {
397 let provider = MlsProvider::new(&self.crypto, self.storage.mls_storage());
398 let group = &mut self.group;
399 let conversation_id = &self.conversation_id;
400
401 let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(ciphertext)?;
402 let protocol_message: ProtocolMessage = mls_message.try_into_protocol_message()?;
403
404 if protocol_message.group_id().as_slice() != group.group_id().as_slice() {
405 return Ok(DecryptResult::Ignored);
406 }
407
408 if protocol_message.epoch() != group.epoch() {
411 tracing::debug!(
412 "Ignoring message from epoch {} (current: {})",
413 protocol_message.epoch().as_u64(),
414 group.epoch().as_u64()
415 );
416 return Ok(DecryptResult::Ignored);
417 }
418
419 if protocol_message.content_type() == ContentType::Commit {
420 tracing::debug!(
421 "Ignoring commit on decrypt() path for group {}: use stage_remote_commit() instead",
422 conversation_id,
423 );
424 return Ok(DecryptResult::Ignored);
425 }
426
427 let processed = group.process_message(&provider, protocol_message)?;
428 let sender_identity = processed.credential().serialized_content().to_vec();
429
430 match processed.into_content() {
431 ProcessedMessageContent::ApplicationMessage(app) => Ok(DecryptResult::Application(
432 app.into_bytes(),
433 sender_identity,
434 )),
435 ProcessedMessageContent::ProposalMessage(proposal) => {
436 let action =
437 OpenMlsService::<S>::extract_proposal_action(group, proposal.proposal())?;
438
439 group
440 .store_pending_proposal(provider.storage(), proposal.as_ref().clone())
441 .map_err(MlsError::storage)?;
442 Ok(DecryptResult::ProposalStored(sender_identity, action))
443 }
444 ProcessedMessageContent::StagedCommitMessage(_) => Ok(DecryptResult::Ignored),
445 ProcessedMessageContent::ExternalJoinProposalMessage(_) => Ok(DecryptResult::Ignored),
446 }
447 }
448
449 fn inspect_message_kind(&self, message_bytes: &[u8]) -> Result<MlsMessageKind, MlsError> {
450 let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(message_bytes)?;
451 let protocol = match mls_message.extract() {
452 MlsMessageBodyIn::Welcome(_) => return Ok(MlsMessageKind::Welcome),
453 MlsMessageBodyIn::PrivateMessage(m) => ProtocolMessage::PrivateMessage(m),
454 MlsMessageBodyIn::PublicMessage(m) => ProtocolMessage::PublicMessage(Box::new(m)),
455 _ => return Ok(MlsMessageKind::Other),
456 };
457
458 let kind = match protocol.content_type() {
459 ContentType::Application => MlsMessageKind::Application,
460 ContentType::Proposal => MlsMessageKind::Proposal,
461 ContentType::Commit => MlsMessageKind::Commit,
462 };
463 Ok(kind)
464 }
465}