1use std::collections::HashMap;
4use std::sync::RwLock;
5
6use alloy::primitives::Address;
7use openmls::credentials::CredentialWithKey;
8use openmls::group::{GroupId, MlsGroup, MlsGroupCreateConfig, MlsGroupJoinConfig};
9use openmls::prelude::{
10 BasicCredential, Ciphersuite, DeserializeBytes, MlsMessageBodyIn, MlsMessageIn,
11 ProcessedMessageContent, ProtocolMessage, StagedWelcome,
12};
13use openmls_basic_credential::SignatureKeyPair;
14use openmls_rust_crypto::{MemoryStorage, RustCrypto};
15use openmls_traits::OpenMlsProvider;
16
17use crate::mls_crypto::{
18 error::{IdentityError, MlsError, MlsServiceError, Result, StorageError},
19 identity::IdentityData,
20 storage::DeMlsStorage,
21 types::{CommitResult, DecryptResult, GroupUpdate, KeyPackageBytes},
22};
23
24pub const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
26
27struct MlsProvider<'a> {
29 crypto: &'a RustCrypto,
30 storage: &'a MemoryStorage,
31}
32
33impl<'a> OpenMlsProvider for MlsProvider<'a> {
34 type CryptoProvider = RustCrypto;
35 type RandProvider = RustCrypto;
36 type StorageProvider = MemoryStorage;
37
38 fn crypto(&self) -> &Self::CryptoProvider {
39 self.crypto
40 }
41
42 fn rand(&self) -> &Self::RandProvider {
43 self.crypto
44 }
45
46 fn storage(&self) -> &Self::StorageProvider {
47 self.storage
48 }
49}
50
51pub struct MlsService<S: DeMlsStorage> {
60 storage: S,
61 crypto: RustCrypto,
62 identity: RwLock<Option<IdentityData>>,
63 groups: RwLock<HashMap<String, MlsGroup>>,
64}
65
66impl<S> MlsService<S>
67where
68 S: DeMlsStorage<MlsStorage = MemoryStorage>,
69{
70 pub fn new(storage: S) -> Self {
72 Self {
73 storage,
74 crypto: RustCrypto::default(),
75 identity: RwLock::new(None),
76 groups: RwLock::new(HashMap::new()),
77 }
78 }
79
80 pub fn init(&self, wallet: Address) -> Result<()> {
89 {
90 let guard = self
91 .identity
92 .read()
93 .map_err(|e| StorageError::Lock(e.to_string()))?;
94 if guard.is_some() {
95 return Err(MlsError::Identity(IdentityError::AlreadyInitialized));
96 }
97 }
98
99 let credential = BasicCredential::new(wallet.as_slice().to_vec());
100 let signer = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm())?;
101
102 signer.store(self.storage.mls_storage())?;
104
105 let data = IdentityData {
106 wallet,
107 credential: CredentialWithKey {
108 credential: credential.into(),
109 signature_key: signer.to_public_vec().into(),
110 },
111 signer,
112 };
113
114 let mut guard = self
115 .identity
116 .write()
117 .map_err(|e| StorageError::Lock(e.to_string()))?;
118 *guard = Some(data);
119 Ok(())
120 }
121
122 pub fn wallet_hex(&self) -> String {
124 self.identity
125 .read()
126 .ok()
127 .and_then(|guard| guard.as_ref().map(|id| id.wallet.to_checksum(None)))
128 .unwrap_or_default()
129 }
130
131 pub fn generate_key_package(&self) -> Result<KeyPackageBytes> {
139 let guard = self
140 .identity
141 .read()
142 .map_err(|e| StorageError::Lock(e.to_string()))?;
143 let identity = guard
144 .as_ref()
145 .ok_or(MlsError::Identity(IdentityError::IdentityNotFound))?;
146
147 let provider = self.make_provider();
148
149 let kp_bundle = openmls::key_packages::KeyPackage::builder().build(
150 CIPHERSUITE,
151 &provider,
152 &identity.signer,
153 identity.credential.clone(),
154 )?;
155
156 let kp = kp_bundle.key_package();
157 let hash_ref = kp.hash_ref(provider.crypto())?.as_slice().to_vec();
158 let bytes = serde_json::to_vec(kp).map_err(IdentityError::InvalidJson)?;
159
160 self.storage.store_key_package_ref(&hash_ref)?;
161
162 Ok(KeyPackageBytes::new(
163 bytes,
164 identity.wallet.as_slice().to_vec(),
165 ))
166 }
167
168 pub fn create_group(&self, group_id: &str) -> Result<()> {
177 let guard = self
178 .identity
179 .read()
180 .map_err(|e| StorageError::Lock(e.to_string()))?;
181 let identity = guard
182 .as_ref()
183 .ok_or(MlsError::Identity(IdentityError::IdentityNotFound))?;
184
185 let provider = self.make_provider();
186
187 let config = MlsGroupCreateConfig::builder()
188 .use_ratchet_tree_extension(true)
189 .build();
190
191 let group = MlsGroup::new_with_group_id(
192 &provider,
193 &identity.signer,
194 &config,
195 GroupId::from_slice(group_id.as_bytes()),
196 identity.credential.clone(),
197 )?;
198
199 self.groups
200 .write()
201 .map_err(|e| StorageError::Lock(e.to_string()))?
202 .insert(group_id.to_string(), group);
203
204 Ok(())
205 }
206
207 pub fn join_group(&self, welcome_bytes: &[u8]) -> Result<String> {
212 let provider = self.make_provider();
213
214 let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(welcome_bytes)?;
215 let welcome = match mls_message.extract() {
216 MlsMessageBodyIn::Welcome(w) => w,
217 _ => return Err(MlsError::Service(MlsServiceError::UnexpectedMessageType)),
218 };
219
220 let is_for_us = welcome.secrets().iter().any(|s| {
222 self.storage
223 .is_our_key_package(s.new_member().as_slice())
224 .unwrap_or(false)
225 });
226 if !is_for_us {
227 return Err(MlsError::Service(MlsServiceError::WelcomeNotForUs));
228 }
229
230 for secret in welcome.secrets() {
232 let _ = self
233 .storage
234 .remove_key_package_ref(secret.new_member().as_slice());
235 }
236
237 let config = MlsGroupJoinConfig::builder().build();
238 let group = StagedWelcome::new_from_welcome(&provider, &config, welcome, None)?
239 .into_group(&provider)?;
240
241 let group_id = String::from_utf8_lossy(group.group_id().as_slice()).to_string();
242
243 self.groups
244 .write()
245 .map_err(|e| StorageError::Lock(e.to_string()))?
246 .insert(group_id.clone(), group);
247
248 Ok(group_id)
249 }
250
251 pub fn is_welcome_for_us(&self, welcome_bytes: &[u8]) -> Result<bool> {
255 let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(welcome_bytes)?;
256 let welcome = match mls_message.extract() {
257 MlsMessageBodyIn::Welcome(w) => w,
258 _ => return Ok(false),
259 };
260
261 Ok(welcome.secrets().iter().any(|s| {
262 self.storage
263 .is_our_key_package(s.new_member().as_slice())
264 .unwrap_or(false)
265 }))
266 }
267
268 pub fn members(&self, group_id: &str) -> Result<Vec<Vec<u8>>> {
270 let groups = self
271 .groups
272 .read()
273 .map_err(|e| StorageError::Lock(e.to_string()))?;
274 let group = groups.get(group_id).ok_or_else(|| {
275 MlsError::Service(MlsServiceError::GroupNotFound(group_id.to_string()))
276 })?;
277
278 Ok(group
279 .members()
280 .map(|m| m.credential.serialized_content().to_vec())
281 .collect())
282 }
283
284 pub fn encrypt(&self, group_id: &str, plaintext: &[u8]) -> Result<Vec<u8>> {
292 let id_guard = self
293 .identity
294 .read()
295 .map_err(|e| StorageError::Lock(e.to_string()))?;
296 let identity = id_guard
297 .as_ref()
298 .ok_or(MlsError::Identity(IdentityError::IdentityNotFound))?;
299
300 let provider = self.make_provider();
301
302 let mut groups = self
303 .groups
304 .write()
305 .map_err(|e| StorageError::Lock(e.to_string()))?;
306 let group = groups.get_mut(group_id).ok_or_else(|| {
307 MlsError::Service(MlsServiceError::GroupNotFound(group_id.to_string()))
308 })?;
309
310 let message = group.create_message(&provider, &identity.signer, plaintext)?;
311 Ok(message.to_bytes()?)
312 }
313
314 pub fn decrypt(&self, group_id: &str, ciphertext: &[u8]) -> Result<DecryptResult> {
318 let provider = self.make_provider();
319
320 let mut groups = self
321 .groups
322 .write()
323 .map_err(|e| StorageError::Lock(e.to_string()))?;
324 let group = groups.get_mut(group_id).ok_or_else(|| {
325 MlsError::Service(MlsServiceError::GroupNotFound(group_id.to_string()))
326 })?;
327
328 let (mls_message, _) = MlsMessageIn::tls_deserialize_bytes(ciphertext)?;
329 let protocol_message: ProtocolMessage = mls_message.try_into_protocol_message()?;
330
331 if protocol_message.group_id().as_slice() != group.group_id().as_slice() {
333 return Ok(DecryptResult::Ignored);
334 }
335
336 if protocol_message.epoch() < group.epoch() {
338 tracing::debug!(
339 "Ignoring message from old epoch {} (current: {})",
340 protocol_message.epoch().as_u64(),
341 group.epoch().as_u64()
342 );
343 return Ok(DecryptResult::Ignored);
344 }
345
346 let processed = group.process_message(&provider, protocol_message)?;
347
348 match processed.into_content() {
349 ProcessedMessageContent::ApplicationMessage(app) => {
350 Ok(DecryptResult::Application(app.into_bytes()))
351 }
352 ProcessedMessageContent::ProposalMessage(proposal) => {
353 group.store_pending_proposal(provider.storage(), proposal.as_ref().clone())?;
354 Ok(DecryptResult::ProposalStored)
355 }
356 ProcessedMessageContent::StagedCommitMessage(staged) => {
357 let removed = staged.self_removed();
358 group.merge_staged_commit(&provider, *staged)?;
359 if removed {
360 if group.is_active() {
361 return Err(MlsError::Service(MlsServiceError::GroupStillActive));
362 }
363 Ok(DecryptResult::Removed)
364 } else {
365 Ok(DecryptResult::CommitProcessed)
366 }
367 }
368 ProcessedMessageContent::ExternalJoinProposalMessage(_) => Ok(DecryptResult::Ignored),
369 }
370 }
371
372 pub fn commit(&self, group_id: &str, updates: &[GroupUpdate]) -> Result<CommitResult> {
381 let id_guard = self
382 .identity
383 .read()
384 .map_err(|e| StorageError::Lock(e.to_string()))?;
385 let identity = id_guard
386 .as_ref()
387 .ok_or(MlsError::Identity(IdentityError::IdentityNotFound))?;
388
389 let provider = self.make_provider();
390
391 let mut groups = self
392 .groups
393 .write()
394 .map_err(|e| StorageError::Lock(e.to_string()))?;
395 let group = groups.get_mut(group_id).ok_or_else(|| {
396 MlsError::Service(MlsServiceError::GroupNotFound(group_id.to_string()))
397 })?;
398
399 let mut mls_proposals = Vec::new();
400
401 for update in updates {
402 match update {
403 GroupUpdate::Add(key_package) => {
404 let kp: openmls::key_packages::KeyPackage =
405 serde_json::from_slice(key_package.as_bytes())
406 .map_err(MlsServiceError::InvalidKeyPackage)?;
407 let (mls_message_out, _proposal_ref) =
408 group.propose_add_member(&provider, &identity.signer, &kp)?;
409 mls_proposals.push(mls_message_out.to_bytes()?);
410 }
411 GroupUpdate::Remove(wallet_bytes) => {
412 let member_index = group.members().find_map(|m| {
413 if m.credential.serialized_content() == wallet_bytes {
414 Some(m.index)
415 } else {
416 None
417 }
418 });
419 if let Some(index) = member_index {
420 let (mls_message_out, _proposal_ref) =
421 group.propose_remove_member(&provider, &identity.signer, index)?;
422 mls_proposals.push(mls_message_out.to_bytes()?);
423 }
424 }
425 }
426 }
427
428 let (commit_msg, welcome, _group_info) =
429 group.commit_to_pending_proposals(&provider, &identity.signer)?;
430 group.merge_pending_commit(&provider)?;
431
432 let welcome_bytes = match welcome {
433 Some(w) => Some(w.to_bytes()?),
434 None => None,
435 };
436
437 Ok(CommitResult {
438 proposals: mls_proposals,
439 commit: commit_msg.to_bytes()?,
440 welcome: welcome_bytes,
441 })
442 }
443
444 fn make_provider(&self) -> MlsProvider<'_> {
449 MlsProvider {
450 crypto: &self.crypto,
451 storage: self.storage.mls_storage(),
452 }
453 }
454}