Skip to main content

mdk_memory_storage/
welcomes.rs

1//! Memory-based storage implementation of the MdkStorageProvider trait for MDK welcomes
2
3use mdk_storage_traits::welcomes::error::WelcomeError;
4use mdk_storage_traits::welcomes::types::*;
5use mdk_storage_traits::welcomes::{MAX_PENDING_WELCOMES_LIMIT, Pagination, WelcomeStorage};
6use nostr::EventId;
7
8use crate::MdkMemoryStorage;
9
10impl WelcomeStorage for MdkMemoryStorage {
11    fn save_welcome(&self, welcome: Welcome) -> Result<(), WelcomeError> {
12        // Validate relay count to prevent memory exhaustion
13        if welcome.group_relays.len() > self.limits.max_relays_per_welcome {
14            return Err(WelcomeError::InvalidParameters(format!(
15                "Welcome relay count exceeds maximum of {} (got {})",
16                self.limits.max_relays_per_welcome,
17                welcome.group_relays.len()
18            )));
19        }
20
21        // Validate individual relay URL lengths
22        for relay in &welcome.group_relays {
23            if relay.as_str().len() > self.limits.max_relay_url_length {
24                return Err(WelcomeError::InvalidParameters(format!(
25                    "Relay URL exceeds maximum length of {} bytes",
26                    self.limits.max_relay_url_length
27                )));
28            }
29        }
30
31        // Validate admin pubkeys count to prevent memory exhaustion
32        if welcome.group_admin_pubkeys.len() > self.limits.max_admins_per_welcome {
33            return Err(WelcomeError::InvalidParameters(format!(
34                "Welcome admin count exceeds maximum of {} (got {})",
35                self.limits.max_admins_per_welcome,
36                welcome.group_admin_pubkeys.len()
37            )));
38        }
39
40        let mut inner = self.inner.write();
41        inner.welcomes_cache.put(welcome.id, welcome);
42
43        Ok(())
44    }
45
46    fn pending_welcomes(
47        &self,
48        pagination: Option<Pagination>,
49    ) -> Result<Vec<Welcome>, WelcomeError> {
50        let pagination = pagination.unwrap_or_default();
51        let limit = pagination.limit();
52        let offset = pagination.offset();
53
54        // Validate limit is within allowed range
55        if !(1..=MAX_PENDING_WELCOMES_LIMIT).contains(&limit) {
56            return Err(WelcomeError::InvalidParameters(format!(
57                "Limit must be between 1 and {}, got {}",
58                MAX_PENDING_WELCOMES_LIMIT, limit
59            )));
60        }
61
62        let inner = self.inner.read();
63        let mut welcomes: Vec<Welcome> = inner
64            .welcomes_cache
65            .iter()
66            .map(|(_, v)| v.clone())
67            .filter(|welcome| welcome.state == WelcomeState::Pending)
68            .collect();
69
70        // Sort by ID (descending) for consistent ordering
71        welcomes.sort_by(|a, b| b.id.cmp(&a.id));
72
73        // Apply pagination
74        let welcomes: Vec<Welcome> = welcomes.into_iter().skip(offset).take(limit).collect();
75
76        Ok(welcomes)
77    }
78
79    fn find_welcome_by_event_id(
80        &self,
81        event_id: &EventId,
82    ) -> Result<Option<Welcome>, WelcomeError> {
83        let inner = self.inner.read();
84        Ok(inner.welcomes_cache.peek(event_id).cloned())
85    }
86
87    fn save_processed_welcome(
88        &self,
89        processed_welcome: ProcessedWelcome,
90    ) -> Result<(), WelcomeError> {
91        let mut inner = self.inner.write();
92        inner
93            .processed_welcomes_cache
94            .put(processed_welcome.wrapper_event_id, processed_welcome);
95
96        Ok(())
97    }
98
99    fn find_processed_welcome_by_event_id(
100        &self,
101        event_id: &EventId,
102    ) -> Result<Option<ProcessedWelcome>, WelcomeError> {
103        let inner = self.inner.read();
104        Ok(inner.processed_welcomes_cache.peek(event_id).cloned())
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use std::collections::BTreeSet;
111
112    use mdk_storage_traits::GroupId;
113    use mdk_storage_traits::test_utils::cross_storage::create_test_welcome;
114    use nostr::{EventId, Keys, Kind, PublicKey, RelayUrl, Tags, Timestamp, UnsignedEvent};
115
116    use super::*;
117    use crate::{
118        DEFAULT_MAX_ADMINS_PER_WELCOME, DEFAULT_MAX_RELAY_URL_LENGTH,
119        DEFAULT_MAX_RELAYS_PER_WELCOME,
120    };
121
122    fn create_welcome_with_relays(
123        mls_group_id: GroupId,
124        event_id: EventId,
125        relay_count: usize,
126    ) -> Welcome {
127        let pubkey =
128            PublicKey::parse("npub1a6awmmklxfmspwdv52qq58sk5c07kghwc4v2eaudjx2ju079cdqs2452ys")
129                .unwrap();
130        let created_at = Timestamp::now();
131        let content = "Test welcome content".to_string();
132        let tags = Tags::new();
133
134        let event = UnsignedEvent {
135            id: Some(event_id),
136            pubkey,
137            created_at,
138            kind: Kind::Custom(444),
139            tags,
140            content,
141        };
142
143        let mut relays = BTreeSet::new();
144        for i in 0..relay_count {
145            relays.insert(RelayUrl::parse(&format!("wss://relay{}.example.com", i)).unwrap());
146        }
147
148        Welcome {
149            id: event_id,
150            event,
151            mls_group_id,
152            nostr_group_id: [0u8; 32],
153            group_name: "Test Group".to_string(),
154            group_description: "A test group".to_string(),
155            group_image_hash: None,
156            group_image_key: None,
157            group_image_nonce: None,
158            group_admin_pubkeys: BTreeSet::from([pubkey]),
159            group_relays: relays,
160            welcomer: pubkey,
161            member_count: 1,
162            state: WelcomeState::Pending,
163            wrapper_event_id: event_id,
164        }
165    }
166
167    fn create_welcome_with_admins(
168        mls_group_id: GroupId,
169        event_id: EventId,
170        admin_count: usize,
171    ) -> Welcome {
172        let pubkey =
173            PublicKey::parse("npub1a6awmmklxfmspwdv52qq58sk5c07kghwc4v2eaudjx2ju079cdqs2452ys")
174                .unwrap();
175        let created_at = Timestamp::now();
176        let content = "Test welcome content".to_string();
177        let tags = Tags::new();
178
179        let event = UnsignedEvent {
180            id: Some(event_id),
181            pubkey,
182            created_at,
183            kind: Kind::Custom(444),
184            tags,
185            content,
186        };
187
188        let mut admins = BTreeSet::new();
189        for _ in 0..admin_count {
190            admins.insert(Keys::generate().public_key());
191        }
192
193        Welcome {
194            id: event_id,
195            event,
196            mls_group_id,
197            nostr_group_id: [0u8; 32],
198            group_name: "Test Group".to_string(),
199            group_description: "A test group".to_string(),
200            group_image_hash: None,
201            group_image_key: None,
202            group_image_nonce: None,
203            group_admin_pubkeys: admins,
204            group_relays: BTreeSet::from([RelayUrl::parse("wss://relay.example.com").unwrap()]),
205            welcomer: pubkey,
206            member_count: 1,
207            state: WelcomeState::Pending,
208            wrapper_event_id: event_id,
209        }
210    }
211
212    #[test]
213    fn test_save_welcome_relay_count_validation() {
214        let storage = MdkMemoryStorage::new();
215        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
216
217        // Test with relay count at exactly the limit (should succeed)
218        let event_id = EventId::from_hex(&format!("{:064x}", 1)).unwrap();
219        let welcome = create_welcome_with_relays(
220            mls_group_id.clone(),
221            event_id,
222            DEFAULT_MAX_RELAYS_PER_WELCOME,
223        );
224        assert!(storage.save_welcome(welcome).is_ok());
225
226        // Test with relay count exceeding the limit (should fail)
227        let event_id = EventId::from_hex(&format!("{:064x}", 2)).unwrap();
228        let welcome = create_welcome_with_relays(
229            mls_group_id.clone(),
230            event_id,
231            DEFAULT_MAX_RELAYS_PER_WELCOME + 1,
232        );
233        let result = storage.save_welcome(welcome);
234        assert!(result.is_err());
235        assert!(
236            result
237                .unwrap_err()
238                .to_string()
239                .contains("Welcome relay count exceeds maximum")
240        );
241    }
242
243    #[test]
244    fn test_save_welcome_relay_url_length_validation() {
245        let storage = MdkMemoryStorage::new();
246        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
247        let event_id = EventId::from_hex(&format!("{:064x}", 1)).unwrap();
248
249        // Test with URL at exactly the limit (should succeed)
250        let domain = "a".repeat(DEFAULT_MAX_RELAY_URL_LENGTH - 10);
251        let url = format!("wss://{}.com", domain);
252        let mut welcome = create_test_welcome(mls_group_id.clone(), event_id);
253        welcome.group_relays = BTreeSet::from([RelayUrl::parse(&url).unwrap()]);
254        assert!(storage.save_welcome(welcome).is_ok());
255
256        // Test with URL exceeding the limit (should fail)
257        let event_id = EventId::from_hex(&format!("{:064x}", 2)).unwrap();
258        let domain = "a".repeat(DEFAULT_MAX_RELAY_URL_LENGTH);
259        let url = format!("wss://{}.com", domain);
260        let mut welcome = create_test_welcome(mls_group_id.clone(), event_id);
261        welcome.group_relays = BTreeSet::from([RelayUrl::parse(&url).unwrap()]);
262        let result = storage.save_welcome(welcome);
263        assert!(result.is_err());
264        assert!(
265            result
266                .unwrap_err()
267                .to_string()
268                .contains("Relay URL exceeds maximum length")
269        );
270    }
271
272    #[test]
273    fn test_save_welcome_admin_count_validation() {
274        let storage = MdkMemoryStorage::new();
275        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
276
277        // Test with admin count at exactly the limit (should succeed)
278        let event_id = EventId::from_hex(&format!("{:064x}", 1)).unwrap();
279        let welcome = create_welcome_with_admins(
280            mls_group_id.clone(),
281            event_id,
282            DEFAULT_MAX_ADMINS_PER_WELCOME,
283        );
284        assert!(storage.save_welcome(welcome).is_ok());
285
286        // Test with admin count exceeding the limit (should fail)
287        let event_id = EventId::from_hex(&format!("{:064x}", 2)).unwrap();
288        let welcome = create_welcome_with_admins(
289            mls_group_id.clone(),
290            event_id,
291            DEFAULT_MAX_ADMINS_PER_WELCOME + 1,
292        );
293        let result = storage.save_welcome(welcome);
294        assert!(result.is_err());
295        assert!(
296            result
297                .unwrap_err()
298                .to_string()
299                .contains("Welcome admin count exceeds maximum")
300        );
301    }
302
303    #[test]
304    fn test_pending_welcomes_pagination_memory() {
305        let storage = MdkMemoryStorage::new();
306
307        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
308
309        // Create 25 pending welcomes with increasing IDs
310        for i in 0..25 {
311            let event_id = EventId::from_hex(&format!("{:064x}", i + 1)).unwrap();
312            let welcome = create_test_welcome(mls_group_id.clone(), event_id);
313            storage.save_welcome(welcome).unwrap();
314        }
315
316        // Test 1: Get all pending welcomes (should use default limit)
317        let all_welcomes = storage.pending_welcomes(None).unwrap();
318        assert_eq!(all_welcomes.len(), 25);
319
320        // Test 2: Get first 10 welcomes
321        let first_10 = storage
322            .pending_welcomes(Some(Pagination::new(Some(10), Some(0))))
323            .unwrap();
324        assert_eq!(first_10.len(), 10);
325
326        // Test 3: Get next 10 welcomes (offset 10)
327        let next_10 = storage
328            .pending_welcomes(Some(Pagination::new(Some(10), Some(10))))
329            .unwrap();
330        assert_eq!(next_10.len(), 10);
331
332        // Test 4: Get last 5 welcomes (offset 20)
333        let last_5 = storage
334            .pending_welcomes(Some(Pagination::new(Some(10), Some(20))))
335            .unwrap();
336        assert_eq!(last_5.len(), 5);
337
338        // Test 5: Offset beyond available welcomes returns empty
339        let beyond = storage
340            .pending_welcomes(Some(Pagination::new(Some(10), Some(30))))
341            .unwrap();
342        assert_eq!(beyond.len(), 0);
343
344        // Test 6: Verify no overlap between pages
345        let first_id = first_10[0].id;
346        let second_page_ids: Vec<EventId> = next_10.iter().map(|w| w.id).collect();
347        assert!(
348            !second_page_ids.contains(&first_id),
349            "Pages should not overlap"
350        );
351
352        // Test 7: Verify ordering is descending by ID
353        for i in 0..first_10.len() - 1 {
354            assert!(
355                first_10[i].id > first_10[i + 1].id,
356                "Welcomes should be ordered by ID descending"
357            );
358        }
359
360        // Test 8: Limit of 0 should return error
361        let result = storage.pending_welcomes(Some(Pagination::new(Some(0), Some(0))));
362        assert!(result.is_err());
363        assert!(
364            result
365                .unwrap_err()
366                .to_string()
367                .contains("must be between 1 and")
368        );
369
370        // Test 9: Limit exceeding MAX should return error
371        let result = storage.pending_welcomes(Some(Pagination::new(Some(20000), Some(0))));
372        assert!(result.is_err());
373        assert!(
374            result
375                .unwrap_err()
376                .to_string()
377                .contains("must be between 1 and")
378        );
379
380        // Test 10: Large offset should work (no MAX_OFFSET validation)
381        let result = storage.pending_welcomes(Some(Pagination::new(Some(10), Some(2_000_000))));
382        assert!(result.is_ok());
383        assert_eq!(result.unwrap().len(), 0); // No results at that offset
384
385        // Test 11: Empty results when no pending entries
386        let storage2 = MdkMemoryStorage::new();
387        let empty = storage2
388            .pending_welcomes(Some(Pagination::new(Some(10), Some(0))))
389            .unwrap();
390        assert_eq!(empty.len(), 0);
391    }
392
393    /// Test that custom validation limits work correctly for welcomes
394    #[test]
395    fn test_custom_welcome_limits() {
396        use crate::ValidationLimits;
397
398        // Create storage with custom smaller limits
399        let limits = ValidationLimits::default()
400            .with_max_relays_per_welcome(2)
401            .with_max_admins_per_welcome(3)
402            .with_max_relay_url_length(50);
403
404        let storage = MdkMemoryStorage::with_limits(limits);
405
406        // Verify limits are accessible
407        assert_eq!(storage.limits().max_relays_per_welcome, 2);
408        assert_eq!(storage.limits().max_admins_per_welcome, 3);
409        assert_eq!(storage.limits().max_relay_url_length, 50);
410
411        let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
412
413        // Test relay count with custom limit (2 relays should succeed)
414        let event_id = EventId::from_hex(&format!("{:064x}", 1)).unwrap();
415        let welcome = create_welcome_with_relays(mls_group_id.clone(), event_id, 2);
416        assert!(storage.save_welcome(welcome).is_ok());
417
418        // Test relay count exceeding custom limit (3 relays should fail)
419        let event_id = EventId::from_hex(&format!("{:064x}", 2)).unwrap();
420        let welcome = create_welcome_with_relays(mls_group_id.clone(), event_id, 3);
421        let result = storage.save_welcome(welcome);
422        assert!(result.is_err());
423        assert!(result.unwrap_err().to_string().contains("maximum of 2"));
424
425        // Test admin count with custom limit (3 admins should succeed)
426        let event_id = EventId::from_hex(&format!("{:064x}", 3)).unwrap();
427        let welcome = create_welcome_with_admins(mls_group_id.clone(), event_id, 3);
428        assert!(storage.save_welcome(welcome).is_ok());
429
430        // Test admin count exceeding custom limit (4 admins should fail)
431        let event_id = EventId::from_hex(&format!("{:064x}", 4)).unwrap();
432        let welcome = create_welcome_with_admins(mls_group_id.clone(), event_id, 4);
433        let result = storage.save_welcome(welcome);
434        assert!(result.is_err());
435        assert!(result.unwrap_err().to_string().contains("maximum of 3"));
436
437        // Test relay URL length with custom limit (50 bytes)
438        // URL "wss://a{40}.com" = 6 + 40 + 4 = 50 bytes (should succeed)
439        let event_id = EventId::from_hex(&format!("{:064x}", 5)).unwrap();
440        let mut welcome = create_test_welcome(mls_group_id.clone(), event_id);
441        let domain = "a".repeat(40);
442        let url = format!("wss://{}.com", domain);
443        welcome.group_relays = BTreeSet::from([RelayUrl::parse(&url).unwrap()]);
444        assert!(storage.save_welcome(welcome).is_ok());
445
446        // URL exceeding 50 bytes should fail
447        let event_id = EventId::from_hex(&format!("{:064x}", 6)).unwrap();
448        let mut welcome = create_test_welcome(mls_group_id.clone(), event_id);
449        let domain = "a".repeat(45); // 6 + 45 + 4 = 55 bytes
450        let url = format!("wss://{}.com", domain);
451        welcome.group_relays = BTreeSet::from([RelayUrl::parse(&url).unwrap()]);
452        let result = storage.save_welcome(welcome);
453        assert!(result.is_err());
454        assert!(result.unwrap_err().to_string().contains("50 bytes"));
455    }
456}