1use core::fmt;
16use std::{
17 borrow::{Borrow, Cow},
18 cmp::min,
19 iter,
20 ops::Deref,
21};
22
23use async_trait::async_trait;
24use itertools::Itertools;
25use matrix_sdk_store_encryption::StoreCipher;
26use ruma::{serde::Raw, time::SystemTime, OwnedEventId, OwnedRoomId};
27use rusqlite::{limits::Limit, OptionalExtension, Params, Row, Statement, Transaction};
28use serde::{de::DeserializeOwned, Serialize};
29use tracing::{error, trace, warn};
30use zeroize::Zeroize;
31
32use crate::{
33 connection::Connection as SqliteAsyncConn,
34 error::{Error, Result},
35 OpenStoreError, RuntimeConfig, Secret,
36};
37
38#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
39pub(crate) enum Key {
40 Plain(Vec<u8>),
41 Hashed([u8; 32]),
42}
43
44impl Deref for Key {
45 type Target = [u8];
46
47 fn deref(&self) -> &Self::Target {
48 match self {
49 Key::Plain(slice) => slice,
50 Key::Hashed(bytes) => bytes,
51 }
52 }
53}
54
55impl Borrow<[u8]> for Key {
56 fn borrow(&self) -> &[u8] {
57 self.deref()
58 }
59}
60
61impl rusqlite::ToSql for Key {
62 fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
63 self.deref().to_sql()
64 }
65}
66
67#[async_trait]
68pub(crate) trait SqliteAsyncConnExt {
69 async fn execute<P>(
70 &self,
71 sql: impl AsRef<str> + Send + 'static,
72 params: P,
73 ) -> rusqlite::Result<usize>
74 where
75 P: Params + Send + 'static;
76
77 async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()>;
78
79 async fn prepare<T, F>(
80 &self,
81 sql: impl AsRef<str> + Send + 'static,
82 f: F,
83 ) -> rusqlite::Result<T>
84 where
85 T: Send + 'static,
86 F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static;
87
88 async fn query_row<T, P, F>(
89 &self,
90 sql: impl AsRef<str> + Send + 'static,
91 params: P,
92 f: F,
93 ) -> rusqlite::Result<T>
94 where
95 T: Send + 'static,
96 P: Params + Send + 'static,
97 F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static;
98
99 async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
100 where
101 T: Send + 'static,
102 E: From<rusqlite::Error> + Send + 'static,
103 F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static;
104
105 async fn chunk_large_query_over<Query, Res>(
106 &self,
107 mut keys_to_chunk: Vec<Key>,
108 result_capacity: Option<usize>,
109 do_query: Query,
110 ) -> Result<Vec<Res>>
111 where
112 Res: Send + 'static,
113 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
114
115 async fn apply_runtime_config(&self, runtime_config: RuntimeConfig) -> Result<()> {
124 let RuntimeConfig { optimize, cache_size, journal_size_limit } = runtime_config;
125
126 if optimize {
127 self.optimize().await?;
128 }
129
130 self.cache_size(cache_size).await?;
131 self.journal_size_limit(journal_size_limit).await?;
132
133 Ok(())
134 }
135
136 async fn optimize(&self) -> Result<()> {
146 self.execute_batch("PRAGMA optimize = 0x10002;").await?;
147 Ok(())
148 }
149
150 async fn cache_size(&self, cache_size: u32) -> Result<()> {
156 let n = cache_size / 1024;
159
160 self.execute_batch(format!("PRAGMA cache_size = -{n};")).await?;
161 Ok(())
162 }
163
164 async fn journal_size_limit(&self, limit: u32) -> Result<()> {
182 self.execute_batch(format!("PRAGMA journal_size_limit = {limit};")).await?;
183 Ok(())
184 }
185
186 async fn vacuum(&self) -> Result<()> {
190 if let Err(error) = self.execute_batch("VACUUM").await {
191 #[cfg(not(any(test, debug_assertions)))]
194 tracing::warn!("Failed to vacuum database: {error}");
195
196 #[cfg(any(test, debug_assertions))]
198 return Err(error.into());
199 } else {
200 trace!("VACUUM complete");
201 }
202
203 Ok(())
204 }
205
206 async fn get_db_size(&self) -> Result<usize> {
207 let page_size =
208 self.query_row("PRAGMA page_size;", (), |row| row.get::<_, usize>(0)).await?;
209 let total_pages =
210 self.query_row("PRAGMA page_count;", (), |row| row.get::<_, usize>(0)).await?;
211
212 Ok(total_pages * page_size)
213 }
214}
215
216#[async_trait]
217impl SqliteAsyncConnExt for SqliteAsyncConn {
218 async fn execute<P>(
219 &self,
220 sql: impl AsRef<str> + Send + 'static,
221 params: P,
222 ) -> rusqlite::Result<usize>
223 where
224 P: Params + Send + 'static,
225 {
226 self.interact(move |conn| conn.execute(sql.as_ref(), params)).await.unwrap()
227 }
228
229 async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()> {
230 self.interact(move |conn| conn.execute_batch(sql.as_ref())).await.unwrap()
231 }
232
233 async fn prepare<T, F>(
234 &self,
235 sql: impl AsRef<str> + Send + 'static,
236 f: F,
237 ) -> rusqlite::Result<T>
238 where
239 T: Send + 'static,
240 F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static,
241 {
242 self.interact(move |conn| f(conn.prepare(sql.as_ref())?)).await.unwrap()
243 }
244
245 async fn query_row<T, P, F>(
246 &self,
247 sql: impl AsRef<str> + Send + 'static,
248 params: P,
249 f: F,
250 ) -> rusqlite::Result<T>
251 where
252 T: Send + 'static,
253 P: Params + Send + 'static,
254 F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static,
255 {
256 self.interact(move |conn| conn.query_row(sql.as_ref(), params, f)).await.unwrap()
257 }
258
259 async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
260 where
261 T: Send + 'static,
262 E: From<rusqlite::Error> + Send + 'static,
263 F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static,
264 {
265 self.interact(move |conn| {
266 let txn = conn.transaction()?;
267 let result = f(&txn)?;
268 txn.commit()?;
269 Ok(result)
270 })
271 .await
272 .unwrap()
273 }
274
275 async fn chunk_large_query_over<Query, Res>(
282 &self,
283 keys_to_chunk: Vec<Key>,
284 result_capacity: Option<usize>,
285 do_query: Query,
286 ) -> Result<Vec<Res>>
287 where
288 Res: Send + 'static,
289 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
290 {
291 self.with_transaction(move |txn| {
292 txn.chunk_large_query_over(keys_to_chunk, result_capacity, do_query)
293 })
294 .await
295 }
296}
297
298pub(crate) trait SqliteTransactionExt {
299 fn chunk_large_query_over<Key, Query, Res>(
300 &self,
301 keys_to_chunk: Vec<Key>,
302 result_capacity: Option<usize>,
303 do_query: Query,
304 ) -> Result<Vec<Res>>
305 where
306 Res: Send + 'static,
307 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
308}
309
310impl SqliteTransactionExt for Transaction<'_> {
311 fn chunk_large_query_over<Key, Query, Res>(
312 &self,
313 mut keys_to_chunk: Vec<Key>,
314 result_capacity: Option<usize>,
315 do_query: Query,
316 ) -> Result<Vec<Res>>
317 where
318 Res: Send + 'static,
319 Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
320 {
321 let maximum_chunk_size = self.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER)? / 2;
324 let maximum_chunk_size: usize = maximum_chunk_size
325 .try_into()
326 .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;
327
328 if keys_to_chunk.len() < maximum_chunk_size {
329 let chunk = keys_to_chunk;
331
332 Ok(do_query(self, chunk)?)
333 } else {
334 let capacity = result_capacity.unwrap_or_default();
338 let mut all_results = Vec::with_capacity(capacity);
339
340 while !keys_to_chunk.is_empty() {
341 let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size));
343 let chunk = keys_to_chunk;
344 keys_to_chunk = tail;
345
346 all_results.extend(do_query(self, chunk)?);
347 }
348
349 Ok(all_results)
350 }
351 }
352}
353
354pub(crate) trait SqliteKeyValueStoreConnExt {
366 fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()>;
368
369 fn set_serialized_kv<T: Serialize + Send>(&self, key: &str, value: T) -> Result<()> {
371 let serialized_value = rmp_serde::to_vec_named(&value)?;
372 self.set_kv(key, &serialized_value)?;
373
374 Ok(())
375 }
376
377 fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
379
380 fn set_db_version(&self, version: u8) -> rusqlite::Result<()> {
382 self.set_kv("version", &[version])
383 }
384}
385
386impl SqliteKeyValueStoreConnExt for rusqlite::Connection {
387 fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()> {
388 self.execute(
389 "INSERT INTO kv VALUES (?1, ?2) ON CONFLICT (key) DO UPDATE SET value = ?2",
390 (key, value),
391 )?;
392 Ok(())
393 }
394
395 fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
396 self.execute("DELETE FROM kv WHERE key = ?1", (key,))?;
397 Ok(())
398 }
399}
400
401#[async_trait]
413pub(crate) trait SqliteKeyValueStoreAsyncConnExt: SqliteAsyncConnExt {
414 async fn kv_table_exists(&self) -> rusqlite::Result<bool> {
416 self.query_row(
417 "SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'kv')",
418 (),
419 |row| row.get(0),
420 )
421 .await
422 }
423
424 async fn get_kv(&self, key: &str) -> rusqlite::Result<Option<Vec<u8>>> {
426 let key = key.to_owned();
427 self.query_row("SELECT value FROM kv WHERE key = ?", (key,), |row| row.get(0))
428 .await
429 .optional()
430 }
431
432 async fn get_serialized_kv<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
434 let Some(bytes) = self.get_kv(key).await? else {
435 return Ok(None);
436 };
437
438 Ok(Some(rmp_serde::from_slice(&bytes)?))
439 }
440
441 async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()>;
443
444 async fn set_serialized_kv<T: Serialize + Send + 'static>(
446 &self,
447 key: &str,
448 value: T,
449 ) -> Result<()>;
450
451 async fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
453
454 async fn db_version(&self) -> Result<u8, OpenStoreError> {
456 let kv_exists = self.kv_table_exists().await.map_err(OpenStoreError::LoadVersion)?;
457
458 if kv_exists {
459 match self.get_kv("version").await.map_err(OpenStoreError::LoadVersion)?.as_deref() {
460 Some([v]) => Ok(*v),
461 Some(_) => Err(OpenStoreError::InvalidVersion),
462 None => Err(OpenStoreError::MissingVersion),
463 }
464 } else {
465 Ok(0)
466 }
467 }
468
469 async fn get_or_create_store_cipher(
471 &self,
472 mut secret: Secret,
473 ) -> Result<StoreCipher, OpenStoreError> {
474 let encrypted_cipher = self.get_kv("cipher").await.map_err(OpenStoreError::LoadCipher)?;
475
476 let cipher = if let Some(encrypted) = encrypted_cipher {
477 match secret {
478 Secret::PassPhrase(ref passphrase) => StoreCipher::import(passphrase, &encrypted)?,
479 Secret::Key(ref key) => StoreCipher::import_with_key(key, &encrypted)?,
480 }
481 } else {
482 let cipher = StoreCipher::new()?;
483 let export = match secret {
484 Secret::PassPhrase(ref passphrase) => {
485 #[cfg(not(test))]
486 {
487 cipher.export(passphrase)
488 }
489 #[cfg(test)]
490 {
491 cipher._insecure_export_fast_for_testing(passphrase)
492 }
493 }
494 Secret::Key(ref key) => cipher.export_with_key(key),
495 };
496 self.set_kv("cipher", export?).await.map_err(OpenStoreError::SaveCipher)?;
497 cipher
498 };
499 secret.zeroize();
500 Ok(cipher)
501 }
502}
503
504#[async_trait]
505impl SqliteKeyValueStoreAsyncConnExt for SqliteAsyncConn {
506 async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()> {
507 let key = key.to_owned();
508 self.interact(move |conn| conn.set_kv(&key, &value)).await.unwrap()?;
509
510 Ok(())
511 }
512
513 async fn set_serialized_kv<T: Serialize + Send + 'static>(
514 &self,
515 key: &str,
516 value: T,
517 ) -> Result<()> {
518 let key = key.to_owned();
519 self.interact(move |conn| conn.set_serialized_kv(&key, value)).await.unwrap()?;
520
521 Ok(())
522 }
523
524 async fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
525 let key = key.to_owned();
526 self.interact(move |conn| conn.clear_kv(&key)).await.unwrap()?;
527
528 Ok(())
529 }
530}
531
532pub(crate) fn repeat_vars(count: usize) -> impl fmt::Display {
534 assert_ne!(count, 0, "Can't generate zero repeated vars");
535
536 iter::repeat_n("?", count).format(",")
537}
538
539pub(crate) fn time_to_timestamp(time: SystemTime) -> i64 {
544 time.duration_since(SystemTime::UNIX_EPOCH)
545 .ok()
546 .and_then(|d| d.as_secs().try_into().ok())
547 .unwrap_or(0)
550}
551
552pub(crate) trait EncryptableStore {
561 fn get_cypher(&self) -> Option<&StoreCipher>;
562
563 fn encode_key(&self, table_name: &str, key: impl AsRef<[u8]>) -> Key {
568 let bytes = key.as_ref();
569 if let Some(store_cipher) = self.get_cypher() {
570 Key::Hashed(store_cipher.hash_key(table_name, bytes))
571 } else {
572 Key::Plain(bytes.to_owned())
573 }
574 }
575
576 fn encode_value(&self, value: Vec<u8>) -> Result<Vec<u8>> {
577 if let Some(key) = self.get_cypher() {
578 let encrypted = key.encrypt_value_data(value)?;
579 Ok(rmp_serde::to_vec_named(&encrypted)?)
580 } else {
581 Ok(value)
582 }
583 }
584
585 fn decode_value<'a>(&self, value: &'a [u8]) -> Result<Cow<'a, [u8]>> {
586 if let Some(key) = self.get_cypher() {
587 let encrypted = rmp_serde::from_slice(value)?;
588 let decrypted = key.decrypt_value_data(encrypted)?;
589 Ok(Cow::Owned(decrypted))
590 } else {
591 Ok(Cow::Borrowed(value))
592 }
593 }
594
595 fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
596 let serialized = rmp_serde::to_vec_named(value)?;
597 self.encode_value(serialized)
598 }
599
600 fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
601 let decoded = self.decode_value(value)?;
602 Ok(rmp_serde::from_slice(&decoded)?)
603 }
604
605 fn serialize_json(&self, value: &impl Serialize) -> Result<Vec<u8>> {
606 let serialized = serde_json::to_vec(value)?;
607 self.encode_value(serialized)
608 }
609
610 fn deserialize_json<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T> {
611 let decoded = self.decode_value(data)?;
612
613 let json_deserializer = &mut serde_json::Deserializer::from_slice(&decoded);
614
615 serde_path_to_error::deserialize(json_deserializer).map_err(|err| {
616 let raw_json: Option<Raw<serde_json::Value>> = serde_json::from_slice(&decoded).ok();
617
618 let target_type = std::any::type_name::<T>();
619 let serde_path = err.path().to_string();
620
621 error!(
622 sentry = true,
623 %err,
624 "Failed to deserialize {target_type} in a store: {serde_path}",
625 );
626
627 if let Some(raw) = raw_json {
628 if let Some(room_id) = raw.get_field::<OwnedRoomId>("room_id").ok().flatten() {
629 warn!("Found a room id in the source data to deserialize: {room_id}");
630 }
631 if let Some(event_id) = raw.get_field::<OwnedEventId>("event_id").ok().flatten() {
632 warn!("Found an event id in the source data to deserialize: {event_id}");
633 }
634 }
635
636 err.into_inner().into()
637 })
638 }
639}
640
641#[cfg(test)]
642mod unit_tests {
643 use std::time::Duration;
644
645 use super::*;
646
647 #[test]
648 fn can_generate_repeated_vars() {
649 assert_eq!(repeat_vars(1).to_string(), "?");
650 assert_eq!(repeat_vars(2).to_string(), "?,?");
651 assert_eq!(repeat_vars(5).to_string(), "?,?,?,?,?");
652 }
653
654 #[test]
655 #[should_panic(expected = "Can't generate zero repeated vars")]
656 fn generating_zero_vars_panics() {
657 repeat_vars(0);
658 }
659
660 #[test]
661 fn test_time_to_timestamp() {
662 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH), 0);
663 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH + Duration::from_secs(60)), 60);
664
665 assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH - Duration::from_secs(60)), 0);
667 }
668}