Skip to main content

mdk_sqlite_storage/
welcomes.rs

1//! Implementation of WelcomeStorage trait for SQLite storage.
2
3use mdk_storage_traits::welcomes::error::WelcomeError;
4use mdk_storage_traits::welcomes::types::{ProcessedWelcome, Welcome};
5use mdk_storage_traits::welcomes::{MAX_PENDING_WELCOMES_LIMIT, Pagination, WelcomeStorage};
6use nostr::{EventId, JsonUtil};
7use rusqlite::{OptionalExtension, params};
8
9use crate::db::{Hash32, Nonce12};
10use crate::validation::{
11    MAX_ADMIN_PUBKEYS_JSON_SIZE, MAX_EVENT_JSON_SIZE, MAX_GROUP_DESCRIPTION_LENGTH,
12    MAX_GROUP_NAME_LENGTH, MAX_GROUP_RELAYS_JSON_SIZE, validate_size, validate_string_length,
13};
14use crate::{MdkSqliteStorage, db};
15
16#[inline]
17fn into_welcome_err<T>(e: T) -> WelcomeError
18where
19    T: std::error::Error,
20{
21    WelcomeError::DatabaseError(e.to_string())
22}
23
24impl WelcomeStorage for MdkSqliteStorage {
25    fn save_welcome(&self, welcome: Welcome) -> Result<(), WelcomeError> {
26        // Validate group name and description lengths
27        validate_string_length(&welcome.group_name, MAX_GROUP_NAME_LENGTH, "Group name")
28            .map_err(|e| WelcomeError::InvalidParameters(e.to_string()))?;
29
30        validate_string_length(
31            &welcome.group_description,
32            MAX_GROUP_DESCRIPTION_LENGTH,
33            "Group description",
34        )
35        .map_err(|e| WelcomeError::InvalidParameters(e.to_string()))?;
36
37        // Serialize complex types to JSON
38        let group_admin_pubkeys_json: String = serde_json::to_string(&welcome.group_admin_pubkeys)
39            .map_err(|e| {
40                WelcomeError::DatabaseError(format!("Failed to serialize admin pubkeys: {}", e))
41            })?;
42
43        // Validate admin pubkeys JSON size
44        validate_size(
45            group_admin_pubkeys_json.as_bytes(),
46            MAX_ADMIN_PUBKEYS_JSON_SIZE,
47            "Admin pubkeys JSON",
48        )
49        .map_err(|e| WelcomeError::InvalidParameters(e.to_string()))?;
50
51        let group_relays_json: String =
52            serde_json::to_string(&welcome.group_relays).map_err(|e| {
53                WelcomeError::DatabaseError(format!("Failed to serialize group relays: {}", e))
54            })?;
55
56        // Validate group relays JSON size
57        validate_size(
58            group_relays_json.as_bytes(),
59            MAX_GROUP_RELAYS_JSON_SIZE,
60            "Group relays JSON",
61        )
62        .map_err(|e| WelcomeError::InvalidParameters(e.to_string()))?;
63
64        // Serialize event to JSON
65        let event_json = welcome.event.as_json();
66
67        // Validate event JSON size
68        validate_size(event_json.as_bytes(), MAX_EVENT_JSON_SIZE, "Event JSON")
69            .map_err(|e| WelcomeError::InvalidParameters(e.to_string()))?;
70
71        self.with_connection(|conn| {
72            conn.execute(
73                "INSERT OR REPLACE INTO welcomes
74             (id, event, mls_group_id, nostr_group_id, group_name, group_description, group_image_hash, group_image_key, group_image_nonce,
75              group_admin_pubkeys, group_relays, welcomer, member_count, state, wrapper_event_id)
76             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
77                params![
78                    welcome.id.as_bytes(),
79                    &event_json,
80                    welcome.mls_group_id.as_slice(),
81                    &welcome.nostr_group_id,
82                    &welcome.group_name,
83                    &welcome.group_description,
84                    welcome.group_image_hash.map(Hash32::from),
85                    welcome.group_image_key.as_ref().map(|k| Hash32::from(**k)),
86                    welcome.group_image_nonce.as_ref().map(|n| Nonce12::from(**n)),
87                    &group_admin_pubkeys_json,
88                    &group_relays_json,
89                    welcome.welcomer.as_bytes(),
90                    welcome.member_count as u64,
91                    welcome.state.as_str(),
92                    welcome.wrapper_event_id.as_bytes(),
93                ],
94            )
95            .map_err(into_welcome_err)?;
96
97            Ok(())
98        })
99    }
100
101    fn find_welcome_by_event_id(
102        &self,
103        event_id: &EventId,
104    ) -> Result<Option<Welcome>, WelcomeError> {
105        self.with_connection(|conn| {
106            let mut stmt = conn
107                .prepare("SELECT * FROM welcomes WHERE id = ?")
108                .map_err(into_welcome_err)?;
109
110            stmt.query_row(params![event_id.as_bytes()], db::row_to_welcome)
111                .optional()
112                .map_err(into_welcome_err)
113        })
114    }
115
116    fn pending_welcomes(
117        &self,
118        pagination: Option<Pagination>,
119    ) -> Result<Vec<Welcome>, WelcomeError> {
120        let pagination = pagination.unwrap_or_default();
121        let limit = pagination.limit();
122        let offset = pagination.offset();
123
124        // Validate limit is within allowed range
125        if !(1..=MAX_PENDING_WELCOMES_LIMIT).contains(&limit) {
126            return Err(WelcomeError::InvalidParameters(format!(
127                "Limit must be between 1 and {}, got {}",
128                MAX_PENDING_WELCOMES_LIMIT, limit
129            )));
130        }
131
132        self.with_connection(|conn| {
133            let mut stmt = conn
134                .prepare(
135                    "SELECT * FROM welcomes WHERE state = 'pending' 
136                     ORDER BY id DESC 
137                     LIMIT ? OFFSET ?",
138                )
139                .map_err(into_welcome_err)?;
140
141            let welcomes_iter = stmt
142                .query_map(params![limit as i64, offset as i64], db::row_to_welcome)
143                .map_err(into_welcome_err)?;
144
145            let mut welcomes: Vec<Welcome> = Vec::new();
146
147            for welcome_result in welcomes_iter {
148                let welcome: Welcome = welcome_result.map_err(into_welcome_err)?;
149                welcomes.push(welcome);
150            }
151
152            Ok(welcomes)
153        })
154    }
155
156    fn save_processed_welcome(
157        &self,
158        processed_welcome: ProcessedWelcome,
159    ) -> Result<(), WelcomeError> {
160        // Convert welcome_event_id to string if it exists
161        let welcome_event_id: Option<&[u8; 32]> = processed_welcome
162            .welcome_event_id
163            .as_ref()
164            .map(|id| id.as_bytes());
165
166        self.with_connection(|conn| {
167            conn.execute(
168                "INSERT OR REPLACE INTO processed_welcomes
169             (wrapper_event_id, welcome_event_id, processed_at, state, failure_reason)
170             VALUES (?, ?, ?, ?, ?)",
171                params![
172                    processed_welcome.wrapper_event_id.as_bytes(),
173                    welcome_event_id,
174                    processed_welcome.processed_at.as_secs(),
175                    processed_welcome.state.as_str(),
176                    &processed_welcome.failure_reason
177                ],
178            )
179            .map_err(into_welcome_err)?;
180
181            Ok(())
182        })
183    }
184
185    fn find_processed_welcome_by_event_id(
186        &self,
187        event_id: &EventId,
188    ) -> Result<Option<ProcessedWelcome>, WelcomeError> {
189        self.with_connection(|conn| {
190            let mut stmt = conn
191                .prepare("SELECT * FROM processed_welcomes WHERE wrapper_event_id = ?")
192                .map_err(into_welcome_err)?;
193
194            stmt.query_row(params![event_id.as_bytes()], db::row_to_processed_welcome)
195                .optional()
196                .map_err(into_welcome_err)
197        })
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use std::collections::BTreeSet;
204
205    use mdk_storage_traits::GroupId;
206    use mdk_storage_traits::groups::GroupStorage;
207    use mdk_storage_traits::test_utils::cross_storage::{
208        create_test_group, create_test_processed_welcome, create_test_welcome,
209    };
210    use mdk_storage_traits::welcomes::types::{ProcessedWelcomeState, Welcome, WelcomeState};
211    use nostr::{EventId, Kind, PublicKey, Timestamp, UnsignedEvent};
212
213    use super::*;
214
215    #[test]
216    fn test_save_and_find_welcome() {
217        let storage = MdkSqliteStorage::new_in_memory().unwrap();
218
219        // First create a group (welcomes require a valid group foreign key)
220        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
221        let group = create_test_group(mls_group_id.clone());
222
223        // Save the group
224        let result = storage.save_group(group);
225        assert!(result.is_ok(), "{:?}", result);
226
227        // Create a test welcome using the helper
228        let event_id = EventId::all_zeros();
229        let welcome = create_test_welcome(mls_group_id.clone(), event_id);
230
231        // Save the welcome
232        let result = storage.save_welcome(welcome.clone());
233        assert!(result.is_ok(), "{:?}", result);
234
235        // Find by event ID
236        let found_welcome = storage
237            .find_welcome_by_event_id(&event_id)
238            .unwrap()
239            .unwrap();
240        assert_eq!(found_welcome.id, event_id);
241        assert_eq!(found_welcome.mls_group_id, mls_group_id);
242        assert_eq!(found_welcome.state, welcome.state);
243
244        // Test pending welcomes
245        let pending_welcomes = storage.pending_welcomes(None).unwrap();
246        assert_eq!(pending_welcomes.len(), 1);
247        assert_eq!(pending_welcomes[0].id, event_id);
248    }
249
250    #[test]
251    fn test_processed_welcome() {
252        let storage = MdkSqliteStorage::new_in_memory().unwrap();
253
254        // Create test event IDs using helper methods
255        let wrapper_event_id = EventId::all_zeros();
256        let welcome_event_id =
257            EventId::from_hex("1111111111111111111111111111111111111111111111111111111111111111")
258                .unwrap();
259
260        // Create a test processed welcome using the helper
261        let processed_welcome =
262            create_test_processed_welcome(wrapper_event_id, Some(welcome_event_id));
263
264        // Save the processed welcome
265        let result = storage.save_processed_welcome(processed_welcome.clone());
266        assert!(result.is_ok());
267
268        // Find by event ID
269        let found_processed_welcome = storage
270            .find_processed_welcome_by_event_id(&wrapper_event_id)
271            .unwrap()
272            .unwrap();
273        assert_eq!(found_processed_welcome.wrapper_event_id, wrapper_event_id);
274        assert_eq!(
275            found_processed_welcome.welcome_event_id.unwrap(),
276            welcome_event_id
277        );
278        assert_eq!(
279            found_processed_welcome.state,
280            ProcessedWelcomeState::Processed
281        );
282    }
283
284    #[test]
285    fn test_welcome_group_name_length_validation() {
286        let storage = MdkSqliteStorage::new_in_memory().unwrap();
287
288        // Create a group first
289        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
290        let group = create_test_group(mls_group_id.clone());
291        storage.save_group(group).unwrap();
292
293        // Create a welcome with oversized group name
294        let oversized_name = "x".repeat(256);
295
296        let event_id = EventId::all_zeros();
297        let pubkey = PublicKey::from_slice(&[1u8; 32]).unwrap();
298        let wrapper_event_id =
299            EventId::from_hex("1111111111111111111111111111111111111111111111111111111111111111")
300                .unwrap();
301
302        let welcome = Welcome {
303            id: event_id,
304            event: UnsignedEvent::new(
305                pubkey,
306                Timestamp::now(),
307                Kind::from(444u16),
308                vec![],
309                "content".to_string(),
310            ),
311            mls_group_id: mls_group_id.clone(),
312            nostr_group_id: [0u8; 32],
313            group_name: oversized_name,
314            group_description: "Test".to_string(),
315            group_image_hash: None,
316            group_image_key: None,
317            group_image_nonce: None,
318            group_admin_pubkeys: BTreeSet::new(),
319            group_relays: BTreeSet::new(),
320            welcomer: pubkey,
321            member_count: 1,
322            state: WelcomeState::Pending,
323            wrapper_event_id,
324        };
325
326        // Should fail due to group name length
327        let result = storage.save_welcome(welcome);
328        assert!(result.is_err());
329        assert!(
330            result
331                .unwrap_err()
332                .to_string()
333                .contains("Group name exceeds maximum length")
334        );
335    }
336
337    #[test]
338    fn test_pending_welcomes_pagination() {
339        let storage = MdkSqliteStorage::new_in_memory().unwrap();
340
341        // Create a group first
342        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
343        let group = create_test_group(mls_group_id.clone());
344        storage.save_group(group).unwrap();
345
346        // Create 25 pending welcomes
347        for i in 0..25 {
348            let event_id = EventId::from_hex(&format!(
349                "{:064x}",
350                i + 1 // Start from 1 to avoid all_zeros
351            ))
352            .unwrap();
353            let welcome = create_test_welcome(mls_group_id.clone(), event_id);
354            storage.save_welcome(welcome).unwrap();
355        }
356
357        // Test: Get all pending welcomes (should use default limit of 1000)
358        let all_welcomes = storage.pending_welcomes(None).unwrap();
359        assert_eq!(all_welcomes.len(), 25);
360
361        // Test: Get first 10 welcomes
362        let first_10 = storage
363            .pending_welcomes(Some(Pagination::new(Some(10), Some(0))))
364            .unwrap();
365        assert_eq!(first_10.len(), 10);
366
367        // Test: Get next 10 welcomes (offset 10)
368        let next_10 = storage
369            .pending_welcomes(Some(Pagination::new(Some(10), Some(10))))
370            .unwrap();
371        assert_eq!(next_10.len(), 10);
372
373        // Test: Get last 5 welcomes (offset 20)
374        let last_5 = storage
375            .pending_welcomes(Some(Pagination::new(Some(10), Some(20))))
376            .unwrap();
377        assert_eq!(last_5.len(), 5);
378
379        // Test: Offset beyond available welcomes
380        let beyond = storage
381            .pending_welcomes(Some(Pagination::new(Some(10), Some(30))))
382            .unwrap();
383        assert_eq!(beyond.len(), 0);
384
385        // Test: Verify no overlap between pages
386        let first_id = first_10[0].id;
387        let second_page_ids: Vec<EventId> = next_10.iter().map(|w| w.id).collect();
388        assert!(!second_page_ids.contains(&first_id));
389
390        // Test: Limit of 0 should return error
391        let result = storage.pending_welcomes(Some(Pagination::new(Some(0), Some(0))));
392        assert!(result.is_err());
393        assert!(
394            result
395                .unwrap_err()
396                .to_string()
397                .contains("must be between 1 and")
398        );
399
400        // Test: Limit exceeding MAX should return error
401        let result = storage.pending_welcomes(Some(Pagination::new(Some(20000), Some(0))));
402        assert!(result.is_err());
403        assert!(
404            result
405                .unwrap_err()
406                .to_string()
407                .contains("must be between 1 and")
408        );
409
410        // Test: Large offset should work (no MAX_OFFSET validation)
411        let result = storage.pending_welcomes(Some(Pagination::new(Some(10), Some(2_000_000))));
412        assert!(result.is_ok());
413        assert_eq!(result.unwrap().len(), 0); // No results at that offset
414    }
415}