Skip to main content

de_mls/mls_crypto/
service.rs

1//! Main MLS service providing all cryptographic operations.
2
3use std::collections::HashMap;
4use std::sync::RwLock;
5
6use alloy::primitives::Address;
7use openmls::credentials::CredentialWithKey;
8use openmls::group::{GroupId, MlsGroup, MlsGroupCreateConfig, MlsGroupJoinConfig};
9use openmls::prelude::{
10    BasicCredential, Ciphersuite, DeserializeBytes, MlsMessageBodyIn, MlsMessageIn,
11    ProcessedMessageContent, ProtocolMessage, StagedWelcome,
12};
13use openmls_basic_credential::SignatureKeyPair;
14use openmls_rust_crypto::{MemoryStorage, RustCrypto};
15use openmls_traits::OpenMlsProvider;
16
17use crate::mls_crypto::{
18    error::{IdentityError, MlsError, MlsServiceError, Result, StorageError},
19    identity::IdentityData,
20    storage::DeMlsStorage,
21    types::{CommitResult, DecryptResult, GroupUpdate, KeyPackageBytes},
22};
23
24/// The MLS ciphersuite used for all operations.
25pub const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
26
27/// Internal OpenMLS provider that wraps storage.
28struct MlsProvider<'a> {
29    crypto: &'a RustCrypto,
30    storage: &'a MemoryStorage,
31}
32
33impl<'a> OpenMlsProvider for MlsProvider<'a> {
34    type CryptoProvider = RustCrypto;
35    type RandProvider = RustCrypto;
36    type StorageProvider = MemoryStorage;
37
38    fn crypto(&self) -> &Self::CryptoProvider {
39        self.crypto
40    }
41
42    fn rand(&self) -> &Self::RandProvider {
43        self.crypto
44    }
45
46    fn storage(&self) -> &Self::StorageProvider {
47        self.storage
48    }
49}
50
51/// Main MLS service - unified API for all MLS operations.
52///
53/// Groups are managed internally by group ID string. The service handles:
54/// - Identity initialization and management
55/// - Key package generation
56/// - Group creation and joining
57/// - Message encryption and decryption
58/// - Steward commit operations
59pub struct MlsService<S: DeMlsStorage> {
60    storage: S,
61    crypto: RustCrypto,
62    identity: RwLock<Option<IdentityData>>,
63    groups: RwLock<HashMap<String, MlsGroup>>,
64}
65
66impl<S> MlsService<S>
67where
68    S: DeMlsStorage<MlsStorage = MemoryStorage>,
69{
70    /// Create a new MLS service with the given storage backend.
71    pub fn new(storage: S) -> Self {
72        Self {
73            storage,
74            crypto: RustCrypto::default(),
75            identity: RwLock::new(None),
76            groups: RwLock::new(HashMap::new()),
77        }
78    }
79
80    // ══════════════════════════════════════════════════════════
81    // Identity
82    // ══════════════════════════════════════════════════════════
83
84    /// Initialize identity from wallet address.
85    ///
86    /// Creates MLS credentials and signing keys from the wallet address.
87    /// Call this once before using any other methods.
88    pub fn init(&self, wallet: Address) -> Result<()> {
89        {
90            let guard = self
91                .identity
92                .read()
93                .map_err(|e| StorageError::Lock(e.to_string()))?;
94            if guard.is_some() {
95                return Err(MlsError::Identity(IdentityError::AlreadyInitialized));
96            }
97        }
98
99        let credential = BasicCredential::new(wallet.as_slice().to_vec());
100        let signer = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm())?;
101
102        // Store signer in OpenMLS storage
103        signer.store(self.storage.mls_storage())?;
104
105        let data = IdentityData {
106            wallet,
107            credential: CredentialWithKey {
108                credential: credential.into(),
109                signature_key: signer.to_public_vec().into(),
110            },
111            signer,
112        };
113
114        let mut guard = self
115            .identity
116            .write()
117            .map_err(|e| StorageError::Lock(e.to_string()))?;
118        *guard = Some(data);
119        Ok(())
120    }
121
122    /// Get the wallet address as a checksummed hex string ("0x...").
123    pub fn wallet_hex(&self) -> String {
124        self.identity
125            .read()
126            .ok()
127            .and_then(|guard| guard.as_ref().map(|id| id.wallet.to_checksum(None)))
128            .unwrap_or_default()
129    }
130
131    // ══════════════════════════════════════════════════════════
132    // Key Packages
133    // ══════════════════════════════════════════════════════════
134
135    /// Generate a key package for joining a group.
136    ///
137    /// Key packages are single-use and should be regenerated after each join.
138    pub fn generate_key_package(&self) -> Result<KeyPackageBytes> {
139        let guard = self
140            .identity
141            .read()
142            .map_err(|e| StorageError::Lock(e.to_string()))?;
143        let identity = guard
144            .as_ref()
145            .ok_or(MlsError::Identity(IdentityError::IdentityNotFound))?;
146
147        let provider = self.make_provider();
148
149        let kp_bundle = openmls::key_packages::KeyPackage::builder().build(
150            CIPHERSUITE,
151            &provider,
152            &identity.signer,
153            identity.credential.clone(),
154        )?;
155
156        let kp = kp_bundle.key_package();
157        let hash_ref = kp.hash_ref(provider.crypto())?.as_slice().to_vec();
158        let bytes = serde_json::to_vec(kp).map_err(IdentityError::InvalidJson)?;
159
160        self.storage.store_key_package_ref(&hash_ref)?;
161
162        Ok(KeyPackageBytes::new(
163            bytes,
164            identity.wallet.as_slice().to_vec(),
165        ))
166    }
167
168    // ══════════════════════════════════════════════════════════
169    // Groups
170    // ══════════════════════════════════════════════════════════
171
172    /// Create a new MLS group.
173    ///
174    /// The group name becomes the MLS group ID. The creator becomes
175    /// the only member and is typically the steward.
176    pub fn create_group(&self, group_id: &str) -> Result<()> {
177        let guard = self
178            .identity
179            .read()
180            .map_err(|e| StorageError::Lock(e.to_string()))?;
181        let identity = guard
182            .as_ref()
183            .ok_or(MlsError::Identity(IdentityError::IdentityNotFound))?;
184
185        let provider = self.make_provider();
186
187        let config = MlsGroupCreateConfig::builder()
188            .use_ratchet_tree_extension(true)
189            .build();
190
191        let group = MlsGroup::new_with_group_id(
192            &provider,
193            &identity.signer,
194            &config,
195            GroupId::from_slice(group_id.as_bytes()),
196            identity.credential.clone(),
197        )?;
198
199        self.groups
200            .write()
201            .map_err(|e| StorageError::Lock(e.to_string()))?
202            .insert(group_id.to_string(), group);
203
204        Ok(())
205    }
206
207    /// Join a group from a welcome message.
208    ///
209    /// Returns the group ID on success. The welcome must be for us
210    /// (contain one of our key package references).
211    pub fn join_group(&self, welcome_bytes: &[u8]) -> Result<String> {
212        let provider = self.make_provider();
213
214        let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(welcome_bytes)?;
215        let welcome = match mls_message.extract() {
216            MlsMessageBodyIn::Welcome(w) => w,
217            _ => return Err(MlsError::Service(MlsServiceError::UnexpectedMessageType)),
218        };
219
220        // Check if this welcome is for us
221        let is_for_us = welcome.secrets().iter().any(|s| {
222            self.storage
223                .is_our_key_package(s.new_member().as_slice())
224                .unwrap_or(false)
225        });
226        if !is_for_us {
227            return Err(MlsError::Service(MlsServiceError::WelcomeNotForUs));
228        }
229
230        // Remove used key package references
231        for secret in welcome.secrets() {
232            let _ = self
233                .storage
234                .remove_key_package_ref(secret.new_member().as_slice());
235        }
236
237        let config = MlsGroupJoinConfig::builder().build();
238        let group = StagedWelcome::new_from_welcome(&provider, &config, welcome, None)?
239            .into_group(&provider)?;
240
241        let group_id = String::from_utf8_lossy(group.group_id().as_slice()).to_string();
242
243        self.groups
244            .write()
245            .map_err(|e| StorageError::Lock(e.to_string()))?
246            .insert(group_id.clone(), group);
247
248        Ok(group_id)
249    }
250
251    /// Check if a welcome message is for us (without joining).
252    ///
253    /// Returns true if the welcome contains one of our key package references.
254    pub fn is_welcome_for_us(&self, welcome_bytes: &[u8]) -> Result<bool> {
255        let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(welcome_bytes)?;
256        let welcome = match mls_message.extract() {
257            MlsMessageBodyIn::Welcome(w) => w,
258            _ => return Ok(false),
259        };
260
261        Ok(welcome.secrets().iter().any(|s| {
262            self.storage
263                .is_our_key_package(s.new_member().as_slice())
264                .unwrap_or(false)
265        }))
266    }
267
268    /// Get all current group members as wallet addresses.
269    pub fn members(&self, group_id: &str) -> Result<Vec<Vec<u8>>> {
270        let groups = self
271            .groups
272            .read()
273            .map_err(|e| StorageError::Lock(e.to_string()))?;
274        let group = groups.get(group_id).ok_or_else(|| {
275            MlsError::Service(MlsServiceError::GroupNotFound(group_id.to_string()))
276        })?;
277
278        Ok(group
279            .members()
280            .map(|m| m.credential.serialized_content().to_vec())
281            .collect())
282    }
283
284    // ══════════════════════════════════════════════════════════
285    // Messages
286    // ══════════════════════════════════════════════════════════
287
288    /// Encrypt an application message for the group.
289    ///
290    /// Returns MLS ciphertext that only group members can decrypt.
291    pub fn encrypt(&self, group_id: &str, plaintext: &[u8]) -> Result<Vec<u8>> {
292        let id_guard = self
293            .identity
294            .read()
295            .map_err(|e| StorageError::Lock(e.to_string()))?;
296        let identity = id_guard
297            .as_ref()
298            .ok_or(MlsError::Identity(IdentityError::IdentityNotFound))?;
299
300        let provider = self.make_provider();
301
302        let mut groups = self
303            .groups
304            .write()
305            .map_err(|e| StorageError::Lock(e.to_string()))?;
306        let group = groups.get_mut(group_id).ok_or_else(|| {
307            MlsError::Service(MlsServiceError::GroupNotFound(group_id.to_string()))
308        })?;
309
310        let message = group.create_message(&provider, &identity.signer, plaintext)?;
311        Ok(message.to_bytes()?)
312    }
313
314    /// Decrypt/process an inbound MLS message.
315    ///
316    /// Handles application messages, proposals, and commits.
317    pub fn decrypt(&self, group_id: &str, ciphertext: &[u8]) -> Result<DecryptResult> {
318        let provider = self.make_provider();
319
320        let mut groups = self
321            .groups
322            .write()
323            .map_err(|e| StorageError::Lock(e.to_string()))?;
324        let group = groups.get_mut(group_id).ok_or_else(|| {
325            MlsError::Service(MlsServiceError::GroupNotFound(group_id.to_string()))
326        })?;
327
328        let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(ciphertext)?;
329        let protocol_message: ProtocolMessage = mls_message.try_into_protocol_message()?;
330
331        // Check group ID
332        if protocol_message.group_id().as_slice() != group.group_id().as_slice() {
333            return Ok(DecryptResult::Ignored);
334        }
335
336        // Ignore messages from old epochs - they can't be processed after the group advances
337        if protocol_message.epoch() < group.epoch() {
338            tracing::debug!(
339                "Ignoring message from old epoch {} (current: {})",
340                protocol_message.epoch().as_u64(),
341                group.epoch().as_u64()
342            );
343            return Ok(DecryptResult::Ignored);
344        }
345
346        let processed = group.process_message(&provider, protocol_message)?;
347
348        match processed.into_content() {
349            ProcessedMessageContent::ApplicationMessage(app) => {
350                Ok(DecryptResult::Application(app.into_bytes()))
351            }
352            ProcessedMessageContent::ProposalMessage(proposal) => {
353                group.store_pending_proposal(provider.storage(), proposal.as_ref().clone())?;
354                Ok(DecryptResult::ProposalStored)
355            }
356            ProcessedMessageContent::StagedCommitMessage(staged) => {
357                let removed = staged.self_removed();
358                group.merge_staged_commit(&provider, *staged)?;
359                if removed {
360                    if group.is_active() {
361                        return Err(MlsError::Service(MlsServiceError::GroupStillActive));
362                    }
363                    Ok(DecryptResult::Removed)
364                } else {
365                    Ok(DecryptResult::CommitProcessed)
366                }
367            }
368            ProcessedMessageContent::ExternalJoinProposalMessage(_) => Ok(DecryptResult::Ignored),
369        }
370    }
371
372    // ══════════════════════════════════════════════════════════
373    // Steward
374    // ══════════════════════════════════════════════════════════
375
376    /// Create proposals for membership changes and commit them.
377    ///
378    /// This is the core steward operation: takes a list of add/remove
379    /// operations, creates MLS proposals, and commits them in a batch.
380    pub fn commit(&self, group_id: &str, updates: &[GroupUpdate]) -> Result<CommitResult> {
381        let id_guard = self
382            .identity
383            .read()
384            .map_err(|e| StorageError::Lock(e.to_string()))?;
385        let identity = id_guard
386            .as_ref()
387            .ok_or(MlsError::Identity(IdentityError::IdentityNotFound))?;
388
389        let provider = self.make_provider();
390
391        let mut groups = self
392            .groups
393            .write()
394            .map_err(|e| StorageError::Lock(e.to_string()))?;
395        let group = groups.get_mut(group_id).ok_or_else(|| {
396            MlsError::Service(MlsServiceError::GroupNotFound(group_id.to_string()))
397        })?;
398
399        let mut mls_proposals = Vec::new();
400
401        for update in updates {
402            match update {
403                GroupUpdate::Add(key_package) => {
404                    let kp: openmls::key_packages::KeyPackage =
405                        serde_json::from_slice(key_package.as_bytes())
406                            .map_err(MlsServiceError::InvalidKeyPackage)?;
407                    let (mls_message_out, _proposal_ref) =
408                        group.propose_add_member(&provider, &identity.signer, &kp)?;
409                    mls_proposals.push(mls_message_out.to_bytes()?);
410                }
411                GroupUpdate::Remove(wallet_bytes) => {
412                    let member_index = group.members().find_map(|m| {
413                        if m.credential.serialized_content() == wallet_bytes {
414                            Some(m.index)
415                        } else {
416                            None
417                        }
418                    });
419                    if let Some(index) = member_index {
420                        let (mls_message_out, _proposal_ref) =
421                            group.propose_remove_member(&provider, &identity.signer, index)?;
422                        mls_proposals.push(mls_message_out.to_bytes()?);
423                    }
424                }
425            }
426        }
427
428        let (commit_msg, welcome, _group_info) =
429            group.commit_to_pending_proposals(&provider, &identity.signer)?;
430        group.merge_pending_commit(&provider)?;
431
432        let welcome_bytes = match welcome {
433            Some(w) => Some(w.to_bytes()?),
434            None => None,
435        };
436
437        Ok(CommitResult {
438            proposals: mls_proposals,
439            commit: commit_msg.to_bytes()?,
440            welcome: welcome_bytes,
441        })
442    }
443
444    // ══════════════════════════════════════════════════════════
445    // Internal
446    // ══════════════════════════════════════════════════════════
447
448    fn make_provider(&self) -> MlsProvider<'_> {
449        MlsProvider {
450            crypto: &self.crypto,
451            storage: self.storage.mls_storage(),
452        }
453    }
454}