Skip to main content

de_mls/mls_crypto/service/
backend.rs

1//! OpenMLS-backed implementation of [`super::MlsService`].
2//!
3//! The trait definition lives in `super::api`; this file is the reference
4//! impl for the OpenMLS engine. It works with any `DeMlsStorage` whose
5//! `MlsStorage` implements [`openmls_traits::storage::StorageProvider`] —
6//! the impl is not tied to `openmls_rust_crypto::MemoryStorage`.
7
8use openmls::{
9    group::MlsGroup,
10    prelude::{
11        ContentType, DeserializeBytes, MlsMessageBodyIn, MlsMessageIn, ProcessedMessageContent,
12        Proposal, ProtocolMessage,
13    },
14};
15use openmls_rust_crypto::RustCrypto;
16use openmls_traits::{OpenMlsProvider, storage::StorageProvider};
17use prost::Message;
18
19use crate::{
20    ds::{APP_MSG_SUBTOPIC, OutboundPacket},
21    mls_crypto::{
22        CommitCandidate, DeMlsStorage, DecryptResult, MlsCommitInput, MlsError, MlsMessageKind,
23        MlsProposalOutput, OpenMlsService, StagedCandidateResult, service::api::MlsService,
24    },
25    protos::de_mls::messages::v1::AppMessage,
26};
27
28/// Internal OpenMLS provider that wraps the configured storage backend.
29pub(super) struct MlsProvider<'a, T: StorageProvider<1>> {
30    crypto: &'a RustCrypto,
31    storage: &'a T,
32}
33
34impl<'a, T: StorageProvider<1>> MlsProvider<'a, T> {
35    pub(super) fn new(crypto: &'a RustCrypto, storage: &'a T) -> Self {
36        Self { crypto, storage }
37    }
38}
39
40impl<'a, T: StorageProvider<1>> OpenMlsProvider for MlsProvider<'a, T> {
41    type CryptoProvider = RustCrypto;
42    type RandProvider = RustCrypto;
43    type StorageProvider = T;
44
45    fn crypto(&self) -> &Self::CryptoProvider {
46        self.crypto
47    }
48
49    fn rand(&self) -> &Self::RandProvider {
50        self.crypto
51    }
52
53    fn storage(&self) -> &Self::StorageProvider {
54        self.storage
55    }
56}
57
58impl<S> OpenMlsService<S>
59where
60    S: DeMlsStorage,
61{
62    fn extract_proposal_action(
63        group: &MlsGroup,
64        proposal: &Proposal,
65    ) -> Result<MlsProposalOutput, MlsError> {
66        match proposal {
67            Proposal::Add(add) => {
68                let id = add
69                    .key_package()
70                    .leaf_node()
71                    .credential()
72                    .serialized_content()
73                    .to_vec();
74                Ok(MlsProposalOutput::Add(id))
75            }
76            Proposal::Remove(remove) => {
77                let removed = remove.removed();
78                let id = group
79                    .member(removed)
80                    .map(|c| c.serialized_content().to_vec())
81                    .ok_or(MlsError::UnknownLeafIndex(removed.u32()))?;
82                Ok(MlsProposalOutput::Remove(id))
83            }
84            other => Ok(MlsProposalOutput::Other(format!("{other:?}"))),
85        }
86    }
87}
88
89impl<S> MlsService for OpenMlsService<S>
90where
91    S: DeMlsStorage,
92{
93    fn conversation_id(&self) -> &str {
94        &self.conversation_id
95    }
96
97    // ══════════════════════════════════════════════════════════
98    // Conversation lifecycle
99    // ══════════════════════════════════════════════════════════
100
101    fn delete(&mut self) -> Result<(), MlsError> {
102        self.group
103            .delete(self.storage.mls_storage())
104            .map_err(MlsError::storage)
105    }
106
107    // ══════════════════════════════════════════════════════════
108    // Membership / state queries
109    // ══════════════════════════════════════════════════════════
110
111    fn members(&self) -> Result<Vec<Vec<u8>>, MlsError> {
112        Ok(self
113            .group
114            .members()
115            .map(|m| m.credential.serialized_content().to_vec())
116            .collect())
117    }
118
119    fn is_member(&self, identity: &[u8]) -> bool {
120        self.members()
121            .map(|members| members.iter().any(|m| m.as_slice() == identity))
122            .unwrap_or(false)
123    }
124
125    fn current_epoch(&self) -> Result<u64, MlsError> {
126        Ok(self.group.epoch().as_u64())
127    }
128
129    // ══════════════════════════════════════════════════════════
130    // Local commit pipeline (steward)
131    // ══════════════════════════════════════════════════════════
132
133    fn create_commit_candidate(
134        &mut self,
135        updates: &[MlsCommitInput],
136    ) -> Result<CommitCandidate, MlsError> {
137        let crypto = &self.crypto;
138        let mls_storage = self.storage.mls_storage();
139        let provider = MlsProvider::new(crypto, mls_storage);
140        let signer = self.credentials.signer();
141        let group = &mut self.group;
142        let mut mls_proposals = Vec::new();
143
144        for update in updates {
145            match update {
146                MlsCommitInput::Add(key_package) => {
147                    let kp: openmls::key_packages::KeyPackage =
148                        serde_json::from_slice(key_package.as_bytes())
149                            .map_err(MlsError::KeyPackageJson)?;
150                    let (mls_message_out, _proposal_ref) =
151                        group.propose_add_member(&provider, signer, &kp)?;
152                    mls_proposals.push(mls_message_out.to_bytes()?);
153                }
154                MlsCommitInput::Remove(identity) => {
155                    let member_index = group.members().find_map(|m| {
156                        if m.credential.serialized_content() == identity {
157                            Some(m.index)
158                        } else {
159                            None
160                        }
161                    });
162                    if let Some(index) = member_index {
163                        let (mls_message_out, _proposal_ref) =
164                            group.propose_remove_member(&provider, signer, index)?;
165                        mls_proposals.push(mls_message_out.to_bytes()?);
166                    }
167                }
168            }
169        }
170
171        let (commit_msg, welcome, _group_info) =
172            group.commit_to_pending_proposals(&provider, signer)?;
173
174        let welcome_bytes = match welcome {
175            Some(w) => Some(w.to_bytes()?),
176            None => None,
177        };
178
179        Ok(CommitCandidate {
180            proposals: mls_proposals,
181            commit: commit_msg.to_bytes()?,
182            welcome: welcome_bytes,
183        })
184    }
185
186    fn merge_own_commit(&mut self) -> Result<(), MlsError> {
187        let crypto = &self.crypto;
188        let mls_storage = self.storage.mls_storage();
189        let provider = MlsProvider::new(crypto, mls_storage);
190        self.group.merge_pending_commit(&provider)?;
191        Ok(())
192    }
193
194    fn discard_own_commit(&mut self) -> Result<(), MlsError> {
195        let crypto = &self.crypto;
196        let mls_storage = self.storage.mls_storage();
197        let provider = MlsProvider::new(crypto, mls_storage);
198        self.group
199            .clear_pending_commit(provider.storage())
200            .map_err(MlsError::storage)?;
201        self.group
202            .clear_pending_proposals(provider.storage())
203            .map_err(MlsError::storage)?;
204        Ok(())
205    }
206
207    // ══════════════════════════════════════════════════════════
208    // Inbound candidate pipeline (stage → merge/discard)
209    // ══════════════════════════════════════════════════════════
210
211    fn stage_remote_commit(
212        &mut self,
213        proposals: &[Vec<u8>],
214        commit_bytes: &[u8],
215    ) -> Result<StagedCandidateResult, MlsError> {
216        let provider = MlsProvider::new(&self.crypto, self.storage.mls_storage());
217        let group = &mut self.group;
218        let conversation_id = &self.conversation_id;
219
220        // ── Stage every proposal, collecting senders ──
221        let mut proposal_senders: Vec<Vec<u8>> = Vec::with_capacity(proposals.len());
222        for (i, proposal_bytes) in proposals.iter().enumerate() {
223            let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(proposal_bytes)?;
224            let protocol_message: ProtocolMessage = mls_message.try_into_protocol_message()?;
225            let processed = group.process_message(&provider, protocol_message)?;
226            let sender = processed.credential().serialized_content().to_vec();
227            match processed.into_content() {
228                ProcessedMessageContent::ProposalMessage(proposal) => {
229                    group
230                        .store_pending_proposal(provider.storage(), proposal.as_ref().clone())
231                        .map_err(MlsError::storage)?;
232                    proposal_senders.push(sender);
233                }
234                _ => {
235                    tracing::debug!(
236                        group = %conversation_id,
237                        index = i,
238                        "stage_remote_commit: non-proposal in proposal slot",
239                    );
240                    return Ok(StagedCandidateResult::Aborted);
241                }
242            }
243        }
244
245        // ── Stage the commit ──
246        let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(commit_bytes)?;
247        let protocol_message: ProtocolMessage = mls_message.try_into_protocol_message()?;
248
249        if protocol_message.group_id().as_slice() != group.group_id().as_slice() {
250            tracing::debug!(
251                "stage_remote_commit: ignoring commit for wrong group ID (expected {})",
252                conversation_id,
253            );
254            return Ok(StagedCandidateResult::Aborted);
255        }
256        if protocol_message.epoch() < group.epoch() {
257            tracing::debug!(
258                "stage_remote_commit: ignoring stale commit from epoch {} (current: {})",
259                protocol_message.epoch().as_u64(),
260                group.epoch().as_u64(),
261            );
262            return Ok(StagedCandidateResult::Aborted);
263        }
264
265        let processed = group.process_message(&provider, protocol_message)?;
266        let commit_sender = processed.credential().serialized_content().to_vec();
267
268        let outcome = match processed.into_content() {
269            ProcessedMessageContent::StagedCommitMessage(staged) => {
270                let self_removed = staged.self_removed();
271                let mut actions = Vec::new();
272                for add in staged.add_proposals() {
273                    let id = add
274                        .add_proposal()
275                        .key_package()
276                        .leaf_node()
277                        .credential()
278                        .serialized_content()
279                        .to_vec();
280                    actions.push(MlsProposalOutput::Add(id));
281                }
282                for remove in staged.remove_proposals() {
283                    let removed_index = remove.remove_proposal().removed();
284                    let id = group
285                        .member(removed_index)
286                        .map(|c| c.serialized_content().to_vec())
287                        .ok_or(MlsError::UnknownLeafIndex(removed_index.u32()))?;
288                    actions.push(MlsProposalOutput::Remove(id));
289                }
290                Some((commit_sender, self_removed, actions, *staged))
291            }
292            _ => {
293                tracing::debug!(
294                    "stage_remote_commit: ignoring non-commit message for group {}",
295                    conversation_id,
296                );
297                None
298            }
299        };
300
301        match outcome {
302            Some((commit_sender, self_removed, actions, staged)) => {
303                self.pending_staged_commit = Some(staged);
304                Ok(StagedCandidateResult::Staged {
305                    commit_sender,
306                    proposal_senders,
307                    self_removed,
308                    actions,
309                })
310            }
311            None => Ok(StagedCandidateResult::Aborted),
312        }
313    }
314
315    fn merge_staged_commit(&mut self) -> Result<(), MlsError> {
316        let provider = MlsProvider::new(&self.crypto, self.storage.mls_storage());
317        let staged = self
318            .pending_staged_commit
319            .take()
320            .ok_or_else(|| MlsError::NoPendingStagedCommit(self.conversation_id.clone()))?;
321        self.group.merge_staged_commit(&provider, staged)?;
322        Ok(())
323    }
324
325    fn discard_staged_commit(&mut self) -> Result<(), MlsError> {
326        self.pending_staged_commit = None;
327        self.group
328            .clear_pending_proposals(self.storage.mls_storage())
329            .map_err(MlsError::storage)?;
330        Ok(())
331    }
332
333    // ══════════════════════════════════════════════════════════
334    // Application messages
335    // ══════════════════════════════════════════════════════════
336
337    fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, MlsError> {
338        let provider = MlsProvider::new(&self.crypto, self.storage.mls_storage());
339        let signer = self.credentials.signer();
340        let message = self.group.create_message(&provider, signer, plaintext)?;
341        Ok(message.to_bytes()?)
342    }
343
344    fn build_message(
345        &mut self,
346        app_msg: &AppMessage,
347        app_id: &[u8],
348    ) -> Result<OutboundPacket, MlsError> {
349        let bytes = self.encrypt(&app_msg.encode_to_vec())?;
350        Ok(OutboundPacket::new(
351            bytes,
352            APP_MSG_SUBTOPIC,
353            &self.conversation_id,
354            app_id,
355        ))
356    }
357
358    fn decrypt_application_only(&mut self, ciphertext: &[u8]) -> Result<DecryptResult, MlsError> {
359        let provider = MlsProvider::new(&self.crypto, self.storage.mls_storage());
360        let group = &mut self.group;
361
362        let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(ciphertext)?;
363        let protocol_message: ProtocolMessage = mls_message.try_into_protocol_message()?;
364
365        if protocol_message.group_id().as_slice() != group.group_id().as_slice() {
366            return Ok(DecryptResult::Ignored);
367        }
368
369        // OpenMLS rejects both old and future epochs; ignore both to avoid
370        // hard errors (a joiner sends at epoch N+1 before we've merged).
371        if protocol_message.epoch() != group.epoch() {
372            return Ok(DecryptResult::Ignored);
373        }
374
375        // Reject commits/proposals before process_message to avoid MLS errors
376        // (e.g. MissingProposal when commit's proposals aren't stored).
377        match protocol_message.content_type() {
378            ContentType::Commit | ContentType::Proposal => {
379                return Ok(DecryptResult::Ignored);
380            }
381            ContentType::Application => {}
382        }
383
384        let processed = group.process_message(&provider, protocol_message)?;
385        let sender_identity = processed.credential().serialized_content().to_vec();
386
387        match processed.into_content() {
388            ProcessedMessageContent::ApplicationMessage(app) => Ok(DecryptResult::Application(
389                app.into_bytes(),
390                sender_identity,
391            )),
392            _ => Ok(DecryptResult::Ignored),
393        }
394    }
395
396    fn decrypt(&mut self, ciphertext: &[u8]) -> Result<DecryptResult, MlsError> {
397        let provider = MlsProvider::new(&self.crypto, self.storage.mls_storage());
398        let group = &mut self.group;
399        let conversation_id = &self.conversation_id;
400
401        let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(ciphertext)?;
402        let protocol_message: ProtocolMessage = mls_message.try_into_protocol_message()?;
403
404        if protocol_message.group_id().as_slice() != group.group_id().as_slice() {
405            return Ok(DecryptResult::Ignored);
406        }
407
408        // Old epochs can't be processed; future epochs arrive when a joiner
409        // sends at epoch N+1 before we've merged our pending commit.
410        if protocol_message.epoch() != group.epoch() {
411            tracing::debug!(
412                "Ignoring message from epoch {} (current: {})",
413                protocol_message.epoch().as_u64(),
414                group.epoch().as_u64()
415            );
416            return Ok(DecryptResult::Ignored);
417        }
418
419        if protocol_message.content_type() == ContentType::Commit {
420            tracing::debug!(
421                "Ignoring commit on decrypt() path for group {}: use stage_remote_commit() instead",
422                conversation_id,
423            );
424            return Ok(DecryptResult::Ignored);
425        }
426
427        let processed = group.process_message(&provider, protocol_message)?;
428        let sender_identity = processed.credential().serialized_content().to_vec();
429
430        match processed.into_content() {
431            ProcessedMessageContent::ApplicationMessage(app) => Ok(DecryptResult::Application(
432                app.into_bytes(),
433                sender_identity,
434            )),
435            ProcessedMessageContent::ProposalMessage(proposal) => {
436                let action =
437                    OpenMlsService::<S>::extract_proposal_action(group, proposal.proposal())?;
438
439                group
440                    .store_pending_proposal(provider.storage(), proposal.as_ref().clone())
441                    .map_err(MlsError::storage)?;
442                Ok(DecryptResult::ProposalStored(sender_identity, action))
443            }
444            ProcessedMessageContent::StagedCommitMessage(_) => Ok(DecryptResult::Ignored),
445            ProcessedMessageContent::ExternalJoinProposalMessage(_) => Ok(DecryptResult::Ignored),
446        }
447    }
448
449    fn inspect_message_kind(&self, message_bytes: &[u8]) -> Result<MlsMessageKind, MlsError> {
450        let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(message_bytes)?;
451        let protocol = match mls_message.extract() {
452            MlsMessageBodyIn::Welcome(_) => return Ok(MlsMessageKind::Welcome),
453            MlsMessageBodyIn::PrivateMessage(m) => ProtocolMessage::PrivateMessage(m),
454            MlsMessageBodyIn::PublicMessage(m) => ProtocolMessage::PublicMessage(Box::new(m)),
455            _ => return Ok(MlsMessageKind::Other),
456        };
457
458        let kind = match protocol.content_type() {
459            ContentType::Application => MlsMessageKind::Application,
460            ContentType::Proposal => MlsMessageKind::Proposal,
461            ContentType::Commit => MlsMessageKind::Commit,
462        };
463        Ok(kind)
464    }
465}