1use mdk_storage_traits::welcomes::error::WelcomeError;
4use mdk_storage_traits::welcomes::types::*;
5use mdk_storage_traits::welcomes::{MAX_PENDING_WELCOMES_LIMIT, Pagination, WelcomeStorage};
6use nostr::EventId;
7
8use crate::MdkMemoryStorage;
9
10impl WelcomeStorage for MdkMemoryStorage {
11 fn save_welcome(&self, welcome: Welcome) -> Result<(), WelcomeError> {
12 if welcome.group_relays.len() > self.limits.max_relays_per_welcome {
14 return Err(WelcomeError::InvalidParameters(format!(
15 "Welcome relay count exceeds maximum of {} (got {})",
16 self.limits.max_relays_per_welcome,
17 welcome.group_relays.len()
18 )));
19 }
20
21 for relay in &welcome.group_relays {
23 if relay.as_str().len() > self.limits.max_relay_url_length {
24 return Err(WelcomeError::InvalidParameters(format!(
25 "Relay URL exceeds maximum length of {} bytes",
26 self.limits.max_relay_url_length
27 )));
28 }
29 }
30
31 if welcome.group_admin_pubkeys.len() > self.limits.max_admins_per_welcome {
33 return Err(WelcomeError::InvalidParameters(format!(
34 "Welcome admin count exceeds maximum of {} (got {})",
35 self.limits.max_admins_per_welcome,
36 welcome.group_admin_pubkeys.len()
37 )));
38 }
39
40 let mut inner = self.inner.write();
41 inner.welcomes_cache.put(welcome.id, welcome);
42
43 Ok(())
44 }
45
46 fn pending_welcomes(
47 &self,
48 pagination: Option<Pagination>,
49 ) -> Result<Vec<Welcome>, WelcomeError> {
50 let pagination = pagination.unwrap_or_default();
51 let limit = pagination.limit();
52 let offset = pagination.offset();
53
54 if !(1..=MAX_PENDING_WELCOMES_LIMIT).contains(&limit) {
56 return Err(WelcomeError::InvalidParameters(format!(
57 "Limit must be between 1 and {}, got {}",
58 MAX_PENDING_WELCOMES_LIMIT, limit
59 )));
60 }
61
62 let inner = self.inner.read();
63 let mut welcomes: Vec<Welcome> = inner
64 .welcomes_cache
65 .iter()
66 .map(|(_, v)| v.clone())
67 .filter(|welcome| welcome.state == WelcomeState::Pending)
68 .collect();
69
70 welcomes.sort_by(|a, b| b.id.cmp(&a.id));
72
73 let welcomes: Vec<Welcome> = welcomes.into_iter().skip(offset).take(limit).collect();
75
76 Ok(welcomes)
77 }
78
79 fn find_welcome_by_event_id(
80 &self,
81 event_id: &EventId,
82 ) -> Result<Option<Welcome>, WelcomeError> {
83 let inner = self.inner.read();
84 Ok(inner.welcomes_cache.peek(event_id).cloned())
85 }
86
87 fn save_processed_welcome(
88 &self,
89 processed_welcome: ProcessedWelcome,
90 ) -> Result<(), WelcomeError> {
91 let mut inner = self.inner.write();
92 inner
93 .processed_welcomes_cache
94 .put(processed_welcome.wrapper_event_id, processed_welcome);
95
96 Ok(())
97 }
98
99 fn find_processed_welcome_by_event_id(
100 &self,
101 event_id: &EventId,
102 ) -> Result<Option<ProcessedWelcome>, WelcomeError> {
103 let inner = self.inner.read();
104 Ok(inner.processed_welcomes_cache.peek(event_id).cloned())
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use std::collections::BTreeSet;
111
112 use mdk_storage_traits::GroupId;
113 use mdk_storage_traits::test_utils::cross_storage::create_test_welcome;
114 use nostr::{EventId, Keys, Kind, PublicKey, RelayUrl, Tags, Timestamp, UnsignedEvent};
115
116 use super::*;
117 use crate::{
118 DEFAULT_MAX_ADMINS_PER_WELCOME, DEFAULT_MAX_RELAY_URL_LENGTH,
119 DEFAULT_MAX_RELAYS_PER_WELCOME,
120 };
121
122 fn create_welcome_with_relays(
123 mls_group_id: GroupId,
124 event_id: EventId,
125 relay_count: usize,
126 ) -> Welcome {
127 let pubkey =
128 PublicKey::parse("npub1a6awmmklxfmspwdv52qq58sk5c07kghwc4v2eaudjx2ju079cdqs2452ys")
129 .unwrap();
130 let created_at = Timestamp::now();
131 let content = "Test welcome content".to_string();
132 let tags = Tags::new();
133
134 let event = UnsignedEvent {
135 id: Some(event_id),
136 pubkey,
137 created_at,
138 kind: Kind::Custom(444),
139 tags,
140 content,
141 };
142
143 let mut relays = BTreeSet::new();
144 for i in 0..relay_count {
145 relays.insert(RelayUrl::parse(&format!("wss://relay{}.example.com", i)).unwrap());
146 }
147
148 Welcome {
149 id: event_id,
150 event,
151 mls_group_id,
152 nostr_group_id: [0u8; 32],
153 group_name: "Test Group".to_string(),
154 group_description: "A test group".to_string(),
155 group_image_hash: None,
156 group_image_key: None,
157 group_image_nonce: None,
158 group_admin_pubkeys: BTreeSet::from([pubkey]),
159 group_relays: relays,
160 welcomer: pubkey,
161 member_count: 1,
162 state: WelcomeState::Pending,
163 wrapper_event_id: event_id,
164 }
165 }
166
167 fn create_welcome_with_admins(
168 mls_group_id: GroupId,
169 event_id: EventId,
170 admin_count: usize,
171 ) -> Welcome {
172 let pubkey =
173 PublicKey::parse("npub1a6awmmklxfmspwdv52qq58sk5c07kghwc4v2eaudjx2ju079cdqs2452ys")
174 .unwrap();
175 let created_at = Timestamp::now();
176 let content = "Test welcome content".to_string();
177 let tags = Tags::new();
178
179 let event = UnsignedEvent {
180 id: Some(event_id),
181 pubkey,
182 created_at,
183 kind: Kind::Custom(444),
184 tags,
185 content,
186 };
187
188 let mut admins = BTreeSet::new();
189 for _ in 0..admin_count {
190 admins.insert(Keys::generate().public_key());
191 }
192
193 Welcome {
194 id: event_id,
195 event,
196 mls_group_id,
197 nostr_group_id: [0u8; 32],
198 group_name: "Test Group".to_string(),
199 group_description: "A test group".to_string(),
200 group_image_hash: None,
201 group_image_key: None,
202 group_image_nonce: None,
203 group_admin_pubkeys: admins,
204 group_relays: BTreeSet::from([RelayUrl::parse("wss://relay.example.com").unwrap()]),
205 welcomer: pubkey,
206 member_count: 1,
207 state: WelcomeState::Pending,
208 wrapper_event_id: event_id,
209 }
210 }
211
212 #[test]
213 fn test_save_welcome_relay_count_validation() {
214 let storage = MdkMemoryStorage::new();
215 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
216
217 let event_id = EventId::from_hex(&format!("{:064x}", 1)).unwrap();
219 let welcome = create_welcome_with_relays(
220 mls_group_id.clone(),
221 event_id,
222 DEFAULT_MAX_RELAYS_PER_WELCOME,
223 );
224 assert!(storage.save_welcome(welcome).is_ok());
225
226 let event_id = EventId::from_hex(&format!("{:064x}", 2)).unwrap();
228 let welcome = create_welcome_with_relays(
229 mls_group_id.clone(),
230 event_id,
231 DEFAULT_MAX_RELAYS_PER_WELCOME + 1,
232 );
233 let result = storage.save_welcome(welcome);
234 assert!(result.is_err());
235 assert!(
236 result
237 .unwrap_err()
238 .to_string()
239 .contains("Welcome relay count exceeds maximum")
240 );
241 }
242
243 #[test]
244 fn test_save_welcome_relay_url_length_validation() {
245 let storage = MdkMemoryStorage::new();
246 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
247 let event_id = EventId::from_hex(&format!("{:064x}", 1)).unwrap();
248
249 let domain = "a".repeat(DEFAULT_MAX_RELAY_URL_LENGTH - 10);
251 let url = format!("wss://{}.com", domain);
252 let mut welcome = create_test_welcome(mls_group_id.clone(), event_id);
253 welcome.group_relays = BTreeSet::from([RelayUrl::parse(&url).unwrap()]);
254 assert!(storage.save_welcome(welcome).is_ok());
255
256 let event_id = EventId::from_hex(&format!("{:064x}", 2)).unwrap();
258 let domain = "a".repeat(DEFAULT_MAX_RELAY_URL_LENGTH);
259 let url = format!("wss://{}.com", domain);
260 let mut welcome = create_test_welcome(mls_group_id.clone(), event_id);
261 welcome.group_relays = BTreeSet::from([RelayUrl::parse(&url).unwrap()]);
262 let result = storage.save_welcome(welcome);
263 assert!(result.is_err());
264 assert!(
265 result
266 .unwrap_err()
267 .to_string()
268 .contains("Relay URL exceeds maximum length")
269 );
270 }
271
272 #[test]
273 fn test_save_welcome_admin_count_validation() {
274 let storage = MdkMemoryStorage::new();
275 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
276
277 let event_id = EventId::from_hex(&format!("{:064x}", 1)).unwrap();
279 let welcome = create_welcome_with_admins(
280 mls_group_id.clone(),
281 event_id,
282 DEFAULT_MAX_ADMINS_PER_WELCOME,
283 );
284 assert!(storage.save_welcome(welcome).is_ok());
285
286 let event_id = EventId::from_hex(&format!("{:064x}", 2)).unwrap();
288 let welcome = create_welcome_with_admins(
289 mls_group_id.clone(),
290 event_id,
291 DEFAULT_MAX_ADMINS_PER_WELCOME + 1,
292 );
293 let result = storage.save_welcome(welcome);
294 assert!(result.is_err());
295 assert!(
296 result
297 .unwrap_err()
298 .to_string()
299 .contains("Welcome admin count exceeds maximum")
300 );
301 }
302
303 #[test]
304 fn test_pending_welcomes_pagination_memory() {
305 let storage = MdkMemoryStorage::new();
306
307 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
308
309 for i in 0..25 {
311 let event_id = EventId::from_hex(&format!("{:064x}", i + 1)).unwrap();
312 let welcome = create_test_welcome(mls_group_id.clone(), event_id);
313 storage.save_welcome(welcome).unwrap();
314 }
315
316 let all_welcomes = storage.pending_welcomes(None).unwrap();
318 assert_eq!(all_welcomes.len(), 25);
319
320 let first_10 = storage
322 .pending_welcomes(Some(Pagination::new(Some(10), Some(0))))
323 .unwrap();
324 assert_eq!(first_10.len(), 10);
325
326 let next_10 = storage
328 .pending_welcomes(Some(Pagination::new(Some(10), Some(10))))
329 .unwrap();
330 assert_eq!(next_10.len(), 10);
331
332 let last_5 = storage
334 .pending_welcomes(Some(Pagination::new(Some(10), Some(20))))
335 .unwrap();
336 assert_eq!(last_5.len(), 5);
337
338 let beyond = storage
340 .pending_welcomes(Some(Pagination::new(Some(10), Some(30))))
341 .unwrap();
342 assert_eq!(beyond.len(), 0);
343
344 let first_id = first_10[0].id;
346 let second_page_ids: Vec<EventId> = next_10.iter().map(|w| w.id).collect();
347 assert!(
348 !second_page_ids.contains(&first_id),
349 "Pages should not overlap"
350 );
351
352 for i in 0..first_10.len() - 1 {
354 assert!(
355 first_10[i].id > first_10[i + 1].id,
356 "Welcomes should be ordered by ID descending"
357 );
358 }
359
360 let result = storage.pending_welcomes(Some(Pagination::new(Some(0), Some(0))));
362 assert!(result.is_err());
363 assert!(
364 result
365 .unwrap_err()
366 .to_string()
367 .contains("must be between 1 and")
368 );
369
370 let result = storage.pending_welcomes(Some(Pagination::new(Some(20000), Some(0))));
372 assert!(result.is_err());
373 assert!(
374 result
375 .unwrap_err()
376 .to_string()
377 .contains("must be between 1 and")
378 );
379
380 let result = storage.pending_welcomes(Some(Pagination::new(Some(10), Some(2_000_000))));
382 assert!(result.is_ok());
383 assert_eq!(result.unwrap().len(), 0); let storage2 = MdkMemoryStorage::new();
387 let empty = storage2
388 .pending_welcomes(Some(Pagination::new(Some(10), Some(0))))
389 .unwrap();
390 assert_eq!(empty.len(), 0);
391 }
392
393 #[test]
395 fn test_custom_welcome_limits() {
396 use crate::ValidationLimits;
397
398 let limits = ValidationLimits::default()
400 .with_max_relays_per_welcome(2)
401 .with_max_admins_per_welcome(3)
402 .with_max_relay_url_length(50);
403
404 let storage = MdkMemoryStorage::with_limits(limits);
405
406 assert_eq!(storage.limits().max_relays_per_welcome, 2);
408 assert_eq!(storage.limits().max_admins_per_welcome, 3);
409 assert_eq!(storage.limits().max_relay_url_length, 50);
410
411 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
412
413 let event_id = EventId::from_hex(&format!("{:064x}", 1)).unwrap();
415 let welcome = create_welcome_with_relays(mls_group_id.clone(), event_id, 2);
416 assert!(storage.save_welcome(welcome).is_ok());
417
418 let event_id = EventId::from_hex(&format!("{:064x}", 2)).unwrap();
420 let welcome = create_welcome_with_relays(mls_group_id.clone(), event_id, 3);
421 let result = storage.save_welcome(welcome);
422 assert!(result.is_err());
423 assert!(result.unwrap_err().to_string().contains("maximum of 2"));
424
425 let event_id = EventId::from_hex(&format!("{:064x}", 3)).unwrap();
427 let welcome = create_welcome_with_admins(mls_group_id.clone(), event_id, 3);
428 assert!(storage.save_welcome(welcome).is_ok());
429
430 let event_id = EventId::from_hex(&format!("{:064x}", 4)).unwrap();
432 let welcome = create_welcome_with_admins(mls_group_id.clone(), event_id, 4);
433 let result = storage.save_welcome(welcome);
434 assert!(result.is_err());
435 assert!(result.unwrap_err().to_string().contains("maximum of 3"));
436
437 let event_id = EventId::from_hex(&format!("{:064x}", 5)).unwrap();
440 let mut welcome = create_test_welcome(mls_group_id.clone(), event_id);
441 let domain = "a".repeat(40);
442 let url = format!("wss://{}.com", domain);
443 welcome.group_relays = BTreeSet::from([RelayUrl::parse(&url).unwrap()]);
444 assert!(storage.save_welcome(welcome).is_ok());
445
446 let event_id = EventId::from_hex(&format!("{:064x}", 6)).unwrap();
448 let mut welcome = create_test_welcome(mls_group_id.clone(), event_id);
449 let domain = "a".repeat(45); let url = format!("wss://{}.com", domain);
451 welcome.group_relays = BTreeSet::from([RelayUrl::parse(&url).unwrap()]);
452 let result = storage.save_welcome(welcome);
453 assert!(result.is_err());
454 assert!(result.unwrap_err().to_string().contains("50 bytes"));
455 }
456}