1use core::fmt;
16use std::{
17 borrow::{Borrow, Cow},
18 cmp::min,
19 iter,
20 ops::Deref,
21};
22
23use async_trait::async_trait;
24use deadpool_sqlite::Object as SqliteAsyncConn;
25use itertools::Itertools;
26use matrix_sdk_store_encryption::StoreCipher;
27use ruma::{serde::Raw, time::SystemTime, OwnedEventId, OwnedRoomId};
28use rusqlite::{limits::Limit, OptionalExtension, Params, Row, Statement, Transaction};
29use serde::{de::DeserializeOwned, Serialize};
30use tracing::{error, warn};
31
32use crate::{
33 error::{Error, Result},
34 OpenStoreError, RuntimeConfig,
35};
36
37#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
38pub(crate) enum Key {
39 Plain(Vec<u8>),
40 Hashed([u8; 32]),
41}
42
43impl Deref for Key {
44 type Target = [u8];
45
46 fn deref(&self) -> &Self::Target {
47 match self {
48 Key::Plain(slice) => slice,
49 Key::Hashed(bytes) => bytes,
50 }
51 }
52}
53
54impl Borrow<[u8]> for Key {
55 fn borrow(&self) -> &[u8] {
56 self.deref()
57 }
58}
59
60impl rusqlite::ToSql for Key {
61 fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
62 self.deref().to_sql()
63 }
64}
65
66#[async_trait]
67pub(crate) trait SqliteAsyncConnExt {
68 async fn execute<P>(
69 &self,
70 sql: impl AsRef<str> + Send + 'static,
71 params: P,
72 ) -> rusqlite::Result<usize>
73 where
74 P: Params + Send + 'static;
75
76 async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()>;
77
78 async fn prepare<T, F>(
79 &self,
80 sql: impl AsRef<str> + Send + 'static,
81 f: F,
82 ) -> rusqlite::Result<T>
83 where
84 T: Send + 'static,
85 F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static;
86
87 async fn query_row<T, P, F>(
88 &self,
89 sql: impl AsRef<str> + Send + 'static,
90 params: P,
91 f: F,
92 ) -> rusqlite::Result<T>
93 where
94 T: Send + 'static,
95 P: Params + Send + 'static,
96 F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static;
97
98 async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
99 where
100 T: Send + 'static,
101 E: From<rusqlite::Error> + Send + 'static,
102 F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static;
103
104 async fn chunk_large_query_over<Query, Res>(
105 &self,
106 mut keys_to_chunk: Vec<Key>,
107 result_capacity: Option<usize>,
108 do_query: Query,
109 ) -> Result<Vec<Res>>
110 where
111 Res: Send + 'static,
112 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
113
114 async fn apply_runtime_config(&self, runtime_config: RuntimeConfig) -> Result<()> {
123 let RuntimeConfig { optimize, cache_size, journal_size_limit } = runtime_config;
124
125 if optimize {
126 self.optimize().await?;
127 }
128
129 self.cache_size(cache_size).await?;
130 self.journal_size_limit(journal_size_limit).await?;
131
132 Ok(())
133 }
134
135 async fn optimize(&self) -> Result<()> {
145 self.execute_batch("PRAGMA optimize = 0x10002;").await?;
146 Ok(())
147 }
148
149 async fn cache_size(&self, cache_size: u32) -> Result<()> {
155 let n = cache_size / 1024;
158
159 self.execute_batch(format!("PRAGMA cache_size = -{n};")).await?;
160 Ok(())
161 }
162
163 async fn journal_size_limit(&self, limit: u32) -> Result<()> {
181 self.execute_batch(format!("PRAGMA journal_size_limit = {limit};")).await?;
182 Ok(())
183 }
184
185 async fn vacuum(&self) -> Result<()> {
189 if let Err(error) = self.execute_batch("VACUUM").await {
190 #[cfg(not(any(test, debug_assertions)))]
193 tracing::warn!("Failed to vacuum database: {error}");
194
195 #[cfg(any(test, debug_assertions))]
197 return Err(error.into());
198 }
199
200 Ok(())
201 }
202}
203
204#[async_trait]
205impl SqliteAsyncConnExt for SqliteAsyncConn {
206 async fn execute<P>(
207 &self,
208 sql: impl AsRef<str> + Send + 'static,
209 params: P,
210 ) -> rusqlite::Result<usize>
211 where
212 P: Params + Send + 'static,
213 {
214 self.interact(move |conn| conn.execute(sql.as_ref(), params)).await.unwrap()
215 }
216
217 async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()> {
218 self.interact(move |conn| conn.execute_batch(sql.as_ref())).await.unwrap()
219 }
220
221 async fn prepare<T, F>(
222 &self,
223 sql: impl AsRef<str> + Send + 'static,
224 f: F,
225 ) -> rusqlite::Result<T>
226 where
227 T: Send + 'static,
228 F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static,
229 {
230 self.interact(move |conn| f(conn.prepare(sql.as_ref())?)).await.unwrap()
231 }
232
233 async fn query_row<T, P, F>(
234 &self,
235 sql: impl AsRef<str> + Send + 'static,
236 params: P,
237 f: F,
238 ) -> rusqlite::Result<T>
239 where
240 T: Send + 'static,
241 P: Params + Send + 'static,
242 F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static,
243 {
244 self.interact(move |conn| conn.query_row(sql.as_ref(), params, f)).await.unwrap()
245 }
246
247 async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
248 where
249 T: Send + 'static,
250 E: From<rusqlite::Error> + Send + 'static,
251 F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static,
252 {
253 self.interact(move |conn| {
254 let txn = conn.transaction()?;
255 let result = f(&txn)?;
256 txn.commit()?;
257 Ok(result)
258 })
259 .await
260 .unwrap()
261 }
262
263 async fn chunk_large_query_over<Query, Res>(
270 &self,
271 keys_to_chunk: Vec<Key>,
272 result_capacity: Option<usize>,
273 do_query: Query,
274 ) -> Result<Vec<Res>>
275 where
276 Res: Send + 'static,
277 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
278 {
279 self.with_transaction(move |txn| {
280 txn.chunk_large_query_over(keys_to_chunk, result_capacity, do_query)
281 })
282 .await
283 }
284}
285
286pub(crate) trait SqliteTransactionExt {
287 fn chunk_large_query_over<Key, Query, Res>(
288 &self,
289 keys_to_chunk: Vec<Key>,
290 result_capacity: Option<usize>,
291 do_query: Query,
292 ) -> Result<Vec<Res>>
293 where
294 Res: Send + 'static,
295 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
296}
297
298impl SqliteTransactionExt for Transaction<'_> {
299 fn chunk_large_query_over<Key, Query, Res>(
300 &self,
301 mut keys_to_chunk: Vec<Key>,
302 result_capacity: Option<usize>,
303 do_query: Query,
304 ) -> Result<Vec<Res>>
305 where
306 Res: Send + 'static,
307 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
308 {
309 let maximum_chunk_size = self.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER)? / 2;
312 let maximum_chunk_size: usize = maximum_chunk_size
313 .try_into()
314 .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;
315
316 if keys_to_chunk.len() < maximum_chunk_size {
317 let chunk = keys_to_chunk;
319
320 Ok(do_query(self, chunk)?)
321 } else {
322 let capacity = result_capacity.unwrap_or_default();
326 let mut all_results = Vec::with_capacity(capacity);
327
328 while !keys_to_chunk.is_empty() {
329 let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size));
331 let chunk = keys_to_chunk;
332 keys_to_chunk = tail;
333
334 all_results.extend(do_query(self, chunk)?);
335 }
336
337 Ok(all_results)
338 }
339 }
340}
341
342pub(crate) trait SqliteKeyValueStoreConnExt {
354 fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()>;
356
357 fn set_serialized_kv<T: Serialize + Send>(&self, key: &str, value: T) -> Result<()> {
359 let serialized_value = rmp_serde::to_vec_named(&value)?;
360 self.set_kv(key, &serialized_value)?;
361
362 Ok(())
363 }
364
365 fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
367
368 fn set_db_version(&self, version: u8) -> rusqlite::Result<()> {
370 self.set_kv("version", &[version])
371 }
372}
373
374impl SqliteKeyValueStoreConnExt for rusqlite::Connection {
375 fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()> {
376 self.execute(
377 "INSERT INTO kv VALUES (?1, ?2) ON CONFLICT (key) DO UPDATE SET value = ?2",
378 (key, value),
379 )?;
380 Ok(())
381 }
382
383 fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
384 self.execute("DELETE FROM kv WHERE key = ?1", (key,))?;
385 Ok(())
386 }
387}
388
389#[async_trait]
401pub(crate) trait SqliteKeyValueStoreAsyncConnExt: SqliteAsyncConnExt {
402 async fn kv_table_exists(&self) -> rusqlite::Result<bool> {
404 self.query_row(
405 "SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'kv')",
406 (),
407 |row| row.get(0),
408 )
409 .await
410 }
411
412 async fn get_kv(&self, key: &str) -> rusqlite::Result<Option<Vec<u8>>> {
414 let key = key.to_owned();
415 self.query_row("SELECT value FROM kv WHERE key = ?", (key,), |row| row.get(0))
416 .await
417 .optional()
418 }
419
420 async fn get_serialized_kv<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
422 let Some(bytes) = self.get_kv(key).await? else {
423 return Ok(None);
424 };
425
426 Ok(Some(rmp_serde::from_slice(&bytes)?))
427 }
428
429 async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()>;
431
432 async fn set_serialized_kv<T: Serialize + Send + 'static>(
434 &self,
435 key: &str,
436 value: T,
437 ) -> Result<()>;
438
439 async fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
441
442 async fn db_version(&self) -> Result<u8, OpenStoreError> {
444 let kv_exists = self.kv_table_exists().await.map_err(OpenStoreError::LoadVersion)?;
445
446 if kv_exists {
447 match self.get_kv("version").await.map_err(OpenStoreError::LoadVersion)?.as_deref() {
448 Some([v]) => Ok(*v),
449 Some(_) => Err(OpenStoreError::InvalidVersion),
450 None => Err(OpenStoreError::MissingVersion),
451 }
452 } else {
453 Ok(0)
454 }
455 }
456
457 async fn get_or_create_store_cipher(
459 &self,
460 passphrase: &str,
461 ) -> Result<StoreCipher, OpenStoreError> {
462 let encrypted_cipher = self.get_kv("cipher").await.map_err(OpenStoreError::LoadCipher)?;
463
464 let cipher = if let Some(encrypted) = encrypted_cipher {
465 StoreCipher::import(passphrase, &encrypted)?
466 } else {
467 let cipher = StoreCipher::new()?;
468 #[cfg(not(test))]
469 let export = cipher.export(passphrase);
470 #[cfg(test)]
471 let export = cipher._insecure_export_fast_for_testing(passphrase);
472 self.set_kv("cipher", export?).await.map_err(OpenStoreError::SaveCipher)?;
473 cipher
474 };
475
476 Ok(cipher)
477 }
478}
479
480#[async_trait]
481impl SqliteKeyValueStoreAsyncConnExt for SqliteAsyncConn {
482 async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()> {
483 let key = key.to_owned();
484 self.interact(move |conn| conn.set_kv(&key, &value)).await.unwrap()?;
485
486 Ok(())
487 }
488
489 async fn set_serialized_kv<T: Serialize + Send + 'static>(
490 &self,
491 key: &str,
492 value: T,
493 ) -> Result<()> {
494 let key = key.to_owned();
495 self.interact(move |conn| conn.set_serialized_kv(&key, value)).await.unwrap()?;
496
497 Ok(())
498 }
499
500 async fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
501 let key = key.to_owned();
502 self.interact(move |conn| conn.clear_kv(&key)).await.unwrap()?;
503
504 Ok(())
505 }
506}
507
508pub(crate) fn repeat_vars(count: usize) -> impl fmt::Display {
510 assert_ne!(count, 0, "Can't generate zero repeated vars");
511
512 iter::repeat_n("?", count).format(",")
513}
514
515pub(crate) fn time_to_timestamp(time: SystemTime) -> i64 {
520 time.duration_since(SystemTime::UNIX_EPOCH)
521 .ok()
522 .and_then(|d| d.as_secs().try_into().ok())
523 .unwrap_or(0)
526}
527
528pub(crate) trait EncryptableStore {
537 fn get_cypher(&self) -> Option<&StoreCipher>;
538
539 fn encode_key(&self, table_name: &str, key: impl AsRef<[u8]>) -> Key {
544 let bytes = key.as_ref();
545 if let Some(store_cipher) = self.get_cypher() {
546 Key::Hashed(store_cipher.hash_key(table_name, bytes))
547 } else {
548 Key::Plain(bytes.to_owned())
549 }
550 }
551
552 fn encode_value(&self, value: Vec<u8>) -> Result<Vec<u8>> {
553 if let Some(key) = self.get_cypher() {
554 let encrypted = key.encrypt_value_data(value)?;
555 Ok(rmp_serde::to_vec_named(&encrypted)?)
556 } else {
557 Ok(value)
558 }
559 }
560
561 fn decode_value<'a>(&self, value: &'a [u8]) -> Result<Cow<'a, [u8]>> {
562 if let Some(key) = self.get_cypher() {
563 let encrypted = rmp_serde::from_slice(value)?;
564 let decrypted = key.decrypt_value_data(encrypted)?;
565 Ok(Cow::Owned(decrypted))
566 } else {
567 Ok(Cow::Borrowed(value))
568 }
569 }
570
571 fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
572 let serialized = rmp_serde::to_vec_named(value)?;
573 self.encode_value(serialized)
574 }
575
576 fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
577 let decoded = self.decode_value(value)?;
578 Ok(rmp_serde::from_slice(&decoded)?)
579 }
580
581 fn serialize_json(&self, value: &impl Serialize) -> Result<Vec<u8>> {
582 let serialized = serde_json::to_vec(value)?;
583 self.encode_value(serialized)
584 }
585
586 fn deserialize_json<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T> {
587 let decoded = self.decode_value(data)?;
588
589 let json_deserializer = &mut serde_json::Deserializer::from_slice(&decoded);
590
591 serde_path_to_error::deserialize(json_deserializer).map_err(|err| {
592 let raw_json: Option<Raw<serde_json::Value>> = serde_json::from_slice(&decoded).ok();
593
594 let target_type = std::any::type_name::<T>();
595 let serde_path = err.path().to_string();
596
597 error!(
598 sentry = true,
599 %err,
600 "Failed to deserialize {target_type} in a store: {serde_path}",
601 );
602
603 if let Some(raw) = raw_json {
604 if let Some(room_id) = raw.get_field::<OwnedRoomId>("room_id").ok().flatten() {
605 warn!("Found a room id in the source data to deserialize: {room_id}");
606 }
607 if let Some(event_id) = raw.get_field::<OwnedEventId>("event_id").ok().flatten() {
608 warn!("Found an event id in the source data to deserialize: {event_id}");
609 }
610 }
611
612 err.into_inner().into()
613 })
614 }
615}
616
617#[cfg(test)]
618mod unit_tests {
619 use std::time::Duration;
620
621 use super::*;
622
623 #[test]
624 fn can_generate_repeated_vars() {
625 assert_eq!(repeat_vars(1).to_string(), "?");
626 assert_eq!(repeat_vars(2).to_string(), "?,?");
627 assert_eq!(repeat_vars(5).to_string(), "?,?,?,?,?");
628 }
629
630 #[test]
631 #[should_panic(expected = "Can't generate zero repeated vars")]
632 fn generating_zero_vars_panics() {
633 repeat_vars(0);
634 }
635
636 #[test]
637 fn test_time_to_timestamp() {
638 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH), 0);
639 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH + Duration::from_secs(60)), 60);
640
641 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH - Duration::from_secs(60)), 0);
643 }
644}