1use std::collections::HashMap;
4
5use mdk_storage_traits::GroupId;
6use nostr::EventId;
7#[cfg(test)]
8use nostr::{Kind, Tags, Timestamp, UnsignedEvent};
9
10use mdk_storage_traits::groups::GroupStorage;
11use mdk_storage_traits::messages::MessageStorage;
12use mdk_storage_traits::messages::error::MessageError;
13use mdk_storage_traits::messages::types::*;
14
15use crate::MdkMemoryStorage;
16
17impl MessageStorage for MdkMemoryStorage {
18 fn save_message(&self, message: Message) -> Result<(), MessageError> {
19 match self.find_group_by_mls_group_id(&message.mls_group_id) {
21 Ok(Some(_)) => {
22 }
24 Ok(None) => {
25 return Err(MessageError::InvalidParameters(
26 "Group not found".to_string(),
27 ));
28 }
29 Err(e) => {
30 return Err(MessageError::InvalidParameters(format!(
31 "Failed to verify group existence: {}",
32 e
33 )));
34 }
35 }
36
37 let mut guard = self.inner.write();
39 let inner = &mut *guard;
40 let cache = &mut inner.messages_cache;
41 let group_cache = &mut inner.messages_by_group_cache;
42
43 match group_cache.get_mut(&message.mls_group_id) {
44 Some(group_messages) => {
45 let is_update = group_messages.contains_key(&message.id);
47
48 if !is_update && group_messages.len() >= self.limits.max_messages_per_group {
49 if let Some(oldest_id) = group_messages
52 .iter()
53 .min_by_key(|(_, msg)| msg.created_at)
54 .map(|(id, _)| *id)
55 {
56 group_messages.remove(&oldest_id);
58 cache.pop(&oldest_id);
59 }
60 }
61
62 group_messages.insert(message.id, message.clone());
64 }
65 None => {
66 let mut messages = HashMap::new();
68 let group_id = message.mls_group_id.clone();
69 messages.insert(message.id, message.clone());
70 group_cache.put(group_id, messages);
71 }
72 }
73
74 cache.put(message.id, message);
76
77 Ok(())
78 }
79
80 fn find_message_by_event_id(
81 &self,
82 mls_group_id: &GroupId,
83 event_id: &EventId,
84 ) -> Result<Option<Message>, MessageError> {
85 let inner = self.inner.read();
86 match inner.messages_by_group_cache.peek(mls_group_id) {
87 Some(group_messages) => Ok(group_messages.get(event_id).cloned()),
88 None => Ok(None),
89 }
90 }
91
92 fn find_processed_message_by_event_id(
93 &self,
94 event_id: &EventId,
95 ) -> Result<Option<ProcessedMessage>, MessageError> {
96 let inner = self.inner.read();
97 Ok(inner.processed_messages_cache.peek(event_id).cloned())
98 }
99
100 fn save_processed_message(
101 &self,
102 processed_message: ProcessedMessage,
103 ) -> Result<(), MessageError> {
104 let mut inner = self.inner.write();
105 inner
106 .processed_messages_cache
107 .put(processed_message.wrapper_event_id, processed_message);
108
109 Ok(())
110 }
111
112 fn invalidate_messages_after_epoch(
113 &self,
114 group_id: &GroupId,
115 epoch: u64,
116 ) -> Result<Vec<EventId>, MessageError> {
117 let mut inner = self.inner.write();
118 let mut invalidated_ids = Vec::new();
119
120 if let Some(group_messages) = inner.messages_by_group_cache.get_mut(group_id) {
122 for (event_id, message) in group_messages.iter_mut() {
123 if let Some(msg_epoch) = message.epoch
125 && msg_epoch > epoch
126 {
127 message.state = MessageState::EpochInvalidated;
128 invalidated_ids.push(*event_id);
129 }
130 }
131 }
132
133 for event_id in &invalidated_ids {
135 if let Some(message) = inner.messages_cache.get_mut(event_id) {
136 message.state = MessageState::EpochInvalidated;
137 }
138 }
139
140 Ok(invalidated_ids)
141 }
142
143 fn invalidate_processed_messages_after_epoch(
144 &self,
145 group_id: &GroupId,
146 epoch: u64,
147 ) -> Result<Vec<EventId>, MessageError> {
148 let mut inner = self.inner.write();
149 let mut invalidated_ids = Vec::new();
150
151 let cache = &mut inner.processed_messages_cache;
153 for (wrapper_event_id, processed_message) in cache.iter_mut() {
154 if let Some(ref msg_group_id) = processed_message.mls_group_id
156 && msg_group_id == group_id
157 && let Some(msg_epoch) = processed_message.epoch
158 && msg_epoch > epoch
159 {
160 processed_message.state = ProcessedMessageState::EpochInvalidated;
161 invalidated_ids.push(*wrapper_event_id);
162 }
163 }
164
165 Ok(invalidated_ids)
166 }
167
168 fn find_invalidated_messages(&self, group_id: &GroupId) -> Result<Vec<Message>, MessageError> {
169 let inner = self.inner.read();
170
171 if let Some(group_messages) = inner.messages_by_group_cache.peek(group_id) {
172 let invalidated: Vec<Message> = group_messages
173 .values()
174 .filter(|msg| msg.state == MessageState::EpochInvalidated)
175 .cloned()
176 .collect();
177 Ok(invalidated)
178 } else {
179 Ok(Vec::new())
180 }
181 }
182
183 fn find_invalidated_processed_messages(
184 &self,
185 group_id: &GroupId,
186 ) -> Result<Vec<ProcessedMessage>, MessageError> {
187 let inner = self.inner.read();
188
189 let invalidated: Vec<ProcessedMessage> = inner
190 .processed_messages_cache
191 .iter()
192 .filter_map(|(_, pm)| {
193 if let Some(ref msg_group_id) = pm.mls_group_id
194 && msg_group_id == group_id
195 && pm.state == ProcessedMessageState::EpochInvalidated
196 {
197 return Some(pm.clone());
198 }
199 None
200 })
201 .collect();
202
203 Ok(invalidated)
204 }
205
206 fn find_failed_messages_for_retry(
207 &self,
208 group_id: &GroupId,
209 ) -> Result<Vec<EventId>, MessageError> {
210 let inner = self.inner.read();
211
212 let event_ids: Vec<EventId> = inner
217 .processed_messages_cache
218 .iter()
219 .filter_map(|(wrapper_event_id, pm)| {
220 if let Some(ref msg_group_id) = pm.mls_group_id
221 && msg_group_id == group_id
222 && pm.state == ProcessedMessageState::Failed
223 && pm.epoch.is_none()
224 {
225 return Some(*wrapper_event_id);
226 }
227 None
228 })
229 .collect();
230
231 Ok(event_ids)
232 }
233
234 fn mark_processed_message_retryable(&self, event_id: &EventId) -> Result<(), MessageError> {
235 let mut inner = self.inner.write();
236
237 if let Some(pm) = inner.processed_messages_cache.get_mut(event_id)
239 && pm.state == ProcessedMessageState::Failed
240 {
241 pm.state = ProcessedMessageState::Retryable;
242 return Ok(());
243 }
244
245 Err(MessageError::NotFound)
246 }
247
248 fn find_message_epoch_by_tag_content(
249 &self,
250 group_id: &GroupId,
251 content_substring: &str,
252 ) -> Result<Option<u64>, MessageError> {
253 let inner = self.inner.read();
254
255 let Some(group_messages) = inner.messages_by_group_cache.peek(group_id) else {
256 return Ok(None);
257 };
258
259 for (epoch, message) in group_messages
260 .values()
261 .filter_map(|message| message.epoch.map(|epoch| (epoch, message)))
262 {
263 let tags_json = serde_json::to_string(&message.tags).map_err(|e| {
264 MessageError::DatabaseError(format!("Failed to serialize tags: {e}"))
265 })?;
266
267 if tags_json.contains(content_substring) {
268 return Ok(Some(epoch));
269 }
270 }
271
272 Ok(None)
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use std::collections::BTreeSet;
279
280 use mdk_storage_traits::groups::GroupStorage;
281 use mdk_storage_traits::groups::types::{Group, GroupState, SelfUpdateState};
282 use nostr::Keys;
283
284 use super::*;
285
286 fn create_test_group(group_id: GroupId) -> Group {
287 let mut nostr_group_id = [0u8; 32];
289 let group_id_bytes = group_id.as_slice();
290 nostr_group_id[..group_id_bytes.len().min(32)]
291 .copy_from_slice(&group_id_bytes[..group_id_bytes.len().min(32)]);
292
293 Group {
294 mls_group_id: group_id.clone(),
295 nostr_group_id,
296 name: "Test Group".to_string(),
297 description: "A test group".to_string(),
298 admin_pubkeys: BTreeSet::new(),
299 last_message_id: None,
300 last_message_at: None,
301 last_message_processed_at: None,
302 epoch: 0,
303 state: GroupState::Active,
304 image_hash: None,
305 image_key: None,
306 image_nonce: None,
307 self_update_state: SelfUpdateState::Required,
308 }
309 }
310
311 fn create_test_message(
312 event_id: EventId,
313 group_id: GroupId,
314 content: &str,
315 timestamp: u64,
316 ) -> Message {
317 create_test_message_with_epoch(event_id, group_id, content, timestamp, None)
318 }
319
320 fn create_test_message_with_epoch(
321 event_id: EventId,
322 group_id: GroupId,
323 content: &str,
324 timestamp: u64,
325 epoch: Option<u64>,
326 ) -> Message {
327 let pubkey = Keys::generate().public_key();
328 let wrapper_event_id = EventId::from_slice(&[200u8; 32]).unwrap();
329 let ts = Timestamp::from(timestamp);
330
331 Message {
332 id: event_id,
333 pubkey,
334 kind: Kind::from(1u16),
335 mls_group_id: group_id,
336 created_at: ts,
337 processed_at: ts,
338 content: content.to_string(),
339 tags: Tags::new(),
340 event: UnsignedEvent::new(pubkey, ts, Kind::from(9u16), vec![], content.to_string()),
341 wrapper_event_id,
342 epoch,
343 state: MessageState::Created,
344 }
345 }
346
347 #[test]
351 fn test_save_message_update_existing() {
352 let storage = MdkMemoryStorage::new();
353
354 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
355 let event_id = EventId::from_slice(&[10u8; 32]).unwrap();
356
357 let group = create_test_group(group_id.clone());
359 storage.save_group(group).unwrap();
360
361 let message1 = create_test_message(event_id, group_id.clone(), "Original content", 1000);
363 storage.save_message(message1).unwrap();
364
365 let found = storage
367 .find_message_by_event_id(&group_id, &event_id)
368 .unwrap()
369 .unwrap();
370 assert_eq!(found.content, "Original content");
371
372 {
374 let inner = storage.inner.read();
375 let cache = &inner.messages_by_group_cache;
376 let group_messages = cache.peek(&group_id).unwrap();
377 assert_eq!(group_messages.len(), 1);
378 }
379
380 let message2 = create_test_message(event_id, group_id.clone(), "Updated content", 1001);
382 storage.save_message(message2).unwrap();
383
384 let found = storage
386 .find_message_by_event_id(&group_id, &event_id)
387 .unwrap()
388 .unwrap();
389 assert_eq!(found.content, "Updated content");
390 assert_eq!(found.created_at, Timestamp::from(1001u64));
391
392 {
394 let inner = storage.inner.read();
395 let cache = &inner.messages_by_group_cache;
396 let group_messages = cache.peek(&group_id).unwrap();
397 assert_eq!(
398 group_messages.len(),
399 1,
400 "Should have exactly 1 message after update, not 2"
401 );
402 assert_eq!(
403 group_messages.get(&event_id).unwrap().content,
404 "Updated content"
405 );
406 }
407 }
408
409 #[test]
411 fn test_save_message_multiple_groups() {
412 let storage = MdkMemoryStorage::new();
413
414 let group1_id = GroupId::from_slice(&[1, 1, 1, 1]);
415 let group2_id = GroupId::from_slice(&[2, 2, 2, 2]);
416
417 let group1 = create_test_group(group1_id.clone());
419 storage.save_group(group1).unwrap();
420 let group2 = create_test_group(group2_id.clone());
421 storage.save_group(group2).unwrap();
422
423 for i in 0..3 {
425 let event_id = EventId::from_slice(&[i as u8; 32]).unwrap();
426 let message = create_test_message(
427 event_id,
428 group1_id.clone(),
429 &format!("Group1 Message {}", i),
430 1000 + i as u64,
431 );
432 storage.save_message(message).unwrap();
433 }
434
435 for i in 0..5 {
437 let event_id = EventId::from_slice(&[100 + i as u8; 32]).unwrap();
438 let message = create_test_message(
439 event_id,
440 group2_id.clone(),
441 &format!("Group2 Message {}", i),
442 2000 + i as u64,
443 );
444 storage.save_message(message).unwrap();
445 }
446
447 {
449 let inner = storage.inner.read();
450 let cache = &inner.messages_by_group_cache;
451 let group1_messages = cache.peek(&group1_id).unwrap();
452 assert_eq!(group1_messages.len(), 3);
453 }
454
455 {
457 let inner = storage.inner.read();
458 let cache = &inner.messages_by_group_cache;
459 let group2_messages = cache.peek(&group2_id).unwrap();
460 assert_eq!(group2_messages.len(), 5);
461 }
462
463 let event_id_group1 = EventId::from_slice(&[0u8; 32]).unwrap();
465 let found = storage
466 .find_message_by_event_id(&group1_id, &event_id_group1)
467 .unwrap()
468 .unwrap();
469 assert_eq!(found.mls_group_id, group1_id);
470 assert!(found.content.contains("Group1"));
471
472 let event_id_group2 = EventId::from_slice(&[100u8; 32]).unwrap();
473 let found = storage
474 .find_message_by_event_id(&group2_id, &event_id_group2)
475 .unwrap()
476 .unwrap();
477 assert_eq!(found.mls_group_id, group2_id);
478 assert!(found.content.contains("Group2"));
479 }
480
481 #[test]
483 fn test_save_message_multiple_updates() {
484 let storage = MdkMemoryStorage::new();
485
486 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
487 let event_id = EventId::from_slice(&[50u8; 32]).unwrap();
488
489 let group = create_test_group(group_id.clone());
491 storage.save_group(group).unwrap();
492
493 for i in 0..10 {
495 let message = create_test_message(
496 event_id,
497 group_id.clone(),
498 &format!("Version {}", i),
499 1000 + i as u64,
500 );
501 storage.save_message(message).unwrap();
502 }
503
504 let found = storage
506 .find_message_by_event_id(&group_id, &event_id)
507 .unwrap()
508 .unwrap();
509 assert_eq!(found.content, "Version 9");
510
511 {
513 let inner = storage.inner.read();
514 let cache = &inner.messages_by_group_cache;
515 let group_messages = cache.peek(&group_id).unwrap();
516 assert_eq!(
517 group_messages.len(),
518 1,
519 "Should have exactly 1 message after 10 updates"
520 );
521 }
522 }
523
524 #[test]
526 fn test_save_message_state_update() {
527 let storage = MdkMemoryStorage::new();
528
529 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
530 let event_id = EventId::from_slice(&[75u8; 32]).unwrap();
531
532 let group = create_test_group(group_id.clone());
534 storage.save_group(group).unwrap();
535
536 let mut message = create_test_message(event_id, group_id.clone(), "Test content", 1000);
538 message.state = MessageState::Created;
539 storage.save_message(message).unwrap();
540
541 let found = storage
543 .find_message_by_event_id(&group_id, &event_id)
544 .unwrap()
545 .unwrap();
546 assert_eq!(found.state, MessageState::Created);
547
548 let mut message = create_test_message(event_id, group_id.clone(), "Test content", 1000);
550 message.state = MessageState::Processed;
551 storage.save_message(message).unwrap();
552
553 let found = storage
555 .find_message_by_event_id(&group_id, &event_id)
556 .unwrap()
557 .unwrap();
558 assert_eq!(found.state, MessageState::Processed);
559
560 {
562 let inner = storage.inner.read();
563 let cache = &inner.messages_by_group_cache;
564 let group_messages = cache.peek(&group_id).unwrap();
565 assert_eq!(group_messages.len(), 1);
566 }
567 }
568
569 #[test]
575 fn test_save_message_per_group_limit_eviction() {
576 use crate::{DEFAULT_MAX_MESSAGES_PER_GROUP, ValidationLimits};
577
578 let limits = ValidationLimits::default().with_cache_size(20000);
580 let storage = MdkMemoryStorage::with_limits(limits);
581
582 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
583
584 let group = create_test_group(group_id.clone());
586 storage.save_group(group).unwrap();
587
588 for i in 0..DEFAULT_MAX_MESSAGES_PER_GROUP {
591 let mut event_bytes = [0u8; 32];
592 event_bytes[0] = (i % 256) as u8;
593 event_bytes[1] = ((i / 256) % 256) as u8;
594 event_bytes[2] = ((i / 65536) % 256) as u8;
595 let event_id = EventId::from_slice(&event_bytes).unwrap();
596 let message = create_test_message(
597 event_id,
598 group_id.clone(),
599 &format!("Message {}", i),
600 1000 + i as u64, );
602 storage.save_message(message).unwrap();
603 }
604
605 {
607 let inner = storage.inner.read();
608 let cache = &inner.messages_by_group_cache;
609 let group_messages = cache.peek(&group_id).unwrap();
610 assert_eq!(group_messages.len(), DEFAULT_MAX_MESSAGES_PER_GROUP);
611 }
612
613 let oldest_event_id = EventId::from_slice(&[0u8; 32]).unwrap();
615 {
616 let found = storage
617 .find_message_by_event_id(&group_id, &oldest_event_id)
618 .unwrap();
619 assert!(
620 found.is_some(),
621 "Oldest message should exist before eviction"
622 );
623 }
624
625 let new_event_bytes = [255u8; 32]; let new_event_id = EventId::from_slice(&new_event_bytes).unwrap();
628 let new_message = create_test_message(
629 new_event_id,
630 group_id.clone(),
631 "New message triggering eviction",
632 999999, );
634 storage.save_message(new_message).unwrap();
635
636 {
638 let inner = storage.inner.read();
639 let cache = &inner.messages_by_group_cache;
640 let group_messages = cache.peek(&group_id).unwrap();
641 assert_eq!(
642 group_messages.len(),
643 DEFAULT_MAX_MESSAGES_PER_GROUP,
644 "Should still have DEFAULT_MAX_MESSAGES_PER_GROUP after eviction"
645 );
646 }
647
648 {
650 let found = storage
651 .find_message_by_event_id(&group_id, &oldest_event_id)
652 .unwrap();
653 assert!(
654 found.is_none(),
655 "Oldest message should be evicted from messages_by_group_cache"
656 );
657 }
658
659 {
662 let inner = storage.inner.read();
663 let cache = &inner.messages_cache;
664 assert!(
665 !cache.contains(&oldest_event_id),
666 "Oldest message should be evicted from messages_cache too (no orphaned entries)"
667 );
668 }
669
670 {
672 let found = storage
673 .find_message_by_event_id(&group_id, &new_event_id)
674 .unwrap();
675 assert!(found.is_some(), "New message should exist after eviction");
676 assert_eq!(found.unwrap().content, "New message triggering eviction");
677 }
678
679 let mut update_event_bytes = [0u8; 32];
682 update_event_bytes[0] = 1;
683 let update_event_id = EventId::from_slice(&update_event_bytes).unwrap();
684 let update_message =
685 create_test_message(update_event_id, group_id.clone(), "Updated Message 1", 2000);
686 storage.save_message(update_message).unwrap();
687
688 {
690 let inner = storage.inner.read();
691 let cache = &inner.messages_by_group_cache;
692 let group_messages = cache.peek(&group_id).unwrap();
693 assert_eq!(
694 group_messages.len(),
695 DEFAULT_MAX_MESSAGES_PER_GROUP,
696 "Update should not change message count"
697 );
698 }
699
700 let found = storage
702 .find_message_by_event_id(&group_id, &update_event_id)
703 .unwrap()
704 .unwrap();
705 assert_eq!(found.content, "Updated Message 1");
706 }
707
708 #[test]
710 fn test_custom_message_limit() {
711 use crate::ValidationLimits;
712
713 let custom_limit = 5;
715 let limits = ValidationLimits::default()
716 .with_cache_size(100)
717 .with_max_messages_per_group(custom_limit);
718 let storage = MdkMemoryStorage::with_limits(limits);
719
720 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
721
722 let group = create_test_group(group_id.clone());
724 storage.save_group(group).unwrap();
725
726 for i in 0..custom_limit {
728 let mut event_bytes = [0u8; 32];
729 event_bytes[0] = i as u8;
730 let event_id = EventId::from_slice(&event_bytes).unwrap();
731 let message = create_test_message(
732 event_id,
733 group_id.clone(),
734 &format!("Message {}", i),
735 1000 + i as u64,
736 );
737 storage.save_message(message).unwrap();
738 }
739
740 {
742 let inner = storage.inner.read();
743 let cache = &inner.messages_by_group_cache;
744 let group_messages = cache.peek(&group_id).unwrap();
745 assert_eq!(group_messages.len(), custom_limit);
746 }
747
748 let new_event_id = EventId::from_slice(&[255u8; 32]).unwrap();
750 let new_message = create_test_message(
751 new_event_id,
752 group_id.clone(),
753 "New message triggering eviction",
754 999999,
755 );
756 storage.save_message(new_message).unwrap();
757
758 {
760 let inner = storage.inner.read();
761 let cache = &inner.messages_by_group_cache;
762 let group_messages = cache.peek(&group_id).unwrap();
763 assert_eq!(group_messages.len(), custom_limit);
764 }
765
766 let oldest_event_id = EventId::from_slice(&[0u8; 32]).unwrap();
768 {
769 let found = storage
770 .find_message_by_event_id(&group_id, &oldest_event_id)
771 .unwrap();
772 assert!(found.is_none(), "Oldest message should be evicted");
773 }
774 }
775
776 #[test]
777 fn test_mark_processed_message_retryable() {
778 use mdk_storage_traits::messages::types::ProcessedMessage;
779
780 let storage = MdkMemoryStorage::new();
781
782 let wrapper_event_id = EventId::from_slice(&[100u8; 32]).unwrap();
784
785 let processed_message = ProcessedMessage {
786 wrapper_event_id,
787 message_event_id: None,
788 processed_at: Timestamp::from(1_000_000_000u64),
789 epoch: None,
790 mls_group_id: Some(GroupId::from_slice(&[1, 2, 3, 4])),
791 state: ProcessedMessageState::Failed,
792 failure_reason: Some("Decryption failed".to_string()),
793 };
794
795 storage
797 .save_processed_message(processed_message)
798 .expect("Failed to save processed message");
799
800 let found = storage
802 .find_processed_message_by_event_id(&wrapper_event_id)
803 .unwrap()
804 .unwrap();
805 assert_eq!(found.state, ProcessedMessageState::Failed);
806
807 storage
809 .mark_processed_message_retryable(&wrapper_event_id)
810 .expect("Failed to mark message as retryable");
811
812 let found = storage
814 .find_processed_message_by_event_id(&wrapper_event_id)
815 .unwrap()
816 .unwrap();
817 assert_eq!(found.state, ProcessedMessageState::Retryable);
818
819 assert_eq!(found.failure_reason, Some("Decryption failed".to_string()));
821 }
822
823 #[test]
824 fn test_mark_nonexistent_message_retryable_fails() {
825 use mdk_storage_traits::messages::error::MessageError;
826
827 let storage = MdkMemoryStorage::new();
828
829 let wrapper_event_id = EventId::from_slice(&[100u8; 32]).unwrap();
830
831 let result = storage.mark_processed_message_retryable(&wrapper_event_id);
833 assert!(result.is_err());
834 assert!(matches!(result.unwrap_err(), MessageError::NotFound));
835 }
836
837 #[test]
838 fn test_mark_non_failed_message_retryable_fails() {
839 use mdk_storage_traits::messages::error::MessageError;
840 use mdk_storage_traits::messages::types::ProcessedMessage;
841
842 let storage = MdkMemoryStorage::new();
843
844 let wrapper_event_id = EventId::from_slice(&[100u8; 32]).unwrap();
846
847 let processed_message = ProcessedMessage {
848 wrapper_event_id,
849 message_event_id: None,
850 processed_at: Timestamp::from(1_000_000_000u64),
851 epoch: Some(1),
852 mls_group_id: Some(GroupId::from_slice(&[1, 2, 3, 4])),
853 state: ProcessedMessageState::Processed,
854 failure_reason: None,
855 };
856
857 storage
858 .save_processed_message(processed_message)
859 .expect("Failed to save processed message");
860
861 let result = storage.mark_processed_message_retryable(&wrapper_event_id);
863 assert!(result.is_err());
864 assert!(matches!(result.unwrap_err(), MessageError::NotFound));
865
866 let found = storage
868 .find_processed_message_by_event_id(&wrapper_event_id)
869 .unwrap()
870 .unwrap();
871 assert_eq!(found.state, ProcessedMessageState::Processed);
872 }
873
874 #[test]
876 fn test_find_message_epoch_by_tag_content_unknown_group() {
877 let storage = MdkMemoryStorage::new();
878 let unknown_group_id = GroupId::from_slice(&[99, 99, 99, 99]);
879
880 let result = storage
881 .find_message_epoch_by_tag_content(&unknown_group_id, "x abcdef")
882 .unwrap();
883
884 assert_eq!(result, None);
885 }
886
887 #[test]
890 fn test_find_message_epoch_by_tag_content_no_matching_tag() {
891 let storage = MdkMemoryStorage::new();
892 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
893
894 let group = create_test_group(group_id.clone());
895 storage.save_group(group).unwrap();
896
897 let event_id = EventId::from_slice(&[10u8; 32]).unwrap();
898 let message = create_test_message_with_epoch(
899 event_id,
900 group_id.clone(),
901 "some content",
902 1000,
903 Some(5),
904 );
905 storage.save_message(message).unwrap();
906
907 let result = storage
908 .find_message_epoch_by_tag_content(&group_id, "x deadbeef_not_present")
909 .unwrap();
910
911 assert_eq!(result, None);
912 }
913
914 #[test]
917 fn test_find_message_epoch_by_tag_content_happy_path() {
918 let storage = MdkMemoryStorage::new();
919 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
920
921 let group = create_test_group(group_id.clone());
922 storage.save_group(group).unwrap();
923
924 let event_id = EventId::from_slice(&[10u8; 32]).unwrap();
925 let pubkey = Keys::generate().public_key();
926 let wrapper_event_id = EventId::from_slice(&[200u8; 32]).unwrap();
927
928 let tags = Tags::parse(vec![vec!["imeta", "x abcdef123456"]]).unwrap();
929 let message = Message {
930 id: event_id,
931 pubkey,
932 kind: Kind::from(445u16),
933 mls_group_id: group_id.clone(),
934 created_at: Timestamp::from(1000u64),
935 processed_at: Timestamp::from(1000u64),
936 content: "".to_string(),
937 tags: tags.clone(),
938 event: UnsignedEvent::new(
939 pubkey,
940 Timestamp::from(1000u64),
941 Kind::from(445u16),
942 tags,
943 "".to_string(),
944 ),
945 wrapper_event_id,
946 epoch: Some(7),
947 state: MessageState::Processed,
948 };
949 storage.save_message(message).unwrap();
950
951 let result = storage
952 .find_message_epoch_by_tag_content(&group_id, "x abcdef123456")
953 .unwrap();
954
955 assert_eq!(result, Some(7));
956 }
957
958 #[test]
961 fn test_find_message_epoch_by_tag_content_skips_null_epoch() {
962 let storage = MdkMemoryStorage::new();
963 let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
964
965 let group = create_test_group(group_id.clone());
966 storage.save_group(group).unwrap();
967
968 let event_id = EventId::from_slice(&[10u8; 32]).unwrap();
969 let pubkey = Keys::generate().public_key();
970 let wrapper_event_id = EventId::from_slice(&[200u8; 32]).unwrap();
971
972 let tags = Tags::parse(vec![vec!["imeta", "x abcdef123456"]]).unwrap();
974 let message = Message {
975 id: event_id,
976 pubkey,
977 kind: Kind::from(445u16),
978 mls_group_id: group_id.clone(),
979 created_at: Timestamp::from(1000u64),
980 processed_at: Timestamp::from(1000u64),
981 content: "".to_string(),
982 tags: tags.clone(),
983 event: UnsignedEvent::new(
984 pubkey,
985 Timestamp::from(1000u64),
986 Kind::from(445u16),
987 tags,
988 "".to_string(),
989 ),
990 wrapper_event_id,
991 epoch: None,
992 state: MessageState::Processed,
993 };
994 storage.save_message(message).unwrap();
995
996 let result = storage
997 .find_message_epoch_by_tag_content(&group_id, "x abcdef123456")
998 .unwrap();
999
1000 assert_eq!(result, None);
1001 }
1002}