1use core::fmt;
16use std::{
17 borrow::{Borrow, Cow},
18 cmp::min,
19 iter,
20 ops::Deref,
21};
22
23use async_trait::async_trait;
24use deadpool_sync::InteractError;
25use itertools::Itertools;
26use matrix_sdk_store_encryption::StoreCipher;
27use ruma::{OwnedEventId, OwnedRoomId, serde::Raw, time::SystemTime};
28use rusqlite::{OptionalExtension, Params, Row, Statement, Transaction, limits::Limit};
29use serde::{Serialize, de::DeserializeOwned};
30use tracing::{error, trace, warn};
31use zeroize::Zeroize;
32
33use crate::{
34 OpenStoreError, RuntimeConfig, Secret,
35 connection::Connection as SqliteAsyncConn,
36 error::{Error, Result},
37};
38
39#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
40pub(crate) enum Key {
41 Plain(Vec<u8>),
42 Hashed([u8; 32]),
43}
44
45impl Deref for Key {
46 type Target = [u8];
47
48 fn deref(&self) -> &Self::Target {
49 match self {
50 Key::Plain(slice) => slice,
51 Key::Hashed(bytes) => bytes,
52 }
53 }
54}
55
56impl Borrow<[u8]> for Key {
57 fn borrow(&self) -> &[u8] {
58 self.deref()
59 }
60}
61
62impl rusqlite::ToSql for Key {
63 fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
64 self.deref().to_sql()
65 }
66}
67
68#[async_trait]
69pub(crate) trait SqliteAsyncConnExt {
70 async fn execute<P>(
71 &self,
72 sql: impl AsRef<str> + Send + 'static,
73 params: P,
74 ) -> rusqlite::Result<usize>
75 where
76 P: Params + Send + 'static;
77
78 async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()>;
79
80 async fn prepare<T, F>(
81 &self,
82 sql: impl AsRef<str> + Send + 'static,
83 f: F,
84 ) -> rusqlite::Result<T>
85 where
86 T: Send + 'static,
87 F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static;
88
89 async fn query_row<T, P, F>(
90 &self,
91 sql: impl AsRef<str> + Send + 'static,
92 params: P,
93 f: F,
94 ) -> rusqlite::Result<T>
95 where
96 T: Send + 'static,
97 P: Params + Send + 'static,
98 F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static;
99
100 async fn query_many<T, P, F>(
101 &self,
102 sql: impl AsRef<str> + Send + 'static,
103 params: P,
104 f: F,
105 ) -> rusqlite::Result<Vec<T>>
106 where
107 T: Send + 'static,
108 P: Params + Send + 'static,
109 F: FnMut(&Row<'_>) -> rusqlite::Result<T> + Send + 'static;
110
111 async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
112 where
113 T: Send + 'static,
114 E: From<rusqlite::Error> + Send + 'static,
115 F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static;
116
117 async fn chunk_large_query_over<Query, Res>(
118 &self,
119 mut keys_to_chunk: Vec<Key>,
120 result_capacity: Option<usize>,
121 do_query: Query,
122 ) -> Result<Vec<Res>>
123 where
124 Res: Send + 'static,
125 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
126
127 async fn apply_runtime_config(&self, runtime_config: RuntimeConfig) -> Result<()> {
136 let RuntimeConfig { optimize, cache_size, journal_size_limit } = runtime_config;
137
138 if optimize {
139 self.optimize().await?;
140 }
141
142 self.cache_size(cache_size).await?;
143 self.journal_size_limit(journal_size_limit).await?;
144
145 Ok(())
146 }
147
148 async fn optimize(&self) -> Result<()> {
158 self.execute_batch("PRAGMA optimize = 0x10002;").await?;
159 Ok(())
160 }
161
162 async fn cache_size(&self, cache_size: u32) -> Result<()> {
168 let n = cache_size / 1024;
171
172 self.execute_batch(format!("PRAGMA cache_size = -{n};")).await?;
173 Ok(())
174 }
175
176 async fn journal_size_limit(&self, limit: u32) -> Result<()> {
194 self.execute_batch(format!("PRAGMA journal_size_limit = {limit};")).await?;
195 Ok(())
196 }
197
198 async fn vacuum(&self) -> Result<()> {
202 self.wal_checkpoint().await;
204 if let Err(error) = self.execute_batch("VACUUM").await {
205 #[cfg(not(any(test, debug_assertions)))]
208 tracing::warn!("Failed to vacuum database: {error}");
209
210 #[cfg(any(test, debug_assertions))]
212 return Err(error.into());
213 } else {
214 trace!("VACUUM complete");
215 self.wal_checkpoint().await;
217 }
218
219 Ok(())
220 }
221
222 async fn wal_checkpoint(&self) {
227 match self.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);").await {
228 Ok(_) => trace!("WAL checkpoint completed"),
229 Err(error) => error!(?error, "WAL checkpoint error"),
230 }
231 }
232
233 async fn get_db_size(&self) -> Result<usize> {
234 let page_size =
235 self.query_row("PRAGMA page_size;", (), |row| row.get::<_, usize>(0)).await?;
236 let total_pages =
237 self.query_row("PRAGMA page_count;", (), |row| row.get::<_, usize>(0)).await?;
238
239 Ok(total_pages * page_size)
240 }
241}
242
243#[async_trait]
244impl SqliteAsyncConnExt for SqliteAsyncConn {
245 async fn execute<P>(
246 &self,
247 sql: impl AsRef<str> + Send + 'static,
248 params: P,
249 ) -> rusqlite::Result<usize>
250 where
251 P: Params + Send + 'static,
252 {
253 self.interact(move |conn| conn.execute(sql.as_ref(), params))
254 .await
255 .map_err(map_interact_err)?
256 }
257
258 async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()> {
259 self.interact(move |conn| conn.execute_batch(sql.as_ref()))
260 .await
261 .map_err(map_interact_err)?
262 }
263
264 async fn prepare<T, F>(
265 &self,
266 sql: impl AsRef<str> + Send + 'static,
267 f: F,
268 ) -> rusqlite::Result<T>
269 where
270 T: Send + 'static,
271 F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static,
272 {
273 self.interact(move |conn| f(conn.prepare(sql.as_ref())?)).await.map_err(map_interact_err)?
274 }
275
276 async fn query_row<T, P, F>(
277 &self,
278 sql: impl AsRef<str> + Send + 'static,
279 params: P,
280 f: F,
281 ) -> rusqlite::Result<T>
282 where
283 T: Send + 'static,
284 P: Params + Send + 'static,
285 F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static,
286 {
287 self.interact(move |conn| conn.query_row(sql.as_ref(), params, f))
288 .await
289 .map_err(map_interact_err)?
290 }
291
292 async fn query_many<T, P, F>(
293 &self,
294 sql: impl AsRef<str> + Send + 'static,
295 params: P,
296 f: F,
297 ) -> rusqlite::Result<Vec<T>>
298 where
299 T: Send + 'static,
300 P: Params + Send + 'static,
301 F: FnMut(&Row<'_>) -> rusqlite::Result<T> + Send + 'static,
302 {
303 self.interact(move |conn| {
304 let mut stmt = conn.prepare(sql.as_ref())?;
305 stmt.query_and_then(params, f)?.collect()
306 })
307 .await
308 .map_err(map_interact_err)?
309 }
310
311 async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
312 where
313 T: Send + 'static,
314 E: From<rusqlite::Error> + Send + 'static,
315 F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static,
316 {
317 self.interact(move |conn| {
318 let txn = conn.transaction()?;
319 let result = f(&txn)?;
320 txn.commit()?;
321 Ok(result)
322 })
323 .await
324 .map_err(map_interact_err)
325 .map_err(E::from)?
326 }
327
328 async fn chunk_large_query_over<Query, Res>(
335 &self,
336 keys_to_chunk: Vec<Key>,
337 result_capacity: Option<usize>,
338 do_query: Query,
339 ) -> Result<Vec<Res>>
340 where
341 Res: Send + 'static,
342 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
343 {
344 self.with_transaction(move |txn| {
345 txn.chunk_large_query_over(keys_to_chunk, result_capacity, do_query)
346 })
347 .await
348 }
349}
350
351fn map_interact_err(error: InteractError) -> rusqlite::Error {
357 match error {
358 InteractError::Panic(p) => panic!("{p:?}"),
359 InteractError::Cancelled => rusqlite::Error::SqliteFailure(
360 rusqlite::ffi::Error::new(rusqlite::ffi::SQLITE_ABORT),
361 None,
362 ),
363 }
364}
365
366pub(crate) trait SqliteTransactionExt {
367 fn chunk_large_query_over<Key, Query, Res>(
368 &self,
369 keys_to_chunk: Vec<Key>,
370 result_capacity: Option<usize>,
371 do_query: Query,
372 ) -> Result<Vec<Res>>
373 where
374 Res: Send + 'static,
375 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
376}
377
378impl SqliteTransactionExt for Transaction<'_> {
379 fn chunk_large_query_over<Key, Query, Res>(
380 &self,
381 mut keys_to_chunk: Vec<Key>,
382 result_capacity: Option<usize>,
383 do_query: Query,
384 ) -> Result<Vec<Res>>
385 where
386 Res: Send + 'static,
387 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
388 {
389 let maximum_chunk_size = self.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER)? / 2;
392 let maximum_chunk_size: usize = maximum_chunk_size
393 .try_into()
394 .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;
395
396 if keys_to_chunk.len() < maximum_chunk_size {
397 let chunk = keys_to_chunk;
399
400 Ok(do_query(self, chunk)?)
401 } else {
402 let capacity = result_capacity.unwrap_or_default();
406 let mut all_results = Vec::with_capacity(capacity);
407
408 while !keys_to_chunk.is_empty() {
409 let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size));
411 let chunk = keys_to_chunk;
412 keys_to_chunk = tail;
413
414 all_results.extend(do_query(self, chunk)?);
415 }
416
417 Ok(all_results)
418 }
419 }
420}
421
422pub(crate) trait SqliteKeyValueStoreConnExt {
434 fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()>;
436
437 fn set_serialized_kv<T: Serialize + Send>(&self, key: &str, value: T) -> Result<()> {
439 let serialized_value = rmp_serde::to_vec_named(&value)?;
440 self.set_kv(key, &serialized_value)?;
441
442 Ok(())
443 }
444
445 fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
447
448 fn set_db_version(&self, version: u8) -> rusqlite::Result<()> {
450 self.set_kv("version", &[version])
451 }
452}
453
454impl SqliteKeyValueStoreConnExt for rusqlite::Connection {
455 fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()> {
456 self.execute(
457 "INSERT INTO kv VALUES (?1, ?2) ON CONFLICT (key) DO UPDATE SET value = ?2",
458 (key, value),
459 )?;
460 Ok(())
461 }
462
463 fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
464 self.execute("DELETE FROM kv WHERE key = ?1", (key,))?;
465 Ok(())
466 }
467}
468
469#[async_trait]
481pub(crate) trait SqliteKeyValueStoreAsyncConnExt: SqliteAsyncConnExt {
482 async fn kv_table_exists(&self) -> rusqlite::Result<bool> {
484 self.query_row(
485 "SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'kv')",
486 (),
487 |row| row.get(0),
488 )
489 .await
490 }
491
492 async fn get_kv(&self, key: &str) -> rusqlite::Result<Option<Vec<u8>>> {
494 let key = key.to_owned();
495 self.query_row("SELECT value FROM kv WHERE key = ?", (key,), |row| row.get(0))
496 .await
497 .optional()
498 }
499
500 async fn get_serialized_kv<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
502 let Some(bytes) = self.get_kv(key).await? else {
503 return Ok(None);
504 };
505
506 Ok(Some(rmp_serde::from_slice(&bytes)?))
507 }
508
509 async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()>;
511
512 async fn set_serialized_kv<T: Serialize + Send + 'static>(
514 &self,
515 key: &str,
516 value: T,
517 ) -> Result<()>;
518
519 async fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
521
522 async fn db_version(&self) -> Result<u8, OpenStoreError> {
524 let kv_exists = self.kv_table_exists().await.map_err(OpenStoreError::LoadVersion)?;
525
526 if kv_exists {
527 match self.get_kv("version").await.map_err(OpenStoreError::LoadVersion)?.as_deref() {
528 Some([v]) => Ok(*v),
529 Some(_) => Err(OpenStoreError::InvalidVersion),
530 None => Err(OpenStoreError::MissingVersion),
531 }
532 } else {
533 Ok(0)
534 }
535 }
536
537 async fn get_or_create_store_cipher(
539 &self,
540 mut secret: Secret,
541 ) -> Result<StoreCipher, OpenStoreError> {
542 let encrypted_cipher = self.get_kv("cipher").await.map_err(OpenStoreError::LoadCipher)?;
543
544 let cipher = if let Some(encrypted) = encrypted_cipher {
545 match secret {
546 Secret::PassPhrase(ref passphrase) => StoreCipher::import(passphrase, &encrypted)?,
547 Secret::Key(ref key) => StoreCipher::import_with_key(key, &encrypted)?,
548 }
549 } else {
550 let cipher = StoreCipher::new()?;
551 let export = match secret {
552 Secret::PassPhrase(ref passphrase) => {
553 #[cfg(not(test))]
554 {
555 cipher.export(passphrase)
556 }
557 #[cfg(test)]
558 {
559 cipher._insecure_export_fast_for_testing(passphrase)
560 }
561 }
562 Secret::Key(ref key) => cipher.export_with_key(key),
563 };
564 self.set_kv("cipher", export?).await.map_err(OpenStoreError::SaveCipher)?;
565 cipher
566 };
567 secret.zeroize();
568 Ok(cipher)
569 }
570}
571
572#[async_trait]
573impl SqliteKeyValueStoreAsyncConnExt for SqliteAsyncConn {
574 async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()> {
575 let key = key.to_owned();
576 self.interact(move |conn| conn.set_kv(&key, &value)).await.unwrap()?;
577
578 Ok(())
579 }
580
581 async fn set_serialized_kv<T: Serialize + Send + 'static>(
582 &self,
583 key: &str,
584 value: T,
585 ) -> Result<()> {
586 let key = key.to_owned();
587 self.interact(move |conn| conn.set_serialized_kv(&key, value)).await.unwrap()?;
588
589 Ok(())
590 }
591
592 async fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
593 let key = key.to_owned();
594 self.interact(move |conn| conn.clear_kv(&key)).await.unwrap()?;
595
596 Ok(())
597 }
598}
599
600pub(crate) fn repeat_vars(count: usize) -> impl fmt::Display {
602 assert_ne!(count, 0, "Can't generate zero repeated vars");
603
604 iter::repeat_n("?", count).format(",")
605}
606
607pub(crate) fn time_to_timestamp(time: SystemTime) -> i64 {
612 time.duration_since(SystemTime::UNIX_EPOCH)
613 .ok()
614 .and_then(|d| d.as_secs().try_into().ok())
615 .unwrap_or(0)
618}
619
620pub(crate) trait EncryptableStore {
629 fn get_cypher(&self) -> Option<&StoreCipher>;
630
631 fn encode_key(&self, table_name: &str, key: impl AsRef<[u8]>) -> Key {
636 let bytes = key.as_ref();
637 if let Some(store_cipher) = self.get_cypher() {
638 Key::Hashed(store_cipher.hash_key(table_name, bytes))
639 } else {
640 Key::Plain(bytes.to_owned())
641 }
642 }
643
644 fn encode_value(&self, value: Vec<u8>) -> Result<Vec<u8>> {
645 if let Some(key) = self.get_cypher() {
646 let encrypted = key.encrypt_value_data(value)?;
647 Ok(rmp_serde::to_vec_named(&encrypted)?)
648 } else {
649 Ok(value)
650 }
651 }
652
653 fn decode_value<'a>(&self, value: &'a [u8]) -> Result<Cow<'a, [u8]>> {
654 if let Some(key) = self.get_cypher() {
655 let encrypted = rmp_serde::from_slice(value)?;
656 let decrypted = key.decrypt_value_data(encrypted)?;
657 Ok(Cow::Owned(decrypted))
658 } else {
659 Ok(Cow::Borrowed(value))
660 }
661 }
662
663 fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
664 let serialized = rmp_serde::to_vec_named(value)?;
665 self.encode_value(serialized)
666 }
667
668 fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
669 let decoded = self.decode_value(value)?;
670 Ok(rmp_serde::from_slice(&decoded)?)
671 }
672
673 fn serialize_json(&self, value: &impl Serialize) -> Result<Vec<u8>> {
674 let serialized = serde_json::to_vec(value)?;
675 self.encode_value(serialized)
676 }
677
678 fn deserialize_json<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T> {
679 let decoded = self.decode_value(data)?;
680
681 let json_deserializer = &mut serde_json::Deserializer::from_slice(&decoded);
682
683 serde_path_to_error::deserialize(json_deserializer).map_err(|err| {
684 let raw_json: Option<Raw<serde_json::Value>> = serde_json::from_slice(&decoded).ok();
685
686 let target_type = std::any::type_name::<T>();
687 let serde_path = err.path().to_string();
688
689 error!(
690 sentry = true,
691 %err,
692 "Failed to deserialize {target_type} in a store: {serde_path}",
693 );
694
695 if let Some(raw) = raw_json {
696 if let Some(room_id) = raw.get_field::<OwnedRoomId>("room_id").ok().flatten() {
697 warn!("Found a room id in the source data to deserialize: {room_id}");
698 }
699 if let Some(event_id) = raw.get_field::<OwnedEventId>("event_id").ok().flatten() {
700 warn!("Found an event id in the source data to deserialize: {event_id}");
701 }
702 }
703
704 err.into_inner().into()
705 })
706 }
707}
708
709#[cfg(test)]
710mod unit_tests {
711 use std::time::Duration;
712
713 use super::*;
714
715 #[test]
716 fn can_generate_repeated_vars() {
717 assert_eq!(repeat_vars(1).to_string(), "?");
718 assert_eq!(repeat_vars(2).to_string(), "?,?");
719 assert_eq!(repeat_vars(5).to_string(), "?,?,?,?,?");
720 }
721
722 #[test]
723 #[should_panic(expected = "Can't generate zero repeated vars")]
724 fn generating_zero_vars_panics() {
725 repeat_vars(0);
726 }
727
728 #[test]
729 fn test_time_to_timestamp() {
730 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH), 0);
731 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH + Duration::from_secs(60)), 60);
732
733 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH - Duration::from_secs(60)), 0);
735 }
736}