matrix_sdk_sql/
statestore.rs

1//! Database code for matrix-sdk-statestore-sql
2
3use std::collections::BTreeSet;
4
5use crate::{
6    helpers::{BorrowedSqlType, SqlType},
7    StateStore, SupportedDatabase,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::TryStreamExt;
12use matrix_sdk_base::{
13    deserialized_responses::MemberEvent, media::MediaRequest, MinimalRoomMemberEvent, RoomInfo,
14    StateChanges, StoreError,
15};
16use ruma::{
17    events::{
18        presence::PresenceEvent,
19        receipt::Receipt,
20        room::{
21            member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
22            MediaSource,
23        },
24        AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
25        AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
26    },
27    receipt::ReceiptType,
28    serde::Raw,
29    EventId, MxcUri, OwnedEventId, OwnedUserId, RoomId, UserId,
30};
31use sqlx::{
32    database::HasArguments, types::Json, ColumnIndex, Database, Executor, IntoArguments, Row,
33    Transaction,
34};
35
36impl<DB: SupportedDatabase> StateStore<DB>
37where
38    for<'a> <DB as HasArguments<'a>>::Arguments: IntoArguments<'a, DB>,
39    for<'c> &'c mut <DB as sqlx::Database>::Connection: Executor<'c, Database = DB>,
40    for<'a, 'c> &'c mut Transaction<'a, DB>: Executor<'c, Database = DB>,
41    for<'a> &'a [u8]: BorrowedSqlType<'a, DB>,
42    for<'a> &'a str: BorrowedSqlType<'a, DB>,
43    Vec<u8>: SqlType<DB>,
44    Option<String>: SqlType<DB>,
45    String: SqlType<DB>,
46    Json<Raw<AnyGlobalAccountDataEvent>>: SqlType<DB>,
47    Json<Raw<PresenceEvent>>: SqlType<DB>,
48    Json<SyncRoomMemberEvent>: SqlType<DB>,
49    Json<MinimalRoomMemberEvent>: SqlType<DB>,
50    bool: SqlType<DB>,
51    Json<Raw<AnySyncStateEvent>>: SqlType<DB>,
52    Json<Raw<AnyRoomAccountDataEvent>>: SqlType<DB>,
53    Json<RoomInfo>: SqlType<DB>,
54    Json<Receipt>: SqlType<DB>,
55    Json<Raw<AnyStrippedStateEvent>>: SqlType<DB>,
56    Json<StrippedRoomMemberEvent>: SqlType<DB>,
57    Json<MemberEvent>: SqlType<DB>,
58    for<'a> &'a str: ColumnIndex<<DB as Database>::Row>,
59{
60    /// Put arbitrary data into the custom store
61    ///
62    /// # Errors
63    /// This function will return an error if the upsert cannot be performed
64    pub(crate) async fn set_custom_value(&self, key_ref: &[u8], val: &[u8]) -> Result<()> {
65        let mut key = Vec::with_capacity(7 + key_ref.len());
66        key.extend_from_slice(b"custom:");
67        key.extend_from_slice(key_ref);
68
69        self.insert_kv(&key, val).await
70    }
71
72    /// Get arbitrary data from the custom store
73    ///
74    /// # Errors
75    /// This function will return an error if the database query fails
76    pub(crate) async fn get_custom_value(&self, key_ref: &[u8]) -> Result<Option<Vec<u8>>> {
77        let mut key = Vec::with_capacity(7 + key_ref.len());
78        key.extend_from_slice(b"custom:");
79        key.extend_from_slice(key_ref);
80        self.get_kv(&key).await
81    }
82
83    /// Save the given filter id under the given name
84    ///
85    /// # Errors
86    /// This function will return an error if the upsert cannot be performed
87    pub(crate) async fn save_filter(&self, name: &str, filter_id: &str) -> Result<()> {
88        let mut key = Vec::with_capacity(7 + name.len());
89        key.extend_from_slice(b"filter:");
90        key.extend_from_slice(name.as_bytes());
91
92        self.insert_kv(&key, filter_id.as_bytes()).await
93    }
94
95    /// Get the filter id that was stored under the given filter name.
96    ///
97    /// # Errors
98    /// This function will return an error if the database query fails
99    pub(crate) async fn get_filter(&self, name: &str) -> Result<Option<String>> {
100        let mut key = Vec::with_capacity(7 + name.len());
101        key.extend_from_slice(b"filter:");
102        key.extend_from_slice(name.as_bytes());
103        let result = self.get_kv(&key).await?;
104        match result {
105            Some(value) => Ok(Some(String::from_utf8(value)?)),
106            None => Ok(None),
107        }
108    }
109
110    /// Insert media into the media store
111    ///
112    /// # Errors
113    /// This function will return an error if the media cannot be inserted
114    pub(crate) async fn insert_media(&self, url: &MxcUri, media: &[u8]) -> Result<()> {
115        let mut txn = self.db.begin().await?;
116
117        DB::media_insert_query_1()
118            .bind(url.as_str())
119            .bind(media)
120            .execute(&mut txn)
121            .await?;
122        DB::media_insert_query_2().execute(&mut txn).await?;
123
124        txn.commit().await?;
125        Ok(())
126    }
127
128    /// Deletes media from the media store
129    ///
130    /// # Errors
131    /// This function will return an error if the media cannot be deleted
132    pub(crate) async fn delete_media(&self, url: &MxcUri) -> Result<()> {
133        DB::media_delete_query()
134            .bind(url.as_str())
135            .execute(&*self.db)
136            .await?;
137        Ok(())
138    }
139
140    /// Gets media from the media store
141    ///
142    /// # Errors
143    /// This function will return an error if the query fails
144    pub(crate) async fn get_media(&self, url: &MxcUri) -> Result<Option<Vec<u8>>> {
145        let row = DB::media_load_query()
146            .bind(url.as_str())
147            .fetch_optional(&*self.db)
148            .await?;
149        let row = if let Some(row) = row {
150            row
151        } else {
152            return Ok(None);
153        };
154        Ok(row.try_get("media_data")?)
155    }
156
157    /// Extracts an [`MxcUri`] from a media query
158    ///
159    /// [`MxcUri`]: ruma::identifiers::MxcUri
160    #[must_use]
161    pub(crate) fn extract_media_url(request: &MediaRequest) -> &MxcUri {
162        match request.source {
163            MediaSource::Plain(ref p) => p,
164            MediaSource::Encrypted(ref e) => &e.url,
165        }
166    }
167
168    /// Deletes a room from the room store
169    ///
170    /// # Errors
171    /// This function will return an error if the the query fails
172    pub(crate) async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
173        let mut txn = self.db.begin().await?;
174
175        for query in DB::room_remove_queries() {
176            query.bind(room_id.as_str()).execute(&mut txn).await?;
177        }
178
179        txn.commit().await?;
180        Ok(())
181    }
182
183    /// Sets global account data for an account data event
184    ///
185    /// # Errors
186    /// This function will return an error if the the query fails
187    pub(crate) async fn set_global_account_data<'c>(
188        txn: &mut Transaction<'c, DB>,
189        event_type: &GlobalAccountDataEventType,
190        event_data: Raw<AnyGlobalAccountDataEvent>,
191    ) -> Result<()> {
192        DB::account_data_upsert_query()
193            .bind("")
194            .bind(event_type.to_string())
195            .bind(Json(event_data))
196            .execute(txn)
197            .await?;
198
199        Ok(())
200    }
201
202    /// Get global account data for an account data event type
203    ///
204    /// # Errors
205    /// This function will return an error if the the query fails
206    pub(crate) async fn get_account_data_event(
207        &self,
208        event_type: GlobalAccountDataEventType,
209    ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
210        let row = DB::account_data_load_query()
211            .bind("")
212            .bind(event_type.to_string())
213            .fetch_optional(&*self.db)
214            .await?;
215        let row = if let Some(row) = row {
216            row
217        } else {
218            return Ok(None);
219        };
220        let row: Json<Raw<AnyGlobalAccountDataEvent>> = row.try_get("account_data")?;
221        Ok(Some(row.0))
222    }
223
224    /// Get global account data for an account data event type
225    ///
226    /// # Errors
227    /// This function will return an error if the the query fails
228    pub(crate) async fn get_room_account_data_event(
229        &self,
230        room_id: &RoomId,
231        event_type: RoomAccountDataEventType,
232    ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
233        let row = DB::account_data_load_query()
234            .bind(room_id.as_str())
235            .bind(event_type.to_string())
236            .fetch_optional(&*self.db)
237            .await?;
238        let row = if let Some(row) = row {
239            row
240        } else {
241            return Ok(None);
242        };
243        let row: Json<Raw<AnyRoomAccountDataEvent>> = row.try_get("account_data")?;
244        Ok(Some(row.0))
245    }
246
247    /// Sets presence for a user
248    ///
249    /// # Errors
250    /// This function will return an error if the the query fails
251    pub(crate) async fn set_presence_event<'c>(
252        txn: &mut Transaction<'c, DB>,
253        user_id: &UserId,
254        presence: Raw<PresenceEvent>,
255    ) -> Result<()> {
256        DB::presence_upsert_query()
257            .bind(user_id.as_str())
258            .bind(Json(presence))
259            .execute(txn)
260            .await?;
261        Ok(())
262    }
263
264    /// Gets presence for a user
265    ///
266    /// # Errors
267    /// This function will return an error if the the query fails
268    pub(crate) async fn get_presence_event(
269        &self,
270        user_id: &UserId,
271    ) -> Result<Option<Raw<PresenceEvent>>> {
272        let row = DB::presence_load_query()
273            .bind(user_id.as_str())
274            .fetch_optional(&*self.db)
275            .await?;
276        let row = if let Some(row) = row {
277            row
278        } else {
279            return Ok(None);
280        };
281        let row: Json<Raw<PresenceEvent>> = row.try_get("presence")?;
282        Ok(Some(row.0))
283    }
284
285    /// Removes a member from a channel
286    ///
287    /// # Errors
288    /// This function will return an error if the the query fails
289    async fn remove_member<'c>(
290        txn: &mut Transaction<'c, DB>,
291        room_id: &RoomId,
292        user_id: &UserId,
293    ) -> Result<()> {
294        DB::member_remove_query()
295            .bind(room_id.as_str())
296            .bind(user_id.as_str())
297            .execute(txn)
298            .await?;
299        Ok(())
300    }
301
302    /// Stores room membership info for a user
303    ///
304    /// # Errors
305    /// This function will return an error if the the query fails
306    pub(crate) async fn set_room_membership<'c>(
307        txn: &mut Transaction<'c, DB>,
308        room_id: &RoomId,
309        user_id: &UserId,
310        member_event: SyncRoomMemberEvent,
311    ) -> Result<()> {
312        let displayname = member_event
313            .as_original()
314            .and_then(|v| v.content.displayname.clone());
315        let joined = match member_event.as_original().map(|v| &v.content.membership) {
316            Some(MembershipState::Join) => true,
317            Some(MembershipState::Invite) => false,
318            _ => return Self::remove_member(txn, room_id, user_id).await,
319        };
320        DB::member_upsert_query()
321            .bind(room_id.as_str())
322            .bind(user_id.as_str())
323            .bind(false)
324            .bind(Json(member_event))
325            .bind(displayname)
326            .bind(joined)
327            .execute(txn)
328            .await?;
329        Ok(())
330    }
331
332    /// Stores stripped room membership info for a user
333    ///
334    /// # Errors
335    /// This function will return an error if the the query fails
336    pub(crate) async fn set_stripped_room_membership<'c>(
337        txn: &mut Transaction<'c, DB>,
338        room_id: &RoomId,
339        user_id: &UserId,
340        member_event: StrippedRoomMemberEvent,
341    ) -> Result<()> {
342        let displayname = member_event.content.displayname.clone();
343        let joined = match member_event.content.membership {
344            MembershipState::Join => true,
345            MembershipState::Invite => false,
346            _ => return Self::remove_member(txn, room_id, user_id).await,
347        };
348        DB::member_upsert_query()
349            .bind(room_id.as_str())
350            .bind(user_id.as_str())
351            .bind(true)
352            .bind(Json(member_event))
353            .bind(displayname)
354            .bind(joined)
355            .execute(txn)
356            .await?;
357        Ok(())
358    }
359
360    /// Stores user profile in room
361    ///
362    /// # Errors
363    /// This function will return an error if the the query fails
364    pub(crate) async fn set_room_profile<'c>(
365        txn: &mut Transaction<'c, DB>,
366        room_id: &RoomId,
367        user_id: &UserId,
368        profile: MinimalRoomMemberEvent,
369    ) -> Result<()> {
370        DB::member_profile_upsert_query()
371            .bind(room_id.as_str())
372            .bind(user_id.as_str())
373            .bind(false)
374            .bind(Json(profile))
375            .execute(txn)
376            .await?;
377        Ok(())
378    }
379
380    /// Stores a state event for a room
381    ///
382    /// # Errors
383    /// This function will return an error if the the query fails
384    pub(crate) async fn set_room_state<'c>(
385        txn: &mut Transaction<'c, DB>,
386        room_id: &RoomId,
387        event_type: &StateEventType,
388        state_key: &str,
389        state: Raw<AnySyncStateEvent>,
390    ) -> Result<()> {
391        DB::state_upsert_query()
392            .bind(room_id.as_str())
393            .bind(event_type.to_string())
394            .bind(state_key)
395            .bind(false)
396            .bind(Json(state))
397            .execute(txn)
398            .await?;
399        Ok(())
400    }
401
402    /// Stores a stripped state event for a room
403    ///
404    /// # Errors
405    /// This function will return an error if the the query fails
406    pub(crate) async fn set_stripped_room_state<'c>(
407        txn: &mut Transaction<'c, DB>,
408        room_id: &RoomId,
409        event_type: &StateEventType,
410        state_key: &str,
411        state: Raw<AnyStrippedStateEvent>,
412    ) -> Result<()> {
413        DB::state_upsert_query()
414            .bind(room_id.as_str())
415            .bind(event_type.to_string())
416            .bind(state_key)
417            .bind(true)
418            .bind(Json(state))
419            .execute(txn)
420            .await?;
421        Ok(())
422    }
423
424    /// Stores account data for a room
425    ///
426    /// # Errors
427    /// This function will return an error if the the query fails
428    pub(crate) async fn set_room_account_data<'c>(
429        txn: &mut Transaction<'c, DB>,
430        room_id: &RoomId,
431        event_type: &RoomAccountDataEventType,
432        event_data: Raw<AnyRoomAccountDataEvent>,
433    ) -> Result<()> {
434        DB::account_data_upsert_query()
435            .bind(room_id.as_str())
436            .bind(event_type.to_string())
437            .bind(Json(event_data))
438            .execute(txn)
439            .await?;
440        Ok(())
441    }
442
443    /// Stores info for a room
444    ///
445    /// # Errors
446    /// This function will return an error if the the query fails
447    pub(crate) async fn set_room_info<'c>(
448        txn: &mut Transaction<'c, DB>,
449        room_id: &RoomId,
450        room_info: RoomInfo,
451    ) -> Result<()> {
452        DB::room_upsert_query()
453            .bind(room_id.as_str())
454            .bind(false)
455            .bind(Json(room_info))
456            .execute(txn)
457            .await?;
458        Ok(())
459    }
460
461    /// Stores stripped info for a room
462    ///
463    /// # Errors
464    /// This function will return an error if the the query fails
465    pub(crate) async fn set_stripped_room_info<'c>(
466        txn: &mut Transaction<'c, DB>,
467        room_id: &RoomId,
468        room_info: RoomInfo,
469    ) -> Result<()> {
470        DB::room_upsert_query()
471            .bind(room_id.as_str())
472            .bind(true)
473            .bind(Json(room_info))
474            .execute(txn)
475            .await?;
476        Ok(())
477    }
478
479    /// Stores receipt for an event
480    ///
481    /// # Errors
482    /// This function will return an error if the the query fails
483    pub(crate) async fn set_receipt<'c>(
484        txn: &mut Transaction<'c, DB>,
485        room_id: &RoomId,
486        event_id: &EventId,
487        receipt_type: &ReceiptType,
488        user_id: &UserId,
489        receipt: Receipt,
490    ) -> Result<()> {
491        DB::receipt_upsert_query()
492            .bind(room_id.as_str())
493            .bind(event_id.as_str())
494            .bind(receipt_type.as_str())
495            .bind(user_id.as_str())
496            .bind(Json(receipt))
497            .execute(txn)
498            .await?;
499        Ok(())
500    }
501
502    /// Retrieves a state event in room by event type and state key
503    ///
504    /// # Errors
505    /// This function will return an error if the the query fails
506    pub(crate) async fn get_state_event(
507        &self,
508        room_id: &RoomId,
509        event_type: StateEventType,
510        state_key: &str,
511    ) -> Result<Option<Raw<AnySyncStateEvent>>> {
512        let row = DB::state_load_query()
513            .bind(room_id.as_str())
514            .bind(event_type.to_string())
515            .bind(state_key)
516            .fetch_optional(&*self.db)
517            .await?;
518        let row = if let Some(row) = row {
519            row
520        } else {
521            return Ok(None);
522        };
523        let row: Json<Raw<AnySyncStateEvent>> = row.try_get("state_event")?;
524        Ok(Some(row.0))
525    }
526
527    /// Retrieves all state events of a given type in a room
528    ///
529    /// # Errors
530    /// This function will return an error if the the query fails
531    pub(crate) async fn get_state_events(
532        &self,
533        room_id: &RoomId,
534        event_type: StateEventType,
535    ) -> Result<Vec<Raw<AnySyncStateEvent>>> {
536        let mut rows = DB::states_load_query()
537            .bind(room_id.as_str())
538            .bind(event_type.to_string())
539            .bind(false)
540            .fetch(&*self.db);
541        let mut result = Vec::new();
542        while let Some(row) = rows.try_next().await? {
543            result.push(
544                row.try_get::<'_, Json<Raw<AnySyncStateEvent>>, _>("state_event")?
545                    .0,
546            );
547        }
548        Ok(result)
549    }
550
551    /// Retrieves the profile of a user in a room
552    ///
553    /// # Errors
554    /// This function will return an error if the the query fails
555    pub(crate) async fn get_profile(
556        &self,
557        room_id: &RoomId,
558        user_id: &UserId,
559    ) -> Result<Option<MinimalRoomMemberEvent>> {
560        let row = DB::profile_load_query()
561            .bind(room_id.as_str())
562            .bind(user_id.as_str())
563            .fetch_optional(&*self.db)
564            .await?;
565        let row = if let Some(row) = row {
566            row
567        } else {
568            return Ok(None);
569        };
570        let row: Json<MinimalRoomMemberEvent> = row.try_get("user_profile")?;
571        Ok(Some(row.0))
572    }
573
574    /// Retrieves a list of user ids in a room
575    ///
576    /// # Errors
577    /// This function will return an error if the the query fails
578    pub(crate) async fn get_user_ids(&self, room_id: &RoomId) -> Result<Vec<OwnedUserId>> {
579        let mut rows = DB::members_load_query()
580            .bind(room_id.as_str())
581            .fetch(&*self.db);
582        let mut result = Vec::new();
583        while let Some(row) = rows.try_next().await? {
584            result.push(row.try_get::<'_, String, _>("user_id")?.try_into()?);
585        }
586        Ok(result)
587    }
588
589    /// Retrieves a list of invited user ids in a room
590    ///
591    /// # Errors
592    /// This function will return an error if the the query fails
593    pub(crate) async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<OwnedUserId>> {
594        let mut rows = DB::members_load_query_with_join_status()
595            .bind(room_id.as_str())
596            .bind(false)
597            .fetch(&*self.db);
598        let mut result = Vec::new();
599        while let Some(row) = rows.try_next().await? {
600            result.push(row.try_get::<'_, String, _>("user_id")?.try_into()?);
601        }
602        Ok(result)
603    }
604
605    /// Retrieves a list of joined user ids in a room
606    ///
607    /// # Errors
608    /// This function will return an error if the the query fails
609    pub(crate) async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<OwnedUserId>> {
610        let mut rows = DB::members_load_query_with_join_status()
611            .bind(room_id.as_str())
612            .bind(true)
613            .fetch(&*self.db);
614        let mut result = Vec::new();
615        while let Some(row) = rows.try_next().await? {
616            result.push(row.try_get::<'_, String, _>("user_id")?.try_into()?);
617        }
618        Ok(result)
619    }
620
621    /// Retrieves a member event for a user in a room
622    ///
623    /// # Errors
624    /// This function will return an error if the the query fails
625    pub(crate) async fn get_member_event(
626        &self,
627        room_id: &RoomId,
628        user_id: &UserId,
629    ) -> Result<Option<MemberEvent>> {
630        let row = DB::member_load_query()
631            .bind(room_id.as_str())
632            .bind(user_id.as_str())
633            .fetch_optional(&*self.db)
634            .await?;
635        let row = if let Some(row) = row {
636            row
637        } else {
638            return Ok(None);
639        };
640        if row.try_get::<'_, bool, _>("is_partial")? {
641            let row: Json<StrippedRoomMemberEvent> = row.try_get("member_event")?;
642            Ok(Some(MemberEvent::Stripped(row.0)))
643        } else {
644            let row: Json<SyncRoomMemberEvent> = row.try_get("member_event")?;
645            Ok(Some(MemberEvent::Sync(row.0)))
646        }
647    }
648
649    /// Get room infos
650    ///
651    /// # Errors
652    /// This function will return an error if the the query fails
653    async fn get_room_infos_internal(&self, partial: bool) -> Result<Vec<RoomInfo>> {
654        let mut rows = DB::room_info_load_query().bind(partial).fetch(&*self.db);
655        let mut result = Vec::new();
656        while let Some(row) = rows.try_next().await? {
657            result.push((row.try_get::<'_, Json<RoomInfo>, _>("room_info")?).0);
658        }
659        Ok(result)
660    }
661
662    /// Get room infos
663    ///
664    /// # Errors
665    /// This function will return an error if the the query fails
666    pub(crate) async fn get_room_infos(&self) -> Result<Vec<RoomInfo>> {
667        self.get_room_infos_internal(false).await
668    }
669    /// Get partial room infos
670    ///
671    /// # Errors
672    /// This function will return an error if the the query fails
673    pub(crate) async fn get_stripped_room_infos(&self) -> Result<Vec<RoomInfo>> {
674        self.get_room_infos_internal(true).await
675    }
676
677    /// Get users with display names in room
678    ///
679    /// # Errors
680    /// This function will return an error if the the query fails
681    pub(crate) async fn get_users_with_display_name(
682        &self,
683        room_id: &RoomId,
684        display_name: &str,
685    ) -> Result<BTreeSet<OwnedUserId>> {
686        let mut rows = DB::users_with_display_name_load_query()
687            .bind(room_id.as_ref())
688            .bind(display_name)
689            .fetch(&*self.db);
690        let mut result = BTreeSet::new();
691        while let Some(row) = rows.try_next().await? {
692            result.insert(row.try_get::<'_, String, _>("user_id")?.try_into()?);
693        }
694        Ok(result)
695    }
696
697    /// Get latest receipt for user in room
698    ///
699    /// # Errors
700    /// This function will return an error if the the query fails
701    pub(crate) async fn get_user_room_receipt_event(
702        &self,
703        room_id: &RoomId,
704        receipt_type: ReceiptType,
705        user_id: &UserId,
706    ) -> Result<Option<(OwnedEventId, Receipt)>> {
707        let row = DB::receipt_load_query()
708            .bind(room_id.as_ref())
709            .bind(receipt_type.as_ref())
710            .bind(user_id.as_ref())
711            .fetch_optional(&*self.db)
712            .await?;
713        let row = if let Some(row) = row {
714            row
715        } else {
716            return Ok(None);
717        };
718        let event_id = row.try_get::<'_, String, _>("event_id")?.try_into()?;
719        let receipt = row.try_get::<'_, Json<Receipt>, _>("receipt")?.0;
720        Ok(Some((event_id, receipt)))
721    }
722
723    /// Get all receipts for event in room
724    ///
725    /// # Errors
726    /// This function will return an error if the the query fails
727    pub(crate) async fn get_event_room_receipt_events(
728        &self,
729        room_id: &RoomId,
730        receipt_type: ReceiptType,
731        event_id: &EventId,
732    ) -> Result<Vec<(OwnedUserId, Receipt)>> {
733        let mut rows = DB::event_receipt_load_query()
734            .bind(room_id.as_ref())
735            .bind(receipt_type.as_ref())
736            .bind(event_id.as_ref())
737            .fetch(&*self.db);
738        let mut result = Vec::new();
739        while let Some(row) = rows.try_next().await? {
740            let user_id = row.try_get::<'_, String, _>("user_id")?.try_into()?;
741            let receipt = row.try_get::<'_, Json<Receipt>, _>("receipt")?.0;
742            result.push((user_id, receipt));
743        }
744        Ok(result)
745    }
746
747    /// Put a sync token into the sync token store
748    ///
749    /// # Errors
750    /// This function will return an error if the upsert cannot be performed
751    #[cfg(test)]
752    async fn save_sync_token_test(&self, token: &str) -> Result<()> {
753        self.insert_kv(b"sync_token", token.as_bytes()).await
754    }
755
756    /// Put a sync token into the sync token store
757    ///
758    /// # Errors
759    /// This function will return an error if the upsert cannot be performed
760    pub(crate) async fn save_sync_token<'c>(
761        txn: &mut Transaction<'c, DB>,
762        token: &str,
763    ) -> Result<()> {
764        Self::insert_kv_txn(txn, b"sync_token", token.as_bytes()).await
765    }
766
767    /// Get the last stored sync token
768    ///
769    /// # Errors
770    /// This function will return an error if the database query fails
771    pub(crate) async fn get_sync_token(&self) -> Result<Option<String>> {
772        let result = self.get_kv(b"sync_token").await?;
773        match result {
774            Some(value) => Ok(Some(String::from_utf8(value)?)),
775            None => Ok(None),
776        }
777    }
778
779    /// Insert a key-value pair into the kv table
780    ///
781    /// # Errors
782    /// This function will return an error if the upsert cannot be performed
783    pub(crate) async fn insert_kv(&self, key: &[u8], value: &[u8]) -> Result<()> {
784        DB::kv_upsert_query()
785            .bind(key)
786            .bind(value)
787            .execute(&*self.db)
788            .await?;
789        Ok(())
790    }
791
792    /// Insert a key-value pair into the kv table as part of a transaction
793    ///
794    /// # Errors
795    /// This function will return an error if the upsert cannot be performed
796    pub(crate) async fn insert_kv_txn<'c>(
797        txn: &mut Transaction<'c, DB>,
798        key: &[u8],
799        value: &[u8],
800    ) -> Result<()> {
801        DB::kv_upsert_query()
802            .bind(key)
803            .bind(value)
804            .execute(txn)
805            .await?;
806        Ok(())
807    }
808
809    /// Get a value from the kv table
810    ///
811    /// # Errors
812    /// This function will return an error if the database query fails
813    pub(crate) async fn get_kv(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
814        let row = DB::kv_load_query()
815            .bind(key)
816            .fetch_optional(&*self.db)
817            .await?;
818
819        let row = if let Some(row) = row {
820            row
821        } else {
822            return Ok(None);
823        };
824
825        Ok(row.try_get("kv_value")?)
826    }
827
828    /// Save state changes to the database in a transaction
829    ///
830    /// # Errors
831    /// This function will return an error if the database query fails
832    pub(crate) async fn save_state_changes_txn<'c>(
833        txn: &mut Transaction<'c, DB>,
834        state_changes: &StateChanges,
835    ) -> Result<()> {
836        if let Some(sync_token) = &state_changes.sync_token {
837            Self::save_sync_token(txn, sync_token).await?;
838        }
839
840        for (event_type, event_data) in &state_changes.account_data {
841            Self::set_global_account_data(txn, event_type, event_data.clone()).await?;
842        }
843
844        for (user_id, presence) in &state_changes.presence {
845            Self::set_presence_event(txn, user_id, presence.clone()).await?;
846        }
847
848        for (room_id, room_info) in &state_changes.room_infos {
849            Self::set_room_info(txn, room_id, room_info.clone()).await?;
850        }
851        for (room_id, room_info) in &state_changes.stripped_room_infos {
852            Self::set_stripped_room_info(txn, room_id, room_info.clone()).await?;
853        }
854
855        for (room_id, members) in &state_changes.members {
856            for (user_id, member_event) in members {
857                Self::set_room_membership(txn, room_id, user_id, member_event.clone()).await?;
858            }
859        }
860
861        for (room_id, members) in &state_changes.stripped_members {
862            for (user_id, member_event) in members {
863                Self::set_stripped_room_membership(txn, room_id, user_id, member_event.clone())
864                    .await?;
865            }
866        }
867
868        for (room_id, profiles) in &state_changes.profiles {
869            for (user_id, profile) in profiles {
870                Self::set_room_profile(txn, room_id, user_id, profile.clone()).await?;
871            }
872        }
873
874        for (room_id, state_events) in &state_changes.state {
875            for (event_type, event_data) in state_events {
876                for (state_key, event_data) in event_data {
877                    Self::set_room_state(txn, room_id, event_type, state_key, event_data.clone())
878                        .await?;
879                }
880            }
881        }
882
883        for (room_id, state_events) in &state_changes.stripped_state {
884            for (event_type, event_data) in state_events {
885                for (state_key, event_data) in event_data {
886                    Self::set_stripped_room_state(
887                        txn,
888                        room_id,
889                        event_type,
890                        state_key,
891                        event_data.clone(),
892                    )
893                    .await?;
894                }
895            }
896        }
897
898        for (room_id, account_data) in &state_changes.room_account_data {
899            for (event_type, event_data) in account_data {
900                Self::set_room_account_data(txn, room_id, event_type, event_data.clone()).await?;
901            }
902        }
903
904        for (room_id, receipt) in &state_changes.receipts {
905            for (event_id, receipt) in &receipt.0 {
906                for (receipt_type, receipt) in receipt {
907                    for (user_id, receipt) in receipt {
908                        Self::set_receipt(
909                            txn,
910                            room_id,
911                            event_id,
912                            receipt_type,
913                            user_id,
914                            receipt.clone(),
915                        )
916                        .await?;
917                    }
918                }
919            }
920        }
921
922        Ok(())
923    }
924
925    /// Save state changes to the database
926    ///
927    /// # Errors
928    /// This function will return an error if the database query fails
929    pub(crate) async fn save_state_changes(&self, state_changes: &StateChanges) -> Result<()> {
930        let mut txn = self.db.begin().await?;
931        Self::save_state_changes_txn(&mut txn, state_changes).await?;
932        txn.commit().await?;
933        Ok(())
934    }
935}
936
937/// Shorthand for the store error type
938type StoreResult<T> = Result<T, StoreError>;
939
940#[async_trait]
941impl<DB: SupportedDatabase> matrix_sdk_base::StateStore for StateStore<DB>
942where
943    for<'a> <DB as HasArguments<'a>>::Arguments: IntoArguments<'a, DB>,
944    for<'c> &'c mut <DB as sqlx::Database>::Connection: Executor<'c, Database = DB>,
945    for<'a, 'c> &'c mut Transaction<'a, DB>: Executor<'c, Database = DB>,
946    for<'a> &'a [u8]: BorrowedSqlType<'a, DB>,
947    for<'a> &'a str: BorrowedSqlType<'a, DB>,
948    Vec<u8>: SqlType<DB>,
949    Option<String>: SqlType<DB>,
950    String: SqlType<DB>,
951    Json<Raw<AnyGlobalAccountDataEvent>>: SqlType<DB>,
952    Json<Raw<PresenceEvent>>: SqlType<DB>,
953    Json<SyncRoomMemberEvent>: SqlType<DB>,
954    Json<MinimalRoomMemberEvent>: SqlType<DB>,
955    bool: SqlType<DB>,
956    Json<Raw<AnySyncStateEvent>>: SqlType<DB>,
957    Json<Raw<AnyRoomAccountDataEvent>>: SqlType<DB>,
958    Json<RoomInfo>: SqlType<DB>,
959    Json<Receipt>: SqlType<DB>,
960    Json<Raw<AnyStrippedStateEvent>>: SqlType<DB>,
961    Json<StrippedRoomMemberEvent>: SqlType<DB>,
962    Json<MemberEvent>: SqlType<DB>,
963    for<'a> &'a str: ColumnIndex<<DB as Database>::Row>,
964{
965    /// Save the given filter id under the given name.
966    ///
967    /// # Arguments
968    ///
969    /// * `filter_name` - The name that should be used to store the filter id.
970    ///
971    /// * `filter_id` - The filter id that should be stored in the state store.
972    async fn save_filter(&self, filter_name: &str, filter_id: &str) -> StoreResult<()> {
973        self.save_filter(filter_name, filter_id)
974            .await
975            .map_err(|e| StoreError::Backend(e.into()))
976    }
977
978    /// Save the set of state changes in the store.
979    async fn save_changes(&self, changes: &StateChanges) -> StoreResult<()> {
980        self.save_state_changes(changes)
981            .await
982            .map_err(|e| StoreError::Backend(e.into()))
983    }
984
985    /// Get the filter id that was stored under the given filter name.
986    ///
987    /// # Arguments
988    ///
989    /// * `filter_name` - The name that was used to store the filter id.
990    async fn get_filter(&self, filter_name: &str) -> StoreResult<Option<String>> {
991        self.get_filter(filter_name)
992            .await
993            .map_err(|e| StoreError::Backend(e.into()))
994    }
995
996    /// Get the last stored sync token.
997    async fn get_sync_token(&self) -> StoreResult<Option<String>> {
998        self.get_sync_token()
999            .await
1000            .map_err(|e| StoreError::Backend(e.into()))
1001    }
1002
1003    /// Get the stored presence event for the given user.
1004    ///
1005    /// # Arguments
1006    ///
1007    /// * `user_id` - The id of the user for which we wish to fetch the presence
1008    /// event for.
1009    async fn get_presence_event(
1010        &self,
1011        user_id: &UserId,
1012    ) -> StoreResult<Option<Raw<PresenceEvent>>> {
1013        self.get_presence_event(user_id)
1014            .await
1015            .map_err(|e| StoreError::Backend(e.into()))
1016    }
1017
1018    /// Get a state event out of the state store.
1019    ///
1020    /// # Arguments
1021    ///
1022    /// * `room_id` - The id of the room the state event was received for.
1023    ///
1024    /// * `event_type` - The event type of the state event.
1025    async fn get_state_event(
1026        &self,
1027        room_id: &RoomId,
1028        event_type: StateEventType,
1029        state_key: &str,
1030    ) -> StoreResult<Option<Raw<AnySyncStateEvent>>> {
1031        self.get_state_event(room_id, event_type, state_key)
1032            .await
1033            .map_err(|e| StoreError::Backend(e.into()))
1034    }
1035
1036    /// Get a list of state events for a given room and `StateEventType`.
1037    ///
1038    /// # Arguments
1039    ///
1040    /// * `room_id` - The id of the room to find events for.
1041    ///
1042    /// * `event_type` - The event type.
1043    async fn get_state_events(
1044        &self,
1045        room_id: &RoomId,
1046        event_type: StateEventType,
1047    ) -> StoreResult<Vec<Raw<AnySyncStateEvent>>> {
1048        self.get_state_events(room_id, event_type)
1049            .await
1050            .map_err(|e| StoreError::Backend(e.into()))
1051    }
1052
1053    /// Get the current profile for the given user in the given room.
1054    ///
1055    /// # Arguments
1056    ///
1057    /// * `room_id` - The room id the profile is used in.
1058    ///
1059    /// * `user_id` - The id of the user the profile belongs to.
1060    async fn get_profile(
1061        &self,
1062        room_id: &RoomId,
1063        user_id: &UserId,
1064    ) -> StoreResult<Option<MinimalRoomMemberEvent>> {
1065        self.get_profile(room_id, user_id)
1066            .await
1067            .map_err(|e| StoreError::Backend(e.into()))
1068    }
1069
1070    /// Get the `MemberEvent` for the given state key in the given room id.
1071    ///
1072    /// # Arguments
1073    ///
1074    /// * `room_id` - The room id the member event belongs to.
1075    ///
1076    /// * `state_key` - The user id that the member event defines the state for.
1077    async fn get_member_event(
1078        &self,
1079        room_id: &RoomId,
1080        state_key: &UserId,
1081    ) -> StoreResult<Option<MemberEvent>> {
1082        self.get_member_event(room_id, state_key)
1083            .await
1084            .map_err(|e| StoreError::Backend(e.into()))
1085    }
1086
1087    /// Get all the user ids of members for a given room, for stripped and
1088    /// regular rooms alike.
1089    async fn get_user_ids(&self, room_id: &RoomId) -> StoreResult<Vec<OwnedUserId>> {
1090        self.get_user_ids(room_id)
1091            .await
1092            .map_err(|e| StoreError::Backend(e.into()))
1093    }
1094
1095    /// Get all the user ids of members that are in the invited state for a
1096    /// given room, for stripped and regular rooms alike.
1097    async fn get_invited_user_ids(&self, room_id: &RoomId) -> StoreResult<Vec<OwnedUserId>> {
1098        self.get_invited_user_ids(room_id)
1099            .await
1100            .map_err(|e| StoreError::Backend(e.into()))
1101    }
1102
1103    /// Get all the user ids of members that are in the joined state for a
1104    /// given room, for stripped and regular rooms alike.
1105    async fn get_joined_user_ids(&self, room_id: &RoomId) -> StoreResult<Vec<OwnedUserId>> {
1106        self.get_joined_user_ids(room_id)
1107            .await
1108            .map_err(|e| StoreError::Backend(e.into()))
1109    }
1110
1111    /// Get all the pure `RoomInfo`s the store knows about.
1112    async fn get_room_infos(&self) -> StoreResult<Vec<RoomInfo>> {
1113        self.get_room_infos()
1114            .await
1115            .map_err(|e| StoreError::Backend(e.into()))
1116    }
1117
1118    /// Get all the pure `RoomInfo`s the store knows about.
1119    async fn get_stripped_room_infos(&self) -> StoreResult<Vec<RoomInfo>> {
1120        self.get_stripped_room_infos()
1121            .await
1122            .map_err(|e| StoreError::Backend(e.into()))
1123    }
1124
1125    /// Get all the users that use the given display name in the given room.
1126    ///
1127    /// # Arguments
1128    ///
1129    /// * `room_id` - The id of the room for which the display name users should
1130    /// be fetched for.
1131    ///
1132    /// * `display_name` - The display name that the users use.
1133    async fn get_users_with_display_name(
1134        &self,
1135        room_id: &RoomId,
1136        display_name: &str,
1137    ) -> StoreResult<BTreeSet<OwnedUserId>> {
1138        self.get_users_with_display_name(room_id, display_name)
1139            .await
1140            .map_err(|e| StoreError::Backend(e.into()))
1141    }
1142
1143    /// Get an event out of the account data store.
1144    ///
1145    /// # Arguments
1146    ///
1147    /// * `event_type` - The event type of the account data event.
1148    async fn get_account_data_event(
1149        &self,
1150        event_type: GlobalAccountDataEventType,
1151    ) -> StoreResult<Option<Raw<AnyGlobalAccountDataEvent>>> {
1152        self.get_account_data_event(event_type)
1153            .await
1154            .map_err(|e| StoreError::Backend(e.into()))
1155    }
1156
1157    /// Get an event out of the room account data store.
1158    ///
1159    /// # Arguments
1160    ///
1161    /// * `room_id` - The id of the room for which the room account data event
1162    ///   should
1163    /// be fetched.
1164    ///
1165    /// * `event_type` - The event type of the room account data event.
1166    async fn get_room_account_data_event(
1167        &self,
1168        room_id: &RoomId,
1169        event_type: RoomAccountDataEventType,
1170    ) -> StoreResult<Option<Raw<AnyRoomAccountDataEvent>>> {
1171        self.get_room_account_data_event(room_id, event_type)
1172            .await
1173            .map_err(|e| StoreError::Backend(e.into()))
1174    }
1175
1176    /// Get an event out of the user room receipt store.
1177    ///
1178    /// # Arguments
1179    ///
1180    /// * `room_id` - The id of the room for which the receipt should be
1181    ///   fetched.
1182    ///
1183    /// * `receipt_type` - The type of the receipt.
1184    ///
1185    /// * `user_id` - The id of the user for who the receipt should be fetched.
1186    async fn get_user_room_receipt_event(
1187        &self,
1188        room_id: &RoomId,
1189        receipt_type: ReceiptType,
1190        user_id: &UserId,
1191    ) -> StoreResult<Option<(OwnedEventId, Receipt)>> {
1192        self.get_user_room_receipt_event(room_id, receipt_type, user_id)
1193            .await
1194            .map_err(|e| StoreError::Backend(e.into()))
1195    }
1196
1197    /// Get events out of the event room receipt store.
1198    ///
1199    /// # Arguments
1200    ///
1201    /// * `room_id` - The id of the room for which the receipts should be
1202    ///   fetched.
1203    ///
1204    /// * `receipt_type` - The type of the receipts.
1205    ///
1206    /// * `event_id` - The id of the event for which the receipts should be
1207    ///   fetched.
1208    async fn get_event_room_receipt_events(
1209        &self,
1210        room_id: &RoomId,
1211        receipt_type: ReceiptType,
1212        event_id: &EventId,
1213    ) -> StoreResult<Vec<(OwnedUserId, Receipt)>> {
1214        self.get_event_room_receipt_events(room_id, receipt_type, event_id)
1215            .await
1216            .map_err(|e| StoreError::Backend(e.into()))
1217    }
1218
1219    /// Get arbitrary data from the custom store
1220    ///
1221    /// # Arguments
1222    ///
1223    /// * `key` - The key to fetch data for
1224    async fn get_custom_value(&self, key: &[u8]) -> StoreResult<Option<Vec<u8>>> {
1225        self.get_custom_value(key)
1226            .await
1227            .map_err(|e| StoreError::Backend(e.into()))
1228    }
1229
1230    /// Put arbitrary data into the custom store
1231    ///
1232    /// # Arguments
1233    ///
1234    /// * `key` - The key to insert data into
1235    ///
1236    /// * `value` - The value to insert
1237    async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> StoreResult<Option<Vec<u8>>> {
1238        let old_val = self
1239            .get_custom_value(key)
1240            .await
1241            .map_err(|e| StoreError::Backend(e.into()))?;
1242        self.set_custom_value(key, &value)
1243            .await
1244            .map_err(|e| StoreError::Backend(e.into()))?;
1245        Ok(old_val)
1246    }
1247
1248    /// Add a media file's content in the media store.
1249    ///
1250    /// # Arguments
1251    ///
1252    /// * `request` - The `MediaRequest` of the file.
1253    ///
1254    /// * `content` - The content of the file.
1255    async fn add_media_content(&self, request: &MediaRequest, content: Vec<u8>) -> StoreResult<()> {
1256        self.insert_media(Self::extract_media_url(request), &content)
1257            .await
1258            .map_err(|e| StoreError::Backend(e.into()))
1259    }
1260
1261    /// Get a media file's content out of the media store.
1262    ///
1263    /// # Arguments
1264    ///
1265    /// * `request` - The `MediaRequest` of the file.
1266    async fn get_media_content(&self, request: &MediaRequest) -> StoreResult<Option<Vec<u8>>> {
1267        self.get_media(Self::extract_media_url(request))
1268            .await
1269            .map_err(|e| StoreError::Backend(e.into()))
1270    }
1271
1272    /// Removes a media file's content from the media store.
1273    ///
1274    /// # Arguments
1275    ///
1276    /// * `request` - The `MediaRequest` of the file.
1277    async fn remove_media_content(&self, request: &MediaRequest) -> StoreResult<()> {
1278        self.delete_media(Self::extract_media_url(request))
1279            .await
1280            .map_err(|e| StoreError::Backend(e.into()))
1281    }
1282
1283    /// Removes all the media files' content associated to an `MxcUri` from the
1284    /// media store.
1285    ///
1286    /// # Arguments
1287    ///
1288    /// * `uri` - The `MxcUri` of the media files.
1289    async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> StoreResult<()> {
1290        self.delete_media(uri)
1291            .await
1292            .map_err(|e| StoreError::Backend(e.into()))
1293    }
1294
1295    /// Removes a room and all elements associated from the state store.
1296    ///
1297    /// # Arguments
1298    ///
1299    /// * `room_id` - The `RoomId` of the room to delete.
1300    async fn remove_room(&self, room_id: &RoomId) -> StoreResult<()> {
1301        self.remove_room(room_id)
1302            .await
1303            .map_err(|e| StoreError::Backend(e.into()))
1304    }
1305}
1306
1307#[cfg(test)]
1308#[allow(unused_imports, unreachable_pub, clippy::unwrap_used)]
1309mod tests {
1310    use crate::{StateStore, SupportedDatabase};
1311    use anyhow::Result;
1312    use ruma::{MxcUri, OwnedMxcUri};
1313    use sqlx::{
1314        database::HasArguments, migrate::Migrate, ColumnIndex, Database, Decode, Encode, Executor,
1315        IntoArguments, Pool, Type,
1316    };
1317    use std::sync::Arc;
1318    #[cfg(feature = "sqlite")]
1319    pub async fn open_sqlite_database() -> Result<StateStore<sqlx::Sqlite>> {
1320        let db = Arc::new(sqlx::SqlitePool::connect("sqlite://:memory:").await?);
1321        let store = StateStore::new(&db).await?;
1322        Ok(store)
1323    }
1324
1325    #[cfg(feature = "postgres")]
1326    async fn open_postgres_database() -> Result<StateStore<sqlx::Postgres>> {
1327        let db = Arc::new(
1328            sqlx::PgPool::connect("postgres://postgres:postgres@localhost:5432/postgres").await?,
1329        );
1330        let store = StateStore::new(&db).await?;
1331        Ok(store)
1332    }
1333
1334    #[cfg(feature = "sqlite")]
1335    #[tokio::test]
1336    async fn test_sqlite_custom_values() {
1337        let store = open_sqlite_database().await.unwrap();
1338        assert_eq!(store.get_custom_value(b"test").await.unwrap(), None);
1339        store.set_custom_value(b"test", b"test").await.unwrap();
1340        assert_eq!(
1341            store.get_custom_value(b"test").await.unwrap(),
1342            Some(b"test".to_vec())
1343        );
1344        store.set_custom_value(b"test2", b"test3").await.unwrap();
1345        assert_eq!(
1346            store.get_custom_value(b"test2").await.unwrap(),
1347            Some(b"test3".to_vec())
1348        );
1349        assert_eq!(
1350            store.get_custom_value(b"test").await.unwrap(),
1351            Some(b"test".to_vec())
1352        );
1353        store.set_custom_value(b"test", b"test4").await.unwrap();
1354        assert_eq!(
1355            store.get_custom_value(b"test").await.unwrap(),
1356            Some(b"test4".to_vec())
1357        );
1358        assert_eq!(
1359            store.get_custom_value(b"test2").await.unwrap(),
1360            Some(b"test3".to_vec())
1361        );
1362    }
1363
1364    #[cfg(feature = "postgres")]
1365    #[tokio::test]
1366    #[cfg_attr(not(feature = "ci"), ignore)]
1367    async fn test_postgres_custom_values() {
1368        let store = open_postgres_database().await.unwrap();
1369        assert_eq!(store.get_custom_value(b"test").await.unwrap(), None);
1370        store.set_custom_value(b"test", b"test").await.unwrap();
1371        assert_eq!(
1372            store.get_custom_value(b"test").await.unwrap(),
1373            Some(b"test".to_vec())
1374        );
1375        store.set_custom_value(b"test2", b"test3").await.unwrap();
1376        assert_eq!(
1377            store.get_custom_value(b"test2").await.unwrap(),
1378            Some(b"test3".to_vec())
1379        );
1380        assert_eq!(
1381            store.get_custom_value(b"test").await.unwrap(),
1382            Some(b"test".to_vec())
1383        );
1384        store.set_custom_value(b"test", b"test4").await.unwrap();
1385        assert_eq!(
1386            store.get_custom_value(b"test").await.unwrap(),
1387            Some(b"test4".to_vec())
1388        );
1389        assert_eq!(
1390            store.get_custom_value(b"test2").await.unwrap(),
1391            Some(b"test3".to_vec())
1392        );
1393    }
1394
1395    #[cfg(feature = "sqlite")]
1396    #[tokio::test]
1397    async fn test_sqlite_filters() {
1398        let store = open_sqlite_database().await.unwrap();
1399        assert_eq!(store.get_filter("test").await.unwrap(), None);
1400        store.save_filter("test", "test").await.unwrap();
1401        assert_eq!(
1402            store.get_filter("test").await.unwrap(),
1403            Some("test".to_owned())
1404        );
1405        store.save_filter("test2", "test3").await.unwrap();
1406        assert_eq!(
1407            store.get_filter("test2").await.unwrap(),
1408            Some("test3".to_owned())
1409        );
1410        assert_eq!(
1411            store.get_filter("test").await.unwrap(),
1412            Some("test".to_owned())
1413        );
1414        store.save_filter("test", "test4").await.unwrap();
1415        assert_eq!(
1416            store.get_filter("test").await.unwrap(),
1417            Some("test4".to_owned())
1418        );
1419        assert_eq!(
1420            store.get_filter("test2").await.unwrap(),
1421            Some("test3".to_owned())
1422        );
1423    }
1424
1425    #[cfg(feature = "postgres")]
1426    #[tokio::test]
1427    #[cfg_attr(not(feature = "ci"), ignore)]
1428    async fn test_postgres_filters() {
1429        let store = open_postgres_database().await.unwrap();
1430        assert_eq!(store.get_filter("test").await.unwrap(), None);
1431        store.save_filter("test", "test").await.unwrap();
1432        assert_eq!(
1433            store.get_filter("test").await.unwrap(),
1434            Some("test".to_owned())
1435        );
1436        store.save_filter("test2", "test3").await.unwrap();
1437        assert_eq!(
1438            store.get_filter("test2").await.unwrap(),
1439            Some("test3".to_owned())
1440        );
1441        assert_eq!(
1442            store.get_filter("test").await.unwrap(),
1443            Some("test".to_owned())
1444        );
1445        store.save_filter("test", "test4").await.unwrap();
1446        assert_eq!(
1447            store.get_filter("test").await.unwrap(),
1448            Some("test4".to_owned())
1449        );
1450        assert_eq!(
1451            store.get_filter("test2").await.unwrap(),
1452            Some("test3".to_owned())
1453        );
1454    }
1455
1456    #[cfg(feature = "sqlite")]
1457    #[tokio::test]
1458    async fn test_sqlite_mediastore() {
1459        let store = open_sqlite_database().await.unwrap();
1460        let entry_0 = <&MxcUri>::from("mxc://localhost:8080/media/0");
1461        let entry_1 = <&MxcUri>::from("mxc://localhost:8080/media/1");
1462
1463        store.insert_media(entry_0, b"media_0").await.unwrap();
1464        store.insert_media(entry_1, b"media_1").await.unwrap();
1465
1466        for entry in 2..101 {
1467            let entry = OwnedMxcUri::from(format!("mxc://localhost:8080/media/{}", entry));
1468            store.insert_media(&entry, b"media_0").await.unwrap();
1469        }
1470
1471        assert_eq!(store.get_media(entry_0).await.unwrap(), None);
1472        assert_eq!(
1473            store.get_media(entry_1).await.unwrap(),
1474            Some(b"media_1".to_vec())
1475        );
1476    }
1477
1478    #[cfg(feature = "postgres")]
1479    #[tokio::test]
1480    #[cfg_attr(not(feature = "ci"), ignore)]
1481    async fn test_postgres_mediastore() {
1482        let store = open_postgres_database().await.unwrap();
1483        let entry_0 = <&MxcUri>::from("mxc://localhost:8080/media/0");
1484        let entry_1 = <&MxcUri>::from("mxc://localhost:8080/media/1");
1485
1486        store.insert_media(entry_0, b"media_0").await.unwrap();
1487        store.insert_media(entry_1, b"media_1").await.unwrap();
1488
1489        for entry in 2..101 {
1490            let entry = OwnedMxcUri::from(format!("mxc://localhost:8080/media/{}", entry));
1491            store.insert_media(&entry, b"media_0").await.unwrap();
1492        }
1493
1494        assert_eq!(store.get_media(entry_0).await.unwrap(), None);
1495        assert_eq!(
1496            store.get_media(entry_1).await.unwrap(),
1497            Some(b"media_1".to_vec())
1498        );
1499    }
1500
1501    #[cfg(feature = "sqlite")]
1502    #[tokio::test]
1503    async fn test_sqlite_sync_token() {
1504        let store = open_sqlite_database().await.unwrap();
1505        assert_eq!(store.get_sync_token().await.unwrap(), None);
1506        store.save_sync_token_test("test").await.unwrap();
1507        assert_eq!(
1508            store.get_sync_token().await.unwrap(),
1509            Some("test".to_owned())
1510        );
1511    }
1512
1513    #[cfg(feature = "postgres")]
1514    #[tokio::test]
1515    #[cfg_attr(not(feature = "ci"), ignore)]
1516    async fn test_postgres_sync_token() {
1517        let store = open_postgres_database().await.unwrap();
1518        assert_eq!(store.get_sync_token().await.unwrap(), None);
1519        store.save_sync_token_test("test").await.unwrap();
1520        assert_eq!(
1521            store.get_sync_token().await.unwrap(),
1522            Some("test".to_owned())
1523        );
1524    }
1525
1526    #[cfg(feature = "sqlite")]
1527    #[tokio::test]
1528    async fn test_sqlite_kv_store() {
1529        let store = open_sqlite_database().await.unwrap();
1530        store.insert_kv(b"key", b"value").await.unwrap();
1531        let value = store.get_kv(b"key").await.unwrap();
1532        assert_eq!(value, Some(b"value".to_vec()));
1533        store.insert_kv(b"key", b"value2").await.unwrap();
1534        let value = store.get_kv(b"key").await.unwrap();
1535        assert_eq!(value, Some(b"value2".to_vec()));
1536    }
1537
1538    #[cfg(feature = "postgres")]
1539    #[tokio::test]
1540    #[cfg_attr(not(feature = "ci"), ignore)]
1541    async fn test_postgres_kv_store() {
1542        let store = open_postgres_database().await.unwrap();
1543        store.insert_kv(b"key", b"value").await.unwrap();
1544        let value = store.get_kv(b"key").await.unwrap();
1545        assert_eq!(value, Some(b"value".to_vec()));
1546        store.insert_kv(b"key", b"value2").await.unwrap();
1547        let value = store.get_kv(b"key").await.unwrap();
1548        assert_eq!(value, Some(b"value2".to_vec()));
1549    }
1550}
1551
1552#[allow(clippy::redundant_pub_crate)]
1553#[cfg(all(test, feature = "postgres", feature = "ci"))]
1554mod postgres_integration_test {
1555    use std::sync::Arc;
1556
1557    use matrix_sdk_base::{statestore_integration_tests, StateStore, StoreError};
1558    use rand::distributions::{Alphanumeric, DistString};
1559    use sqlx::migrate::MigrateDatabase;
1560
1561    use super::StoreResult;
1562    async fn get_store_anyhow() -> anyhow::Result<impl StateStore> {
1563        let name = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
1564        let db_url = format!("postgres://postgres:postgres@localhost:5432/{}", name);
1565        if !sqlx::Postgres::database_exists(&db_url).await? {
1566            sqlx::Postgres::create_database(&db_url).await?;
1567        }
1568        let db = Arc::new(sqlx::PgPool::connect(&db_url).await?);
1569        let store = crate::StateStore::new(&db).await?;
1570        Ok(store)
1571    }
1572    async fn get_store() -> StoreResult<impl StateStore> {
1573        get_store_anyhow()
1574            .await
1575            .map_err(|e| StoreError::Backend(e.into()))
1576    }
1577
1578    statestore_integration_tests! { integration }
1579}
1580
1581#[allow(clippy::redundant_pub_crate)]
1582#[cfg(all(test, feature = "sqlite"))]
1583mod sqlite_integration_test {
1584    use matrix_sdk_base::{statestore_integration_tests, StateStore, StoreError};
1585
1586    use super::StoreResult;
1587    async fn get_store() -> StoreResult<impl StateStore> {
1588        super::tests::open_sqlite_database()
1589            .await
1590            .map_err(|e| StoreError::Backend(e.into()))
1591    }
1592
1593    statestore_integration_tests! { integration }
1594}