1use std::collections::BTreeSet;
4
5use crate::{
6 helpers::{BorrowedSqlType, SqlType},
7 StateStore, SupportedDatabase,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::TryStreamExt;
12use matrix_sdk_base::{
13 deserialized_responses::MemberEvent, media::MediaRequest, MinimalRoomMemberEvent, RoomInfo,
14 StateChanges, StoreError,
15};
16use ruma::{
17 events::{
18 presence::PresenceEvent,
19 receipt::Receipt,
20 room::{
21 member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
22 MediaSource,
23 },
24 AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
25 AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
26 },
27 receipt::ReceiptType,
28 serde::Raw,
29 EventId, MxcUri, OwnedEventId, OwnedUserId, RoomId, UserId,
30};
31use sqlx::{
32 database::HasArguments, types::Json, ColumnIndex, Database, Executor, IntoArguments, Row,
33 Transaction,
34};
35
36impl<DB: SupportedDatabase> StateStore<DB>
37where
38 for<'a> <DB as HasArguments<'a>>::Arguments: IntoArguments<'a, DB>,
39 for<'c> &'c mut <DB as sqlx::Database>::Connection: Executor<'c, Database = DB>,
40 for<'a, 'c> &'c mut Transaction<'a, DB>: Executor<'c, Database = DB>,
41 for<'a> &'a [u8]: BorrowedSqlType<'a, DB>,
42 for<'a> &'a str: BorrowedSqlType<'a, DB>,
43 Vec<u8>: SqlType<DB>,
44 Option<String>: SqlType<DB>,
45 String: SqlType<DB>,
46 Json<Raw<AnyGlobalAccountDataEvent>>: SqlType<DB>,
47 Json<Raw<PresenceEvent>>: SqlType<DB>,
48 Json<SyncRoomMemberEvent>: SqlType<DB>,
49 Json<MinimalRoomMemberEvent>: SqlType<DB>,
50 bool: SqlType<DB>,
51 Json<Raw<AnySyncStateEvent>>: SqlType<DB>,
52 Json<Raw<AnyRoomAccountDataEvent>>: SqlType<DB>,
53 Json<RoomInfo>: SqlType<DB>,
54 Json<Receipt>: SqlType<DB>,
55 Json<Raw<AnyStrippedStateEvent>>: SqlType<DB>,
56 Json<StrippedRoomMemberEvent>: SqlType<DB>,
57 Json<MemberEvent>: SqlType<DB>,
58 for<'a> &'a str: ColumnIndex<<DB as Database>::Row>,
59{
60 pub(crate) async fn set_custom_value(&self, key_ref: &[u8], val: &[u8]) -> Result<()> {
65 let mut key = Vec::with_capacity(7 + key_ref.len());
66 key.extend_from_slice(b"custom:");
67 key.extend_from_slice(key_ref);
68
69 self.insert_kv(&key, val).await
70 }
71
72 pub(crate) async fn get_custom_value(&self, key_ref: &[u8]) -> Result<Option<Vec<u8>>> {
77 let mut key = Vec::with_capacity(7 + key_ref.len());
78 key.extend_from_slice(b"custom:");
79 key.extend_from_slice(key_ref);
80 self.get_kv(&key).await
81 }
82
83 pub(crate) async fn save_filter(&self, name: &str, filter_id: &str) -> Result<()> {
88 let mut key = Vec::with_capacity(7 + name.len());
89 key.extend_from_slice(b"filter:");
90 key.extend_from_slice(name.as_bytes());
91
92 self.insert_kv(&key, filter_id.as_bytes()).await
93 }
94
95 pub(crate) async fn get_filter(&self, name: &str) -> Result<Option<String>> {
100 let mut key = Vec::with_capacity(7 + name.len());
101 key.extend_from_slice(b"filter:");
102 key.extend_from_slice(name.as_bytes());
103 let result = self.get_kv(&key).await?;
104 match result {
105 Some(value) => Ok(Some(String::from_utf8(value)?)),
106 None => Ok(None),
107 }
108 }
109
110 pub(crate) async fn insert_media(&self, url: &MxcUri, media: &[u8]) -> Result<()> {
115 let mut txn = self.db.begin().await?;
116
117 DB::media_insert_query_1()
118 .bind(url.as_str())
119 .bind(media)
120 .execute(&mut txn)
121 .await?;
122 DB::media_insert_query_2().execute(&mut txn).await?;
123
124 txn.commit().await?;
125 Ok(())
126 }
127
128 pub(crate) async fn delete_media(&self, url: &MxcUri) -> Result<()> {
133 DB::media_delete_query()
134 .bind(url.as_str())
135 .execute(&*self.db)
136 .await?;
137 Ok(())
138 }
139
140 pub(crate) async fn get_media(&self, url: &MxcUri) -> Result<Option<Vec<u8>>> {
145 let row = DB::media_load_query()
146 .bind(url.as_str())
147 .fetch_optional(&*self.db)
148 .await?;
149 let row = if let Some(row) = row {
150 row
151 } else {
152 return Ok(None);
153 };
154 Ok(row.try_get("media_data")?)
155 }
156
157 #[must_use]
161 pub(crate) fn extract_media_url(request: &MediaRequest) -> &MxcUri {
162 match request.source {
163 MediaSource::Plain(ref p) => p,
164 MediaSource::Encrypted(ref e) => &e.url,
165 }
166 }
167
168 pub(crate) async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
173 let mut txn = self.db.begin().await?;
174
175 for query in DB::room_remove_queries() {
176 query.bind(room_id.as_str()).execute(&mut txn).await?;
177 }
178
179 txn.commit().await?;
180 Ok(())
181 }
182
183 pub(crate) async fn set_global_account_data<'c>(
188 txn: &mut Transaction<'c, DB>,
189 event_type: &GlobalAccountDataEventType,
190 event_data: Raw<AnyGlobalAccountDataEvent>,
191 ) -> Result<()> {
192 DB::account_data_upsert_query()
193 .bind("")
194 .bind(event_type.to_string())
195 .bind(Json(event_data))
196 .execute(txn)
197 .await?;
198
199 Ok(())
200 }
201
202 pub(crate) async fn get_account_data_event(
207 &self,
208 event_type: GlobalAccountDataEventType,
209 ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
210 let row = DB::account_data_load_query()
211 .bind("")
212 .bind(event_type.to_string())
213 .fetch_optional(&*self.db)
214 .await?;
215 let row = if let Some(row) = row {
216 row
217 } else {
218 return Ok(None);
219 };
220 let row: Json<Raw<AnyGlobalAccountDataEvent>> = row.try_get("account_data")?;
221 Ok(Some(row.0))
222 }
223
224 pub(crate) async fn get_room_account_data_event(
229 &self,
230 room_id: &RoomId,
231 event_type: RoomAccountDataEventType,
232 ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
233 let row = DB::account_data_load_query()
234 .bind(room_id.as_str())
235 .bind(event_type.to_string())
236 .fetch_optional(&*self.db)
237 .await?;
238 let row = if let Some(row) = row {
239 row
240 } else {
241 return Ok(None);
242 };
243 let row: Json<Raw<AnyRoomAccountDataEvent>> = row.try_get("account_data")?;
244 Ok(Some(row.0))
245 }
246
247 pub(crate) async fn set_presence_event<'c>(
252 txn: &mut Transaction<'c, DB>,
253 user_id: &UserId,
254 presence: Raw<PresenceEvent>,
255 ) -> Result<()> {
256 DB::presence_upsert_query()
257 .bind(user_id.as_str())
258 .bind(Json(presence))
259 .execute(txn)
260 .await?;
261 Ok(())
262 }
263
264 pub(crate) async fn get_presence_event(
269 &self,
270 user_id: &UserId,
271 ) -> Result<Option<Raw<PresenceEvent>>> {
272 let row = DB::presence_load_query()
273 .bind(user_id.as_str())
274 .fetch_optional(&*self.db)
275 .await?;
276 let row = if let Some(row) = row {
277 row
278 } else {
279 return Ok(None);
280 };
281 let row: Json<Raw<PresenceEvent>> = row.try_get("presence")?;
282 Ok(Some(row.0))
283 }
284
285 async fn remove_member<'c>(
290 txn: &mut Transaction<'c, DB>,
291 room_id: &RoomId,
292 user_id: &UserId,
293 ) -> Result<()> {
294 DB::member_remove_query()
295 .bind(room_id.as_str())
296 .bind(user_id.as_str())
297 .execute(txn)
298 .await?;
299 Ok(())
300 }
301
302 pub(crate) async fn set_room_membership<'c>(
307 txn: &mut Transaction<'c, DB>,
308 room_id: &RoomId,
309 user_id: &UserId,
310 member_event: SyncRoomMemberEvent,
311 ) -> Result<()> {
312 let displayname = member_event
313 .as_original()
314 .and_then(|v| v.content.displayname.clone());
315 let joined = match member_event.as_original().map(|v| &v.content.membership) {
316 Some(MembershipState::Join) => true,
317 Some(MembershipState::Invite) => false,
318 _ => return Self::remove_member(txn, room_id, user_id).await,
319 };
320 DB::member_upsert_query()
321 .bind(room_id.as_str())
322 .bind(user_id.as_str())
323 .bind(false)
324 .bind(Json(member_event))
325 .bind(displayname)
326 .bind(joined)
327 .execute(txn)
328 .await?;
329 Ok(())
330 }
331
332 pub(crate) async fn set_stripped_room_membership<'c>(
337 txn: &mut Transaction<'c, DB>,
338 room_id: &RoomId,
339 user_id: &UserId,
340 member_event: StrippedRoomMemberEvent,
341 ) -> Result<()> {
342 let displayname = member_event.content.displayname.clone();
343 let joined = match member_event.content.membership {
344 MembershipState::Join => true,
345 MembershipState::Invite => false,
346 _ => return Self::remove_member(txn, room_id, user_id).await,
347 };
348 DB::member_upsert_query()
349 .bind(room_id.as_str())
350 .bind(user_id.as_str())
351 .bind(true)
352 .bind(Json(member_event))
353 .bind(displayname)
354 .bind(joined)
355 .execute(txn)
356 .await?;
357 Ok(())
358 }
359
360 pub(crate) async fn set_room_profile<'c>(
365 txn: &mut Transaction<'c, DB>,
366 room_id: &RoomId,
367 user_id: &UserId,
368 profile: MinimalRoomMemberEvent,
369 ) -> Result<()> {
370 DB::member_profile_upsert_query()
371 .bind(room_id.as_str())
372 .bind(user_id.as_str())
373 .bind(false)
374 .bind(Json(profile))
375 .execute(txn)
376 .await?;
377 Ok(())
378 }
379
380 pub(crate) async fn set_room_state<'c>(
385 txn: &mut Transaction<'c, DB>,
386 room_id: &RoomId,
387 event_type: &StateEventType,
388 state_key: &str,
389 state: Raw<AnySyncStateEvent>,
390 ) -> Result<()> {
391 DB::state_upsert_query()
392 .bind(room_id.as_str())
393 .bind(event_type.to_string())
394 .bind(state_key)
395 .bind(false)
396 .bind(Json(state))
397 .execute(txn)
398 .await?;
399 Ok(())
400 }
401
402 pub(crate) async fn set_stripped_room_state<'c>(
407 txn: &mut Transaction<'c, DB>,
408 room_id: &RoomId,
409 event_type: &StateEventType,
410 state_key: &str,
411 state: Raw<AnyStrippedStateEvent>,
412 ) -> Result<()> {
413 DB::state_upsert_query()
414 .bind(room_id.as_str())
415 .bind(event_type.to_string())
416 .bind(state_key)
417 .bind(true)
418 .bind(Json(state))
419 .execute(txn)
420 .await?;
421 Ok(())
422 }
423
424 pub(crate) async fn set_room_account_data<'c>(
429 txn: &mut Transaction<'c, DB>,
430 room_id: &RoomId,
431 event_type: &RoomAccountDataEventType,
432 event_data: Raw<AnyRoomAccountDataEvent>,
433 ) -> Result<()> {
434 DB::account_data_upsert_query()
435 .bind(room_id.as_str())
436 .bind(event_type.to_string())
437 .bind(Json(event_data))
438 .execute(txn)
439 .await?;
440 Ok(())
441 }
442
443 pub(crate) async fn set_room_info<'c>(
448 txn: &mut Transaction<'c, DB>,
449 room_id: &RoomId,
450 room_info: RoomInfo,
451 ) -> Result<()> {
452 DB::room_upsert_query()
453 .bind(room_id.as_str())
454 .bind(false)
455 .bind(Json(room_info))
456 .execute(txn)
457 .await?;
458 Ok(())
459 }
460
461 pub(crate) async fn set_stripped_room_info<'c>(
466 txn: &mut Transaction<'c, DB>,
467 room_id: &RoomId,
468 room_info: RoomInfo,
469 ) -> Result<()> {
470 DB::room_upsert_query()
471 .bind(room_id.as_str())
472 .bind(true)
473 .bind(Json(room_info))
474 .execute(txn)
475 .await?;
476 Ok(())
477 }
478
479 pub(crate) async fn set_receipt<'c>(
484 txn: &mut Transaction<'c, DB>,
485 room_id: &RoomId,
486 event_id: &EventId,
487 receipt_type: &ReceiptType,
488 user_id: &UserId,
489 receipt: Receipt,
490 ) -> Result<()> {
491 DB::receipt_upsert_query()
492 .bind(room_id.as_str())
493 .bind(event_id.as_str())
494 .bind(receipt_type.as_str())
495 .bind(user_id.as_str())
496 .bind(Json(receipt))
497 .execute(txn)
498 .await?;
499 Ok(())
500 }
501
502 pub(crate) async fn get_state_event(
507 &self,
508 room_id: &RoomId,
509 event_type: StateEventType,
510 state_key: &str,
511 ) -> Result<Option<Raw<AnySyncStateEvent>>> {
512 let row = DB::state_load_query()
513 .bind(room_id.as_str())
514 .bind(event_type.to_string())
515 .bind(state_key)
516 .fetch_optional(&*self.db)
517 .await?;
518 let row = if let Some(row) = row {
519 row
520 } else {
521 return Ok(None);
522 };
523 let row: Json<Raw<AnySyncStateEvent>> = row.try_get("state_event")?;
524 Ok(Some(row.0))
525 }
526
527 pub(crate) async fn get_state_events(
532 &self,
533 room_id: &RoomId,
534 event_type: StateEventType,
535 ) -> Result<Vec<Raw<AnySyncStateEvent>>> {
536 let mut rows = DB::states_load_query()
537 .bind(room_id.as_str())
538 .bind(event_type.to_string())
539 .bind(false)
540 .fetch(&*self.db);
541 let mut result = Vec::new();
542 while let Some(row) = rows.try_next().await? {
543 result.push(
544 row.try_get::<'_, Json<Raw<AnySyncStateEvent>>, _>("state_event")?
545 .0,
546 );
547 }
548 Ok(result)
549 }
550
551 pub(crate) async fn get_profile(
556 &self,
557 room_id: &RoomId,
558 user_id: &UserId,
559 ) -> Result<Option<MinimalRoomMemberEvent>> {
560 let row = DB::profile_load_query()
561 .bind(room_id.as_str())
562 .bind(user_id.as_str())
563 .fetch_optional(&*self.db)
564 .await?;
565 let row = if let Some(row) = row {
566 row
567 } else {
568 return Ok(None);
569 };
570 let row: Json<MinimalRoomMemberEvent> = row.try_get("user_profile")?;
571 Ok(Some(row.0))
572 }
573
574 pub(crate) async fn get_user_ids(&self, room_id: &RoomId) -> Result<Vec<OwnedUserId>> {
579 let mut rows = DB::members_load_query()
580 .bind(room_id.as_str())
581 .fetch(&*self.db);
582 let mut result = Vec::new();
583 while let Some(row) = rows.try_next().await? {
584 result.push(row.try_get::<'_, String, _>("user_id")?.try_into()?);
585 }
586 Ok(result)
587 }
588
589 pub(crate) async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<OwnedUserId>> {
594 let mut rows = DB::members_load_query_with_join_status()
595 .bind(room_id.as_str())
596 .bind(false)
597 .fetch(&*self.db);
598 let mut result = Vec::new();
599 while let Some(row) = rows.try_next().await? {
600 result.push(row.try_get::<'_, String, _>("user_id")?.try_into()?);
601 }
602 Ok(result)
603 }
604
605 pub(crate) async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<OwnedUserId>> {
610 let mut rows = DB::members_load_query_with_join_status()
611 .bind(room_id.as_str())
612 .bind(true)
613 .fetch(&*self.db);
614 let mut result = Vec::new();
615 while let Some(row) = rows.try_next().await? {
616 result.push(row.try_get::<'_, String, _>("user_id")?.try_into()?);
617 }
618 Ok(result)
619 }
620
621 pub(crate) async fn get_member_event(
626 &self,
627 room_id: &RoomId,
628 user_id: &UserId,
629 ) -> Result<Option<MemberEvent>> {
630 let row = DB::member_load_query()
631 .bind(room_id.as_str())
632 .bind(user_id.as_str())
633 .fetch_optional(&*self.db)
634 .await?;
635 let row = if let Some(row) = row {
636 row
637 } else {
638 return Ok(None);
639 };
640 if row.try_get::<'_, bool, _>("is_partial")? {
641 let row: Json<StrippedRoomMemberEvent> = row.try_get("member_event")?;
642 Ok(Some(MemberEvent::Stripped(row.0)))
643 } else {
644 let row: Json<SyncRoomMemberEvent> = row.try_get("member_event")?;
645 Ok(Some(MemberEvent::Sync(row.0)))
646 }
647 }
648
649 async fn get_room_infos_internal(&self, partial: bool) -> Result<Vec<RoomInfo>> {
654 let mut rows = DB::room_info_load_query().bind(partial).fetch(&*self.db);
655 let mut result = Vec::new();
656 while let Some(row) = rows.try_next().await? {
657 result.push((row.try_get::<'_, Json<RoomInfo>, _>("room_info")?).0);
658 }
659 Ok(result)
660 }
661
662 pub(crate) async fn get_room_infos(&self) -> Result<Vec<RoomInfo>> {
667 self.get_room_infos_internal(false).await
668 }
669 pub(crate) async fn get_stripped_room_infos(&self) -> Result<Vec<RoomInfo>> {
674 self.get_room_infos_internal(true).await
675 }
676
677 pub(crate) async fn get_users_with_display_name(
682 &self,
683 room_id: &RoomId,
684 display_name: &str,
685 ) -> Result<BTreeSet<OwnedUserId>> {
686 let mut rows = DB::users_with_display_name_load_query()
687 .bind(room_id.as_ref())
688 .bind(display_name)
689 .fetch(&*self.db);
690 let mut result = BTreeSet::new();
691 while let Some(row) = rows.try_next().await? {
692 result.insert(row.try_get::<'_, String, _>("user_id")?.try_into()?);
693 }
694 Ok(result)
695 }
696
697 pub(crate) async fn get_user_room_receipt_event(
702 &self,
703 room_id: &RoomId,
704 receipt_type: ReceiptType,
705 user_id: &UserId,
706 ) -> Result<Option<(OwnedEventId, Receipt)>> {
707 let row = DB::receipt_load_query()
708 .bind(room_id.as_ref())
709 .bind(receipt_type.as_ref())
710 .bind(user_id.as_ref())
711 .fetch_optional(&*self.db)
712 .await?;
713 let row = if let Some(row) = row {
714 row
715 } else {
716 return Ok(None);
717 };
718 let event_id = row.try_get::<'_, String, _>("event_id")?.try_into()?;
719 let receipt = row.try_get::<'_, Json<Receipt>, _>("receipt")?.0;
720 Ok(Some((event_id, receipt)))
721 }
722
723 pub(crate) async fn get_event_room_receipt_events(
728 &self,
729 room_id: &RoomId,
730 receipt_type: ReceiptType,
731 event_id: &EventId,
732 ) -> Result<Vec<(OwnedUserId, Receipt)>> {
733 let mut rows = DB::event_receipt_load_query()
734 .bind(room_id.as_ref())
735 .bind(receipt_type.as_ref())
736 .bind(event_id.as_ref())
737 .fetch(&*self.db);
738 let mut result = Vec::new();
739 while let Some(row) = rows.try_next().await? {
740 let user_id = row.try_get::<'_, String, _>("user_id")?.try_into()?;
741 let receipt = row.try_get::<'_, Json<Receipt>, _>("receipt")?.0;
742 result.push((user_id, receipt));
743 }
744 Ok(result)
745 }
746
747 #[cfg(test)]
752 async fn save_sync_token_test(&self, token: &str) -> Result<()> {
753 self.insert_kv(b"sync_token", token.as_bytes()).await
754 }
755
756 pub(crate) async fn save_sync_token<'c>(
761 txn: &mut Transaction<'c, DB>,
762 token: &str,
763 ) -> Result<()> {
764 Self::insert_kv_txn(txn, b"sync_token", token.as_bytes()).await
765 }
766
767 pub(crate) async fn get_sync_token(&self) -> Result<Option<String>> {
772 let result = self.get_kv(b"sync_token").await?;
773 match result {
774 Some(value) => Ok(Some(String::from_utf8(value)?)),
775 None => Ok(None),
776 }
777 }
778
779 pub(crate) async fn insert_kv(&self, key: &[u8], value: &[u8]) -> Result<()> {
784 DB::kv_upsert_query()
785 .bind(key)
786 .bind(value)
787 .execute(&*self.db)
788 .await?;
789 Ok(())
790 }
791
792 pub(crate) async fn insert_kv_txn<'c>(
797 txn: &mut Transaction<'c, DB>,
798 key: &[u8],
799 value: &[u8],
800 ) -> Result<()> {
801 DB::kv_upsert_query()
802 .bind(key)
803 .bind(value)
804 .execute(txn)
805 .await?;
806 Ok(())
807 }
808
809 pub(crate) async fn get_kv(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
814 let row = DB::kv_load_query()
815 .bind(key)
816 .fetch_optional(&*self.db)
817 .await?;
818
819 let row = if let Some(row) = row {
820 row
821 } else {
822 return Ok(None);
823 };
824
825 Ok(row.try_get("kv_value")?)
826 }
827
828 pub(crate) async fn save_state_changes_txn<'c>(
833 txn: &mut Transaction<'c, DB>,
834 state_changes: &StateChanges,
835 ) -> Result<()> {
836 if let Some(sync_token) = &state_changes.sync_token {
837 Self::save_sync_token(txn, sync_token).await?;
838 }
839
840 for (event_type, event_data) in &state_changes.account_data {
841 Self::set_global_account_data(txn, event_type, event_data.clone()).await?;
842 }
843
844 for (user_id, presence) in &state_changes.presence {
845 Self::set_presence_event(txn, user_id, presence.clone()).await?;
846 }
847
848 for (room_id, room_info) in &state_changes.room_infos {
849 Self::set_room_info(txn, room_id, room_info.clone()).await?;
850 }
851 for (room_id, room_info) in &state_changes.stripped_room_infos {
852 Self::set_stripped_room_info(txn, room_id, room_info.clone()).await?;
853 }
854
855 for (room_id, members) in &state_changes.members {
856 for (user_id, member_event) in members {
857 Self::set_room_membership(txn, room_id, user_id, member_event.clone()).await?;
858 }
859 }
860
861 for (room_id, members) in &state_changes.stripped_members {
862 for (user_id, member_event) in members {
863 Self::set_stripped_room_membership(txn, room_id, user_id, member_event.clone())
864 .await?;
865 }
866 }
867
868 for (room_id, profiles) in &state_changes.profiles {
869 for (user_id, profile) in profiles {
870 Self::set_room_profile(txn, room_id, user_id, profile.clone()).await?;
871 }
872 }
873
874 for (room_id, state_events) in &state_changes.state {
875 for (event_type, event_data) in state_events {
876 for (state_key, event_data) in event_data {
877 Self::set_room_state(txn, room_id, event_type, state_key, event_data.clone())
878 .await?;
879 }
880 }
881 }
882
883 for (room_id, state_events) in &state_changes.stripped_state {
884 for (event_type, event_data) in state_events {
885 for (state_key, event_data) in event_data {
886 Self::set_stripped_room_state(
887 txn,
888 room_id,
889 event_type,
890 state_key,
891 event_data.clone(),
892 )
893 .await?;
894 }
895 }
896 }
897
898 for (room_id, account_data) in &state_changes.room_account_data {
899 for (event_type, event_data) in account_data {
900 Self::set_room_account_data(txn, room_id, event_type, event_data.clone()).await?;
901 }
902 }
903
904 for (room_id, receipt) in &state_changes.receipts {
905 for (event_id, receipt) in &receipt.0 {
906 for (receipt_type, receipt) in receipt {
907 for (user_id, receipt) in receipt {
908 Self::set_receipt(
909 txn,
910 room_id,
911 event_id,
912 receipt_type,
913 user_id,
914 receipt.clone(),
915 )
916 .await?;
917 }
918 }
919 }
920 }
921
922 Ok(())
923 }
924
925 pub(crate) async fn save_state_changes(&self, state_changes: &StateChanges) -> Result<()> {
930 let mut txn = self.db.begin().await?;
931 Self::save_state_changes_txn(&mut txn, state_changes).await?;
932 txn.commit().await?;
933 Ok(())
934 }
935}
936
937type StoreResult<T> = Result<T, StoreError>;
939
940#[async_trait]
941impl<DB: SupportedDatabase> matrix_sdk_base::StateStore for StateStore<DB>
942where
943 for<'a> <DB as HasArguments<'a>>::Arguments: IntoArguments<'a, DB>,
944 for<'c> &'c mut <DB as sqlx::Database>::Connection: Executor<'c, Database = DB>,
945 for<'a, 'c> &'c mut Transaction<'a, DB>: Executor<'c, Database = DB>,
946 for<'a> &'a [u8]: BorrowedSqlType<'a, DB>,
947 for<'a> &'a str: BorrowedSqlType<'a, DB>,
948 Vec<u8>: SqlType<DB>,
949 Option<String>: SqlType<DB>,
950 String: SqlType<DB>,
951 Json<Raw<AnyGlobalAccountDataEvent>>: SqlType<DB>,
952 Json<Raw<PresenceEvent>>: SqlType<DB>,
953 Json<SyncRoomMemberEvent>: SqlType<DB>,
954 Json<MinimalRoomMemberEvent>: SqlType<DB>,
955 bool: SqlType<DB>,
956 Json<Raw<AnySyncStateEvent>>: SqlType<DB>,
957 Json<Raw<AnyRoomAccountDataEvent>>: SqlType<DB>,
958 Json<RoomInfo>: SqlType<DB>,
959 Json<Receipt>: SqlType<DB>,
960 Json<Raw<AnyStrippedStateEvent>>: SqlType<DB>,
961 Json<StrippedRoomMemberEvent>: SqlType<DB>,
962 Json<MemberEvent>: SqlType<DB>,
963 for<'a> &'a str: ColumnIndex<<DB as Database>::Row>,
964{
965 async fn save_filter(&self, filter_name: &str, filter_id: &str) -> StoreResult<()> {
973 self.save_filter(filter_name, filter_id)
974 .await
975 .map_err(|e| StoreError::Backend(e.into()))
976 }
977
978 async fn save_changes(&self, changes: &StateChanges) -> StoreResult<()> {
980 self.save_state_changes(changes)
981 .await
982 .map_err(|e| StoreError::Backend(e.into()))
983 }
984
985 async fn get_filter(&self, filter_name: &str) -> StoreResult<Option<String>> {
991 self.get_filter(filter_name)
992 .await
993 .map_err(|e| StoreError::Backend(e.into()))
994 }
995
996 async fn get_sync_token(&self) -> StoreResult<Option<String>> {
998 self.get_sync_token()
999 .await
1000 .map_err(|e| StoreError::Backend(e.into()))
1001 }
1002
1003 async fn get_presence_event(
1010 &self,
1011 user_id: &UserId,
1012 ) -> StoreResult<Option<Raw<PresenceEvent>>> {
1013 self.get_presence_event(user_id)
1014 .await
1015 .map_err(|e| StoreError::Backend(e.into()))
1016 }
1017
1018 async fn get_state_event(
1026 &self,
1027 room_id: &RoomId,
1028 event_type: StateEventType,
1029 state_key: &str,
1030 ) -> StoreResult<Option<Raw<AnySyncStateEvent>>> {
1031 self.get_state_event(room_id, event_type, state_key)
1032 .await
1033 .map_err(|e| StoreError::Backend(e.into()))
1034 }
1035
1036 async fn get_state_events(
1044 &self,
1045 room_id: &RoomId,
1046 event_type: StateEventType,
1047 ) -> StoreResult<Vec<Raw<AnySyncStateEvent>>> {
1048 self.get_state_events(room_id, event_type)
1049 .await
1050 .map_err(|e| StoreError::Backend(e.into()))
1051 }
1052
1053 async fn get_profile(
1061 &self,
1062 room_id: &RoomId,
1063 user_id: &UserId,
1064 ) -> StoreResult<Option<MinimalRoomMemberEvent>> {
1065 self.get_profile(room_id, user_id)
1066 .await
1067 .map_err(|e| StoreError::Backend(e.into()))
1068 }
1069
1070 async fn get_member_event(
1078 &self,
1079 room_id: &RoomId,
1080 state_key: &UserId,
1081 ) -> StoreResult<Option<MemberEvent>> {
1082 self.get_member_event(room_id, state_key)
1083 .await
1084 .map_err(|e| StoreError::Backend(e.into()))
1085 }
1086
1087 async fn get_user_ids(&self, room_id: &RoomId) -> StoreResult<Vec<OwnedUserId>> {
1090 self.get_user_ids(room_id)
1091 .await
1092 .map_err(|e| StoreError::Backend(e.into()))
1093 }
1094
1095 async fn get_invited_user_ids(&self, room_id: &RoomId) -> StoreResult<Vec<OwnedUserId>> {
1098 self.get_invited_user_ids(room_id)
1099 .await
1100 .map_err(|e| StoreError::Backend(e.into()))
1101 }
1102
1103 async fn get_joined_user_ids(&self, room_id: &RoomId) -> StoreResult<Vec<OwnedUserId>> {
1106 self.get_joined_user_ids(room_id)
1107 .await
1108 .map_err(|e| StoreError::Backend(e.into()))
1109 }
1110
1111 async fn get_room_infos(&self) -> StoreResult<Vec<RoomInfo>> {
1113 self.get_room_infos()
1114 .await
1115 .map_err(|e| StoreError::Backend(e.into()))
1116 }
1117
1118 async fn get_stripped_room_infos(&self) -> StoreResult<Vec<RoomInfo>> {
1120 self.get_stripped_room_infos()
1121 .await
1122 .map_err(|e| StoreError::Backend(e.into()))
1123 }
1124
1125 async fn get_users_with_display_name(
1134 &self,
1135 room_id: &RoomId,
1136 display_name: &str,
1137 ) -> StoreResult<BTreeSet<OwnedUserId>> {
1138 self.get_users_with_display_name(room_id, display_name)
1139 .await
1140 .map_err(|e| StoreError::Backend(e.into()))
1141 }
1142
1143 async fn get_account_data_event(
1149 &self,
1150 event_type: GlobalAccountDataEventType,
1151 ) -> StoreResult<Option<Raw<AnyGlobalAccountDataEvent>>> {
1152 self.get_account_data_event(event_type)
1153 .await
1154 .map_err(|e| StoreError::Backend(e.into()))
1155 }
1156
1157 async fn get_room_account_data_event(
1167 &self,
1168 room_id: &RoomId,
1169 event_type: RoomAccountDataEventType,
1170 ) -> StoreResult<Option<Raw<AnyRoomAccountDataEvent>>> {
1171 self.get_room_account_data_event(room_id, event_type)
1172 .await
1173 .map_err(|e| StoreError::Backend(e.into()))
1174 }
1175
1176 async fn get_user_room_receipt_event(
1187 &self,
1188 room_id: &RoomId,
1189 receipt_type: ReceiptType,
1190 user_id: &UserId,
1191 ) -> StoreResult<Option<(OwnedEventId, Receipt)>> {
1192 self.get_user_room_receipt_event(room_id, receipt_type, user_id)
1193 .await
1194 .map_err(|e| StoreError::Backend(e.into()))
1195 }
1196
1197 async fn get_event_room_receipt_events(
1209 &self,
1210 room_id: &RoomId,
1211 receipt_type: ReceiptType,
1212 event_id: &EventId,
1213 ) -> StoreResult<Vec<(OwnedUserId, Receipt)>> {
1214 self.get_event_room_receipt_events(room_id, receipt_type, event_id)
1215 .await
1216 .map_err(|e| StoreError::Backend(e.into()))
1217 }
1218
1219 async fn get_custom_value(&self, key: &[u8]) -> StoreResult<Option<Vec<u8>>> {
1225 self.get_custom_value(key)
1226 .await
1227 .map_err(|e| StoreError::Backend(e.into()))
1228 }
1229
1230 async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> StoreResult<Option<Vec<u8>>> {
1238 let old_val = self
1239 .get_custom_value(key)
1240 .await
1241 .map_err(|e| StoreError::Backend(e.into()))?;
1242 self.set_custom_value(key, &value)
1243 .await
1244 .map_err(|e| StoreError::Backend(e.into()))?;
1245 Ok(old_val)
1246 }
1247
1248 async fn add_media_content(&self, request: &MediaRequest, content: Vec<u8>) -> StoreResult<()> {
1256 self.insert_media(Self::extract_media_url(request), &content)
1257 .await
1258 .map_err(|e| StoreError::Backend(e.into()))
1259 }
1260
1261 async fn get_media_content(&self, request: &MediaRequest) -> StoreResult<Option<Vec<u8>>> {
1267 self.get_media(Self::extract_media_url(request))
1268 .await
1269 .map_err(|e| StoreError::Backend(e.into()))
1270 }
1271
1272 async fn remove_media_content(&self, request: &MediaRequest) -> StoreResult<()> {
1278 self.delete_media(Self::extract_media_url(request))
1279 .await
1280 .map_err(|e| StoreError::Backend(e.into()))
1281 }
1282
1283 async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> StoreResult<()> {
1290 self.delete_media(uri)
1291 .await
1292 .map_err(|e| StoreError::Backend(e.into()))
1293 }
1294
1295 async fn remove_room(&self, room_id: &RoomId) -> StoreResult<()> {
1301 self.remove_room(room_id)
1302 .await
1303 .map_err(|e| StoreError::Backend(e.into()))
1304 }
1305}
1306
1307#[cfg(test)]
1308#[allow(unused_imports, unreachable_pub, clippy::unwrap_used)]
1309mod tests {
1310 use crate::{StateStore, SupportedDatabase};
1311 use anyhow::Result;
1312 use ruma::{MxcUri, OwnedMxcUri};
1313 use sqlx::{
1314 database::HasArguments, migrate::Migrate, ColumnIndex, Database, Decode, Encode, Executor,
1315 IntoArguments, Pool, Type,
1316 };
1317 use std::sync::Arc;
1318 #[cfg(feature = "sqlite")]
1319 pub async fn open_sqlite_database() -> Result<StateStore<sqlx::Sqlite>> {
1320 let db = Arc::new(sqlx::SqlitePool::connect("sqlite://:memory:").await?);
1321 let store = StateStore::new(&db).await?;
1322 Ok(store)
1323 }
1324
1325 #[cfg(feature = "postgres")]
1326 async fn open_postgres_database() -> Result<StateStore<sqlx::Postgres>> {
1327 let db = Arc::new(
1328 sqlx::PgPool::connect("postgres://postgres:postgres@localhost:5432/postgres").await?,
1329 );
1330 let store = StateStore::new(&db).await?;
1331 Ok(store)
1332 }
1333
1334 #[cfg(feature = "sqlite")]
1335 #[tokio::test]
1336 async fn test_sqlite_custom_values() {
1337 let store = open_sqlite_database().await.unwrap();
1338 assert_eq!(store.get_custom_value(b"test").await.unwrap(), None);
1339 store.set_custom_value(b"test", b"test").await.unwrap();
1340 assert_eq!(
1341 store.get_custom_value(b"test").await.unwrap(),
1342 Some(b"test".to_vec())
1343 );
1344 store.set_custom_value(b"test2", b"test3").await.unwrap();
1345 assert_eq!(
1346 store.get_custom_value(b"test2").await.unwrap(),
1347 Some(b"test3".to_vec())
1348 );
1349 assert_eq!(
1350 store.get_custom_value(b"test").await.unwrap(),
1351 Some(b"test".to_vec())
1352 );
1353 store.set_custom_value(b"test", b"test4").await.unwrap();
1354 assert_eq!(
1355 store.get_custom_value(b"test").await.unwrap(),
1356 Some(b"test4".to_vec())
1357 );
1358 assert_eq!(
1359 store.get_custom_value(b"test2").await.unwrap(),
1360 Some(b"test3".to_vec())
1361 );
1362 }
1363
1364 #[cfg(feature = "postgres")]
1365 #[tokio::test]
1366 #[cfg_attr(not(feature = "ci"), ignore)]
1367 async fn test_postgres_custom_values() {
1368 let store = open_postgres_database().await.unwrap();
1369 assert_eq!(store.get_custom_value(b"test").await.unwrap(), None);
1370 store.set_custom_value(b"test", b"test").await.unwrap();
1371 assert_eq!(
1372 store.get_custom_value(b"test").await.unwrap(),
1373 Some(b"test".to_vec())
1374 );
1375 store.set_custom_value(b"test2", b"test3").await.unwrap();
1376 assert_eq!(
1377 store.get_custom_value(b"test2").await.unwrap(),
1378 Some(b"test3".to_vec())
1379 );
1380 assert_eq!(
1381 store.get_custom_value(b"test").await.unwrap(),
1382 Some(b"test".to_vec())
1383 );
1384 store.set_custom_value(b"test", b"test4").await.unwrap();
1385 assert_eq!(
1386 store.get_custom_value(b"test").await.unwrap(),
1387 Some(b"test4".to_vec())
1388 );
1389 assert_eq!(
1390 store.get_custom_value(b"test2").await.unwrap(),
1391 Some(b"test3".to_vec())
1392 );
1393 }
1394
1395 #[cfg(feature = "sqlite")]
1396 #[tokio::test]
1397 async fn test_sqlite_filters() {
1398 let store = open_sqlite_database().await.unwrap();
1399 assert_eq!(store.get_filter("test").await.unwrap(), None);
1400 store.save_filter("test", "test").await.unwrap();
1401 assert_eq!(
1402 store.get_filter("test").await.unwrap(),
1403 Some("test".to_owned())
1404 );
1405 store.save_filter("test2", "test3").await.unwrap();
1406 assert_eq!(
1407 store.get_filter("test2").await.unwrap(),
1408 Some("test3".to_owned())
1409 );
1410 assert_eq!(
1411 store.get_filter("test").await.unwrap(),
1412 Some("test".to_owned())
1413 );
1414 store.save_filter("test", "test4").await.unwrap();
1415 assert_eq!(
1416 store.get_filter("test").await.unwrap(),
1417 Some("test4".to_owned())
1418 );
1419 assert_eq!(
1420 store.get_filter("test2").await.unwrap(),
1421 Some("test3".to_owned())
1422 );
1423 }
1424
1425 #[cfg(feature = "postgres")]
1426 #[tokio::test]
1427 #[cfg_attr(not(feature = "ci"), ignore)]
1428 async fn test_postgres_filters() {
1429 let store = open_postgres_database().await.unwrap();
1430 assert_eq!(store.get_filter("test").await.unwrap(), None);
1431 store.save_filter("test", "test").await.unwrap();
1432 assert_eq!(
1433 store.get_filter("test").await.unwrap(),
1434 Some("test".to_owned())
1435 );
1436 store.save_filter("test2", "test3").await.unwrap();
1437 assert_eq!(
1438 store.get_filter("test2").await.unwrap(),
1439 Some("test3".to_owned())
1440 );
1441 assert_eq!(
1442 store.get_filter("test").await.unwrap(),
1443 Some("test".to_owned())
1444 );
1445 store.save_filter("test", "test4").await.unwrap();
1446 assert_eq!(
1447 store.get_filter("test").await.unwrap(),
1448 Some("test4".to_owned())
1449 );
1450 assert_eq!(
1451 store.get_filter("test2").await.unwrap(),
1452 Some("test3".to_owned())
1453 );
1454 }
1455
1456 #[cfg(feature = "sqlite")]
1457 #[tokio::test]
1458 async fn test_sqlite_mediastore() {
1459 let store = open_sqlite_database().await.unwrap();
1460 let entry_0 = <&MxcUri>::from("mxc://localhost:8080/media/0");
1461 let entry_1 = <&MxcUri>::from("mxc://localhost:8080/media/1");
1462
1463 store.insert_media(entry_0, b"media_0").await.unwrap();
1464 store.insert_media(entry_1, b"media_1").await.unwrap();
1465
1466 for entry in 2..101 {
1467 let entry = OwnedMxcUri::from(format!("mxc://localhost:8080/media/{}", entry));
1468 store.insert_media(&entry, b"media_0").await.unwrap();
1469 }
1470
1471 assert_eq!(store.get_media(entry_0).await.unwrap(), None);
1472 assert_eq!(
1473 store.get_media(entry_1).await.unwrap(),
1474 Some(b"media_1".to_vec())
1475 );
1476 }
1477
1478 #[cfg(feature = "postgres")]
1479 #[tokio::test]
1480 #[cfg_attr(not(feature = "ci"), ignore)]
1481 async fn test_postgres_mediastore() {
1482 let store = open_postgres_database().await.unwrap();
1483 let entry_0 = <&MxcUri>::from("mxc://localhost:8080/media/0");
1484 let entry_1 = <&MxcUri>::from("mxc://localhost:8080/media/1");
1485
1486 store.insert_media(entry_0, b"media_0").await.unwrap();
1487 store.insert_media(entry_1, b"media_1").await.unwrap();
1488
1489 for entry in 2..101 {
1490 let entry = OwnedMxcUri::from(format!("mxc://localhost:8080/media/{}", entry));
1491 store.insert_media(&entry, b"media_0").await.unwrap();
1492 }
1493
1494 assert_eq!(store.get_media(entry_0).await.unwrap(), None);
1495 assert_eq!(
1496 store.get_media(entry_1).await.unwrap(),
1497 Some(b"media_1".to_vec())
1498 );
1499 }
1500
1501 #[cfg(feature = "sqlite")]
1502 #[tokio::test]
1503 async fn test_sqlite_sync_token() {
1504 let store = open_sqlite_database().await.unwrap();
1505 assert_eq!(store.get_sync_token().await.unwrap(), None);
1506 store.save_sync_token_test("test").await.unwrap();
1507 assert_eq!(
1508 store.get_sync_token().await.unwrap(),
1509 Some("test".to_owned())
1510 );
1511 }
1512
1513 #[cfg(feature = "postgres")]
1514 #[tokio::test]
1515 #[cfg_attr(not(feature = "ci"), ignore)]
1516 async fn test_postgres_sync_token() {
1517 let store = open_postgres_database().await.unwrap();
1518 assert_eq!(store.get_sync_token().await.unwrap(), None);
1519 store.save_sync_token_test("test").await.unwrap();
1520 assert_eq!(
1521 store.get_sync_token().await.unwrap(),
1522 Some("test".to_owned())
1523 );
1524 }
1525
1526 #[cfg(feature = "sqlite")]
1527 #[tokio::test]
1528 async fn test_sqlite_kv_store() {
1529 let store = open_sqlite_database().await.unwrap();
1530 store.insert_kv(b"key", b"value").await.unwrap();
1531 let value = store.get_kv(b"key").await.unwrap();
1532 assert_eq!(value, Some(b"value".to_vec()));
1533 store.insert_kv(b"key", b"value2").await.unwrap();
1534 let value = store.get_kv(b"key").await.unwrap();
1535 assert_eq!(value, Some(b"value2".to_vec()));
1536 }
1537
1538 #[cfg(feature = "postgres")]
1539 #[tokio::test]
1540 #[cfg_attr(not(feature = "ci"), ignore)]
1541 async fn test_postgres_kv_store() {
1542 let store = open_postgres_database().await.unwrap();
1543 store.insert_kv(b"key", b"value").await.unwrap();
1544 let value = store.get_kv(b"key").await.unwrap();
1545 assert_eq!(value, Some(b"value".to_vec()));
1546 store.insert_kv(b"key", b"value2").await.unwrap();
1547 let value = store.get_kv(b"key").await.unwrap();
1548 assert_eq!(value, Some(b"value2".to_vec()));
1549 }
1550}
1551
1552#[allow(clippy::redundant_pub_crate)]
1553#[cfg(all(test, feature = "postgres", feature = "ci"))]
1554mod postgres_integration_test {
1555 use std::sync::Arc;
1556
1557 use matrix_sdk_base::{statestore_integration_tests, StateStore, StoreError};
1558 use rand::distributions::{Alphanumeric, DistString};
1559 use sqlx::migrate::MigrateDatabase;
1560
1561 use super::StoreResult;
1562 async fn get_store_anyhow() -> anyhow::Result<impl StateStore> {
1563 let name = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
1564 let db_url = format!("postgres://postgres:postgres@localhost:5432/{}", name);
1565 if !sqlx::Postgres::database_exists(&db_url).await? {
1566 sqlx::Postgres::create_database(&db_url).await?;
1567 }
1568 let db = Arc::new(sqlx::PgPool::connect(&db_url).await?);
1569 let store = crate::StateStore::new(&db).await?;
1570 Ok(store)
1571 }
1572 async fn get_store() -> StoreResult<impl StateStore> {
1573 get_store_anyhow()
1574 .await
1575 .map_err(|e| StoreError::Backend(e.into()))
1576 }
1577
1578 statestore_integration_tests! { integration }
1579}
1580
1581#[allow(clippy::redundant_pub_crate)]
1582#[cfg(all(test, feature = "sqlite"))]
1583mod sqlite_integration_test {
1584 use matrix_sdk_base::{statestore_integration_tests, StateStore, StoreError};
1585
1586 use super::StoreResult;
1587 async fn get_store() -> StoreResult<impl StateStore> {
1588 super::tests::open_sqlite_database()
1589 .await
1590 .map_err(|e| StoreError::Backend(e.into()))
1591 }
1592
1593 statestore_integration_tests! { integration }
1594}