Skip to main content

mdk_memory_storage/
messages.rs

1//! Memory-based storage implementation of the MdkStorageProvider trait for MDK messages
2
3use std::collections::HashMap;
4
5use mdk_storage_traits::GroupId;
6use nostr::EventId;
7#[cfg(test)]
8use nostr::{Kind, Tags, Timestamp, UnsignedEvent};
9
10use mdk_storage_traits::groups::GroupStorage;
11use mdk_storage_traits::messages::MessageStorage;
12use mdk_storage_traits::messages::error::MessageError;
13use mdk_storage_traits::messages::types::*;
14
15use crate::MdkMemoryStorage;
16
17impl MessageStorage for MdkMemoryStorage {
18    fn save_message(&self, message: Message) -> Result<(), MessageError> {
19        // Verify that the group exists before saving the message
20        match self.find_group_by_mls_group_id(&message.mls_group_id) {
21            Ok(Some(_)) => {
22                // Group exists, proceed with saving
23            }
24            Ok(None) => {
25                return Err(MessageError::InvalidParameters(
26                    "Group not found".to_string(),
27                ));
28            }
29            Err(e) => {
30                return Err(MessageError::InvalidParameters(format!(
31                    "Failed to verify group existence: {}",
32                    e
33                )));
34            }
35        }
36
37        // Acquire lock on inner storage
38        let mut guard = self.inner.write();
39        let inner = &mut *guard;
40        let cache = &mut inner.messages_cache;
41        let group_cache = &mut inner.messages_by_group_cache;
42
43        match group_cache.get_mut(&message.mls_group_id) {
44            Some(group_messages) => {
45                // Check if this is an update (message already exists) or a new message
46                let is_update = group_messages.contains_key(&message.id);
47
48                if !is_update && group_messages.len() >= self.limits.max_messages_per_group {
49                    // Evict the oldest message to make room for the new one
50                    // Find the message with the oldest created_at timestamp
51                    if let Some(oldest_id) = group_messages
52                        .iter()
53                        .min_by_key(|(_, msg)| msg.created_at)
54                        .map(|(id, _)| *id)
55                    {
56                        // Remove from both caches to prevent orphaned entries
57                        group_messages.remove(&oldest_id);
58                        cache.pop(&oldest_id);
59                    }
60                }
61
62                // O(1) insert or update using HashMap
63                group_messages.insert(message.id, message.clone());
64            }
65            None => {
66                // Create new HashMap for this group
67                let mut messages = HashMap::new();
68                let group_id = message.mls_group_id.clone();
69                messages.insert(message.id, message.clone());
70                group_cache.put(group_id, messages);
71            }
72        }
73
74        // Save in the messages cache
75        cache.put(message.id, message);
76
77        Ok(())
78    }
79
80    fn find_message_by_event_id(
81        &self,
82        mls_group_id: &GroupId,
83        event_id: &EventId,
84    ) -> Result<Option<Message>, MessageError> {
85        let inner = self.inner.read();
86        match inner.messages_by_group_cache.peek(mls_group_id) {
87            Some(group_messages) => Ok(group_messages.get(event_id).cloned()),
88            None => Ok(None),
89        }
90    }
91
92    fn find_processed_message_by_event_id(
93        &self,
94        event_id: &EventId,
95    ) -> Result<Option<ProcessedMessage>, MessageError> {
96        let inner = self.inner.read();
97        Ok(inner.processed_messages_cache.peek(event_id).cloned())
98    }
99
100    fn save_processed_message(
101        &self,
102        processed_message: ProcessedMessage,
103    ) -> Result<(), MessageError> {
104        let mut inner = self.inner.write();
105        inner
106            .processed_messages_cache
107            .put(processed_message.wrapper_event_id, processed_message);
108
109        Ok(())
110    }
111
112    fn invalidate_messages_after_epoch(
113        &self,
114        group_id: &GroupId,
115        epoch: u64,
116    ) -> Result<Vec<EventId>, MessageError> {
117        let mut inner = self.inner.write();
118        let mut invalidated_ids = Vec::new();
119
120        // Get the group messages
121        if let Some(group_messages) = inner.messages_by_group_cache.get_mut(group_id) {
122            for (event_id, message) in group_messages.iter_mut() {
123                // Only invalidate messages with epoch > target
124                if let Some(msg_epoch) = message.epoch
125                    && msg_epoch > epoch
126                {
127                    message.state = MessageState::EpochInvalidated;
128                    invalidated_ids.push(*event_id);
129                }
130            }
131        }
132
133        // Also update in the messages_cache
134        for event_id in &invalidated_ids {
135            if let Some(message) = inner.messages_cache.get_mut(event_id) {
136                message.state = MessageState::EpochInvalidated;
137            }
138        }
139
140        Ok(invalidated_ids)
141    }
142
143    fn invalidate_processed_messages_after_epoch(
144        &self,
145        group_id: &GroupId,
146        epoch: u64,
147    ) -> Result<Vec<EventId>, MessageError> {
148        let mut inner = self.inner.write();
149        let mut invalidated_ids = Vec::new();
150
151        // Iterate through all processed messages and invalidate those matching the group and epoch
152        let cache = &mut inner.processed_messages_cache;
153        for (wrapper_event_id, processed_message) in cache.iter_mut() {
154            // Check if this message belongs to the specified group
155            if let Some(ref msg_group_id) = processed_message.mls_group_id
156                && msg_group_id == group_id
157                && let Some(msg_epoch) = processed_message.epoch
158                && msg_epoch > epoch
159            {
160                processed_message.state = ProcessedMessageState::EpochInvalidated;
161                invalidated_ids.push(*wrapper_event_id);
162            }
163        }
164
165        Ok(invalidated_ids)
166    }
167
168    fn find_invalidated_messages(&self, group_id: &GroupId) -> Result<Vec<Message>, MessageError> {
169        let inner = self.inner.read();
170
171        if let Some(group_messages) = inner.messages_by_group_cache.peek(group_id) {
172            let invalidated: Vec<Message> = group_messages
173                .values()
174                .filter(|msg| msg.state == MessageState::EpochInvalidated)
175                .cloned()
176                .collect();
177            Ok(invalidated)
178        } else {
179            Ok(Vec::new())
180        }
181    }
182
183    fn find_invalidated_processed_messages(
184        &self,
185        group_id: &GroupId,
186    ) -> Result<Vec<ProcessedMessage>, MessageError> {
187        let inner = self.inner.read();
188
189        let invalidated: Vec<ProcessedMessage> = inner
190            .processed_messages_cache
191            .iter()
192            .filter_map(|(_, pm)| {
193                if let Some(ref msg_group_id) = pm.mls_group_id
194                    && msg_group_id == group_id
195                    && pm.state == ProcessedMessageState::EpochInvalidated
196                {
197                    return Some(pm.clone());
198                }
199                None
200            })
201            .collect();
202
203        Ok(invalidated)
204    }
205
206    fn find_failed_messages_for_retry(
207        &self,
208        group_id: &GroupId,
209    ) -> Result<Vec<EventId>, MessageError> {
210        let inner = self.inner.read();
211
212        // Find processed messages that:
213        // - Are for this group
214        // - Have state = Failed
215        // - Have epoch = None (decryption failed before epoch could be determined)
216        let event_ids: Vec<EventId> = inner
217            .processed_messages_cache
218            .iter()
219            .filter_map(|(wrapper_event_id, pm)| {
220                if let Some(ref msg_group_id) = pm.mls_group_id
221                    && msg_group_id == group_id
222                    && pm.state == ProcessedMessageState::Failed
223                    && pm.epoch.is_none()
224                {
225                    return Some(*wrapper_event_id);
226                }
227                None
228            })
229            .collect();
230
231        Ok(event_ids)
232    }
233
234    fn mark_processed_message_retryable(&self, event_id: &EventId) -> Result<(), MessageError> {
235        let mut inner = self.inner.write();
236
237        // Only update messages that are currently in Failed state
238        if let Some(pm) = inner.processed_messages_cache.get_mut(event_id)
239            && pm.state == ProcessedMessageState::Failed
240        {
241            pm.state = ProcessedMessageState::Retryable;
242            return Ok(());
243        }
244
245        Err(MessageError::NotFound)
246    }
247
248    fn find_message_epoch_by_tag_content(
249        &self,
250        group_id: &GroupId,
251        content_substring: &str,
252    ) -> Result<Option<u64>, MessageError> {
253        let inner = self.inner.read();
254
255        let Some(group_messages) = inner.messages_by_group_cache.peek(group_id) else {
256            return Ok(None);
257        };
258
259        for (epoch, message) in group_messages
260            .values()
261            .filter_map(|message| message.epoch.map(|epoch| (epoch, message)))
262        {
263            let tags_json = serde_json::to_string(&message.tags).map_err(|e| {
264                MessageError::DatabaseError(format!("Failed to serialize tags: {e}"))
265            })?;
266
267            if tags_json.contains(content_substring) {
268                return Ok(Some(epoch));
269            }
270        }
271
272        Ok(None)
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use std::collections::BTreeSet;
279
280    use mdk_storage_traits::groups::GroupStorage;
281    use mdk_storage_traits::groups::types::{Group, GroupState, SelfUpdateState};
282    use nostr::Keys;
283
284    use super::*;
285
286    fn create_test_group(group_id: GroupId) -> Group {
287        // Use the group_id bytes to derive a unique nostr_group_id
288        let mut nostr_group_id = [0u8; 32];
289        let group_id_bytes = group_id.as_slice();
290        nostr_group_id[..group_id_bytes.len().min(32)]
291            .copy_from_slice(&group_id_bytes[..group_id_bytes.len().min(32)]);
292
293        Group {
294            mls_group_id: group_id.clone(),
295            nostr_group_id,
296            name: "Test Group".to_string(),
297            description: "A test group".to_string(),
298            admin_pubkeys: BTreeSet::new(),
299            last_message_id: None,
300            last_message_at: None,
301            last_message_processed_at: None,
302            epoch: 0,
303            state: GroupState::Active,
304            image_hash: None,
305            image_key: None,
306            image_nonce: None,
307            self_update_state: SelfUpdateState::Required,
308        }
309    }
310
311    fn create_test_message(
312        event_id: EventId,
313        group_id: GroupId,
314        content: &str,
315        timestamp: u64,
316    ) -> Message {
317        create_test_message_with_epoch(event_id, group_id, content, timestamp, None)
318    }
319
320    fn create_test_message_with_epoch(
321        event_id: EventId,
322        group_id: GroupId,
323        content: &str,
324        timestamp: u64,
325        epoch: Option<u64>,
326    ) -> Message {
327        let pubkey = Keys::generate().public_key();
328        let wrapper_event_id = EventId::from_slice(&[200u8; 32]).unwrap();
329        let ts = Timestamp::from(timestamp);
330
331        Message {
332            id: event_id,
333            pubkey,
334            kind: Kind::from(1u16),
335            mls_group_id: group_id,
336            created_at: ts,
337            processed_at: ts,
338            content: content.to_string(),
339            tags: Tags::new(),
340            event: UnsignedEvent::new(pubkey, ts, Kind::from(9u16), vec![], content.to_string()),
341            wrapper_event_id,
342            epoch,
343            state: MessageState::Created,
344        }
345    }
346
347    /// Test that saving a message with the same EventId updates the existing message
348    /// rather than creating a duplicate. This verifies the O(1) update behavior
349    /// of the HashMap-based implementation.
350    #[test]
351    fn test_save_message_update_existing() {
352        let storage = MdkMemoryStorage::new();
353
354        let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
355        let event_id = EventId::from_slice(&[10u8; 32]).unwrap();
356
357        // Create the group first
358        let group = create_test_group(group_id.clone());
359        storage.save_group(group).unwrap();
360
361        // Save initial message
362        let message1 = create_test_message(event_id, group_id.clone(), "Original content", 1000);
363        storage.save_message(message1).unwrap();
364
365        // Verify initial message is saved
366        let found = storage
367            .find_message_by_event_id(&group_id, &event_id)
368            .unwrap()
369            .unwrap();
370        assert_eq!(found.content, "Original content");
371
372        // Verify the group cache has exactly 1 message
373        {
374            let inner = storage.inner.read();
375            let cache = &inner.messages_by_group_cache;
376            let group_messages = cache.peek(&group_id).unwrap();
377            assert_eq!(group_messages.len(), 1);
378        }
379
380        // Save updated message with same EventId but different content
381        let message2 = create_test_message(event_id, group_id.clone(), "Updated content", 1001);
382        storage.save_message(message2).unwrap();
383
384        // Verify the message was updated, not duplicated
385        let found = storage
386            .find_message_by_event_id(&group_id, &event_id)
387            .unwrap()
388            .unwrap();
389        assert_eq!(found.content, "Updated content");
390        assert_eq!(found.created_at, Timestamp::from(1001u64));
391
392        // Verify the group cache still has exactly 1 message (no duplicates)
393        {
394            let inner = storage.inner.read();
395            let cache = &inner.messages_by_group_cache;
396            let group_messages = cache.peek(&group_id).unwrap();
397            assert_eq!(
398                group_messages.len(),
399                1,
400                "Should have exactly 1 message after update, not 2"
401            );
402            assert_eq!(
403                group_messages.get(&event_id).unwrap().content,
404                "Updated content"
405            );
406        }
407    }
408
409    /// Test that messages are properly isolated between different groups
410    #[test]
411    fn test_save_message_multiple_groups() {
412        let storage = MdkMemoryStorage::new();
413
414        let group1_id = GroupId::from_slice(&[1, 1, 1, 1]);
415        let group2_id = GroupId::from_slice(&[2, 2, 2, 2]);
416
417        // Create the groups first
418        let group1 = create_test_group(group1_id.clone());
419        storage.save_group(group1).unwrap();
420        let group2 = create_test_group(group2_id.clone());
421        storage.save_group(group2).unwrap();
422
423        // Save messages to group 1
424        for i in 0..3 {
425            let event_id = EventId::from_slice(&[i as u8; 32]).unwrap();
426            let message = create_test_message(
427                event_id,
428                group1_id.clone(),
429                &format!("Group1 Message {}", i),
430                1000 + i as u64,
431            );
432            storage.save_message(message).unwrap();
433        }
434
435        // Save messages to group 2
436        for i in 0..5 {
437            let event_id = EventId::from_slice(&[100 + i as u8; 32]).unwrap();
438            let message = create_test_message(
439                event_id,
440                group2_id.clone(),
441                &format!("Group2 Message {}", i),
442                2000 + i as u64,
443            );
444            storage.save_message(message).unwrap();
445        }
446
447        // Verify group 1 has 3 messages
448        {
449            let inner = storage.inner.read();
450            let cache = &inner.messages_by_group_cache;
451            let group1_messages = cache.peek(&group1_id).unwrap();
452            assert_eq!(group1_messages.len(), 3);
453        }
454
455        // Verify group 2 has 5 messages
456        {
457            let inner = storage.inner.read();
458            let cache = &inner.messages_by_group_cache;
459            let group2_messages = cache.peek(&group2_id).unwrap();
460            assert_eq!(group2_messages.len(), 5);
461        }
462
463        // Verify messages are correctly associated with their groups
464        let event_id_group1 = EventId::from_slice(&[0u8; 32]).unwrap();
465        let found = storage
466            .find_message_by_event_id(&group1_id, &event_id_group1)
467            .unwrap()
468            .unwrap();
469        assert_eq!(found.mls_group_id, group1_id);
470        assert!(found.content.contains("Group1"));
471
472        let event_id_group2 = EventId::from_slice(&[100u8; 32]).unwrap();
473        let found = storage
474            .find_message_by_event_id(&group2_id, &event_id_group2)
475            .unwrap()
476            .unwrap();
477        assert_eq!(found.mls_group_id, group2_id);
478        assert!(found.content.contains("Group2"));
479    }
480
481    /// Test that multiple updates to the same message work correctly
482    #[test]
483    fn test_save_message_multiple_updates() {
484        let storage = MdkMemoryStorage::new();
485
486        let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
487        let event_id = EventId::from_slice(&[50u8; 32]).unwrap();
488
489        // Create the group first
490        let group = create_test_group(group_id.clone());
491        storage.save_group(group).unwrap();
492
493        // Perform multiple updates to the same message
494        for i in 0..10 {
495            let message = create_test_message(
496                event_id,
497                group_id.clone(),
498                &format!("Version {}", i),
499                1000 + i as u64,
500            );
501            storage.save_message(message).unwrap();
502        }
503
504        // Verify only the final version exists
505        let found = storage
506            .find_message_by_event_id(&group_id, &event_id)
507            .unwrap()
508            .unwrap();
509        assert_eq!(found.content, "Version 9");
510
511        // Verify the group cache has exactly 1 message
512        {
513            let inner = storage.inner.read();
514            let cache = &inner.messages_by_group_cache;
515            let group_messages = cache.peek(&group_id).unwrap();
516            assert_eq!(
517                group_messages.len(),
518                1,
519                "Should have exactly 1 message after 10 updates"
520            );
521        }
522    }
523
524    /// Test that updating message state works correctly
525    #[test]
526    fn test_save_message_state_update() {
527        let storage = MdkMemoryStorage::new();
528
529        let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
530        let event_id = EventId::from_slice(&[75u8; 32]).unwrap();
531
532        // Create the group first
533        let group = create_test_group(group_id.clone());
534        storage.save_group(group).unwrap();
535
536        // Save message with Created state
537        let mut message = create_test_message(event_id, group_id.clone(), "Test content", 1000);
538        message.state = MessageState::Created;
539        storage.save_message(message).unwrap();
540
541        // Verify initial state
542        let found = storage
543            .find_message_by_event_id(&group_id, &event_id)
544            .unwrap()
545            .unwrap();
546        assert_eq!(found.state, MessageState::Created);
547
548        // Update message with Processed state
549        let mut message = create_test_message(event_id, group_id.clone(), "Test content", 1000);
550        message.state = MessageState::Processed;
551        storage.save_message(message).unwrap();
552
553        // Verify state was updated
554        let found = storage
555            .find_message_by_event_id(&group_id, &event_id)
556            .unwrap()
557            .unwrap();
558        assert_eq!(found.state, MessageState::Processed);
559
560        // Verify still only 1 message in the group
561        {
562            let inner = storage.inner.read();
563            let cache = &inner.messages_by_group_cache;
564            let group_messages = cache.peek(&group_id).unwrap();
565            assert_eq!(group_messages.len(), 1);
566        }
567    }
568
569    /// Test that the messages per group limit is enforced and oldest messages are evicted.
570    /// This test verifies that:
571    /// 1. When max_messages_per_group is reached, the oldest message is evicted
572    /// 2. Evicted messages are removed from BOTH caches (messages_cache and messages_by_group_cache)
573    /// 3. Updates to existing messages don't trigger eviction
574    #[test]
575    fn test_save_message_per_group_limit_eviction() {
576        use crate::{DEFAULT_MAX_MESSAGES_PER_GROUP, ValidationLimits};
577
578        // Create storage with a large cache size for testing
579        let limits = ValidationLimits::default().with_cache_size(20000);
580        let storage = MdkMemoryStorage::with_limits(limits);
581
582        let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
583
584        // Create the group first
585        let group = create_test_group(group_id.clone());
586        storage.save_group(group).unwrap();
587
588        // Save exactly DEFAULT_MAX_MESSAGES_PER_GROUP messages
589        // Use timestamps 1000..1000+MAX to establish age ordering
590        for i in 0..DEFAULT_MAX_MESSAGES_PER_GROUP {
591            let mut event_bytes = [0u8; 32];
592            event_bytes[0] = (i % 256) as u8;
593            event_bytes[1] = ((i / 256) % 256) as u8;
594            event_bytes[2] = ((i / 65536) % 256) as u8;
595            let event_id = EventId::from_slice(&event_bytes).unwrap();
596            let message = create_test_message(
597                event_id,
598                group_id.clone(),
599                &format!("Message {}", i),
600                1000 + i as u64, // Oldest message has timestamp 1000
601            );
602            storage.save_message(message).unwrap();
603        }
604
605        // Verify all messages are stored
606        {
607            let inner = storage.inner.read();
608            let cache = &inner.messages_by_group_cache;
609            let group_messages = cache.peek(&group_id).unwrap();
610            assert_eq!(group_messages.len(), DEFAULT_MAX_MESSAGES_PER_GROUP);
611        }
612
613        // The oldest message (index 0, timestamp 1000) should exist
614        let oldest_event_id = EventId::from_slice(&[0u8; 32]).unwrap();
615        {
616            let found = storage
617                .find_message_by_event_id(&group_id, &oldest_event_id)
618                .unwrap();
619            assert!(
620                found.is_some(),
621                "Oldest message should exist before eviction"
622            );
623        }
624
625        // Now add one more message to trigger eviction
626        let new_event_bytes = [255u8; 32]; // Unique event ID
627        let new_event_id = EventId::from_slice(&new_event_bytes).unwrap();
628        let new_message = create_test_message(
629            new_event_id,
630            group_id.clone(),
631            "New message triggering eviction",
632            999999, // Much newer timestamp
633        );
634        storage.save_message(new_message).unwrap();
635
636        // Verify the count is still at MAX (eviction occurred)
637        {
638            let inner = storage.inner.read();
639            let cache = &inner.messages_by_group_cache;
640            let group_messages = cache.peek(&group_id).unwrap();
641            assert_eq!(
642                group_messages.len(),
643                DEFAULT_MAX_MESSAGES_PER_GROUP,
644                "Should still have DEFAULT_MAX_MESSAGES_PER_GROUP after eviction"
645            );
646        }
647
648        // The oldest message should have been evicted from messages_by_group_cache
649        {
650            let found = storage
651                .find_message_by_event_id(&group_id, &oldest_event_id)
652                .unwrap();
653            assert!(
654                found.is_none(),
655                "Oldest message should be evicted from messages_by_group_cache"
656            );
657        }
658
659        // CRITICAL: The oldest message should ALSO be evicted from messages_cache
660        // This verifies the coordinated eviction fix
661        {
662            let inner = storage.inner.read();
663            let cache = &inner.messages_cache;
664            assert!(
665                !cache.contains(&oldest_event_id),
666                "Oldest message should be evicted from messages_cache too (no orphaned entries)"
667            );
668        }
669
670        // The new message should exist
671        {
672            let found = storage
673                .find_message_by_event_id(&group_id, &new_event_id)
674                .unwrap();
675            assert!(found.is_some(), "New message should exist after eviction");
676            assert_eq!(found.unwrap().content, "New message triggering eviction");
677        }
678
679        // Verify updating an existing message doesn't trigger eviction
680        // Update the second oldest message (index 1)
681        let mut update_event_bytes = [0u8; 32];
682        update_event_bytes[0] = 1;
683        let update_event_id = EventId::from_slice(&update_event_bytes).unwrap();
684        let update_message =
685            create_test_message(update_event_id, group_id.clone(), "Updated Message 1", 2000);
686        storage.save_message(update_message).unwrap();
687
688        // Should still have the same count (no eviction for updates)
689        {
690            let inner = storage.inner.read();
691            let cache = &inner.messages_by_group_cache;
692            let group_messages = cache.peek(&group_id).unwrap();
693            assert_eq!(
694                group_messages.len(),
695                DEFAULT_MAX_MESSAGES_PER_GROUP,
696                "Update should not change message count"
697            );
698        }
699
700        // Verify the message was updated
701        let found = storage
702            .find_message_by_event_id(&group_id, &update_event_id)
703            .unwrap()
704            .unwrap();
705        assert_eq!(found.content, "Updated Message 1");
706    }
707
708    /// Test that custom validation limits work correctly
709    #[test]
710    fn test_custom_message_limit() {
711        use crate::ValidationLimits;
712
713        // Create storage with a custom small message limit for testing
714        let custom_limit = 5;
715        let limits = ValidationLimits::default()
716            .with_cache_size(100)
717            .with_max_messages_per_group(custom_limit);
718        let storage = MdkMemoryStorage::with_limits(limits);
719
720        let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
721
722        // Create the group first
723        let group = create_test_group(group_id.clone());
724        storage.save_group(group).unwrap();
725
726        // Save exactly custom_limit messages
727        for i in 0..custom_limit {
728            let mut event_bytes = [0u8; 32];
729            event_bytes[0] = i as u8;
730            let event_id = EventId::from_slice(&event_bytes).unwrap();
731            let message = create_test_message(
732                event_id,
733                group_id.clone(),
734                &format!("Message {}", i),
735                1000 + i as u64,
736            );
737            storage.save_message(message).unwrap();
738        }
739
740        // Verify all messages are stored
741        {
742            let inner = storage.inner.read();
743            let cache = &inner.messages_by_group_cache;
744            let group_messages = cache.peek(&group_id).unwrap();
745            assert_eq!(group_messages.len(), custom_limit);
746        }
747
748        // Add one more message to trigger eviction
749        let new_event_id = EventId::from_slice(&[255u8; 32]).unwrap();
750        let new_message = create_test_message(
751            new_event_id,
752            group_id.clone(),
753            "New message triggering eviction",
754            999999,
755        );
756        storage.save_message(new_message).unwrap();
757
758        // Verify the count is still at custom_limit (eviction occurred)
759        {
760            let inner = storage.inner.read();
761            let cache = &inner.messages_by_group_cache;
762            let group_messages = cache.peek(&group_id).unwrap();
763            assert_eq!(group_messages.len(), custom_limit);
764        }
765
766        // The oldest message (index 0) should have been evicted
767        let oldest_event_id = EventId::from_slice(&[0u8; 32]).unwrap();
768        {
769            let found = storage
770                .find_message_by_event_id(&group_id, &oldest_event_id)
771                .unwrap();
772            assert!(found.is_none(), "Oldest message should be evicted");
773        }
774    }
775
776    #[test]
777    fn test_mark_processed_message_retryable() {
778        use mdk_storage_traits::messages::types::ProcessedMessage;
779
780        let storage = MdkMemoryStorage::new();
781
782        // Create a failed processed message
783        let wrapper_event_id = EventId::from_slice(&[100u8; 32]).unwrap();
784
785        let processed_message = ProcessedMessage {
786            wrapper_event_id,
787            message_event_id: None,
788            processed_at: Timestamp::from(1_000_000_000u64),
789            epoch: None,
790            mls_group_id: Some(GroupId::from_slice(&[1, 2, 3, 4])),
791            state: ProcessedMessageState::Failed,
792            failure_reason: Some("Decryption failed".to_string()),
793        };
794
795        // Save the failed processed message
796        storage
797            .save_processed_message(processed_message)
798            .expect("Failed to save processed message");
799
800        // Verify it's in Failed state
801        let found = storage
802            .find_processed_message_by_event_id(&wrapper_event_id)
803            .unwrap()
804            .unwrap();
805        assert_eq!(found.state, ProcessedMessageState::Failed);
806
807        // Mark as retryable
808        storage
809            .mark_processed_message_retryable(&wrapper_event_id)
810            .expect("Failed to mark message as retryable");
811
812        // Verify state changed to Retryable
813        let found = storage
814            .find_processed_message_by_event_id(&wrapper_event_id)
815            .unwrap()
816            .unwrap();
817        assert_eq!(found.state, ProcessedMessageState::Retryable);
818
819        // Verify failure_reason is preserved
820        assert_eq!(found.failure_reason, Some("Decryption failed".to_string()));
821    }
822
823    #[test]
824    fn test_mark_nonexistent_message_retryable_fails() {
825        use mdk_storage_traits::messages::error::MessageError;
826
827        let storage = MdkMemoryStorage::new();
828
829        let wrapper_event_id = EventId::from_slice(&[100u8; 32]).unwrap();
830
831        // Attempt to mark a non-existent message as retryable
832        let result = storage.mark_processed_message_retryable(&wrapper_event_id);
833        assert!(result.is_err());
834        assert!(matches!(result.unwrap_err(), MessageError::NotFound));
835    }
836
837    #[test]
838    fn test_mark_non_failed_message_retryable_fails() {
839        use mdk_storage_traits::messages::error::MessageError;
840        use mdk_storage_traits::messages::types::ProcessedMessage;
841
842        let storage = MdkMemoryStorage::new();
843
844        // Create a processed message in Processed state (not Failed)
845        let wrapper_event_id = EventId::from_slice(&[100u8; 32]).unwrap();
846
847        let processed_message = ProcessedMessage {
848            wrapper_event_id,
849            message_event_id: None,
850            processed_at: Timestamp::from(1_000_000_000u64),
851            epoch: Some(1),
852            mls_group_id: Some(GroupId::from_slice(&[1, 2, 3, 4])),
853            state: ProcessedMessageState::Processed,
854            failure_reason: None,
855        };
856
857        storage
858            .save_processed_message(processed_message)
859            .expect("Failed to save processed message");
860
861        // Attempt to mark a Processed message as retryable should fail
862        let result = storage.mark_processed_message_retryable(&wrapper_event_id);
863        assert!(result.is_err());
864        assert!(matches!(result.unwrap_err(), MessageError::NotFound));
865
866        // Verify state is unchanged
867        let found = storage
868            .find_processed_message_by_event_id(&wrapper_event_id)
869            .unwrap()
870            .unwrap();
871        assert_eq!(found.state, ProcessedMessageState::Processed);
872    }
873
874    /// Verifies that querying a group with no stored messages returns Ok(None).
875    #[test]
876    fn test_find_message_epoch_by_tag_content_unknown_group() {
877        let storage = MdkMemoryStorage::new();
878        let unknown_group_id = GroupId::from_slice(&[99, 99, 99, 99]);
879
880        let result = storage
881            .find_message_epoch_by_tag_content(&unknown_group_id, "x abcdef")
882            .unwrap();
883
884        assert_eq!(result, None);
885    }
886
887    /// Verifies that when messages exist but none contain the searched
888    /// substring, Ok(None) is returned.
889    #[test]
890    fn test_find_message_epoch_by_tag_content_no_matching_tag() {
891        let storage = MdkMemoryStorage::new();
892        let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
893
894        let group = create_test_group(group_id.clone());
895        storage.save_group(group).unwrap();
896
897        let event_id = EventId::from_slice(&[10u8; 32]).unwrap();
898        let message = create_test_message_with_epoch(
899            event_id,
900            group_id.clone(),
901            "some content",
902            1000,
903            Some(5),
904        );
905        storage.save_message(message).unwrap();
906
907        let result = storage
908            .find_message_epoch_by_tag_content(&group_id, "x deadbeef_not_present")
909            .unwrap();
910
911        assert_eq!(result, None);
912    }
913
914    /// Verifies the happy path: a message with matching tag content and a
915    /// non-null epoch returns Ok(Some(epoch)).
916    #[test]
917    fn test_find_message_epoch_by_tag_content_happy_path() {
918        let storage = MdkMemoryStorage::new();
919        let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
920
921        let group = create_test_group(group_id.clone());
922        storage.save_group(group).unwrap();
923
924        let event_id = EventId::from_slice(&[10u8; 32]).unwrap();
925        let pubkey = Keys::generate().public_key();
926        let wrapper_event_id = EventId::from_slice(&[200u8; 32]).unwrap();
927
928        let tags = Tags::parse(vec![vec!["imeta", "x abcdef123456"]]).unwrap();
929        let message = Message {
930            id: event_id,
931            pubkey,
932            kind: Kind::from(445u16),
933            mls_group_id: group_id.clone(),
934            created_at: Timestamp::from(1000u64),
935            processed_at: Timestamp::from(1000u64),
936            content: "".to_string(),
937            tags: tags.clone(),
938            event: UnsignedEvent::new(
939                pubkey,
940                Timestamp::from(1000u64),
941                Kind::from(445u16),
942                tags,
943                "".to_string(),
944            ),
945            wrapper_event_id,
946            epoch: Some(7),
947            state: MessageState::Processed,
948        };
949        storage.save_message(message).unwrap();
950
951        let result = storage
952            .find_message_epoch_by_tag_content(&group_id, "x abcdef123456")
953            .unwrap();
954
955        assert_eq!(result, Some(7));
956    }
957
958    /// Verifies that messages with epoch: None are skipped even when their
959    /// tags match the search substring.
960    #[test]
961    fn test_find_message_epoch_by_tag_content_skips_null_epoch() {
962        let storage = MdkMemoryStorage::new();
963        let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
964
965        let group = create_test_group(group_id.clone());
966        storage.save_group(group).unwrap();
967
968        let event_id = EventId::from_slice(&[10u8; 32]).unwrap();
969        let pubkey = Keys::generate().public_key();
970        let wrapper_event_id = EventId::from_slice(&[200u8; 32]).unwrap();
971
972        // Store a message with matching tags but epoch: None
973        let tags = Tags::parse(vec![vec!["imeta", "x abcdef123456"]]).unwrap();
974        let message = Message {
975            id: event_id,
976            pubkey,
977            kind: Kind::from(445u16),
978            mls_group_id: group_id.clone(),
979            created_at: Timestamp::from(1000u64),
980            processed_at: Timestamp::from(1000u64),
981            content: "".to_string(),
982            tags: tags.clone(),
983            event: UnsignedEvent::new(
984                pubkey,
985                Timestamp::from(1000u64),
986                Kind::from(445u16),
987                tags,
988                "".to_string(),
989            ),
990            wrapper_event_id,
991            epoch: None,
992            state: MessageState::Processed,
993        };
994        storage.save_message(message).unwrap();
995
996        let result = storage
997            .find_message_epoch_by_tag_content(&group_id, "x abcdef123456")
998            .unwrap();
999
1000        assert_eq!(result, None);
1001    }
1002}