1use std::collections::BTreeSet;
4
5use mdk_storage_traits::GroupId;
6use mdk_storage_traits::groups::error::GroupError;
7use mdk_storage_traits::groups::types::{Group, GroupExporterSecret, GroupRelay, SelfUpdateState};
8use mdk_storage_traits::groups::{GroupStorage, MAX_MESSAGE_LIMIT, MessageSortOrder, Pagination};
9use mdk_storage_traits::messages::types::Message;
10use nostr::{PublicKey, RelayUrl};
11use rusqlite::{OptionalExtension, params};
12
13use crate::db::{Hash32, Nonce12};
14use crate::validation::{
15 MAX_ADMIN_PUBKEYS_JSON_SIZE, MAX_GROUP_DESCRIPTION_LENGTH, MAX_GROUP_NAME_LENGTH,
16 validate_size, validate_string_length,
17};
18use crate::{MdkSqliteStorage, db};
19
20#[inline]
21fn into_group_err<T>(e: T) -> GroupError
22where
23 T: std::error::Error,
24{
25 GroupError::DatabaseError(e.to_string())
26}
27
28impl GroupStorage for MdkSqliteStorage {
29 fn all_groups(&self) -> Result<Vec<Group>, GroupError> {
30 self.with_connection(|conn| {
31 let mut stmt = conn
32 .prepare("SELECT * FROM groups")
33 .map_err(into_group_err)?;
34
35 let groups_iter = stmt
36 .query_map([], db::row_to_group)
37 .map_err(into_group_err)?;
38
39 let mut groups: Vec<Group> = Vec::new();
40
41 for group_result in groups_iter {
42 match group_result {
43 Ok(group) => {
44 groups.push(group);
45 }
46 Err(e) => {
47 tracing::warn!(
48 error = %e,
49 "Failed to deserialize group row, skipping"
50 );
51 }
52 }
53 }
54
55 Ok(groups)
56 })
57 }
58
59 fn find_group_by_mls_group_id(
60 &self,
61 mls_group_id: &GroupId,
62 ) -> Result<Option<Group>, GroupError> {
63 self.with_connection(|conn| {
64 let mut stmt = conn
65 .prepare("SELECT * FROM groups WHERE mls_group_id = ?")
66 .map_err(into_group_err)?;
67
68 stmt.query_row([mls_group_id.as_slice()], db::row_to_group)
69 .optional()
70 .map_err(into_group_err)
71 })
72 }
73
74 fn find_group_by_nostr_group_id(
75 &self,
76 nostr_group_id: &[u8; 32],
77 ) -> Result<Option<Group>, GroupError> {
78 self.with_connection(|conn| {
79 let mut stmt = conn
80 .prepare("SELECT * FROM groups WHERE nostr_group_id = ?")
81 .map_err(into_group_err)?;
82
83 stmt.query_row(params![nostr_group_id], db::row_to_group)
84 .optional()
85 .map_err(into_group_err)
86 })
87 }
88
89 fn save_group(&self, group: Group) -> Result<(), GroupError> {
90 validate_string_length(&group.name, MAX_GROUP_NAME_LENGTH, "Group name")
92 .map_err(|e| GroupError::InvalidParameters(e.to_string()))?;
93
94 validate_string_length(
95 &group.description,
96 MAX_GROUP_DESCRIPTION_LENGTH,
97 "Group description",
98 )
99 .map_err(|e| GroupError::InvalidParameters(e.to_string()))?;
100
101 let admin_pubkeys_json: String =
102 serde_json::to_string(&group.admin_pubkeys).map_err(|e| {
103 GroupError::DatabaseError(format!("Failed to serialize admin pubkeys: {}", e))
104 })?;
105
106 validate_size(
108 admin_pubkeys_json.as_bytes(),
109 MAX_ADMIN_PUBKEYS_JSON_SIZE,
110 "Admin pubkeys JSON",
111 )
112 .map_err(|e| GroupError::InvalidParameters(e.to_string()))?;
113
114 let last_message_id: Option<&[u8; 32]> =
115 group.last_message_id.as_ref().map(|id| id.as_bytes());
116 let last_message_at: Option<u64> = group.last_message_at.as_ref().map(|ts| ts.as_secs());
117 let last_message_processed_at: Option<u64> = group
118 .last_message_processed_at
119 .as_ref()
120 .map(|ts| ts.as_secs());
121
122 let last_self_update_at: u64 = match group.self_update_state {
123 SelfUpdateState::Required => 0,
124 SelfUpdateState::CompletedAt(ts) => ts.as_secs(),
125 };
126
127 self.with_connection(|conn| {
128 conn.execute(
129 "INSERT INTO groups
130 (mls_group_id, nostr_group_id, name, description, image_hash, image_key, image_nonce, admin_pubkeys, last_message_id,
131 last_message_at, last_message_processed_at, epoch, state, last_self_update_at)
132 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
133 ON CONFLICT(mls_group_id) DO UPDATE SET
134 nostr_group_id = excluded.nostr_group_id,
135 name = excluded.name,
136 description = excluded.description,
137 image_hash = excluded.image_hash,
138 image_key = excluded.image_key,
139 image_nonce = excluded.image_nonce,
140 admin_pubkeys = excluded.admin_pubkeys,
141 last_message_id = excluded.last_message_id,
142 last_message_at = excluded.last_message_at,
143 last_message_processed_at = excluded.last_message_processed_at,
144 epoch = excluded.epoch,
145 state = excluded.state,
146 last_self_update_at = excluded.last_self_update_at",
147 params![
148 &group.mls_group_id.as_slice(),
149 &group.nostr_group_id,
150 &group.name,
151 &group.description,
152 &group.image_hash.map(Hash32::from),
153 &group.image_key.as_ref().map(|k| Hash32::from(**k)),
154 &group.image_nonce.as_ref().map(|n| Nonce12::from(**n)),
155 &admin_pubkeys_json,
156 last_message_id,
157 &last_message_at,
158 &last_message_processed_at,
159 &(group.epoch as i64),
160 group.state.as_str(),
161 &last_self_update_at
162 ],
163 )
164 .map_err(into_group_err)?;
165
166 Ok(())
167 })
168 }
169
170 fn messages(
171 &self,
172 mls_group_id: &GroupId,
173 pagination: Option<Pagination>,
174 ) -> Result<Vec<Message>, GroupError> {
175 let pagination = pagination.unwrap_or_default();
176 let limit = pagination.limit();
177 let offset = pagination.offset();
178
179 if !(1..=MAX_MESSAGE_LIMIT).contains(&limit) {
181 return Err(GroupError::InvalidParameters(format!(
182 "Limit must be between 1 and {}, got {}",
183 MAX_MESSAGE_LIMIT, limit
184 )));
185 }
186
187 if self.find_group_by_mls_group_id(mls_group_id)?.is_none() {
189 return Err(GroupError::InvalidParameters("Group not found".to_string()));
190 }
191
192 let sort_order = pagination.sort_order();
193
194 self.with_connection(|conn| {
195 let query = match sort_order {
196 MessageSortOrder::CreatedAtFirst => {
197 "SELECT * FROM messages WHERE mls_group_id = ? \
198 ORDER BY created_at DESC, processed_at DESC, id DESC \
199 LIMIT ? OFFSET ?"
200 }
201 MessageSortOrder::ProcessedAtFirst => {
202 "SELECT * FROM messages WHERE mls_group_id = ? \
203 ORDER BY processed_at DESC, created_at DESC, id DESC \
204 LIMIT ? OFFSET ?"
205 }
206 };
207
208 let mut stmt = conn.prepare(query).map_err(into_group_err)?;
209
210 let messages_iter = stmt
211 .query_map(
212 params![mls_group_id.as_slice(), limit as i64, offset as i64],
213 db::row_to_message,
214 )
215 .map_err(into_group_err)?;
216
217 let mut messages: Vec<Message> = Vec::new();
218
219 for message_result in messages_iter {
220 let message: Message = message_result.map_err(into_group_err)?;
221 messages.push(message);
222 }
223
224 Ok(messages)
225 })
226 }
227
228 fn last_message(
229 &self,
230 mls_group_id: &GroupId,
231 sort_order: MessageSortOrder,
232 ) -> Result<Option<Message>, GroupError> {
233 if self.find_group_by_mls_group_id(mls_group_id)?.is_none() {
234 return Err(GroupError::InvalidParameters("Group not found".to_string()));
235 }
236
237 self.with_connection(|conn| {
238 let query = match sort_order {
239 MessageSortOrder::CreatedAtFirst => {
240 "SELECT * FROM messages WHERE mls_group_id = ? \
241 ORDER BY created_at DESC, processed_at DESC, id DESC \
242 LIMIT 1"
243 }
244 MessageSortOrder::ProcessedAtFirst => {
245 "SELECT * FROM messages WHERE mls_group_id = ? \
246 ORDER BY processed_at DESC, created_at DESC, id DESC \
247 LIMIT 1"
248 }
249 };
250
251 conn.prepare(query)
252 .map_err(into_group_err)?
253 .query_row(params![mls_group_id.as_slice()], db::row_to_message)
254 .optional()
255 .map_err(into_group_err)
256 })
257 }
258
259 fn admins(&self, mls_group_id: &GroupId) -> Result<BTreeSet<PublicKey>, GroupError> {
260 match self.find_group_by_mls_group_id(mls_group_id)? {
262 Some(group) => Ok(group.admin_pubkeys),
263 None => Err(GroupError::InvalidParameters("Group not found".to_string())),
264 }
265 }
266
267 fn group_relays(&self, mls_group_id: &GroupId) -> Result<BTreeSet<GroupRelay>, GroupError> {
268 if self.find_group_by_mls_group_id(mls_group_id)?.is_none() {
270 return Err(GroupError::InvalidParameters("Group not found".to_string()));
271 }
272
273 self.with_connection(|conn| {
274 let mut stmt = conn
275 .prepare("SELECT * FROM group_relays WHERE mls_group_id = ?")
276 .map_err(into_group_err)?;
277
278 let relays_iter = stmt
279 .query_map(params![mls_group_id.as_slice()], db::row_to_group_relay)
280 .map_err(into_group_err)?;
281
282 let mut relays: BTreeSet<GroupRelay> = BTreeSet::new();
283
284 for relay_result in relays_iter {
285 let relay: GroupRelay = relay_result.map_err(into_group_err)?;
286 relays.insert(relay);
287 }
288
289 Ok(relays)
290 })
291 }
292
293 fn replace_group_relays(
294 &self,
295 group_id: &GroupId,
296 relays: BTreeSet<RelayUrl>,
297 ) -> Result<(), GroupError> {
298 if self.find_group_by_mls_group_id(group_id)?.is_none() {
300 return Err(GroupError::InvalidParameters("Group not found".to_string()));
301 }
302
303 self.with_connection(|conn| {
304 conn.execute_batch("SAVEPOINT mdk_replace_group_relays")
306 .map_err(into_group_err)?;
307
308 let result: Result<(), GroupError> = (|| {
309 conn.execute(
310 "DELETE FROM group_relays WHERE mls_group_id = ?",
311 params![group_id.as_slice()],
312 )
313 .map_err(into_group_err)?;
314
315 for relay_url in &relays {
316 conn.execute(
317 "INSERT INTO group_relays (mls_group_id, relay_url) VALUES (?, ?)",
318 params![group_id.as_slice(), relay_url.as_str()],
319 )
320 .map_err(into_group_err)?;
321 }
322 Ok(())
323 })();
324
325 match result {
326 Ok(()) => conn
327 .execute_batch("RELEASE SAVEPOINT mdk_replace_group_relays")
328 .map_err(into_group_err),
329 Err(e) => {
330 let _ = conn.execute_batch(
332 "ROLLBACK TO SAVEPOINT mdk_replace_group_relays; \
333 RELEASE SAVEPOINT mdk_replace_group_relays;",
334 );
335 Err(e)
336 }
337 }
338 })
339 }
340
341 fn get_group_exporter_secret(
342 &self,
343 mls_group_id: &GroupId,
344 epoch: u64,
345 ) -> Result<Option<GroupExporterSecret>, GroupError> {
346 if self.find_group_by_mls_group_id(mls_group_id)?.is_none() {
348 return Err(GroupError::InvalidParameters("Group not found".to_string()));
349 }
350
351 self.with_connection(|conn| {
352 let mut stmt = conn
353 .prepare(
354 "SELECT * FROM group_exporter_secrets WHERE mls_group_id = ? AND epoch = ? AND label = 'group-event'",
355 )
356 .map_err(into_group_err)?;
357
358 stmt.query_row(
359 params![mls_group_id.as_slice(), epoch],
360 db::row_to_group_exporter_secret,
361 )
362 .optional()
363 .map_err(into_group_err)
364 })
365 }
366
367 fn save_group_exporter_secret(
368 &self,
369 group_exporter_secret: GroupExporterSecret,
370 ) -> Result<(), GroupError> {
371 if self
372 .find_group_by_mls_group_id(&group_exporter_secret.mls_group_id)?
373 .is_none()
374 {
375 return Err(GroupError::InvalidParameters("Group not found".to_string()));
376 }
377
378 self.with_connection(|conn| {
379 conn.execute(
380 "INSERT OR REPLACE INTO group_exporter_secrets (mls_group_id, epoch, secret, label) VALUES (?, ?, ?, 'group-event')",
381 params![&group_exporter_secret.mls_group_id.as_slice(), &group_exporter_secret.epoch, group_exporter_secret.secret.as_ref()],
382 )
383 .map_err(into_group_err)?;
384
385 Ok(())
386 })
387 }
388
389 fn get_group_mip04_exporter_secret(
390 &self,
391 mls_group_id: &GroupId,
392 epoch: u64,
393 ) -> Result<Option<GroupExporterSecret>, GroupError> {
394 if self.find_group_by_mls_group_id(mls_group_id)?.is_none() {
396 return Err(GroupError::InvalidParameters("Group not found".to_string()));
397 }
398
399 self.with_connection(|conn| {
400 let mut stmt = conn
401 .prepare(
402 "SELECT * FROM group_exporter_secrets WHERE mls_group_id = ? AND epoch = ? AND label = 'encrypted-media'",
403 )
404 .map_err(into_group_err)?;
405
406 stmt.query_row(
407 params![mls_group_id.as_slice(), epoch],
408 db::row_to_group_exporter_secret,
409 )
410 .optional()
411 .map_err(into_group_err)
412 })
413 }
414
415 fn save_group_mip04_exporter_secret(
416 &self,
417 group_exporter_secret: GroupExporterSecret,
418 ) -> Result<(), GroupError> {
419 if self
420 .find_group_by_mls_group_id(&group_exporter_secret.mls_group_id)?
421 .is_none()
422 {
423 return Err(GroupError::InvalidParameters("Group not found".to_string()));
424 }
425
426 self.with_connection(|conn| {
427 conn.execute(
428 "INSERT OR REPLACE INTO group_exporter_secrets (mls_group_id, epoch, secret, label) VALUES (?, ?, ?, 'encrypted-media')",
429 params![&group_exporter_secret.mls_group_id.as_slice(), &group_exporter_secret.epoch, group_exporter_secret.secret.as_ref()],
430 )
431 .map_err(into_group_err)?;
432
433 Ok(())
434 })
435 }
436
437 fn prune_group_exporter_secrets_before_epoch(
438 &self,
439 group_id: &GroupId,
440 min_epoch_to_keep: u64,
441 ) -> Result<(), GroupError> {
442 if self.find_group_by_mls_group_id(group_id)?.is_none() {
443 return Err(GroupError::InvalidParameters("Group not found".to_string()));
444 }
445
446 self.with_connection(|conn| {
447 conn.execute(
448 "DELETE FROM group_exporter_secrets WHERE mls_group_id = ? AND epoch < ?",
449 params![group_id.as_slice(), min_epoch_to_keep],
450 )
451 .map_err(into_group_err)?;
452
453 Ok(())
454 })
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use mdk_storage_traits::Secret;
461 use mdk_storage_traits::groups::types::GroupState;
462 use mdk_storage_traits::messages::MessageStorage;
463 use mdk_storage_traits::messages::types::MessageState;
464 use mdk_storage_traits::test_utils::crypto_utils::generate_random_bytes;
465 use nostr::{EventId, Kind, Tags, Timestamp, UnsignedEvent};
466 use rusqlite::Connection;
467 use tempfile::tempdir;
468
469 use super::*;
470
471 #[test]
472 fn test_save_and_find_group() {
473 let storage = MdkSqliteStorage::new_in_memory().unwrap();
474
475 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
477 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
478 let image_hash = Some(generate_random_bytes(32).try_into().unwrap());
479 let image_key = Some(Secret::new(generate_random_bytes(32).try_into().unwrap()));
480 let image_nonce = Some(Secret::new(generate_random_bytes(12).try_into().unwrap()));
481
482 let group = Group {
483 mls_group_id: mls_group_id.clone(),
484 nostr_group_id,
485 name: "Test Group".to_string(),
486 description: "A test group".to_string(),
487 admin_pubkeys: BTreeSet::new(),
488 last_message_id: None,
489 last_message_at: None,
490 last_message_processed_at: None,
491 epoch: 0,
492 state: GroupState::Active,
493 image_hash,
494 image_key,
495 image_nonce,
496 self_update_state: SelfUpdateState::Required,
497 };
498
499 let result = storage.save_group(group);
501 assert!(result.is_ok());
502
503 let found_group = storage
505 .find_group_by_mls_group_id(&mls_group_id)
506 .unwrap()
507 .unwrap();
508 assert_eq!(found_group.nostr_group_id, nostr_group_id);
509
510 let found_group = storage
512 .find_group_by_nostr_group_id(&nostr_group_id)
513 .unwrap()
514 .unwrap();
515 assert_eq!(found_group.mls_group_id, mls_group_id);
516
517 let all_groups = storage.all_groups().unwrap();
519 assert_eq!(all_groups.len(), 1);
520 }
521
522 #[test]
523 fn test_group_name_length_validation() {
524 let storage = MdkSqliteStorage::new_in_memory().unwrap();
525
526 let oversized_name = "x".repeat(256);
528
529 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
530 let group = Group {
531 mls_group_id: mls_group_id.clone(),
532 nostr_group_id: [0u8; 32],
533 name: oversized_name,
534 description: "Test".to_string(),
535 admin_pubkeys: BTreeSet::new(),
536 last_message_id: None,
537 last_message_at: None,
538 last_message_processed_at: None,
539 epoch: 0,
540 state: GroupState::Active,
541 image_hash: None,
542 image_key: None,
543 image_nonce: None,
544 self_update_state: SelfUpdateState::Required,
545 };
546
547 let result = storage.save_group(group);
549 assert!(result.is_err());
550 assert!(
551 result
552 .unwrap_err()
553 .to_string()
554 .contains("Group name exceeds maximum length")
555 );
556 }
557
558 #[test]
559 fn test_group_description_length_validation() {
560 let storage = MdkSqliteStorage::new_in_memory().unwrap();
561
562 let oversized_description = "x".repeat(2001);
564
565 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
566 let group = Group {
567 mls_group_id: mls_group_id.clone(),
568 nostr_group_id: [0u8; 32],
569 name: "Test Group".to_string(),
570 description: oversized_description,
571 admin_pubkeys: BTreeSet::new(),
572 last_message_id: None,
573 last_message_at: None,
574 last_message_processed_at: None,
575 epoch: 0,
576 state: GroupState::Active,
577 image_hash: None,
578 image_key: None,
579 image_nonce: None,
580 self_update_state: SelfUpdateState::Required,
581 };
582
583 let result = storage.save_group(group);
585 assert!(result.is_err());
586 assert!(
587 result
588 .unwrap_err()
589 .to_string()
590 .contains("Group description exceeds maximum length")
591 );
592 }
593
594 #[test]
598 fn test_messages_pagination() {
599 let storage = MdkSqliteStorage::new_in_memory().unwrap();
600
601 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
603 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
604
605 let group = Group {
606 mls_group_id: mls_group_id.clone(),
607 nostr_group_id,
608 name: "Test Group".to_string(),
609 description: "A test group".to_string(),
610 admin_pubkeys: BTreeSet::new(),
611 last_message_id: None,
612 last_message_at: None,
613 last_message_processed_at: None,
614 epoch: 0,
615 state: GroupState::Active,
616 image_hash: None,
617 image_key: None,
618 image_nonce: None,
619 self_update_state: SelfUpdateState::Required,
620 };
621
622 storage.save_group(group).unwrap();
623
624 let pubkey = PublicKey::from_slice(&[1u8; 32]).unwrap();
626 for i in 0..25 {
627 let event_id = EventId::from_slice(&[i as u8; 32]).unwrap();
628 let wrapper_event_id = EventId::from_slice(&[100 + i as u8; 32]).unwrap();
629
630 let ts = Timestamp::from((1000 + i) as u64);
631 let message = Message {
632 id: event_id,
633 pubkey,
634 kind: Kind::from(1u16),
635 mls_group_id: mls_group_id.clone(),
636 created_at: ts,
637 processed_at: ts,
638 content: format!("Message {}", i),
639 tags: Tags::new(),
640 event: UnsignedEvent::new(
641 pubkey,
642 ts,
643 Kind::from(9u16),
644 vec![],
645 format!("content {}", i),
646 ),
647 wrapper_event_id,
648 state: MessageState::Created,
649 epoch: None,
650 };
651
652 storage.save_message(message).unwrap();
653 }
654
655 let page1 = storage
657 .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(0))))
658 .unwrap();
659 assert_eq!(page1.len(), 10);
660 assert_eq!(page1[0].content, "Message 24");
662
663 let page2 = storage
664 .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(10))))
665 .unwrap();
666 assert_eq!(page2.len(), 10);
667 assert_eq!(page2[0].content, "Message 14");
668
669 let page3 = storage
670 .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(20))))
671 .unwrap();
672 assert_eq!(page3.len(), 5); assert_eq!(page3[0].content, "Message 4");
674
675 let default_messages = storage.messages(&mls_group_id, None).unwrap();
677 assert_eq!(default_messages.len(), 25); let first_id = page1[0].id;
681 let second_page_ids: Vec<EventId> = page2.iter().map(|m| m.id).collect();
682 assert!(
683 !second_page_ids.contains(&first_id),
684 "Pages should not overlap"
685 );
686
687 let beyond = storage
689 .messages(&mls_group_id, Some(Pagination::new(Some(10), Some(30))))
690 .unwrap();
691 assert_eq!(beyond.len(), 0);
692
693 let result = storage.messages(&mls_group_id, Some(Pagination::new(Some(0), Some(0))));
695 assert!(result.is_err());
696 assert!(
697 result
698 .unwrap_err()
699 .to_string()
700 .contains("must be between 1 and")
701 );
702
703 let result = storage.messages(&mls_group_id, Some(Pagination::new(Some(20000), Some(0))));
705 assert!(result.is_err());
706 assert!(
707 result
708 .unwrap_err()
709 .to_string()
710 .contains("must be between 1 and")
711 );
712
713 let fake_group_id = GroupId::from_slice(&[99, 99, 99, 99]);
715 let result = storage.messages(&fake_group_id, Some(Pagination::new(Some(10), Some(0))));
716 assert!(result.is_err());
717 assert!(result.unwrap_err().to_string().contains("not found"));
718
719 let result = storage.messages(
721 &mls_group_id,
722 Some(Pagination::new(Some(10), Some(2_000_000))),
723 );
724 assert!(result.is_ok());
725 assert_eq!(result.unwrap().len(), 0); }
727
728 #[test]
729 fn test_group_exporter_secret() {
730 let storage = MdkSqliteStorage::new_in_memory().unwrap();
731
732 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
734 let nostr_group_id = generate_random_bytes(32).try_into().unwrap();
735
736 let group = Group {
737 mls_group_id: mls_group_id.clone(),
738 nostr_group_id,
739 name: "Test Group".to_string(),
740 description: "A test group".to_string(),
741 admin_pubkeys: BTreeSet::new(),
742 last_message_id: None,
743 last_message_at: None,
744 last_message_processed_at: None,
745 epoch: 0,
746 state: GroupState::Active,
747 image_hash: None,
748 image_key: None,
749 image_nonce: None,
750 self_update_state: SelfUpdateState::Required,
751 };
752
753 storage.save_group(group).unwrap();
755
756 let secret1 = GroupExporterSecret {
758 mls_group_id: mls_group_id.clone(),
759 epoch: 1,
760 secret: Secret::new([0u8; 32]),
761 };
762
763 storage.save_group_exporter_secret(secret1).unwrap();
765
766 let retrieved_secret = storage
768 .get_group_exporter_secret(&mls_group_id, 1)
769 .unwrap()
770 .unwrap();
771 assert_eq!(*retrieved_secret.secret, [0u8; 32]);
772
773 let secret2 = GroupExporterSecret {
775 mls_group_id: mls_group_id.clone(),
776 epoch: 1,
777 secret: Secret::new([0u8; 32]),
778 };
779
780 storage.save_group_exporter_secret(secret2).unwrap();
782
783 let retrieved_secret = storage
785 .get_group_exporter_secret(&mls_group_id, 1)
786 .unwrap()
787 .unwrap();
788 assert_eq!(*retrieved_secret.secret, [0u8; 32]);
789
790 let secret3 = GroupExporterSecret {
792 mls_group_id: mls_group_id.clone(),
793 epoch: 2,
794 secret: Secret::new([0u8; 32]),
795 };
796
797 storage.save_group_exporter_secret(secret3).unwrap();
798
799 let retrieved_secret1 = storage
801 .get_group_exporter_secret(&mls_group_id, 1)
802 .unwrap()
803 .unwrap();
804 let retrieved_secret2 = storage
805 .get_group_exporter_secret(&mls_group_id, 2)
806 .unwrap()
807 .unwrap();
808
809 assert_eq!(*retrieved_secret1.secret, [0u8; 32]);
810 assert_eq!(*retrieved_secret2.secret, [0u8; 32]);
811 }
812
813 #[test]
814 fn test_all_groups_skips_corrupted_rows() {
815 let temp_dir = tempdir().unwrap();
817 let db_path = temp_dir.path().join("test.db");
818 let storage = MdkSqliteStorage::new_unencrypted(&db_path).unwrap();
819
820 let mls_group_id1 = GroupId::from_slice(&[1, 2, 3, 4]);
822 let nostr_group_id1 = generate_random_bytes(32).try_into().unwrap();
823 let group1 = Group {
824 mls_group_id: mls_group_id1.clone(),
825 nostr_group_id: nostr_group_id1,
826 name: "Group 1".to_string(),
827 description: "First group".to_string(),
828 admin_pubkeys: BTreeSet::new(),
829 last_message_id: None,
830 last_message_at: None,
831 last_message_processed_at: None,
832 epoch: 0,
833 state: GroupState::Active,
834 image_hash: None,
835 image_key: None,
836 image_nonce: None,
837 self_update_state: SelfUpdateState::Required,
838 };
839 storage.save_group(group1).unwrap();
840
841 let mls_group_id2 = GroupId::from_slice(&[5, 6, 7, 8]);
842 let nostr_group_id2 = generate_random_bytes(32).try_into().unwrap();
843 let group2 = Group {
844 mls_group_id: mls_group_id2.clone(),
845 nostr_group_id: nostr_group_id2,
846 name: "Group 2".to_string(),
847 description: "Second group".to_string(),
848 admin_pubkeys: BTreeSet::new(),
849 last_message_id: None,
850 last_message_at: None,
851 last_message_processed_at: None,
852 epoch: 0,
853 state: GroupState::Active,
854 image_hash: None,
855 image_key: None,
856 image_nonce: None,
857 self_update_state: SelfUpdateState::Required,
858 };
859 storage.save_group(group2).unwrap();
860
861 let corrupt_conn = Connection::open(&db_path).unwrap();
862 let corrupted_nostr_id_bytes = generate_random_bytes(32);
863 let corrupted_nostr_id: [u8; 32] = corrupted_nostr_id_bytes.try_into().unwrap();
864 corrupt_conn
865 .execute(
866 "INSERT INTO groups (mls_group_id, nostr_group_id, name, description, admin_pubkeys, epoch, state) VALUES (?, ?, ?, ?, ?, ?, ?)",
867 params![
868 &[9u8; 16], &corrupted_nostr_id,
870 "Corrupted Group",
871 "This group has invalid state",
872 "[]", 0,
874 "invalid_state" ],
876 )
877 .unwrap();
878
879 let all_groups = storage.all_groups().unwrap();
881 assert_eq!(all_groups.len(), 2);
882 assert_eq!(all_groups[0].mls_group_id, mls_group_id1);
883 assert_eq!(all_groups[1].mls_group_id, mls_group_id2);
884 }
885}