1use mdk_storage_traits::welcomes::error::WelcomeError;
4use mdk_storage_traits::welcomes::types::{ProcessedWelcome, Welcome};
5use mdk_storage_traits::welcomes::{MAX_PENDING_WELCOMES_LIMIT, Pagination, WelcomeStorage};
6use nostr::{EventId, JsonUtil};
7use rusqlite::{OptionalExtension, params};
8
9use crate::db::{Hash32, Nonce12};
10use crate::validation::{
11 MAX_ADMIN_PUBKEYS_JSON_SIZE, MAX_EVENT_JSON_SIZE, MAX_GROUP_DESCRIPTION_LENGTH,
12 MAX_GROUP_NAME_LENGTH, MAX_GROUP_RELAYS_JSON_SIZE, validate_size, validate_string_length,
13};
14use crate::{MdkSqliteStorage, db};
15
16#[inline]
17fn into_welcome_err<T>(e: T) -> WelcomeError
18where
19 T: std::error::Error,
20{
21 WelcomeError::DatabaseError(e.to_string())
22}
23
24impl WelcomeStorage for MdkSqliteStorage {
25 fn save_welcome(&self, welcome: Welcome) -> Result<(), WelcomeError> {
26 validate_string_length(&welcome.group_name, MAX_GROUP_NAME_LENGTH, "Group name")
28 .map_err(|e| WelcomeError::InvalidParameters(e.to_string()))?;
29
30 validate_string_length(
31 &welcome.group_description,
32 MAX_GROUP_DESCRIPTION_LENGTH,
33 "Group description",
34 )
35 .map_err(|e| WelcomeError::InvalidParameters(e.to_string()))?;
36
37 let group_admin_pubkeys_json: String = serde_json::to_string(&welcome.group_admin_pubkeys)
39 .map_err(|e| {
40 WelcomeError::DatabaseError(format!("Failed to serialize admin pubkeys: {}", e))
41 })?;
42
43 validate_size(
45 group_admin_pubkeys_json.as_bytes(),
46 MAX_ADMIN_PUBKEYS_JSON_SIZE,
47 "Admin pubkeys JSON",
48 )
49 .map_err(|e| WelcomeError::InvalidParameters(e.to_string()))?;
50
51 let group_relays_json: String =
52 serde_json::to_string(&welcome.group_relays).map_err(|e| {
53 WelcomeError::DatabaseError(format!("Failed to serialize group relays: {}", e))
54 })?;
55
56 validate_size(
58 group_relays_json.as_bytes(),
59 MAX_GROUP_RELAYS_JSON_SIZE,
60 "Group relays JSON",
61 )
62 .map_err(|e| WelcomeError::InvalidParameters(e.to_string()))?;
63
64 let event_json = welcome.event.as_json();
66
67 validate_size(event_json.as_bytes(), MAX_EVENT_JSON_SIZE, "Event JSON")
69 .map_err(|e| WelcomeError::InvalidParameters(e.to_string()))?;
70
71 self.with_connection(|conn| {
72 conn.execute(
73 "INSERT OR REPLACE INTO welcomes
74 (id, event, mls_group_id, nostr_group_id, group_name, group_description, group_image_hash, group_image_key, group_image_nonce,
75 group_admin_pubkeys, group_relays, welcomer, member_count, state, wrapper_event_id)
76 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
77 params![
78 welcome.id.as_bytes(),
79 &event_json,
80 welcome.mls_group_id.as_slice(),
81 &welcome.nostr_group_id,
82 &welcome.group_name,
83 &welcome.group_description,
84 welcome.group_image_hash.map(Hash32::from),
85 welcome.group_image_key.as_ref().map(|k| Hash32::from(**k)),
86 welcome.group_image_nonce.as_ref().map(|n| Nonce12::from(**n)),
87 &group_admin_pubkeys_json,
88 &group_relays_json,
89 welcome.welcomer.as_bytes(),
90 welcome.member_count as u64,
91 welcome.state.as_str(),
92 welcome.wrapper_event_id.as_bytes(),
93 ],
94 )
95 .map_err(into_welcome_err)?;
96
97 Ok(())
98 })
99 }
100
101 fn find_welcome_by_event_id(
102 &self,
103 event_id: &EventId,
104 ) -> Result<Option<Welcome>, WelcomeError> {
105 self.with_connection(|conn| {
106 let mut stmt = conn
107 .prepare("SELECT * FROM welcomes WHERE id = ?")
108 .map_err(into_welcome_err)?;
109
110 stmt.query_row(params![event_id.as_bytes()], db::row_to_welcome)
111 .optional()
112 .map_err(into_welcome_err)
113 })
114 }
115
116 fn pending_welcomes(
117 &self,
118 pagination: Option<Pagination>,
119 ) -> Result<Vec<Welcome>, WelcomeError> {
120 let pagination = pagination.unwrap_or_default();
121 let limit = pagination.limit();
122 let offset = pagination.offset();
123
124 if !(1..=MAX_PENDING_WELCOMES_LIMIT).contains(&limit) {
126 return Err(WelcomeError::InvalidParameters(format!(
127 "Limit must be between 1 and {}, got {}",
128 MAX_PENDING_WELCOMES_LIMIT, limit
129 )));
130 }
131
132 self.with_connection(|conn| {
133 let mut stmt = conn
134 .prepare(
135 "SELECT * FROM welcomes WHERE state = 'pending'
136 ORDER BY id DESC
137 LIMIT ? OFFSET ?",
138 )
139 .map_err(into_welcome_err)?;
140
141 let welcomes_iter = stmt
142 .query_map(params![limit as i64, offset as i64], db::row_to_welcome)
143 .map_err(into_welcome_err)?;
144
145 let mut welcomes: Vec<Welcome> = Vec::new();
146
147 for welcome_result in welcomes_iter {
148 let welcome: Welcome = welcome_result.map_err(into_welcome_err)?;
149 welcomes.push(welcome);
150 }
151
152 Ok(welcomes)
153 })
154 }
155
156 fn save_processed_welcome(
157 &self,
158 processed_welcome: ProcessedWelcome,
159 ) -> Result<(), WelcomeError> {
160 let welcome_event_id: Option<&[u8; 32]> = processed_welcome
162 .welcome_event_id
163 .as_ref()
164 .map(|id| id.as_bytes());
165
166 self.with_connection(|conn| {
167 conn.execute(
168 "INSERT OR REPLACE INTO processed_welcomes
169 (wrapper_event_id, welcome_event_id, processed_at, state, failure_reason)
170 VALUES (?, ?, ?, ?, ?)",
171 params![
172 processed_welcome.wrapper_event_id.as_bytes(),
173 welcome_event_id,
174 processed_welcome.processed_at.as_secs(),
175 processed_welcome.state.as_str(),
176 &processed_welcome.failure_reason
177 ],
178 )
179 .map_err(into_welcome_err)?;
180
181 Ok(())
182 })
183 }
184
185 fn find_processed_welcome_by_event_id(
186 &self,
187 event_id: &EventId,
188 ) -> Result<Option<ProcessedWelcome>, WelcomeError> {
189 self.with_connection(|conn| {
190 let mut stmt = conn
191 .prepare("SELECT * FROM processed_welcomes WHERE wrapper_event_id = ?")
192 .map_err(into_welcome_err)?;
193
194 stmt.query_row(params![event_id.as_bytes()], db::row_to_processed_welcome)
195 .optional()
196 .map_err(into_welcome_err)
197 })
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use std::collections::BTreeSet;
204
205 use mdk_storage_traits::GroupId;
206 use mdk_storage_traits::groups::GroupStorage;
207 use mdk_storage_traits::test_utils::cross_storage::{
208 create_test_group, create_test_processed_welcome, create_test_welcome,
209 };
210 use mdk_storage_traits::welcomes::types::{ProcessedWelcomeState, Welcome, WelcomeState};
211 use nostr::{EventId, Kind, PublicKey, Timestamp, UnsignedEvent};
212
213 use super::*;
214
215 #[test]
216 fn test_save_and_find_welcome() {
217 let storage = MdkSqliteStorage::new_in_memory().unwrap();
218
219 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
221 let group = create_test_group(mls_group_id.clone());
222
223 let result = storage.save_group(group);
225 assert!(result.is_ok(), "{:?}", result);
226
227 let event_id = EventId::all_zeros();
229 let welcome = create_test_welcome(mls_group_id.clone(), event_id);
230
231 let result = storage.save_welcome(welcome.clone());
233 assert!(result.is_ok(), "{:?}", result);
234
235 let found_welcome = storage
237 .find_welcome_by_event_id(&event_id)
238 .unwrap()
239 .unwrap();
240 assert_eq!(found_welcome.id, event_id);
241 assert_eq!(found_welcome.mls_group_id, mls_group_id);
242 assert_eq!(found_welcome.state, welcome.state);
243
244 let pending_welcomes = storage.pending_welcomes(None).unwrap();
246 assert_eq!(pending_welcomes.len(), 1);
247 assert_eq!(pending_welcomes[0].id, event_id);
248 }
249
250 #[test]
251 fn test_processed_welcome() {
252 let storage = MdkSqliteStorage::new_in_memory().unwrap();
253
254 let wrapper_event_id = EventId::all_zeros();
256 let welcome_event_id =
257 EventId::from_hex("1111111111111111111111111111111111111111111111111111111111111111")
258 .unwrap();
259
260 let processed_welcome =
262 create_test_processed_welcome(wrapper_event_id, Some(welcome_event_id));
263
264 let result = storage.save_processed_welcome(processed_welcome.clone());
266 assert!(result.is_ok());
267
268 let found_processed_welcome = storage
270 .find_processed_welcome_by_event_id(&wrapper_event_id)
271 .unwrap()
272 .unwrap();
273 assert_eq!(found_processed_welcome.wrapper_event_id, wrapper_event_id);
274 assert_eq!(
275 found_processed_welcome.welcome_event_id.unwrap(),
276 welcome_event_id
277 );
278 assert_eq!(
279 found_processed_welcome.state,
280 ProcessedWelcomeState::Processed
281 );
282 }
283
284 #[test]
285 fn test_welcome_group_name_length_validation() {
286 let storage = MdkSqliteStorage::new_in_memory().unwrap();
287
288 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
290 let group = create_test_group(mls_group_id.clone());
291 storage.save_group(group).unwrap();
292
293 let oversized_name = "x".repeat(256);
295
296 let event_id = EventId::all_zeros();
297 let pubkey = PublicKey::from_slice(&[1u8; 32]).unwrap();
298 let wrapper_event_id =
299 EventId::from_hex("1111111111111111111111111111111111111111111111111111111111111111")
300 .unwrap();
301
302 let welcome = Welcome {
303 id: event_id,
304 event: UnsignedEvent::new(
305 pubkey,
306 Timestamp::now(),
307 Kind::from(444u16),
308 vec![],
309 "content".to_string(),
310 ),
311 mls_group_id: mls_group_id.clone(),
312 nostr_group_id: [0u8; 32],
313 group_name: oversized_name,
314 group_description: "Test".to_string(),
315 group_image_hash: None,
316 group_image_key: None,
317 group_image_nonce: None,
318 group_admin_pubkeys: BTreeSet::new(),
319 group_relays: BTreeSet::new(),
320 welcomer: pubkey,
321 member_count: 1,
322 state: WelcomeState::Pending,
323 wrapper_event_id,
324 };
325
326 let result = storage.save_welcome(welcome);
328 assert!(result.is_err());
329 assert!(
330 result
331 .unwrap_err()
332 .to_string()
333 .contains("Group name exceeds maximum length")
334 );
335 }
336
337 #[test]
338 fn test_pending_welcomes_pagination() {
339 let storage = MdkSqliteStorage::new_in_memory().unwrap();
340
341 let mls_group_id = GroupId::from_slice(&[1, 2, 3, 4]);
343 let group = create_test_group(mls_group_id.clone());
344 storage.save_group(group).unwrap();
345
346 for i in 0..25 {
348 let event_id = EventId::from_hex(&format!(
349 "{:064x}",
350 i + 1 ))
352 .unwrap();
353 let welcome = create_test_welcome(mls_group_id.clone(), event_id);
354 storage.save_welcome(welcome).unwrap();
355 }
356
357 let all_welcomes = storage.pending_welcomes(None).unwrap();
359 assert_eq!(all_welcomes.len(), 25);
360
361 let first_10 = storage
363 .pending_welcomes(Some(Pagination::new(Some(10), Some(0))))
364 .unwrap();
365 assert_eq!(first_10.len(), 10);
366
367 let next_10 = storage
369 .pending_welcomes(Some(Pagination::new(Some(10), Some(10))))
370 .unwrap();
371 assert_eq!(next_10.len(), 10);
372
373 let last_5 = storage
375 .pending_welcomes(Some(Pagination::new(Some(10), Some(20))))
376 .unwrap();
377 assert_eq!(last_5.len(), 5);
378
379 let beyond = storage
381 .pending_welcomes(Some(Pagination::new(Some(10), Some(30))))
382 .unwrap();
383 assert_eq!(beyond.len(), 0);
384
385 let first_id = first_10[0].id;
387 let second_page_ids: Vec<EventId> = next_10.iter().map(|w| w.id).collect();
388 assert!(!second_page_ids.contains(&first_id));
389
390 let result = storage.pending_welcomes(Some(Pagination::new(Some(0), Some(0))));
392 assert!(result.is_err());
393 assert!(
394 result
395 .unwrap_err()
396 .to_string()
397 .contains("must be between 1 and")
398 );
399
400 let result = storage.pending_welcomes(Some(Pagination::new(Some(20000), Some(0))));
402 assert!(result.is_err());
403 assert!(
404 result
405 .unwrap_err()
406 .to_string()
407 .contains("must be between 1 and")
408 );
409
410 let result = storage.pending_welcomes(Some(Pagination::new(Some(10), Some(2_000_000))));
412 assert!(result.is_ok());
413 assert_eq!(result.unwrap().len(), 0); }
415}