1use std::sync::Arc;
40
41use anyhow::Result;
42
43#[cfg(feature = "e2e-encryption")]
44use cryptostore::CryptostoreData;
45use helpers::{BorrowedSqlType, SqlType};
46use matrix_sdk_base::store::StoreConfig;
47#[cfg(feature = "e2e-encryption")]
48use matrix_sdk_store_encryption::StoreCipher;
49
50mod helpers;
51pub use helpers::SupportedDatabase;
52use matrix_sdk_base::{deserialized_responses::MemberEvent, MinimalRoomMemberEvent, RoomInfo};
53use ruma::{
54 events::{
55 presence::PresenceEvent,
56 receipt::Receipt,
57 room::member::{StrippedRoomMemberEvent, SyncRoomMemberEvent},
58 AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
59 AnySyncStateEvent,
60 },
61 serde::Raw,
62};
63use sqlx::{
64 database::HasArguments, migrate::Migrate, types::Json, ColumnIndex, Database, Executor,
65 IntoArguments, Pool, Transaction,
66};
67
68#[cfg(feature = "e2e-encryption")]
69mod cryptostore;
70mod statestore;
71
72#[allow(single_use_lifetimes)]
74#[derive(Debug)]
75pub struct StateStore<DB: SupportedDatabase> {
76 db: Arc<Pool<DB>>,
78 #[cfg(feature = "e2e-encryption")]
79 cryptostore: Option<CryptostoreData>,
81}
82
83#[allow(single_use_lifetimes)]
84impl<DB: SupportedDatabase> StateStore<DB> {
85 pub async fn new(db: &Arc<Pool<DB>>) -> Result<Self>
90 where
91 <DB as Database>::Connection: Migrate,
92 {
93 let db = Arc::clone(db);
94 let migrator = DB::get_migrator();
95 migrator.run(&*db).await?;
96 #[cfg(not(feature = "e2e-encryption"))]
97 {
98 Ok(Self { db })
99 }
100 #[cfg(feature = "e2e-encryption")]
101 {
102 Ok(Self {
103 db,
104 cryptostore: None,
105 })
106 }
107 }
108
109 #[cfg(feature = "e2e-encryption")]
114 pub(crate) fn ensure_e2e(&self) -> Result<&CryptostoreData> {
115 self.cryptostore
116 .as_ref()
117 .ok_or_else(|| anyhow::anyhow!("Not unlocked"))
118 }
119
120 #[cfg(feature = "e2e-encryption")]
124 pub async fn unlock(&mut self) -> Result<()>
125 where
126 for<'a> <DB as HasArguments<'a>>::Arguments: IntoArguments<'a, DB>,
127 for<'c> &'c mut <DB as sqlx::Database>::Connection: Executor<'c, Database = DB>,
128 for<'c, 'a> &'a mut Transaction<'c, DB>: Executor<'a, Database = DB>,
129 for<'a> &'a [u8]: BorrowedSqlType<'a, DB>,
130 for<'a> &'a str: BorrowedSqlType<'a, DB>,
131 Vec<u8>: SqlType<DB>,
132 String: SqlType<DB>,
133 bool: SqlType<DB>,
134 Vec<u8>: SqlType<DB>,
135 Option<String>: SqlType<DB>,
136 Json<Raw<AnyGlobalAccountDataEvent>>: SqlType<DB>,
137 Json<Raw<PresenceEvent>>: SqlType<DB>,
138 Json<SyncRoomMemberEvent>: SqlType<DB>,
139 Json<MinimalRoomMemberEvent>: SqlType<DB>,
140 Json<Raw<AnySyncStateEvent>>: SqlType<DB>,
141 Json<Raw<AnyRoomAccountDataEvent>>: SqlType<DB>,
142 Json<RoomInfo>: SqlType<DB>,
143 Json<Receipt>: SqlType<DB>,
144 Json<Raw<AnyStrippedStateEvent>>: SqlType<DB>,
145 Json<StrippedRoomMemberEvent>: SqlType<DB>,
146 Json<MemberEvent>: SqlType<DB>,
147 for<'a> &'a str: ColumnIndex<<DB as Database>::Row>,
148 {
149 self.cryptostore = Some(CryptostoreData::new_unencrypted());
150 self.load_tracked_users().await?;
151 Ok(())
152 }
153
154 #[cfg(feature = "e2e-encryption")]
158 pub async fn unlock_with_passphrase(&mut self, passphrase: &str) -> Result<()>
159 where
160 for<'a> <DB as HasArguments<'a>>::Arguments: IntoArguments<'a, DB>,
161 for<'c> &'c mut <DB as sqlx::Database>::Connection: Executor<'c, Database = DB>,
162 for<'c, 'a> &'a mut Transaction<'c, DB>: Executor<'a, Database = DB>,
163 for<'a> &'a [u8]: BorrowedSqlType<'a, DB>,
164 for<'a> &'a str: BorrowedSqlType<'a, DB>,
165 Vec<u8>: SqlType<DB>,
166 String: SqlType<DB>,
167 bool: SqlType<DB>,
168 Vec<u8>: SqlType<DB>,
169 Option<String>: SqlType<DB>,
170 Json<Raw<AnyGlobalAccountDataEvent>>: SqlType<DB>,
171 Json<Raw<PresenceEvent>>: SqlType<DB>,
172 Json<SyncRoomMemberEvent>: SqlType<DB>,
173 Json<MinimalRoomMemberEvent>: SqlType<DB>,
174 Json<Raw<AnySyncStateEvent>>: SqlType<DB>,
175 Json<Raw<AnyRoomAccountDataEvent>>: SqlType<DB>,
176 Json<RoomInfo>: SqlType<DB>,
177 Json<Receipt>: SqlType<DB>,
178 Json<Raw<AnyStrippedStateEvent>>: SqlType<DB>,
179 Json<StrippedRoomMemberEvent>: SqlType<DB>,
180 Json<MemberEvent>: SqlType<DB>,
181 for<'a> &'a str: ColumnIndex<<DB as Database>::Row>,
182 {
183 let cipher_export = self.get_kv(b"cipher").await?;
186 if let Some(cipher) = cipher_export {
187 self.cryptostore = Some(CryptostoreData::new(StoreCipher::import(
188 passphrase, &cipher,
189 )?));
190 } else {
191 let cipher = StoreCipher::new()?;
193 self.insert_kv(b"cipher", &cipher.export(passphrase)?)
194 .await?;
195 self.cryptostore = Some(CryptostoreData::new(cipher));
196 }
197 self.load_tracked_users().await?;
198 Ok(())
199 }
200}
201
202pub async fn store_config<DB: SupportedDatabase>(
208 db: &Arc<Pool<DB>>,
209 passphrase: Option<&str>,
210) -> Result<StoreConfig>
211where
212 <DB as Database>::Connection: Migrate,
213 for<'a> <DB as HasArguments<'a>>::Arguments: IntoArguments<'a, DB>,
214 for<'c> &'c mut <DB as sqlx::Database>::Connection: Executor<'c, Database = DB>,
215 for<'c, 'a> &'a mut Transaction<'c, DB>: Executor<'a, Database = DB>,
216 for<'a> &'a [u8]: BorrowedSqlType<'a, DB>,
217 for<'a> &'a str: BorrowedSqlType<'a, DB>,
218 Vec<u8>: SqlType<DB>,
219 String: SqlType<DB>,
220 bool: SqlType<DB>,
221 Vec<u8>: SqlType<DB>,
222 Option<String>: SqlType<DB>,
223 Json<Raw<AnyGlobalAccountDataEvent>>: SqlType<DB>,
224 Json<Raw<PresenceEvent>>: SqlType<DB>,
225 Json<SyncRoomMemberEvent>: SqlType<DB>,
226 Json<MinimalRoomMemberEvent>: SqlType<DB>,
227 Json<Raw<AnySyncStateEvent>>: SqlType<DB>,
228 Json<Raw<AnyRoomAccountDataEvent>>: SqlType<DB>,
229 Json<RoomInfo>: SqlType<DB>,
230 Json<Receipt>: SqlType<DB>,
231 Json<Raw<AnyStrippedStateEvent>>: SqlType<DB>,
232 Json<StrippedRoomMemberEvent>: SqlType<DB>,
233 Json<MemberEvent>: SqlType<DB>,
234 for<'a> &'a str: ColumnIndex<<DB as Database>::Row>,
235{
236 #[cfg(not(feature = "e2e-encryption"))]
237 {
238 let _ = passphrase;
239 let state_store = StateStore::new(db).await?;
240 Ok(StoreConfig::new().state_store(Box::new(state_store)))
241 }
242 #[cfg(feature = "e2e-encryption")]
243 {
244 let state_store = StateStore::new(db).await?;
245 let mut crypto_store = StateStore::new(db).await?;
246 if let Some(passphrase) = passphrase {
247 crypto_store.unlock_with_passphrase(passphrase).await?;
248 } else {
249 crypto_store.unlock().await?;
250 }
251 Ok(StoreConfig::new()
252 .state_store(Box::new(state_store))
253 .crypto_store(Box::new(crypto_store)))
254 }
255}