Skip to main content

mdk_sqlite_storage/
groups.rs

1//! Implementation of GroupStorage trait for SQLite storage.
2
3use std::collections::BTreeSet;
4
5use mdk_storage_traits::GroupId;
6use mdk_storage_traits::groups::error::GroupError;
7use mdk_storage_traits::groups::types::{Group, GroupExporterSecret, GroupRelay, SelfUpdateState};
8use mdk_storage_traits::groups::{GroupStorage, MAX_MESSAGE_LIMIT, MessageSortOrder, Pagination};
9use mdk_storage_traits::messages::types::Message;
10use nostr::{PublicKey, RelayUrl};
11use rusqlite::{OptionalExtension, params};
12
13use crate::db::{Hash32, Nonce12};
14use crate::validation::{
15    MAX_ADMIN_PUBKEYS_JSON_SIZE, MAX_GROUP_DESCRIPTION_LENGTH, MAX_GROUP_NAME_LENGTH,
16    validate_size, validate_string_length,
17};
18use crate::{MdkSqliteStorage, db};
19
20#[inline]
21fn into_group_err<T>(e: T) -> GroupError
22where
23    T: std::error::Error,
24{
25    GroupError::DatabaseError(e.to_string())
26}
27
28impl GroupStorage for MdkSqliteStorage {
29    fn all_groups(&self) -> Result<Vec<Group>, GroupError> {
30        self.with_connection(|conn| {
31            let mut stmt = conn
32                .prepare("SELECT * FROM groups")
33                .map_err(into_group_err)?;
34
35            let groups_iter = stmt
36                .query_map([], db::row_to_group)
37                .map_err(into_group_err)?;
38
39            let mut groups: Vec<Group> = Vec::new();
40
41            for group_result in groups_iter {
42                match group_result {
43                    Ok(group) => {
44                        groups.push(group);
45                    }
46                    Err(e) => {
47                        tracing::warn!(
48                            error = %e,
49                            "Failed to deserialize group row, skipping"
50                        );
51                    }
52                }
53            }
54
55            Ok(groups)
56        })
57    }
58
59    fn find_group_by_mls_group_id(
60        &self,
61        mls_group_id: &GroupId,
62    ) -> Result<Option<Group>, GroupError> {
63        self.with_connection(|conn| {
64            let mut stmt = conn
65                .prepare("SELECT * FROM groups WHERE mls_group_id = ?")
66                .map_err(into_group_err)?;
67
68            stmt.query_row([mls_group_id.as_slice()], db::row_to_group)
69                .optional()
70                .map_err(into_group_err)
71        })
72    }
73
74    fn find_group_by_nostr_group_id(
75        &self,
76        nostr_group_id: &[u8; 32],
77    ) -> Result<Option<Group>, GroupError> {
78        self.with_connection(|conn| {
79            let mut stmt = conn
80                .prepare("SELECT * FROM groups WHERE nostr_group_id = ?")
81                .map_err(into_group_err)?;
82
83            stmt.query_row(params![nostr_group_id], db::row_to_group)
84                .optional()
85                .map_err(into_group_err)
86        })
87    }
88
89    fn save_group(&self, group: Group) -> Result<(), GroupError> {
90        // Validate group name and description lengths
91        validate_string_length(&group.name, MAX_GROUP_NAME_LENGTH, "Group name")
92            .map_err(|e| GroupError::InvalidParameters(e.to_string()))?;
93
94        validate_string_length(
95            &group.description,
96            MAX_GROUP_DESCRIPTION_LENGTH,
97            "Group description",
98        )
99        .map_err(|e| GroupError::InvalidParameters(e.to_string()))?;
100
101        let admin_pubkeys_json: String =
102            serde_json::to_string(&group.admin_pubkeys).map_err(|e| {
103                GroupError::DatabaseError(format!("Failed to serialize admin pubkeys: {}", e))
104            })?;
105
106        // Validate admin pubkeys JSON size
107        validate_size(
108            admin_pubkeys_json.as_bytes(),
109            MAX_ADMIN_PUBKEYS_JSON_SIZE,
110            "Admin pubkeys JSON",
111        )
112        .map_err(|e| GroupError::InvalidParameters(e.to_string()))?;
113
114        let last_message_id: Option<&[u8; 32]> =
115            group.last_message_id.as_ref().map(|id| id.as_bytes());
116        let last_message_at: Option<u64> = group.last_message_at.as_ref().map(|ts| ts.as_secs());
117        let last_message_processed_at: Option<u64> = group
118            .last_message_processed_at
119            .as_ref()
120            .map(|ts| ts.as_secs());
121
122        let last_self_update_at: u64 = match group.self_update_state {
123            SelfUpdateState::Required => 0,
124            SelfUpdateState::CompletedAt(ts) => ts.as_secs(),
125        };
126
127        self.with_connection(|conn| {
128            conn.execute(
129                "INSERT INTO groups
130             (mls_group_id, nostr_group_id, name, description, image_hash, image_key, image_nonce, admin_pubkeys, last_message_id,
131              last_message_at, last_message_processed_at, epoch, state, last_self_update_at)
132             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
133             ON CONFLICT(mls_group_id) DO UPDATE SET
134                nostr_group_id = excluded.nostr_group_id,
135                name = excluded.name,
136                description = excluded.description,
137                image_hash = excluded.image_hash,
138                image_key = excluded.image_key,
139                image_nonce = excluded.image_nonce,
140                admin_pubkeys = excluded.admin_pubkeys,
141                last_message_id = excluded.last_message_id,
142                last_message_at = excluded.last_message_at,
143                last_message_processed_at = excluded.last_message_processed_at,
144                epoch = excluded.epoch,
145                state = excluded.state,
146                last_self_update_at = excluded.last_self_update_at",
147                params![
148                    &group.mls_group_id.as_slice(),
149                    &group.nostr_group_id,
150                    &group.name,
151                    &group.description,
152                    &group.image_hash.map(Hash32::from),
153                    &group.image_key.as_ref().map(|k| Hash32::from(**k)),
154                    &group.image_nonce.as_ref().map(|n| Nonce12::from(**n)),
155                    &admin_pubkeys_json,
156                    last_message_id,
157                    &last_message_at,
158                    &last_message_processed_at,
159                    &(group.epoch as i64),
160                    group.state.as_str(),
161                    &last_self_update_at
162                ],
163            )
164            .map_err(into_group_err)?;
165
166            Ok(())
167        })
168    }
169
170    fn messages(
171        &self,
172        mls_group_id: &GroupId,
173        pagination: Option<Pagination>,
174    ) -> Result<Vec<Message>, GroupError> {
175        let pagination = pagination.unwrap_or_default();
176        let limit = pagination.limit();
177        let offset = pagination.offset();
178
179        // Validate limit is within allowed range
180        if !(1..=MAX_MESSAGE_LIMIT).contains(&limit) {
181            return Err(GroupError::InvalidParameters(format!(
182                "Limit must be between 1 and {}, got {}",
183                MAX_MESSAGE_LIMIT, limit
184            )));
185        }
186
187        // First verify the group exists
188        if self.find_group_by_mls_group_id(mls_group_id)?.is_none() {
189            return Err(GroupError::InvalidParameters("Group not found".to_string()));
190        }
191
192        let sort_order = pagination.sort_order();
193
194        self.with_connection(|conn| {
195            let query = match sort_order {
196                MessageSortOrder::CreatedAtFirst => {
197                    "SELECT * FROM messages WHERE mls_group_id = ? \
198                     ORDER BY created_at DESC, processed_at DESC, id DESC \
199                     LIMIT ? OFFSET ?"
200                }
201                MessageSortOrder::ProcessedAtFirst => {
202                    "SELECT * FROM messages WHERE mls_group_id = ? \
203                     ORDER BY processed_at DESC, created_at DESC, id DESC \
204                     LIMIT ? OFFSET ?"
205                }
206            };
207
208            let mut stmt = conn.prepare(query).map_err(into_group_err)?;
209
210            let messages_iter = stmt
211                .query_map(
212                    params![mls_group_id.as_slice(), limit as i64, offset as i64],
213                    db::row_to_message,
214                )
215                .map_err(into_group_err)?;
216
217            let mut messages: Vec<Message> = Vec::new();
218
219            for message_result in messages_iter {
220                let message: Message = message_result.map_err(into_group_err)?;
221                messages.push(message);
222            }
223
224            Ok(messages)
225        })
226    }
227
228    fn last_message(
229        &self,
230        mls_group_id: &GroupId,
231        sort_order: MessageSortOrder,
232    ) -> Result<Option<Message>, GroupError> {
233        if self.find_group_by_mls_group_id(mls_group_id)?.is_none() {
234            return Err(GroupError::InvalidParameters("Group not found".to_string()));
235        }
236
237        self.with_connection(|conn| {
238            let query = match sort_order {
239                MessageSortOrder::CreatedAtFirst => {
240                    "SELECT * FROM messages WHERE mls_group_id = ? \
241                     ORDER BY created_at DESC, processed_at DESC, id DESC \
242                     LIMIT 1"
243                }
244                MessageSortOrder::ProcessedAtFirst => {
245                    "SELECT * FROM messages WHERE mls_group_id = ? \
246                     ORDER BY processed_at DESC, created_at DESC, id DESC \
247                     LIMIT 1"
248                }
249            };
250
251            conn.prepare(query)
252                .map_err(into_group_err)?
253                .query_row(params![mls_group_id.as_slice()], db::row_to_message)
254                .optional()
255                .map_err(into_group_err)
256        })
257    }
258
259    fn admins(&self, mls_group_id: &GroupId) -> Result<BTreeSet<PublicKey>, GroupError> {
260        // Get the group which contains the admin_pubkeys
261        match self.find_group_by_mls_group_id(mls_group_id)? {
262            Some(group) => Ok(group.admin_pubkeys),
263            None => Err(GroupError::InvalidParameters("Group not found".to_string())),
264        }
265    }
266
267    fn group_relays(&self, mls_group_id: &GroupId) -> Result<BTreeSet<GroupRelay>, GroupError> {
268        // First verify the group exists
269        if self.find_group_by_mls_group_id(mls_group_id)?.is_none() {
270            return Err(GroupError::InvalidParameters("Group not found".to_string()));
271        }
272
273        self.with_connection(|conn| {
274            let mut stmt = conn
275                .prepare("SELECT * FROM group_relays WHERE mls_group_id = ?")
276                .map_err(into_group_err)?;
277
278            let relays_iter = stmt
279                .query_map(params![mls_group_id.as_slice()], db::row_to_group_relay)
280                .map_err(into_group_err)?;
281
282            let mut relays: BTreeSet<GroupRelay> = BTreeSet::new();
283
284            for relay_result in relays_iter {
285                let relay: GroupRelay = relay_result.map_err(into_group_err)?;
286                relays.insert(relay);
287            }
288
289            Ok(relays)
290        })
291    }
292
293    fn replace_group_relays(
294        &self,
295        group_id: &GroupId,
296        relays: BTreeSet<RelayUrl>,
297    ) -> Result<(), GroupError> {
298        // First verify the group exists
299        if self.find_group_by_mls_group_id(group_id)?.is_none() {
300            return Err(GroupError::InvalidParameters("Group not found".to_string()));
301        }
302
303        self.with_connection(|conn| {
304            // Use a savepoint for atomicity (works both inside/outside an existing transaction).
305            conn.execute_batch("SAVEPOINT mdk_replace_group_relays")
306                .map_err(into_group_err)?;
307
308            let result: Result<(), GroupError> = (|| {
309                conn.execute(
310                    "DELETE FROM group_relays WHERE mls_group_id = ?",
311                    params![group_id.as_slice()],
312                )
313                .map_err(into_group_err)?;
314
315                for relay_url in &relays {
316                    conn.execute(
317                        "INSERT INTO group_relays (mls_group_id, relay_url) VALUES (?, ?)",
318                        params![group_id.as_slice(), relay_url.as_str()],
319                    )
320                    .map_err(into_group_err)?;
321                }
322                Ok(())
323            })();
324
325            match result {
326                Ok(()) => conn
327                    .execute_batch("RELEASE SAVEPOINT mdk_replace_group_relays")
328                    .map_err(into_group_err),
329                Err(e) => {
330                    // Best-effort cleanup to keep connection usable.
331                    let _ = conn.execute_batch(
332                        "ROLLBACK TO SAVEPOINT mdk_replace_group_relays; \
333                         RELEASE SAVEPOINT mdk_replace_group_relays;",
334                    );
335                    Err(e)
336                }
337            }
338        })
339    }
340
341    fn get_group_exporter_secret(
342        &self,
343        mls_group_id: &GroupId,
344        epoch: u64,
345    ) -> Result<Option<GroupExporterSecret>, GroupError> {
346        // First verify the group exists
347        if self.find_group_by_mls_group_id(mls_group_id)?.is_none() {
348            return Err(GroupError::InvalidParameters("Group not found".to_string()));
349        }
350
351        self.with_connection(|conn| {
352            let mut stmt = conn
353                .prepare(
354                    "SELECT * FROM group_exporter_secrets WHERE mls_group_id = ? AND epoch = ?",
355                )
356                .map_err(into_group_err)?;
357
358            stmt.query_row(
359                params![mls_group_id.as_slice(), epoch],
360                db::row_to_group_exporter_secret,
361            )
362            .optional()
363            .map_err(into_group_err)
364        })
365    }
366
367    fn save_group_exporter_secret(
368        &self,
369        group_exporter_secret: GroupExporterSecret,
370    ) -> Result<(), GroupError> {
371        if self
372            .find_group_by_mls_group_id(&group_exporter_secret.mls_group_id)?
373            .is_none()
374        {
375            return Err(GroupError::InvalidParameters("Group not found".to_string()));
376        }
377
378        self.with_connection(|conn| {
379            conn.execute(
380                "INSERT OR REPLACE INTO group_exporter_secrets (mls_group_id, epoch, secret) VALUES (?, ?, ?)",
381                params![&group_exporter_secret.mls_group_id.as_slice(), &group_exporter_secret.epoch, group_exporter_secret.secret.as_ref()],
382            )
383            .map_err(into_group_err)?;
384
385            Ok(())
386        })
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use mdk_storage_traits::Secret;
393    use mdk_storage_traits::groups::types::GroupState;
394    use mdk_storage_traits::messages::MessageStorage;
395    use mdk_storage_traits::messages::types::MessageState;
396    use mdk_storage_traits::test_utils::crypto_utils::generate_random_bytes;
397    use nostr::{EventId, Kind, Tags, Timestamp, UnsignedEvent};
398    use rusqlite::Connection;
399    use tempfile::tempdir;
400
401    use super::*;
402
403    #[test]
404    fn test_save_and_find_group() {
405        let storage = MdkSqliteStorage::new_in_memory().unwrap();
406
407        // Create a test group
408        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
409        let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
410        let image_hash = Some(generate_random_bytes(32).try_into().unwrap());
411        let image_key = Some(Secret::new(generate_random_bytes(32).try_into().unwrap()));
412        let image_nonce = Some(Secret::new(generate_random_bytes(12).try_into().unwrap()));
413
414        let group = Group {
415            mls_group_id: mls_group_id.clone(),
416            nostr_group_id,
417            name: "Test Group".to_string(),
418            description: "A test group".to_string(),
419            admin_pubkeys: BTreeSet::new(),
420            last_message_id: None,
421            last_message_at: None,
422            last_message_processed_at: None,
423            epoch: 0,
424            state: GroupState::Active,
425            image_hash,
426            image_key,
427            image_nonce,
428            self_update_state: SelfUpdateState::Required,
429        };
430
431        // Save the group
432        let result = storage.save_group(group);
433        assert!(result.is_ok());
434
435        // Find by MLS group ID
436        let found_group = storage
437            .find_group_by_mls_group_id(&mls_group_id)
438            .unwrap()
439            .unwrap();
440        assert_eq!(found_group.nostr_group_id, nostr_group_id);
441
442        // Find by Nostr group ID
443        let found_group = storage
444            .find_group_by_nostr_group_id(&nostr_group_id)
445            .unwrap()
446            .unwrap();
447        assert_eq!(found_group.mls_group_id, mls_group_id);
448
449        // Get all groups
450        let all_groups = storage.all_groups().unwrap();
451        assert_eq!(all_groups.len(), 1);
452    }
453
454    #[test]
455    fn test_group_name_length_validation() {
456        let storage = MdkSqliteStorage::new_in_memory().unwrap();
457
458        // Create a group with name exceeding the limit (255 characters)
459        let oversized_name = "x".repeat(256);
460
461        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
462        let group = Group {
463            mls_group_id: mls_group_id.clone(),
464            nostr_group_id: [0u8; 32],
465            name: oversized_name,
466            description: "Test".to_string(),
467            admin_pubkeys: BTreeSet::new(),
468            last_message_id: None,
469            last_message_at: None,
470            last_message_processed_at: None,
471            epoch: 0,
472            state: GroupState::Active,
473            image_hash: None,
474            image_key: None,
475            image_nonce: None,
476            self_update_state: SelfUpdateState::Required,
477        };
478
479        // Should fail due to name length
480        let result = storage.save_group(group);
481        assert!(result.is_err());
482        assert!(
483            result
484                .unwrap_err()
485                .to_string()
486                .contains("Group name exceeds maximum length")
487        );
488    }
489
490    #[test]
491    fn test_group_description_length_validation() {
492        let storage = MdkSqliteStorage::new_in_memory().unwrap();
493
494        // Create a group with description exceeding the limit (2000 characters)
495        let oversized_description = "x".repeat(2001);
496
497        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
498        let group = Group {
499            mls_group_id: mls_group_id.clone(),
500            nostr_group_id: [0u8; 32],
501            name: "Test Group".to_string(),
502            description: oversized_description,
503            admin_pubkeys: BTreeSet::new(),
504            last_message_id: None,
505            last_message_at: None,
506            last_message_processed_at: None,
507            epoch: 0,
508            state: GroupState::Active,
509            image_hash: None,
510            image_key: None,
511            image_nonce: None,
512            self_update_state: SelfUpdateState::Required,
513        };
514
515        // Should fail due to description length
516        let result = storage.save_group(group);
517        assert!(result.is_err());
518        assert!(
519            result
520                .unwrap_err()
521                .to_string()
522                .contains("Group description exceeds maximum length")
523        );
524    }
525
526    // Note: Comprehensive storage functionality tests are now in mdk-storage-traits/tests/
527    // using shared test functions to ensure consistency between storage implementations
528
529    #[test]
530    fn test_messages_pagination() {
531        let storage = MdkSqliteStorage::new_in_memory().unwrap();
532
533        // Create a test group
534        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
535        let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
536
537        let group = Group {
538            mls_group_id: mls_group_id.clone(),
539            nostr_group_id,
540            name: "Test Group".to_string(),
541            description: "A test group".to_string(),
542            admin_pubkeys: BTreeSet::new(),
543            last_message_id: None,
544            last_message_at: None,
545            last_message_processed_at: None,
546            epoch: 0,
547            state: GroupState::Active,
548            image_hash: None,
549            image_key: None,
550            image_nonce: None,
551            self_update_state: SelfUpdateState::Required,
552        };
553
554        storage.save_group(group).unwrap();
555
556        // Create 25 test messages
557        let pubkey = PublicKey::from_slice(&[1u8; 32]).unwrap();
558        for i in 0..25 {
559            let event_id = EventId::from_slice(&[i as u8; 32]).unwrap();
560            let wrapper_event_id = EventId::from_slice(&[100 + i as u8; 32]).unwrap();
561
562            let ts = Timestamp::from((1000 + i) as u64);
563            let message = Message {
564                id: event_id,
565                pubkey,
566                kind: Kind::from(1u16),
567                mls_group_id: mls_group_id.clone(),
568                created_at: ts,
569                processed_at: ts,
570                content: format!("Message {}", i),
571                tags: Tags::new(),
572                event: UnsignedEvent::new(
573                    pubkey,
574                    ts,
575                    Kind::from(9u16),
576                    vec![],
577                    format!("content {}", i),
578                ),
579                wrapper_event_id,
580                state: MessageState::Created,
581                epoch: None,
582            };
583
584            storage.save_message(message).unwrap();
585        }
586
587        // Test pagination
588        let page1 = storage
589            .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(0))))
590            .unwrap();
591        assert_eq!(page1.len(), 10);
592        // Should be newest first (highest timestamp)
593        assert_eq!(page1[0].content, "Message 24");
594
595        let page2 = storage
596            .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(10))))
597            .unwrap();
598        assert_eq!(page2.len(), 10);
599        assert_eq!(page2[0].content, "Message 14");
600
601        let page3 = storage
602            .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(20))))
603            .unwrap();
604        assert_eq!(page3.len(), 5); // Only 5 messages left
605        assert_eq!(page3[0].content, "Message 4");
606
607        // Test default messages() uses limit
608        let default_messages = storage.messages(&mls_group_id, None).unwrap();
609        assert_eq!(default_messages.len(), 25); // All messages since < 1000
610
611        // Test: Verify no overlap between pages
612        let first_id = page1[0].id;
613        let second_page_ids: Vec<EventId> = page2.iter().map(|m| m.id).collect();
614        assert!(
615            !second_page_ids.contains(&first_id),
616            "Pages should not overlap"
617        );
618
619        // Test: Offset beyond available messages returns empty
620        let beyond = storage
621            .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(30))))
622            .unwrap();
623        assert_eq!(beyond.len(), 0);
624
625        // Test: Limit of 0 should return error
626        let result = storage.messages(&mls_group_id, Some(Pagination::new(Some(0), Some(0))));
627        assert!(result.is_err());
628        assert!(
629            result
630                .unwrap_err()
631                .to_string()
632                .contains("must be between 1 and")
633        );
634
635        // Test: Limit exceeding MAX should return error
636        let result = storage.messages(&mls_group_id, Some(Pagination::new(Some(20000), Some(0))));
637        assert!(result.is_err());
638        assert!(
639            result
640                .unwrap_err()
641                .to_string()
642                .contains("must be between 1 and")
643        );
644
645        // Test: Non-existent group returns error
646        let fake_group_id = GroupId::from_slice(&[99, 99, 99, 99]);
647        let result = storage.messages(&fake_group_id, Some(Pagination::new(Some(10), Some(0))));
648        assert!(result.is_err());
649        assert!(result.unwrap_err().to_string().contains("not found"));
650
651        // Test: Large offset should work (no MAX_OFFSET validation)
652        let result = storage.messages(
653            &mls_group_id,
654            Some(Pagination::new(Some(10), Some(2_000_000))),
655        );
656        assert!(result.is_ok());
657        assert_eq!(result.unwrap().len(), 0); // No results at that offset
658    }
659
660    #[test]
661    fn test_group_exporter_secret() {
662        let storage = MdkSqliteStorage::new_in_memory().unwrap();
663
664        // Create a test group
665        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
666        let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
667
668        let group = Group {
669            mls_group_id: mls_group_id.clone(),
670            nostr_group_id,
671            name: "Test Group".to_string(),
672            description: "A test group".to_string(),
673            admin_pubkeys: BTreeSet::new(),
674            last_message_id: None,
675            last_message_at: None,
676            last_message_processed_at: None,
677            epoch: 0,
678            state: GroupState::Active,
679            image_hash: None,
680            image_key: None,
681            image_nonce: None,
682            self_update_state: SelfUpdateState::Required,
683        };
684
685        // Save the group
686        storage.save_group(group).unwrap();
687
688        // Create a group exporter secret
689        let secret1 = GroupExporterSecret {
690            mls_group_id: mls_group_id.clone(),
691            epoch: 1,
692            secret: Secret::new([0u8; 32]),
693        };
694
695        // Save the secret
696        storage.save_group_exporter_secret(secret1).unwrap();
697
698        // Get the secret and verify it was saved correctly
699        let retrieved_secret = storage
700            .get_group_exporter_secret(&mls_group_id, 1)
701            .unwrap()
702            .unwrap();
703        assert_eq!(*retrieved_secret.secret, [0u8; 32]);
704
705        // Create a second secret with same group_id and epoch but different secret value
706        let secret2 = GroupExporterSecret {
707            mls_group_id: mls_group_id.clone(),
708            epoch: 1,
709            secret: Secret::new([0u8; 32]),
710        };
711
712        // Save the second secret - this should replace the first one due to the "OR REPLACE" in the SQL
713        storage.save_group_exporter_secret(secret2).unwrap();
714
715        // Get the secret again and verify it was updated
716        let retrieved_secret = storage
717            .get_group_exporter_secret(&mls_group_id, 1)
718            .unwrap()
719            .unwrap();
720        assert_eq!(*retrieved_secret.secret, [0u8; 32]);
721
722        // Verify we can still save a different epoch
723        let secret3 = GroupExporterSecret {
724            mls_group_id: mls_group_id.clone(),
725            epoch: 2,
726            secret: Secret::new([0u8; 32]),
727        };
728
729        storage.save_group_exporter_secret(secret3).unwrap();
730
731        // Verify both epochs exist
732        let retrieved_secret1 = storage
733            .get_group_exporter_secret(&mls_group_id, 1)
734            .unwrap()
735            .unwrap();
736        let retrieved_secret2 = storage
737            .get_group_exporter_secret(&mls_group_id, 2)
738            .unwrap()
739            .unwrap();
740
741        assert_eq!(*retrieved_secret1.secret, [0u8; 32]);
742        assert_eq!(*retrieved_secret2.secret, [0u8; 32]);
743    }
744
745    #[test]
746    fn test_all_groups_skips_corrupted_rows() {
747        // Use a file-based database so we can access it from multiple connections
748        let temp_dir = tempdir().unwrap();
749        let db_path = temp_dir.path().join("test.db");
750        let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
751
752        // Create and save two valid groups
753        let mls_group_id1 = GroupId::from_slice(&[1, 2, 3, 4]);
754        let nostr_group_id1 = generate_random_bytes(32).try_into().unwrap();
755        let group1 = Group {
756            mls_group_id: mls_group_id1.clone(),
757            nostr_group_id: nostr_group_id1,
758            name: "Group 1".to_string(),
759            description: "First group".to_string(),
760            admin_pubkeys: BTreeSet::new(),
761            last_message_id: None,
762            last_message_at: None,
763            last_message_processed_at: None,
764            epoch: 0,
765            state: GroupState::Active,
766            image_hash: None,
767            image_key: None,
768            image_nonce: None,
769            self_update_state: SelfUpdateState::Required,
770        };
771        storage.save_group(group1).unwrap();
772
773        let mls_group_id2 = GroupId::from_slice(&[5, 6, 7, 8]);
774        let nostr_group_id2 = generate_random_bytes(32).try_into().unwrap();
775        let group2 = Group {
776            mls_group_id: mls_group_id2.clone(),
777            nostr_group_id: nostr_group_id2,
778            name: "Group 2".to_string(),
779            description: "Second group".to_string(),
780            admin_pubkeys: BTreeSet::new(),
781            last_message_id: None,
782            last_message_at: None,
783            last_message_processed_at: None,
784            epoch: 0,
785            state: GroupState::Active,
786            image_hash: None,
787            image_key: None,
788            image_nonce: None,
789            self_update_state: SelfUpdateState::Required,
790        };
791        storage.save_group(group2).unwrap();
792
793        let corrupt_conn = Connection::open(&db_path).unwrap();
794        let corrupted_nostr_id_bytes = generate_random_bytes(32);
795        let corrupted_nostr_id: [u8; 32] = corrupted_nostr_id_bytes.try_into().unwrap();
796        corrupt_conn
797            .execute(
798                "INSERT INTO groups (mls_group_id, nostr_group_id, name, description, admin_pubkeys, epoch, state) VALUES (?, ?, ?, ?, ?, ?, ?)",
799                params![
800                    &[9u8; 16], // Valid mls_group_id
801                    &corrupted_nostr_id,
802                    "Corrupted Group",
803                    "This group has invalid state",
804                    "[]", // Valid JSON for admin_pubkeys
805                    0,
806                    "invalid_state" // Invalid state that will fail deserialization
807                ],
808            )
809            .unwrap();
810
811        // all_groups should return the two valid groups and skip the corrupted one
812        let all_groups = storage.all_groups().unwrap();
813        assert_eq!(all_groups.len(), 2);
814        assert_eq!(all_groups[0].mls_group_id, mls_group_id1);
815        assert_eq!(all_groups[1].mls_group_id, mls_group_id2);
816    }
817}