Skip to main content

mdk_sqlite_storage/
messages.rs

1//! Implementation of MessageStorage trait for SQLite storage.
2
3use mdk_storage_traits::messages::MessageStorage;
4use mdk_storage_traits::messages::error::MessageError;
5use mdk_storage_traits::messages::types::{Message, ProcessedMessage};
6use nostr::{EventId, JsonUtil};
7use rusqlite::{OptionalExtension, params};
8
9use crate::validation::{
10    MAX_EVENT_JSON_SIZE, MAX_MESSAGE_CONTENT_SIZE, MAX_TAGS_JSON_SIZE, validate_size,
11    validate_string_length,
12};
13use crate::{MdkSqliteStorage, db};
14
15#[inline]
16fn into_message_err<T>(e: T) -> MessageError
17where
18    T: std::error::Error,
19{
20    MessageError::DatabaseError(e.to_string())
21}
22
23impl MessageStorage for MdkSqliteStorage {
24    fn save_message(&self, message: Message) -> Result<(), MessageError> {
25        // Validate content size
26        validate_string_length(
27            &message.content,
28            MAX_MESSAGE_CONTENT_SIZE,
29            "Message content",
30        )
31        .map_err(|e| MessageError::InvalidParameters(e.to_string()))?;
32
33        // Serialize complex types to JSON
34        let tags_json: String = serde_json::to_string(&message.tags)
35            .map_err(|e| MessageError::DatabaseError(format!("Failed to serialize tags: {}", e)))?;
36
37        // Validate tags JSON size
38        validate_size(tags_json.as_bytes(), MAX_TAGS_JSON_SIZE, "Tags JSON")
39            .map_err(|e| MessageError::InvalidParameters(e.to_string()))?;
40
41        // Serialize event to JSON
42        let event_json = message.event.as_json();
43
44        // Validate event JSON size
45        validate_size(event_json.as_bytes(), MAX_EVENT_JSON_SIZE, "Event JSON")
46            .map_err(|e| MessageError::InvalidParameters(e.to_string()))?;
47
48        self.with_connection(|conn| {
49            conn.execute(
50                "INSERT INTO messages
51             (id, pubkey, kind, mls_group_id, created_at, processed_at, content, tags, event, wrapper_event_id, epoch, state)
52             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
53             ON CONFLICT(mls_group_id, id) DO UPDATE SET
54                 pubkey = excluded.pubkey,
55                 kind = excluded.kind,
56                 created_at = excluded.created_at,
57                 processed_at = excluded.processed_at,
58                 content = excluded.content,
59                 tags = excluded.tags,
60                 event = excluded.event,
61                 wrapper_event_id = excluded.wrapper_event_id,
62                 epoch = excluded.epoch,
63                 state = excluded.state",
64                params![
65                    message.id.as_bytes(),
66                    message.pubkey.as_bytes(),
67                    message.kind.as_u16(),
68                    message.mls_group_id.as_slice(),
69                    message.created_at.as_secs(),
70                    message.processed_at.as_secs(),
71                    &message.content,
72                    &tags_json,
73                    &event_json,
74                    message.wrapper_event_id.as_bytes(),
75                    message.epoch,
76                    message.state.as_str(),
77                ],
78            )
79            .map_err(into_message_err)?;
80
81            Ok(())
82        })
83    }
84
85    fn find_message_by_event_id(
86        &self,
87        mls_group_id: &mdk_storage_traits::GroupId,
88        event_id: &EventId,
89    ) -> Result<Option<Message>, MessageError> {
90        self.with_connection(|conn| {
91            let mut stmt = conn
92                .prepare("SELECT * FROM messages WHERE mls_group_id = ? AND id = ?")
93                .map_err(into_message_err)?;
94
95            stmt.query_row(
96                params![mls_group_id.as_slice(), event_id.to_bytes()],
97                db::row_to_message,
98            )
99            .optional()
100            .map_err(into_message_err)
101        })
102    }
103
104    fn save_processed_message(
105        &self,
106        processed_message: ProcessedMessage,
107    ) -> Result<(), MessageError> {
108        // Convert message_event_id to bytes if it exists
109        let message_event_id = processed_message
110            .message_event_id
111            .as_ref()
112            .map(|id| id.to_bytes());
113
114        // Convert mls_group_id to bytes if it exists
115        let mls_group_id = processed_message
116            .mls_group_id
117            .as_ref()
118            .map(|id| id.as_slice().to_vec());
119
120        self.with_connection(|conn| {
121            conn.execute(
122                "INSERT OR REPLACE INTO processed_messages
123             (wrapper_event_id, message_event_id, processed_at, epoch, mls_group_id, state, failure_reason)
124             VALUES (?, ?, ?, ?, ?, ?, ?)",
125                params![
126                    &processed_message.wrapper_event_id.to_bytes(),
127                    &message_event_id,
128                    &processed_message.processed_at.as_secs(),
129                    &processed_message.epoch,
130                    &mls_group_id,
131                    &processed_message.state.to_string(),
132                    &processed_message.failure_reason
133                ],
134            )
135            .map_err(into_message_err)?;
136
137            Ok(())
138        })
139    }
140
141    fn find_processed_message_by_event_id(
142        &self,
143        event_id: &EventId,
144    ) -> Result<Option<ProcessedMessage>, MessageError> {
145        self.with_connection(|conn| {
146            let mut stmt = conn
147                .prepare("SELECT * FROM processed_messages WHERE wrapper_event_id = ?")
148                .map_err(into_message_err)?;
149
150            stmt.query_row(params![event_id.to_bytes()], db::row_to_processed_message)
151                .optional()
152                .map_err(into_message_err)
153        })
154    }
155
156    fn invalidate_messages_after_epoch(
157        &self,
158        group_id: &mdk_storage_traits::GroupId,
159        epoch: u64,
160    ) -> Result<Vec<EventId>, MessageError> {
161        self.with_connection(|conn| {
162            // First, get the event IDs that will be invalidated
163            let mut stmt = conn
164                .prepare(
165                    "SELECT id FROM messages
166                     WHERE mls_group_id = ? AND epoch > ?",
167                )
168                .map_err(into_message_err)?;
169
170            let event_ids: Vec<EventId> = stmt
171                .query_map(params![group_id.as_slice(), epoch], |row| {
172                    let id_blob: Vec<u8> = row.get(0)?;
173                    Ok(id_blob)
174                })
175                .map_err(into_message_err)?
176                .filter_map(|r| r.ok())
177                .filter_map(|id_blob| EventId::from_slice(&id_blob).ok())
178                .collect();
179
180            // Then update the state to epoch_invalidated
181            conn.execute(
182                "UPDATE messages SET state = 'epoch_invalidated'
183                 WHERE mls_group_id = ? AND epoch > ?",
184                params![group_id.as_slice(), epoch],
185            )
186            .map_err(into_message_err)?;
187
188            Ok(event_ids)
189        })
190    }
191
192    fn invalidate_processed_messages_after_epoch(
193        &self,
194        group_id: &mdk_storage_traits::GroupId,
195        epoch: u64,
196    ) -> Result<Vec<EventId>, MessageError> {
197        self.with_connection(|conn| {
198            // First, get the wrapper event IDs that will be invalidated
199            let mut stmt = conn
200                .prepare(
201                    "SELECT wrapper_event_id FROM processed_messages
202                     WHERE mls_group_id = ? AND epoch > ?",
203                )
204                .map_err(into_message_err)?;
205
206            let event_ids: Vec<EventId> = stmt
207                .query_map(params![group_id.as_slice(), epoch], |row| {
208                    let id_blob: Vec<u8> = row.get(0)?;
209                    Ok(id_blob)
210                })
211                .map_err(into_message_err)?
212                .filter_map(|r| r.ok())
213                .filter_map(|id_blob| EventId::from_slice(&id_blob).ok())
214                .collect();
215
216            // Then update the state to epoch_invalidated
217            conn.execute(
218                "UPDATE processed_messages SET state = 'epoch_invalidated'
219                 WHERE mls_group_id = ? AND epoch > ?",
220                params![group_id.as_slice(), epoch],
221            )
222            .map_err(into_message_err)?;
223
224            Ok(event_ids)
225        })
226    }
227
228    fn find_invalidated_messages(
229        &self,
230        group_id: &mdk_storage_traits::GroupId,
231    ) -> Result<Vec<Message>, MessageError> {
232        self.with_connection(|conn| {
233            let mut stmt = conn
234                .prepare(
235                    "SELECT * FROM messages
236                     WHERE mls_group_id = ? AND state = 'epoch_invalidated'",
237                )
238                .map_err(into_message_err)?;
239
240            let messages: Vec<Message> = stmt
241                .query_map(params![group_id.as_slice()], db::row_to_message)
242                .map_err(into_message_err)?
243                .filter_map(|r| r.ok())
244                .collect();
245
246            Ok(messages)
247        })
248    }
249
250    fn find_invalidated_processed_messages(
251        &self,
252        group_id: &mdk_storage_traits::GroupId,
253    ) -> Result<Vec<ProcessedMessage>, MessageError> {
254        self.with_connection(|conn| {
255            let mut stmt = conn
256                .prepare(
257                    "SELECT * FROM processed_messages
258                     WHERE mls_group_id = ? AND state = 'epoch_invalidated'",
259                )
260                .map_err(into_message_err)?;
261
262            let messages: Vec<ProcessedMessage> = stmt
263                .query_map(params![group_id.as_slice()], db::row_to_processed_message)
264                .map_err(into_message_err)?
265                .filter_map(|r| r.ok())
266                .collect();
267
268            Ok(messages)
269        })
270    }
271
272    fn find_failed_messages_for_retry(
273        &self,
274        group_id: &mdk_storage_traits::GroupId,
275    ) -> Result<Vec<EventId>, MessageError> {
276        self.with_connection(|conn| {
277            // Find processed messages that:
278            // - Are for this group
279            // - Have state = Failed
280            // - Have epoch IS NULL (decryption failed before epoch could be determined)
281            let mut stmt = conn
282                .prepare(
283                    "SELECT wrapper_event_id FROM processed_messages
284                     WHERE mls_group_id = ? AND state = 'failed' AND epoch IS NULL",
285                )
286                .map_err(into_message_err)?;
287
288            let event_ids: Vec<EventId> = stmt
289                .query_map(params![group_id.as_slice()], |row| {
290                    let id_blob: Vec<u8> = row.get(0)?;
291                    Ok(id_blob)
292                })
293                .map_err(into_message_err)?
294                .filter_map(|r| r.ok())
295                .filter_map(|id_blob| EventId::from_slice(&id_blob).ok())
296                .collect();
297
298            Ok(event_ids)
299        })
300    }
301
302    fn mark_processed_message_retryable(&self, event_id: &EventId) -> Result<(), MessageError> {
303        self.with_connection(|conn| {
304            // Only update messages that are currently in Failed state
305            let rows_affected = conn
306                .execute(
307                    "UPDATE processed_messages SET state = 'retryable'
308                     WHERE wrapper_event_id = ? AND state = 'failed'",
309                    params![event_id.to_bytes()],
310                )
311                .map_err(into_message_err)?;
312
313            if rows_affected == 0 {
314                return Err(MessageError::NotFound);
315            }
316
317            Ok(())
318        })
319    }
320
321    fn find_message_epoch_by_tag_content(
322        &self,
323        group_id: &mdk_storage_traits::GroupId,
324        content_substring: &str,
325    ) -> Result<Option<u64>, MessageError> {
326        let escaped = content_substring
327            .replace('\\', "\\\\")
328            .replace('%', "\\%")
329            .replace('_', "\\_");
330        let pattern = format!("%{}%", escaped);
331        self.with_connection(|conn| {
332            let mut stmt = conn
333                .prepare(
334                    "SELECT epoch FROM messages
335                     WHERE mls_group_id = ? AND tags LIKE ? ESCAPE '\\' AND epoch IS NOT NULL
336                     LIMIT 1",
337                )
338                .map_err(into_message_err)?;
339
340            stmt.query_row(params![group_id.as_slice(), &pattern], |row| {
341                row.get::<_, u64>(0)
342            })
343            .optional()
344            .map_err(into_message_err)
345        })
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use std::collections::BTreeSet;
352
353    use mdk_storage_traits::GroupId;
354    use mdk_storage_traits::groups::GroupStorage;
355    use mdk_storage_traits::groups::types::{Group, GroupState, SelfUpdateState};
356    use mdk_storage_traits::messages::types::{MessageState, ProcessedMessageState};
357    use nostr::{EventId, Kind, PublicKey, Tags, Timestamp, UnsignedEvent};
358
359    use super::*;
360
361    #[test]
362    fn test_save_and_find_message() {
363        let storage = MdkSqliteStorage::new_in_memory().unwrap();
364
365        // First create a group (messages require a valid group foreign key)
366        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
367        let mut nostr_group_id = [0u8; 32];
368        nostr_group_id[0..13].copy_from_slice(b"test_group_12");
369
370        let group = Group {
371            mls_group_id: mls_group_id.clone(),
372            nostr_group_id,
373            name: "Test Group".to_string(),
374            description: "A test group".to_string(),
375            admin_pubkeys: BTreeSet::new(),
376            last_message_id: None,
377            last_message_at: None,
378            last_message_processed_at: None,
379            epoch: 0,
380            state: GroupState::Active,
381            image_hash: None,
382            image_key: None,
383            image_nonce: None,
384            self_update_state: SelfUpdateState::Required,
385        };
386
387        // Save the group
388        let result = storage.save_group(group);
389        assert!(result.is_ok());
390
391        // Create a test message
392        let event_id =
393            EventId::parse("6a2affe9878ebcf50c10cf74c7b25aad62e0db9fb347f6aafeda30e9f578f260")
394                .unwrap();
395        let pubkey =
396            PublicKey::parse("79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798")
397                .unwrap();
398        let wrapper_event_id =
399            EventId::parse("3287abd422284bc3679812c373c52ed4aa0af4f7c57b9c63ec440f6c3ed6c3a2")
400                .unwrap();
401
402        let now = Timestamp::now();
403        let message = Message {
404            id: event_id,
405            pubkey,
406            kind: Kind::from(1u16),
407            mls_group_id: mls_group_id.clone(),
408            created_at: now,
409            processed_at: now,
410            content: "Test message content".to_string(),
411            tags: Tags::new(),
412            event: UnsignedEvent::new(pubkey, now, Kind::from(9u16), vec![], "content".to_string()),
413            wrapper_event_id,
414            epoch: Some(1),
415            state: MessageState::Created,
416        };
417
418        // Save the message
419        let result = storage.save_message(message.clone());
420        assert!(result.is_ok());
421
422        // Find by event ID
423        let found_message = storage
424            .find_message_by_event_id(&mls_group_id, &event_id)
425            .unwrap()
426            .unwrap();
427        assert_eq!(found_message.id, event_id);
428        assert_eq!(found_message.pubkey, pubkey);
429        assert_eq!(found_message.content, "Test message content");
430    }
431
432    #[test]
433    fn test_processed_message() {
434        let storage = MdkSqliteStorage::new_in_memory().unwrap();
435
436        // Create a test processed message
437        let wrapper_event_id =
438            EventId::parse("3287abd422284bc3679812c373c52ed4aa0af4f7c57b9c63ec440f6c3ed6c3a2")
439                .unwrap();
440        let message_event_id =
441            EventId::parse("6a2affe9878ebcf50c10cf74c7b25aad62e0db9fb347f6aafeda30e9f578f260")
442                .unwrap();
443
444        let processed_message = ProcessedMessage {
445            wrapper_event_id,
446            message_event_id: Some(message_event_id),
447            processed_at: Timestamp::from(1_000_000_000u64),
448            epoch: Some(1),
449            mls_group_id: None,
450            state: ProcessedMessageState::Processed,
451            failure_reason: None,
452        };
453
454        // Save the processed message
455        let result = storage.save_processed_message(processed_message.clone());
456        assert!(result.is_ok());
457
458        // Find by event ID
459        let found_processed_message = storage
460            .find_processed_message_by_event_id(&wrapper_event_id)
461            .unwrap()
462            .unwrap();
463        assert_eq!(found_processed_message.wrapper_event_id, wrapper_event_id);
464        assert_eq!(
465            found_processed_message.message_event_id.unwrap(),
466            message_event_id
467        );
468        assert_eq!(
469            found_processed_message.state,
470            ProcessedMessageState::Processed
471        );
472    }
473
474    #[test]
475    fn test_message_content_size_validation() {
476        let storage = MdkSqliteStorage::new_in_memory().unwrap();
477
478        // Create a group first
479        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
480        let mut nostr_group_id = [0u8; 32];
481        nostr_group_id[0..13].copy_from_slice(b"test_group_12");
482
483        let group = Group {
484            mls_group_id: mls_group_id.clone(),
485            nostr_group_id,
486            name: "Test Group".to_string(),
487            description: "Test".to_string(),
488            admin_pubkeys: BTreeSet::new(),
489            last_message_id: None,
490            last_message_at: None,
491            last_message_processed_at: None,
492            epoch: 0,
493            state: GroupState::Active,
494            image_hash: None,
495            image_key: None,
496            image_nonce: None,
497            self_update_state: SelfUpdateState::Required,
498        };
499        storage.save_group(group).unwrap();
500
501        // Create a message with content exceeding the limit (1 MB)
502        let oversized_content = "x".repeat(1024 * 1024 + 1);
503
504        let event_id = EventId::all_zeros();
505        let pubkey = PublicKey::from_slice(&[1u8; 32]).unwrap();
506        let wrapper_event_id =
507            EventId::from_hex("1111111111111111111111111111111111111111111111111111111111111111")
508                .unwrap();
509
510        let now = Timestamp::now();
511        let message = Message {
512            id: event_id,
513            pubkey,
514            kind: Kind::from(1u16),
515            mls_group_id: mls_group_id.clone(),
516            created_at: now,
517            processed_at: now,
518            content: oversized_content,
519            tags: Tags::new(),
520            event: UnsignedEvent::new(pubkey, now, Kind::from(9u16), vec![], "content".to_string()),
521            wrapper_event_id,
522            epoch: None,
523            state: MessageState::Created,
524        };
525
526        // Should fail due to content size
527        let result = storage.save_message(message);
528        assert!(result.is_err());
529        let err_msg = result.unwrap_err().to_string();
530        assert!(err_msg.contains("Message content exceeds maximum"));
531    }
532
533    #[test]
534    fn test_messages_cannot_overwrite_across_groups() {
535        let storage = MdkSqliteStorage::new_in_memory().unwrap();
536
537        // Create two different groups
538        let mls_group_id_1 = GroupId::from_slice(&[1, 2, 3, 4]);
539        let mls_group_id_2 = GroupId::from_slice(&[5, 6, 7, 8]);
540
541        let mut nostr_group_id_1 = [0u8; 32];
542        nostr_group_id_1[0..12].copy_from_slice(b"test_group_1");
543        let mut nostr_group_id_2 = [0u8; 32];
544        nostr_group_id_2[0..12].copy_from_slice(b"test_group_2");
545
546        let group_1 = Group {
547            mls_group_id: mls_group_id_1.clone(),
548            nostr_group_id: nostr_group_id_1,
549            name: "Test Group 1".to_string(),
550            description: "First test group".to_string(),
551            admin_pubkeys: BTreeSet::new(),
552            last_message_id: None,
553            last_message_at: None,
554            last_message_processed_at: None,
555            epoch: 0,
556            state: GroupState::Active,
557            image_hash: None,
558            image_key: None,
559            image_nonce: None,
560            self_update_state: SelfUpdateState::Required,
561        };
562
563        let group_2 = Group {
564            mls_group_id: mls_group_id_2.clone(),
565            nostr_group_id: nostr_group_id_2,
566            name: "Test Group 2".to_string(),
567            description: "Second test group".to_string(),
568            admin_pubkeys: BTreeSet::new(),
569            last_message_id: None,
570            last_message_at: None,
571            last_message_processed_at: None,
572            epoch: 0,
573            state: GroupState::Active,
574            image_hash: None,
575            image_key: None,
576            image_nonce: None,
577            self_update_state: SelfUpdateState::Required,
578        };
579
580        storage.save_group(group_1).unwrap();
581        storage.save_group(group_2).unwrap();
582
583        // Create two messages with the same event ID but different groups
584        let same_event_id =
585            EventId::parse("6a2affe9878ebcf50c10cf74c7b25aad62e0db9fb347f6aafeda30e9f578f260")
586                .unwrap();
587        let pubkey =
588            PublicKey::parse("79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798")
589                .unwrap();
590        let wrapper_event_id_1 =
591            EventId::parse("3287abd422284bc3679812c373c52ed4aa0af4f7c57b9c63ec440f6c3ed6c3a1")
592                .unwrap();
593        let wrapper_event_id_2 =
594            EventId::parse("3287abd422284bc3679812c373c52ed4aa0af4f7c57b9c63ec440f6c3ed6c3a2")
595                .unwrap();
596
597        let now = Timestamp::now();
598        let message_1 = Message {
599            id: same_event_id,
600            pubkey,
601            kind: Kind::from(1u16),
602            mls_group_id: mls_group_id_1.clone(),
603            created_at: now,
604            processed_at: now,
605            content: "Message in group 1".to_string(),
606            tags: Tags::new(),
607            event: UnsignedEvent::new(pubkey, now, Kind::from(9u16), vec![], "content".to_string()),
608            wrapper_event_id: wrapper_event_id_1,
609            epoch: Some(1),
610            state: MessageState::Created,
611        };
612
613        let message_2 = Message {
614            id: same_event_id,
615            pubkey,
616            kind: Kind::from(1u16),
617            mls_group_id: mls_group_id_2.clone(),
618            created_at: now,
619            processed_at: now,
620            content: "Message in group 2".to_string(),
621            tags: Tags::new(),
622            event: UnsignedEvent::new(pubkey, now, Kind::from(9u16), vec![], "content".to_string()),
623            wrapper_event_id: wrapper_event_id_2,
624            epoch: Some(2),
625            state: MessageState::Created,
626        };
627
628        // Save both messages
629        storage.save_message(message_1.clone()).unwrap();
630        storage.save_message(message_2.clone()).unwrap();
631
632        // Verify both messages exist and are distinct
633        let found_message_1 = storage
634            .find_message_by_event_id(&mls_group_id_1, &same_event_id)
635            .unwrap()
636            .unwrap();
637        assert_eq!(found_message_1.content, "Message in group 1");
638        assert_eq!(found_message_1.mls_group_id, mls_group_id_1);
639
640        let found_message_2 = storage
641            .find_message_by_event_id(&mls_group_id_2, &same_event_id)
642            .unwrap()
643            .unwrap();
644        assert_eq!(found_message_2.content, "Message in group 2");
645        assert_eq!(found_message_2.mls_group_id, mls_group_id_2);
646
647        // Verify that looking up the same event ID in group 2 returns group 2's message
648        let wrong_group_lookup = storage
649            .find_message_by_event_id(&mls_group_id_2, &same_event_id)
650            .unwrap();
651        assert!(wrong_group_lookup.is_some());
652        let wrong_group_message = wrong_group_lookup.unwrap();
653        assert_eq!(wrong_group_message.mls_group_id, mls_group_id_2);
654
655        // Verify that looking up the event ID in group 1 still returns group 1's message
656        let group_1_lookup = storage
657            .find_message_by_event_id(&mls_group_id_1, &same_event_id)
658            .unwrap();
659        assert!(group_1_lookup.is_some());
660        let group_1_message = group_1_lookup.unwrap();
661        assert_eq!(group_1_message.mls_group_id, mls_group_id_1);
662        assert_eq!(group_1_message.content, "Message in group 1");
663    }
664
665    #[test]
666    fn test_mark_processed_message_retryable() {
667        let storage = MdkSqliteStorage::new_in_memory().unwrap();
668
669        // Create a failed processed message
670        let wrapper_event_id =
671            EventId::parse("3287abd422284bc3679812c373c52ed4aa0af4f7c57b9c63ec440f6c3ed6c3a2")
672                .unwrap();
673
674        let processed_message = ProcessedMessage {
675            wrapper_event_id,
676            message_event_id: None,
677            processed_at: Timestamp::from(1_000_000_000u64),
678            epoch: None,
679            mls_group_id: Some(GroupId::from_slice(&[1, 2, 3, 4])),
680            state: ProcessedMessageState::Failed,
681            failure_reason: Some("Decryption failed".to_string()),
682        };
683
684        // Save the failed processed message
685        storage
686            .save_processed_message(processed_message)
687            .expect("Failed to save processed message");
688
689        // Verify it's in Failed state
690        let found = storage
691            .find_processed_message_by_event_id(&wrapper_event_id)
692            .unwrap()
693            .unwrap();
694        assert_eq!(found.state, ProcessedMessageState::Failed);
695
696        // Mark as retryable
697        storage
698            .mark_processed_message_retryable(&wrapper_event_id)
699            .expect("Failed to mark message as retryable");
700
701        // Verify state changed to Retryable
702        let found = storage
703            .find_processed_message_by_event_id(&wrapper_event_id)
704            .unwrap()
705            .unwrap();
706        assert_eq!(found.state, ProcessedMessageState::Retryable);
707
708        // Verify failure_reason is preserved
709        assert_eq!(found.failure_reason, Some("Decryption failed".to_string()));
710    }
711
712    #[test]
713    fn test_mark_nonexistent_message_retryable_fails() {
714        let storage = MdkSqliteStorage::new_in_memory().unwrap();
715
716        let wrapper_event_id =
717            EventId::parse("3287abd422284bc3679812c373c52ed4aa0af4f7c57b9c63ec440f6c3ed6c3a2")
718                .unwrap();
719
720        // Attempt to mark a non-existent message as retryable
721        let result = storage.mark_processed_message_retryable(&wrapper_event_id);
722        assert!(result.is_err());
723        assert!(matches!(result.unwrap_err(), MessageError::NotFound));
724    }
725
726    #[test]
727    fn test_mark_non_failed_message_retryable_fails() {
728        let storage = MdkSqliteStorage::new_in_memory().unwrap();
729
730        // Create a processed message in Processed state (not Failed)
731        let wrapper_event_id =
732            EventId::parse("3287abd422284bc3679812c373c52ed4aa0af4f7c57b9c63ec440f6c3ed6c3a2")
733                .unwrap();
734
735        let processed_message = ProcessedMessage {
736            wrapper_event_id,
737            message_event_id: None,
738            processed_at: Timestamp::from(1_000_000_000u64),
739            epoch: Some(1),
740            mls_group_id: Some(GroupId::from_slice(&[1, 2, 3, 4])),
741            state: ProcessedMessageState::Processed,
742            failure_reason: None,
743        };
744
745        storage
746            .save_processed_message(processed_message)
747            .expect("Failed to save processed message");
748
749        // Attempt to mark a Processed message as retryable should fail
750        let result = storage.mark_processed_message_retryable(&wrapper_event_id);
751        assert!(result.is_err());
752        assert!(matches!(result.unwrap_err(), MessageError::NotFound));
753
754        // Verify state is unchanged
755        let found = storage
756            .find_processed_message_by_event_id(&wrapper_event_id)
757            .unwrap()
758            .unwrap();
759        assert_eq!(found.state, ProcessedMessageState::Processed);
760    }
761
762    /// Verifies that %, _, and \ in content_substring are treated as literal
763    /// characters and not as SQL LIKE wildcards.
764    #[test]
765    fn test_find_message_epoch_by_tag_content_escapes_like_wildcards() {
766        let storage = MdkSqliteStorage::new_in_memory().unwrap();
767
768        let group_id = GroupId::from_slice(&[1, 2, 3, 4]);
769        let mut nostr_group_id = [0u8; 32];
770        nostr_group_id[0..4].copy_from_slice(&[1, 2, 3, 4]);
771
772        let group = Group {
773            mls_group_id: group_id.clone(),
774            nostr_group_id,
775            name: "Test Group".to_string(),
776            description: "A test group".to_string(),
777            admin_pubkeys: BTreeSet::new(),
778            last_message_id: None,
779            last_message_at: None,
780            last_message_processed_at: None,
781            epoch: 0,
782            state: GroupState::Active,
783            image_hash: None,
784            image_key: None,
785            image_nonce: None,
786            self_update_state: SelfUpdateState::Required,
787        };
788        storage.save_group(group).unwrap();
789
790        let pubkey =
791            PublicKey::parse("79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798")
792                .unwrap();
793        let event_id = EventId::from_slice(&[10u8; 32]).unwrap();
794        let wrapper_event_id = EventId::from_slice(&[200u8; 32]).unwrap();
795
796        // Store a message with tags containing "x abc" (no wildcards)
797        let tags = Tags::parse(vec![vec!["imeta", "x abc"]]).unwrap();
798        let message = Message {
799            id: event_id,
800            pubkey,
801            kind: Kind::from(445u16),
802            mls_group_id: group_id.clone(),
803            created_at: Timestamp::from(1000u64),
804            processed_at: Timestamp::from(1000u64),
805            content: "".to_string(),
806            tags: tags.clone(),
807            event: UnsignedEvent::new(
808                pubkey,
809                Timestamp::from(1000u64),
810                Kind::from(445u16),
811                tags,
812                "".to_string(),
813            ),
814            wrapper_event_id,
815            epoch: Some(42),
816            state: MessageState::Processed,
817        };
818        storage.save_message(message).unwrap();
819
820        // Searching for exact content should find it
821        let result = storage
822            .find_message_epoch_by_tag_content(&group_id, "x abc")
823            .unwrap();
824        assert_eq!(result, Some(42), "Exact substring should match");
825
826        // Searching with SQL wildcard % should NOT match (treated literally)
827        let result = storage
828            .find_message_epoch_by_tag_content(&group_id, "x%abc")
829            .unwrap();
830        assert_eq!(
831            result, None,
832            "% must be treated as a literal, not a wildcard"
833        );
834
835        // Searching with SQL wildcard _ should NOT match (treated literally)
836        let result = storage
837            .find_message_epoch_by_tag_content(&group_id, "x_abc")
838            .unwrap();
839        assert_eq!(
840            result, None,
841            "_ must be treated as a literal, not a wildcard"
842        );
843
844        // Searching with backslash should NOT match (treated literally)
845        let result = storage
846            .find_message_epoch_by_tag_content(&group_id, "x\\abc")
847            .unwrap();
848        assert_eq!(
849            result, None,
850            "\\ must be treated as a literal, not an escape"
851        );
852    }
853}