1use std::collections::BTreeSet;
4
5use mdk_storage_traits::GroupId;
6use mdk_storage_traits::groups::error::GroupError;
7use mdk_storage_traits::groups::types::*;
8use mdk_storage_traits::groups::{GroupStorage, MAX_MESSAGE_LIMIT, MessageSortOrder, Pagination};
9use mdk_storage_traits::messages::types::Message;
10use nostr::{PublicKey, RelayUrl};
11
12use crate::MdkMemoryStorage;
13
14impl GroupStorage for MdkMemoryStorage {
15 fn save_group(&self, group: Group) -> Result<(), GroupError> {
16 if group.name.len() > self.limits.max_group_name_length {
18 return Err(GroupError::InvalidParameters(format!(
19 "Group name exceeds maximum length of {} bytes (got {} bytes)",
20 self.limits.max_group_name_length,
21 group.name.len()
22 )));
23 }
24
25 if group.description.len() > self.limits.max_group_description_length {
27 return Err(GroupError::InvalidParameters(format!(
28 "Group description exceeds maximum length of {} bytes (got {} bytes)",
29 self.limits.max_group_description_length,
30 group.description.len()
31 )));
32 }
33
34 if group.admin_pubkeys.len() > self.limits.max_admins_per_group {
36 return Err(GroupError::InvalidParameters(format!(
37 "Group admin count exceeds maximum of {} (got {})",
38 self.limits.max_admins_per_group,
39 group.admin_pubkeys.len()
40 )));
41 }
42
43 let mut guard = self.inner.write();
45 let inner = &mut *guard;
46 let groups_cache = &mut inner.groups_cache;
47 let nostr_id_cache = &mut inner.groups_by_nostr_id_cache;
48
49 if let Some(existing_group) = nostr_id_cache.peek(&group.nostr_group_id)
51 && existing_group.mls_group_id != group.mls_group_id
52 {
53 return Err(GroupError::InvalidParameters(
54 "nostr_group_id already exists for a different group".to_string(),
55 ));
56 }
57
58 if let Some(existing_group) = groups_cache.peek(&group.mls_group_id)
60 && existing_group.nostr_group_id != group.nostr_group_id
61 {
62 nostr_id_cache.pop(&existing_group.nostr_group_id);
63 }
64
65 groups_cache.put(group.mls_group_id.clone(), group.clone());
67 nostr_id_cache.put(group.nostr_group_id, group);
68
69 Ok(())
70 }
71
72 fn all_groups(&self) -> Result<Vec<Group>, GroupError> {
73 let inner = self.inner.read();
74 let groups: Vec<Group> = inner.groups_cache.iter().map(|(_, v)| v.clone()).collect();
76 Ok(groups)
77 }
78
79 fn find_group_by_mls_group_id(
80 &self,
81 mls_group_id: &GroupId,
82 ) -> Result<Option<Group>, GroupError> {
83 let inner = self.inner.read();
84 Ok(inner.groups_cache.peek(mls_group_id).cloned())
85 }
86
87 fn find_group_by_nostr_group_id(
88 &self,
89 nostr_group_id: &[u8; 32],
90 ) -> Result<Option<Group>, GroupError> {
91 let inner = self.inner.read();
92 Ok(inner.groups_by_nostr_id_cache.peek(nostr_group_id).cloned())
93 }
94
95 fn messages(
96 &self,
97 mls_group_id: &GroupId,
98 pagination: Option<Pagination>,
99 ) -> Result<Vec<Message>, GroupError> {
100 let pagination = pagination.unwrap_or_default();
101 let limit = pagination.limit();
102 let offset = pagination.offset();
103
104 if !(1..=MAX_MESSAGE_LIMIT).contains(&limit) {
106 return Err(GroupError::InvalidParameters(format!(
107 "Limit must be between 1 and {}, got {}",
108 MAX_MESSAGE_LIMIT, limit
109 )));
110 }
111
112 let inner = self.inner.read();
113
114 if inner.groups_cache.peek(mls_group_id).is_none() {
116 return Err(GroupError::InvalidParameters("Group not found".to_string()));
117 }
118
119 let sort_order = pagination.sort_order();
120
121 match inner.messages_by_group_cache.peek(mls_group_id) {
122 Some(messages_map) => {
123 let mut messages: Vec<Message> = messages_map.values().cloned().collect();
125
126 match sort_order {
129 MessageSortOrder::CreatedAtFirst => {
130 messages.sort_by(|a, b| b.display_order_cmp(a));
131 }
132 MessageSortOrder::ProcessedAtFirst => {
133 messages.sort_by(|a, b| b.processed_at_order_cmp(a));
134 }
135 }
136
137 let start = offset.min(messages.len());
139 let end = (offset + limit).min(messages.len());
140
141 Ok(messages[start..end].to_vec())
142 }
143 None => Ok(Vec::new()),
145 }
146 }
147
148 fn last_message(
149 &self,
150 mls_group_id: &GroupId,
151 sort_order: MessageSortOrder,
152 ) -> Result<Option<Message>, GroupError> {
153 let inner = self.inner.read();
154
155 if inner.groups_cache.peek(mls_group_id).is_none() {
156 return Err(GroupError::InvalidParameters("Group not found".to_string()));
157 }
158
159 match inner.messages_by_group_cache.peek(mls_group_id) {
160 Some(messages_map) if !messages_map.is_empty() => {
161 let winner = match sort_order {
164 MessageSortOrder::CreatedAtFirst => {
165 messages_map.values().max_by(|a, b| a.display_order_cmp(b))
166 }
167 MessageSortOrder::ProcessedAtFirst => messages_map
168 .values()
169 .max_by(|a, b| a.processed_at_order_cmp(b)),
170 };
171 Ok(winner.cloned())
172 }
173 _ => Ok(None),
174 }
175 }
176
177 fn admins(&self, mls_group_id: &GroupId) -> Result<BTreeSet<PublicKey>, GroupError> {
178 match self.find_group_by_mls_group_id(mls_group_id)? {
179 Some(group) => Ok(group.admin_pubkeys.clone()),
180 None => Err(GroupError::InvalidParameters("Group not found".to_string())),
181 }
182 }
183
184 fn group_relays(&self, mls_group_id: &GroupId) -> Result<BTreeSet<GroupRelay>, GroupError> {
185 let inner = self.inner.read();
186
187 if inner.groups_cache.peek(mls_group_id).is_none() {
189 return Err(GroupError::InvalidParameters("Group not found".to_string()));
190 }
191
192 match inner.group_relays_cache.peek(mls_group_id).cloned() {
193 Some(relays) => Ok(relays),
194 None => Ok(BTreeSet::new()),
196 }
197 }
198
199 fn replace_group_relays(
200 &self,
201 group_id: &GroupId,
202 relays: BTreeSet<RelayUrl>,
203 ) -> Result<(), GroupError> {
204 if relays.len() > self.limits.max_relays_per_group {
206 return Err(GroupError::InvalidParameters(format!(
207 "Relay count exceeds maximum of {} (got {})",
208 self.limits.max_relays_per_group,
209 relays.len()
210 )));
211 }
212
213 for relay in &relays {
215 if relay.as_str().len() > self.limits.max_relay_url_length {
216 return Err(GroupError::InvalidParameters(format!(
217 "Relay URL exceeds maximum length of {} bytes",
218 self.limits.max_relay_url_length
219 )));
220 }
221 }
222
223 let mut inner = self.inner.write();
224
225 if inner.groups_cache.peek(group_id).is_none() {
227 return Err(GroupError::InvalidParameters("Group not found".to_string()));
228 }
229
230 let group_relays: BTreeSet<GroupRelay> = relays
232 .into_iter()
233 .map(|relay_url| GroupRelay {
234 mls_group_id: group_id.clone(),
235 relay_url,
236 })
237 .collect();
238
239 inner.group_relays_cache.put(group_id.clone(), group_relays);
241
242 Ok(())
243 }
244
245 fn get_group_exporter_secret(
246 &self,
247 mls_group_id: &GroupId,
248 epoch: u64,
249 ) -> Result<Option<GroupExporterSecret>, GroupError> {
250 let inner = self.inner.read();
251
252 if inner.groups_cache.peek(mls_group_id).is_none() {
254 return Err(GroupError::InvalidParameters("Group not found".to_string()));
255 }
256
257 Ok(inner
259 .group_exporter_secrets_cache
260 .peek(&(mls_group_id.clone(), epoch))
261 .cloned())
262 }
263
264 fn save_group_exporter_secret(
265 &self,
266 group_exporter_secret: GroupExporterSecret,
267 ) -> Result<(), GroupError> {
268 let mut inner = self.inner.write();
269
270 if inner
272 .groups_cache
273 .peek(&group_exporter_secret.mls_group_id)
274 .is_none()
275 {
276 return Err(GroupError::InvalidParameters("Group not found".to_string()));
277 }
278
279 let key = (
281 group_exporter_secret.mls_group_id.clone(),
282 group_exporter_secret.epoch,
283 );
284 inner
285 .group_exporter_secrets_cache
286 .put(key, group_exporter_secret);
287
288 Ok(())
289 }
290
291 fn get_group_mip04_exporter_secret(
292 &self,
293 mls_group_id: &GroupId,
294 epoch: u64,
295 ) -> Result<Option<GroupExporterSecret>, GroupError> {
296 let inner = self.inner.read();
297
298 if inner.groups_cache.peek(mls_group_id).is_none() {
300 return Err(GroupError::InvalidParameters("Group not found".to_string()));
301 }
302
303 Ok(inner
304 .group_mip04_exporter_secrets_cache
305 .peek(&(mls_group_id.clone(), epoch))
306 .cloned())
307 }
308
309 fn save_group_mip04_exporter_secret(
310 &self,
311 group_exporter_secret: GroupExporterSecret,
312 ) -> Result<(), GroupError> {
313 let mut inner = self.inner.write();
314
315 if inner
317 .groups_cache
318 .peek(&group_exporter_secret.mls_group_id)
319 .is_none()
320 {
321 return Err(GroupError::InvalidParameters("Group not found".to_string()));
322 }
323
324 let key = (
325 group_exporter_secret.mls_group_id.clone(),
326 group_exporter_secret.epoch,
327 );
328 inner
329 .group_mip04_exporter_secrets_cache
330 .put(key, group_exporter_secret);
331
332 Ok(())
333 }
334
335 fn prune_group_exporter_secrets_before_epoch(
336 &self,
337 group_id: &GroupId,
338 min_epoch_to_keep: u64,
339 ) -> Result<(), GroupError> {
340 let mut inner = self.inner.write();
341
342 if inner.groups_cache.peek(group_id).is_none() {
343 return Err(GroupError::InvalidParameters("Group not found".to_string()));
344 }
345
346 let group_event_keys: Vec<(GroupId, u64)> = inner
347 .group_exporter_secrets_cache
348 .iter()
349 .filter_map(|(k, _)| {
350 let (gid, epoch) = k;
351 if gid == group_id && *epoch < min_epoch_to_keep {
352 Some((gid.clone(), *epoch))
353 } else {
354 None
355 }
356 })
357 .collect();
358
359 for key in group_event_keys {
360 inner.group_exporter_secrets_cache.pop(&key);
361 }
362
363 let mip04_keys: Vec<(GroupId, u64)> = inner
364 .group_mip04_exporter_secrets_cache
365 .iter()
366 .filter_map(|(k, _)| {
367 let (gid, epoch) = k;
368 if gid == group_id && *epoch < min_epoch_to_keep {
369 Some((gid.clone(), *epoch))
370 } else {
371 None
372 }
373 })
374 .collect();
375
376 for key in mip04_keys {
377 inner.group_mip04_exporter_secrets_cache.pop(&key);
378 }
379
380 Ok(())
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use mdk_storage_traits::groups::types::GroupState;
387 use mdk_storage_traits::messages::MessageStorage;
388 use mdk_storage_traits::messages::types::{Message, MessageState};
389 use nostr::{EventId, Keys, Kind, Tags, Timestamp, UnsignedEvent};
390
391 use super::*;
392 use crate::{
393 DEFAULT_MAX_ADMINS_PER_GROUP, DEFAULT_MAX_GROUP_DESCRIPTION_LENGTH,
394 DEFAULT_MAX_GROUP_NAME_LENGTH, DEFAULT_MAX_RELAY_URL_LENGTH, DEFAULT_MAX_RELAYS_PER_GROUP,
395 };
396
397 fn create_test_group(mls_group_id: GroupId, nostr_group_id: [u8; 32]) -> Group {
398 Group {
399 mls_group_id,
400 nostr_group_id,
401 name: "Test Group".to_string(),
402 description: "A test group".to_string(),
403 admin_pubkeys: BTreeSet::new(),
404 last_message_id: None,
405 last_message_at: None,
406 last_message_processed_at: None,
407 epoch: 0,
408 state: GroupState::Active,
409 image_hash: None,
410 image_key: None,
411 image_nonce: None,
412 self_update_state: SelfUpdateState::Required,
413 }
414 }
415
416 #[test]
417 fn test_save_group_name_length_validation() {
418 let storage = MdkMemoryStorage::new();
419 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
420
421 let mut group = create_test_group(mls_group_id.clone(), [1u8; 32]);
423 group.name = "a".repeat(DEFAULT_MAX_GROUP_NAME_LENGTH);
424 assert!(storage.save_group(group).is_ok());
425
426 let mut group = create_test_group(GroupId::from_slice(&[2, 3, 4, 5]), [2u8; 32]);
428 group.name = "a".repeat(DEFAULT_MAX_GROUP_NAME_LENGTH + 1);
429 let result = storage.save_group(group);
430 assert!(result.is_err());
431 assert!(
432 result
433 .unwrap_err()
434 .to_string()
435 .contains("Group name exceeds maximum length")
436 );
437 }
438
439 #[test]
440 fn test_save_group_description_length_validation() {
441 let storage = MdkMemoryStorage::new();
442 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
443
444 let mut group = create_test_group(mls_group_id.clone(), [1u8; 32]);
446 group.description = "a".repeat(DEFAULT_MAX_GROUP_DESCRIPTION_LENGTH);
447 assert!(storage.save_group(group).is_ok());
448
449 let mut group = create_test_group(GroupId::from_slice(&[2, 3, 4, 5]), [2u8; 32]);
451 group.description = "a".repeat(DEFAULT_MAX_GROUP_DESCRIPTION_LENGTH + 1);
452 let result = storage.save_group(group);
453 assert!(result.is_err());
454 assert!(
455 result
456 .unwrap_err()
457 .to_string()
458 .contains("Group description exceeds maximum length")
459 );
460 }
461
462 #[test]
463 fn test_save_group_admin_count_validation() {
464 let storage = MdkMemoryStorage::new();
465 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
466
467 let mut group = create_test_group(mls_group_id.clone(), [1u8; 32]);
469 for _ in 0..DEFAULT_MAX_ADMINS_PER_GROUP {
470 group.admin_pubkeys.insert(Keys::generate().public_key());
471 }
472 assert!(storage.save_group(group).is_ok());
473
474 let mut group = create_test_group(GroupId::from_slice(&[2, 3, 4, 5]), [2u8; 32]);
476 for _ in 0..=DEFAULT_MAX_ADMINS_PER_GROUP {
477 group.admin_pubkeys.insert(Keys::generate().public_key());
478 }
479 let result = storage.save_group(group);
480 assert!(result.is_err());
481 assert!(
482 result
483 .unwrap_err()
484 .to_string()
485 .contains("Group admin count exceeds maximum")
486 );
487 }
488
489 #[test]
490 fn test_replace_group_relays_count_validation() {
491 let storage = MdkMemoryStorage::new();
492 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
493
494 let group = create_test_group(mls_group_id.clone(), [1u8; 32]);
496 storage.save_group(group).unwrap();
497
498 let mut relays = BTreeSet::new();
500 for i in 0..DEFAULT_MAX_RELAYS_PER_GROUP {
501 relays.insert(RelayUrl::parse(&format!("wss://relay{}.example.com", i)).unwrap());
502 }
503 assert!(storage.replace_group_relays(&mls_group_id, relays).is_ok());
504
505 let mut relays = BTreeSet::new();
507 for i in 0..=DEFAULT_MAX_RELAYS_PER_GROUP {
508 relays.insert(RelayUrl::parse(&format!("wss://relay{}.example.com", i)).unwrap());
509 }
510 let result = storage.replace_group_relays(&mls_group_id, relays);
511 assert!(result.is_err());
512 assert!(
513 result
514 .unwrap_err()
515 .to_string()
516 .contains("Relay count exceeds maximum")
517 );
518 }
519
520 #[test]
521 fn test_replace_group_relays_url_length_validation() {
522 let storage = MdkMemoryStorage::new();
523 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
524
525 let group = create_test_group(mls_group_id.clone(), [1u8; 32]);
527 storage.save_group(group).unwrap();
528
529 let domain = "a".repeat(DEFAULT_MAX_RELAY_URL_LENGTH - 10);
532 let url = format!("wss://{}.com", domain);
533 let relays = BTreeSet::from([RelayUrl::parse(&url).unwrap()]);
534 assert!(storage.replace_group_relays(&mls_group_id, relays).is_ok());
535
536 let domain = "a".repeat(DEFAULT_MAX_RELAY_URL_LENGTH); let url = format!("wss://{}.com", domain);
539 let relays = BTreeSet::from([RelayUrl::parse(&url).unwrap()]);
540 let result = storage.replace_group_relays(&mls_group_id, relays);
541 assert!(result.is_err());
542 assert!(
543 result
544 .unwrap_err()
545 .to_string()
546 .contains("Relay URL exceeds maximum length")
547 );
548 }
549
550 #[test]
551 fn test_messages_pagination_memory() {
552 let storage = MdkMemoryStorage::new();
553
554 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
556 let nostr_group_id = [1u8; 32];
557
558 let group = Group {
559 mls_group_id: mls_group_id.clone(),
560 nostr_group_id,
561 name: "Test Group".to_string(),
562 description: "A test group".to_string(),
563 admin_pubkeys: BTreeSet::new(),
564 last_message_id: None,
565 last_message_at: None,
566 last_message_processed_at: None,
567 epoch: 0,
568 state: GroupState::Active,
569 image_hash: None,
570 image_key: None,
571 image_nonce: None,
572 self_update_state: SelfUpdateState::Required,
573 };
574
575 storage.save_group(group).unwrap();
576
577 let pubkey = Keys::generate().public_key();
579 for i in 0..25 {
580 let event_id = EventId::from_slice(&[i as u8; 32]).unwrap();
581 let wrapper_event_id = EventId::from_slice(&[100 + i as u8; 32]).unwrap();
582
583 let ts = Timestamp::from((1000 + i) as u64);
584 let message = Message {
585 id: event_id,
586 pubkey,
587 kind: Kind::from(1u16),
588 mls_group_id: mls_group_id.clone(),
589 created_at: ts,
590 processed_at: ts,
591 content: format!("Message {}", i),
592 tags: Tags::new(),
593 event: UnsignedEvent::new(
594 pubkey,
595 ts,
596 Kind::from(9u16),
597 vec![],
598 format!("content {}", i),
599 ),
600 wrapper_event_id,
601 state: MessageState::Created,
602 epoch: None,
603 };
604
605 storage.save_message(message).unwrap();
606 }
607
608 let all_messages = storage.messages(&mls_group_id, None).unwrap();
610 assert_eq!(all_messages.len(), 25);
611
612 let page1 = storage
614 .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(0))))
615 .unwrap();
616 assert_eq!(page1.len(), 10);
617 assert_eq!(page1[0].content, "Message 24");
619
620 let page2 = storage
622 .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(10))))
623 .unwrap();
624 assert_eq!(page2.len(), 10);
625 assert_eq!(page2[0].content, "Message 14");
626
627 let page3 = storage
629 .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(20))))
630 .unwrap();
631 assert_eq!(page3.len(), 5);
632 assert_eq!(page3[0].content, "Message 4");
633
634 let beyond = storage
636 .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(30))))
637 .unwrap();
638 assert_eq!(beyond.len(), 0);
639
640 let first_id = page1[0].id;
642 let second_page_ids: Vec<EventId> = page2.iter().map(|m| m.id).collect();
643 assert!(
644 !second_page_ids.contains(&first_id),
645 "Pages should not overlap"
646 );
647
648 for i in 0..page1.len() - 1 {
650 assert!(
651 page1[i].created_at >= page1[i + 1].created_at,
652 "Messages should be ordered by created_at descending"
653 );
654 }
655
656 let result = storage.messages(&mls_group_id, Some(Pagination::new(Some(0), Some(0))));
658 assert!(result.is_err());
659 assert!(
660 result
661 .unwrap_err()
662 .to_string()
663 .contains("must be between 1 and")
664 );
665
666 let result = storage.messages(&mls_group_id, Some(Pagination::new(Some(20000), Some(0))));
668 assert!(result.is_err());
669 assert!(
670 result
671 .unwrap_err()
672 .to_string()
673 .contains("must be between 1 and")
674 );
675
676 let fake_group_id = GroupId::from_slice(&[99, 99, 99, 99]);
678 let result = storage.messages(&fake_group_id, Some(Pagination::new(Some(10), Some(0))));
679 assert!(result.is_err());
680 assert!(result.unwrap_err().to_string().contains("not found"));
681
682 let empty_group_id = GroupId::from_slice(&[5, 6, 7, 8]);
684 let empty_group = Group {
685 mls_group_id: empty_group_id.clone(),
686 nostr_group_id: [2u8; 32],
687 name: "Empty Group".to_string(),
688 description: "A group with no messages".to_string(),
689 admin_pubkeys: BTreeSet::new(),
690 last_message_id: None,
691 last_message_at: None,
692 last_message_processed_at: None,
693 epoch: 0,
694 state: GroupState::Active,
695 image_hash: None,
696 image_key: None,
697 image_nonce: None,
698 self_update_state: SelfUpdateState::Required,
699 };
700 storage.save_group(empty_group).unwrap();
701
702 let empty = storage
703 .messages(&empty_group_id, Some(Pagination::new(Some(10), Some(0))))
704 .unwrap();
705 assert_eq!(empty.len(), 0);
706
707 let result = storage.messages(
709 &mls_group_id,
710 Some(Pagination::new(Some(10), Some(2_000_000))),
711 );
712 assert!(result.is_ok());
713 assert_eq!(result.unwrap().len(), 0); }
715
716 #[test]
718 fn test_custom_group_limits() {
719 use crate::ValidationLimits;
720
721 let limits = ValidationLimits::default()
723 .with_max_group_name_length(10)
724 .with_max_group_description_length(20)
725 .with_max_admins_per_group(2)
726 .with_max_relays_per_group(3);
727
728 let storage = MdkMemoryStorage::with_limits(limits);
729
730 assert_eq!(storage.limits().max_group_name_length, 10);
732 assert_eq!(storage.limits().max_group_description_length, 20);
733 assert_eq!(storage.limits().max_admins_per_group, 2);
734 assert_eq!(storage.limits().max_relays_per_group, 3);
735
736 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
737
738 let mut group = create_test_group(mls_group_id.clone(), [1u8; 32]);
740 group.name = "a".repeat(10);
741 assert!(storage.save_group(group).is_ok());
742
743 let mut group = create_test_group(GroupId::from_slice(&[2, 3, 4, 5]), [2u8; 32]);
745 group.name = "a".repeat(11);
746 let result = storage.save_group(group);
747 assert!(result.is_err());
748 assert!(result.unwrap_err().to_string().contains("10 bytes"));
749
750 let mut group = create_test_group(GroupId::from_slice(&[3, 4, 5, 6]), [3u8; 32]);
752 group.admin_pubkeys.insert(Keys::generate().public_key());
753 group.admin_pubkeys.insert(Keys::generate().public_key());
754 assert!(storage.save_group(group).is_ok());
755
756 let mut group = create_test_group(GroupId::from_slice(&[4, 5, 6, 7]), [4u8; 32]);
758 for _ in 0..3 {
759 group.admin_pubkeys.insert(Keys::generate().public_key());
760 }
761 let result = storage.save_group(group);
762 assert!(result.is_err());
763 assert!(result.unwrap_err().to_string().contains("maximum of 2"));
764
765 let group = create_test_group(GroupId::from_slice(&[5, 6, 7, 8]), [5u8; 32]);
767 storage.save_group(group).unwrap();
768
769 let relays = BTreeSet::from([
771 RelayUrl::parse("wss://r1.com").unwrap(),
772 RelayUrl::parse("wss://r2.com").unwrap(),
773 RelayUrl::parse("wss://r3.com").unwrap(),
774 ]);
775 assert!(
776 storage
777 .replace_group_relays(&GroupId::from_slice(&[5, 6, 7, 8]), relays)
778 .is_ok()
779 );
780
781 let relays = BTreeSet::from([
783 RelayUrl::parse("wss://r1.com").unwrap(),
784 RelayUrl::parse("wss://r2.com").unwrap(),
785 RelayUrl::parse("wss://r3.com").unwrap(),
786 RelayUrl::parse("wss://r4.com").unwrap(),
787 ]);
788 let result = storage.replace_group_relays(&GroupId::from_slice(&[5, 6, 7, 8]), relays);
789 assert!(result.is_err());
790 assert!(result.unwrap_err().to_string().contains("maximum of 3"));
791 }
792
793 #[test]
794 fn test_nostr_group_id_collision_rejected() {
795 let storage = MdkMemoryStorage::new();
796
797 let mls_group_id_1 = GroupId::from_slice(&[1, 2, 3, 4]);
799 let shared_nostr_group_id = [42u8; 32];
800
801 let group1 = Group {
802 mls_group_id: mls_group_id_1.clone(),
803 nostr_group_id: shared_nostr_group_id,
804 name: "Group 1".to_string(),
805 description: "First group".to_string(),
806 admin_pubkeys: BTreeSet::new(),
807 last_message_id: None,
808 last_message_at: None,
809 last_message_processed_at: None,
810 epoch: 0,
811 state: GroupState::Active,
812 image_hash: None,
813 image_key: None,
814 image_nonce: None,
815 self_update_state: SelfUpdateState::Required,
816 };
817
818 storage.save_group(group1).unwrap();
819
820 let mls_group_id_2 = GroupId::from_slice(&[5, 6, 7, 8]);
822
823 let group2 = Group {
824 mls_group_id: mls_group_id_2.clone(),
825 nostr_group_id: shared_nostr_group_id, name: "Group 2".to_string(),
827 description: "Second group trying to hijack".to_string(),
828 admin_pubkeys: BTreeSet::new(),
829 last_message_id: None,
830 last_message_at: None,
831 last_message_processed_at: None,
832 epoch: 0,
833 state: GroupState::Active,
834 image_hash: None,
835 image_key: None,
836 image_nonce: None,
837 self_update_state: SelfUpdateState::Required,
838 };
839
840 let result = storage.save_group(group2);
842 assert!(result.is_err());
843 let err = result.unwrap_err();
844 assert!(
845 err.to_string().contains("nostr_group_id already exists"),
846 "Expected collision error, got: {}",
847 err
848 );
849
850 let found = storage
852 .find_group_by_nostr_group_id(&shared_nostr_group_id)
853 .unwrap()
854 .unwrap();
855 assert_eq!(found.mls_group_id, mls_group_id_1);
856 assert_eq!(found.name, "Group 1");
857 }
858
859 #[test]
860 fn test_nostr_group_id_update_removes_stale_entry() {
861 let storage = MdkMemoryStorage::new();
862
863 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
864 let old_nostr_group_id = [1u8; 32];
865 let new_nostr_group_id = [2u8; 32];
866
867 let group = Group {
869 mls_group_id: mls_group_id.clone(),
870 nostr_group_id: old_nostr_group_id,
871 name: "Test Group".to_string(),
872 description: "A test group".to_string(),
873 admin_pubkeys: BTreeSet::new(),
874 last_message_id: None,
875 last_message_at: None,
876 last_message_processed_at: None,
877 epoch: 0,
878 state: GroupState::Active,
879 image_hash: None,
880 image_key: None,
881 image_nonce: None,
882 self_update_state: SelfUpdateState::Required,
883 };
884
885 storage.save_group(group).unwrap();
886
887 assert!(
889 storage
890 .find_group_by_nostr_group_id(&old_nostr_group_id)
891 .unwrap()
892 .is_some()
893 );
894
895 let updated_group = Group {
897 mls_group_id: mls_group_id.clone(),
898 nostr_group_id: new_nostr_group_id,
899 name: "Test Group Updated".to_string(),
900 description: "A test group".to_string(),
901 admin_pubkeys: BTreeSet::new(),
902 last_message_id: None,
903 last_message_at: None,
904 last_message_processed_at: None,
905 epoch: 1,
906 state: GroupState::Active,
907 image_hash: None,
908 image_key: None,
909 image_nonce: None,
910 self_update_state: SelfUpdateState::Required,
911 };
912
913 storage.save_group(updated_group).unwrap();
914
915 assert!(
917 storage
918 .find_group_by_nostr_group_id(&old_nostr_group_id)
919 .unwrap()
920 .is_none(),
921 "Old nostr_group_id should not find the group after update"
922 );
923
924 let found = storage
926 .find_group_by_nostr_group_id(&new_nostr_group_id)
927 .unwrap()
928 .unwrap();
929 assert_eq!(found.mls_group_id, mls_group_id);
930 assert_eq!(found.name, "Test Group Updated");
931 assert_eq!(found.epoch, 1);
932 }
933
934 #[test]
935 fn test_same_group_update_allowed() {
936 let storage = MdkMemoryStorage::new();
937
938 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
939 let nostr_group_id = [1u8; 32];
940
941 let group = Group {
943 mls_group_id: mls_group_id.clone(),
944 nostr_group_id,
945 name: "Test Group".to_string(),
946 description: "A test group".to_string(),
947 admin_pubkeys: BTreeSet::new(),
948 last_message_id: None,
949 last_message_at: None,
950 last_message_processed_at: None,
951 epoch: 0,
952 state: GroupState::Active,
953 image_hash: None,
954 image_key: None,
955 image_nonce: None,
956 self_update_state: SelfUpdateState::Required,
957 };
958
959 storage.save_group(group).unwrap();
960
961 let updated_group = Group {
963 mls_group_id: mls_group_id.clone(),
964 nostr_group_id, name: "Updated Group Name".to_string(),
966 description: "Updated description".to_string(),
967 admin_pubkeys: BTreeSet::new(),
968 last_message_id: None,
969 last_message_at: None,
970 last_message_processed_at: None,
971 epoch: 1,
972 state: GroupState::Active,
973 image_hash: None,
974 image_key: None,
975 image_nonce: None,
976 self_update_state: SelfUpdateState::Required,
977 };
978
979 let result = storage.save_group(updated_group);
981 assert!(result.is_ok());
982
983 let found = storage
985 .find_group_by_mls_group_id(&mls_group_id)
986 .unwrap()
987 .unwrap();
988 assert_eq!(found.name, "Updated Group Name");
989 assert_eq!(found.epoch, 1);
990 }
991}