matrix_sdk_sqlite/
utils.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::fmt;
16use std::{
17    borrow::{Borrow, Cow},
18    cmp::min,
19    iter,
20    ops::Deref,
21};
22
23use async_trait::async_trait;
24use deadpool_sqlite::Object as SqliteAsyncConn;
25use itertools::Itertools;
26use matrix_sdk_store_encryption::StoreCipher;
27use ruma::{serde::Raw, time::SystemTime, OwnedEventId, OwnedRoomId};
28use rusqlite::{limits::Limit, OptionalExtension, Params, Row, Statement, Transaction};
29use serde::{de::DeserializeOwned, Serialize};
30use tracing::{error, warn};
31
32use crate::{
33    error::{Error, Result},
34    OpenStoreError, RuntimeConfig,
35};
36
37#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
38pub(crate) enum Key {
39    Plain(Vec<u8>),
40    Hashed([u8; 32]),
41}
42
43impl Deref for Key {
44    type Target = [u8];
45
46    fn deref(&self) -> &Self::Target {
47        match self {
48            Key::Plain(slice) => slice,
49            Key::Hashed(bytes) => bytes,
50        }
51    }
52}
53
54impl Borrow<[u8]> for Key {
55    fn borrow(&self) -> &[u8] {
56        self.deref()
57    }
58}
59
60impl rusqlite::ToSql for Key {
61    fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
62        self.deref().to_sql()
63    }
64}
65
66#[async_trait]
67pub(crate) trait SqliteAsyncConnExt {
68    async fn execute<P>(
69        &self,
70        sql: impl AsRef<str> + Send + 'static,
71        params: P,
72    ) -> rusqlite::Result<usize>
73    where
74        P: Params + Send + 'static;
75
76    async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()>;
77
78    async fn prepare<T, F>(
79        &self,
80        sql: impl AsRef<str> + Send + 'static,
81        f: F,
82    ) -> rusqlite::Result<T>
83    where
84        T: Send + 'static,
85        F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static;
86
87    async fn query_row<T, P, F>(
88        &self,
89        sql: impl AsRef<str> + Send + 'static,
90        params: P,
91        f: F,
92    ) -> rusqlite::Result<T>
93    where
94        T: Send + 'static,
95        P: Params + Send + 'static,
96        F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static;
97
98    async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
99    where
100        T: Send + 'static,
101        E: From<rusqlite::Error> + Send + 'static,
102        F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static;
103
104    async fn chunk_large_query_over<Query, Res>(
105        &self,
106        mut keys_to_chunk: Vec<Key>,
107        result_capacity: Option<usize>,
108        do_query: Query,
109    ) -> Result<Vec<Res>>
110    where
111        Res: Send + 'static,
112        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
113
114    /// Apply the [`RuntimeConfig`].
115    ///
116    /// It will call the `Self::optimize`, `Self::cache_size` or
117    /// `Self::journal_size_limit` methods automatically based on the
118    /// `RuntimeConfig` values.
119    ///
120    /// It is possible to call these methods individually though. This
121    /// `apply_runtime_config` method allows to automate this process.
122    async fn apply_runtime_config(&self, runtime_config: RuntimeConfig) -> Result<()> {
123        let RuntimeConfig { optimize, cache_size, journal_size_limit } = runtime_config;
124
125        if optimize {
126            self.optimize().await?;
127        }
128
129        self.cache_size(cache_size).await?;
130        self.journal_size_limit(journal_size_limit).await?;
131
132        Ok(())
133    }
134
135    /// Optimize the database.
136    ///
137    /// The SQLite documentation recommends to run this regularly and after any
138    /// schema change. The easiest is to do it consistently when the store is
139    /// constructed, after eventual migrations.
140    ///
141    /// See [`PRAGMA optimize`] to learn more.
142    ///
143    /// [`PRAGMA cache_size`]: https://www.sqlite.org/pragma.html#pragma_optimize
144    async fn optimize(&self) -> Result<()> {
145        self.execute_batch("PRAGMA optimize = 0x10002;").await?;
146        Ok(())
147    }
148
149    /// Define the maximum size in **bytes** the SQLite cache can use.
150    ///
151    /// See [`PRAGMA cache_size`] to learn more.
152    ///
153    /// [`PRAGMA cache_size`]: https://www.sqlite.org/pragma.html#pragma_cache_size
154    async fn cache_size(&self, cache_size: u32) -> Result<()> {
155        // `N` in `PRAGMA cache_size = -N` is expressed in kibibytes.
156        // `cache_size` is expressed in bytes. Let's convert.
157        let n = cache_size / 1024;
158
159        self.execute_batch(format!("PRAGMA cache_size = -{n};")).await?;
160        Ok(())
161    }
162
163    /// Limit the size of the WAL file, in **bytes**.
164    ///
165    /// By default, while the DB connections of the databases are open, [the
166    /// size of the WAL file can keep increasing][size_wal_file] depending on
167    /// the size needed for the transactions. A critical case is `VACUUM`
168    /// which basically writes the content of the DB file to the WAL file
169    /// before writing it back to the DB file, so we end up taking twice the
170    /// size of the database.
171    ///
172    /// By setting this limit, the WAL file is truncated after its content is
173    /// written to the database, if it is bigger than the limit.
174    ///
175    /// See [`PRAGMA journal_size_limit`] to learn more. The value `limit`
176    /// corresponds to `N` in `PRAGMA journal_size_limit = N`.
177    ///
178    /// [size_wal_file]: https://www.sqlite.org/wal.html#avoiding_excessively_large_wal_files
179    /// [`PRAGMA journal_size_limit`]: https://www.sqlite.org/pragma.html#pragma_journal_size_limit
180    async fn journal_size_limit(&self, limit: u32) -> Result<()> {
181        self.execute_batch(format!("PRAGMA journal_size_limit = {limit};")).await?;
182        Ok(())
183    }
184
185    /// Defragment the database and free space on the filesystem.
186    ///
187    /// Only returns an error in tests, otherwise the error is only logged.
188    async fn vacuum(&self) -> Result<()> {
189        if let Err(error) = self.execute_batch("VACUUM").await {
190            // Since this is an optimisation step, do not propagate the error
191            // but log it.
192            #[cfg(not(any(test, debug_assertions)))]
193            tracing::warn!("Failed to vacuum database: {error}");
194
195            // We want to know if there is an error with this step during tests.
196            #[cfg(any(test, debug_assertions))]
197            return Err(error.into());
198        }
199
200        Ok(())
201    }
202}
203
204#[async_trait]
205impl SqliteAsyncConnExt for SqliteAsyncConn {
206    async fn execute<P>(
207        &self,
208        sql: impl AsRef<str> + Send + 'static,
209        params: P,
210    ) -> rusqlite::Result<usize>
211    where
212        P: Params + Send + 'static,
213    {
214        self.interact(move |conn| conn.execute(sql.as_ref(), params)).await.unwrap()
215    }
216
217    async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()> {
218        self.interact(move |conn| conn.execute_batch(sql.as_ref())).await.unwrap()
219    }
220
221    async fn prepare<T, F>(
222        &self,
223        sql: impl AsRef<str> + Send + 'static,
224        f: F,
225    ) -> rusqlite::Result<T>
226    where
227        T: Send + 'static,
228        F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static,
229    {
230        self.interact(move |conn| f(conn.prepare(sql.as_ref())?)).await.unwrap()
231    }
232
233    async fn query_row<T, P, F>(
234        &self,
235        sql: impl AsRef<str> + Send + 'static,
236        params: P,
237        f: F,
238    ) -> rusqlite::Result<T>
239    where
240        T: Send + 'static,
241        P: Params + Send + 'static,
242        F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static,
243    {
244        self.interact(move |conn| conn.query_row(sql.as_ref(), params, f)).await.unwrap()
245    }
246
247    async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
248    where
249        T: Send + 'static,
250        E: From<rusqlite::Error> + Send + 'static,
251        F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static,
252    {
253        self.interact(move |conn| {
254            let txn = conn.transaction()?;
255            let result = f(&txn)?;
256            txn.commit()?;
257            Ok(result)
258        })
259        .await
260        .unwrap()
261    }
262
263    /// Chunk a large query over some keys.
264    ///
265    /// Imagine there is a _dynamic_ query that runs potentially large number of
266    /// parameters, so much that the maximum number of parameters can be hit.
267    /// Then, this helper is for you. It will execute the query on chunks of
268    /// parameters.
269    async fn chunk_large_query_over<Query, Res>(
270        &self,
271        keys_to_chunk: Vec<Key>,
272        result_capacity: Option<usize>,
273        do_query: Query,
274    ) -> Result<Vec<Res>>
275    where
276        Res: Send + 'static,
277        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
278    {
279        self.with_transaction(move |txn| {
280            txn.chunk_large_query_over(keys_to_chunk, result_capacity, do_query)
281        })
282        .await
283    }
284}
285
286pub(crate) trait SqliteTransactionExt {
287    fn chunk_large_query_over<Key, Query, Res>(
288        &self,
289        keys_to_chunk: Vec<Key>,
290        result_capacity: Option<usize>,
291        do_query: Query,
292    ) -> Result<Vec<Res>>
293    where
294        Res: Send + 'static,
295        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
296}
297
298impl SqliteTransactionExt for Transaction<'_> {
299    fn chunk_large_query_over<Key, Query, Res>(
300        &self,
301        mut keys_to_chunk: Vec<Key>,
302        result_capacity: Option<usize>,
303        do_query: Query,
304    ) -> Result<Vec<Res>>
305    where
306        Res: Send + 'static,
307        Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
308    {
309        // Divide by 2 to allow space for more static parameters (not part of
310        // `keys_to_chunk`).
311        let maximum_chunk_size = self.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER)? / 2;
312        let maximum_chunk_size: usize = maximum_chunk_size
313            .try_into()
314            .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;
315
316        if keys_to_chunk.len() < maximum_chunk_size {
317            // Chunking isn't necessary.
318            let chunk = keys_to_chunk;
319
320            Ok(do_query(self, chunk)?)
321        } else {
322            // Chunking _is_ necessary.
323
324            // Define the accumulator.
325            let capacity = result_capacity.unwrap_or_default();
326            let mut all_results = Vec::with_capacity(capacity);
327
328            while !keys_to_chunk.is_empty() {
329                // Chunk and run the query.
330                let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size));
331                let chunk = keys_to_chunk;
332                keys_to_chunk = tail;
333
334                all_results.extend(do_query(self, chunk)?);
335            }
336
337            Ok(all_results)
338        }
339    }
340}
341
342/// Extension trait for a [`rusqlite::Connection`] that contains a key-value
343/// table named `kv`.
344///
345/// The table should be created like this:
346///
347/// ```sql
348/// CREATE TABLE "kv" (
349///     "key" TEXT PRIMARY KEY NOT NULL,
350///     "value" BLOB NOT NULL
351/// );
352/// ```
353pub(crate) trait SqliteKeyValueStoreConnExt {
354    /// Store the given value for the given key.
355    fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()>;
356
357    /// Store the given value for the given key by serializing it.
358    fn set_serialized_kv<T: Serialize + Send>(&self, key: &str, value: T) -> Result<()> {
359        let serialized_value = rmp_serde::to_vec_named(&value)?;
360        self.set_kv(key, &serialized_value)?;
361
362        Ok(())
363    }
364
365    /// Removes the current key and value if exists.
366    fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
367
368    /// Set the version of the database.
369    fn set_db_version(&self, version: u8) -> rusqlite::Result<()> {
370        self.set_kv("version", &[version])
371    }
372}
373
374impl SqliteKeyValueStoreConnExt for rusqlite::Connection {
375    fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()> {
376        self.execute(
377            "INSERT INTO kv VALUES (?1, ?2) ON CONFLICT (key) DO UPDATE SET value = ?2",
378            (key, value),
379        )?;
380        Ok(())
381    }
382
383    fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
384        self.execute("DELETE FROM kv WHERE key = ?1", (key,))?;
385        Ok(())
386    }
387}
388
389/// Extension trait for an [`SqliteAsyncConn`] that contains a key-value
390/// table named `kv`.
391///
392/// The table should be created like this:
393///
394/// ```sql
395/// CREATE TABLE "kv" (
396///     "key" TEXT PRIMARY KEY NOT NULL,
397///     "value" BLOB NOT NULL
398/// );
399/// ```
400#[async_trait]
401pub(crate) trait SqliteKeyValueStoreAsyncConnExt: SqliteAsyncConnExt {
402    /// Whether the `kv` table exists in this database.
403    async fn kv_table_exists(&self) -> rusqlite::Result<bool> {
404        self.query_row(
405            "SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'kv')",
406            (),
407            |row| row.get(0),
408        )
409        .await
410    }
411
412    /// Get the stored value for the given key.
413    async fn get_kv(&self, key: &str) -> rusqlite::Result<Option<Vec<u8>>> {
414        let key = key.to_owned();
415        self.query_row("SELECT value FROM kv WHERE key = ?", (key,), |row| row.get(0))
416            .await
417            .optional()
418    }
419
420    /// Get the stored serialized value for the given key.
421    async fn get_serialized_kv<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
422        let Some(bytes) = self.get_kv(key).await? else {
423            return Ok(None);
424        };
425
426        Ok(Some(rmp_serde::from_slice(&bytes)?))
427    }
428
429    /// Store the given value for the given key.
430    async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()>;
431
432    /// Store the given value for the given key by serializing it.
433    async fn set_serialized_kv<T: Serialize + Send + 'static>(
434        &self,
435        key: &str,
436        value: T,
437    ) -> Result<()>;
438
439    /// Clears the given value for the given key.
440    async fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
441
442    /// Get the version of the database.
443    async fn db_version(&self) -> Result<u8, OpenStoreError> {
444        let kv_exists = self.kv_table_exists().await.map_err(OpenStoreError::LoadVersion)?;
445
446        if kv_exists {
447            match self.get_kv("version").await.map_err(OpenStoreError::LoadVersion)?.as_deref() {
448                Some([v]) => Ok(*v),
449                Some(_) => Err(OpenStoreError::InvalidVersion),
450                None => Err(OpenStoreError::MissingVersion),
451            }
452        } else {
453            Ok(0)
454        }
455    }
456
457    /// Get the [`StoreCipher`] of the database or create it.
458    async fn get_or_create_store_cipher(
459        &self,
460        passphrase: &str,
461    ) -> Result<StoreCipher, OpenStoreError> {
462        let encrypted_cipher = self.get_kv("cipher").await.map_err(OpenStoreError::LoadCipher)?;
463
464        let cipher = if let Some(encrypted) = encrypted_cipher {
465            StoreCipher::import(passphrase, &encrypted)?
466        } else {
467            let cipher = StoreCipher::new()?;
468            #[cfg(not(test))]
469            let export = cipher.export(passphrase);
470            #[cfg(test)]
471            let export = cipher._insecure_export_fast_for_testing(passphrase);
472            self.set_kv("cipher", export?).await.map_err(OpenStoreError::SaveCipher)?;
473            cipher
474        };
475
476        Ok(cipher)
477    }
478}
479
480#[async_trait]
481impl SqliteKeyValueStoreAsyncConnExt for SqliteAsyncConn {
482    async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()> {
483        let key = key.to_owned();
484        self.interact(move |conn| conn.set_kv(&key, &value)).await.unwrap()?;
485
486        Ok(())
487    }
488
489    async fn set_serialized_kv<T: Serialize + Send + 'static>(
490        &self,
491        key: &str,
492        value: T,
493    ) -> Result<()> {
494        let key = key.to_owned();
495        self.interact(move |conn| conn.set_serialized_kv(&key, value)).await.unwrap()?;
496
497        Ok(())
498    }
499
500    async fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
501        let key = key.to_owned();
502        self.interact(move |conn| conn.clear_kv(&key)).await.unwrap()?;
503
504        Ok(())
505    }
506}
507
508/// Repeat `?` n times, where n is defined by `count`. `?` are comma-separated.
509pub(crate) fn repeat_vars(count: usize) -> impl fmt::Display {
510    assert_ne!(count, 0, "Can't generate zero repeated vars");
511
512    iter::repeat_n("?", count).format(",")
513}
514
515/// Convert the given `SystemTime` to a timestamp, as the number of seconds
516/// since Unix Epoch.
517///
518/// Returns an `i64` as it is the numeric type used by SQLite.
519pub(crate) fn time_to_timestamp(time: SystemTime) -> i64 {
520    time.duration_since(SystemTime::UNIX_EPOCH)
521        .ok()
522        .and_then(|d| d.as_secs().try_into().ok())
523        // It is unlikely to happen unless the time on the system is seriously wrong, but we always
524        // need a value.
525        .unwrap_or(0)
526}
527
528/// Trait for a store that can encrypt its values, based on the presence of a
529/// cipher or not.
530///
531/// A single method must be implemented: `get_cypher`, which returns an optional
532/// cipher.
533///
534/// All the other methods come for free, based on the implementation of
535/// `get_cypher`.
536pub(crate) trait EncryptableStore {
537    fn get_cypher(&self) -> Option<&StoreCipher>;
538
539    /// If the store is using encryption, this will hash the given key. This is
540    /// useful when we need to do queries against a given key, but we don't
541    /// need to store the key in plain text (i.e. it's not both a key and a
542    /// value).
543    fn encode_key(&self, table_name: &str, key: impl AsRef<[u8]>) -> Key {
544        let bytes = key.as_ref();
545        if let Some(store_cipher) = self.get_cypher() {
546            Key::Hashed(store_cipher.hash_key(table_name, bytes))
547        } else {
548            Key::Plain(bytes.to_owned())
549        }
550    }
551
552    fn encode_value(&self, value: Vec<u8>) -> Result<Vec<u8>> {
553        if let Some(key) = self.get_cypher() {
554            let encrypted = key.encrypt_value_data(value)?;
555            Ok(rmp_serde::to_vec_named(&encrypted)?)
556        } else {
557            Ok(value)
558        }
559    }
560
561    fn decode_value<'a>(&self, value: &'a [u8]) -> Result<Cow<'a, [u8]>> {
562        if let Some(key) = self.get_cypher() {
563            let encrypted = rmp_serde::from_slice(value)?;
564            let decrypted = key.decrypt_value_data(encrypted)?;
565            Ok(Cow::Owned(decrypted))
566        } else {
567            Ok(Cow::Borrowed(value))
568        }
569    }
570
571    fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
572        let serialized = rmp_serde::to_vec_named(value)?;
573        self.encode_value(serialized)
574    }
575
576    fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
577        let decoded = self.decode_value(value)?;
578        Ok(rmp_serde::from_slice(&decoded)?)
579    }
580
581    fn serialize_json(&self, value: &impl Serialize) -> Result<Vec<u8>> {
582        let serialized = serde_json::to_vec(value)?;
583        self.encode_value(serialized)
584    }
585
586    fn deserialize_json<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T> {
587        let decoded = self.decode_value(data)?;
588
589        let json_deserializer = &mut serde_json::Deserializer::from_slice(&decoded);
590
591        serde_path_to_error::deserialize(json_deserializer).map_err(|err| {
592            let raw_json: Option<Raw<serde_json::Value>> = serde_json::from_slice(&decoded).ok();
593
594            let target_type = std::any::type_name::<T>();
595            let serde_path = err.path().to_string();
596
597            error!(
598                sentry = true,
599                %err,
600                "Failed to deserialize {target_type} in a store: {serde_path}",
601            );
602
603            if let Some(raw) = raw_json {
604                if let Some(room_id) = raw.get_field::<OwnedRoomId>("room_id").ok().flatten() {
605                    warn!("Found a room id in the source data to deserialize: {room_id}");
606                }
607                if let Some(event_id) = raw.get_field::<OwnedEventId>("event_id").ok().flatten() {
608                    warn!("Found an event id in the source data to deserialize: {event_id}");
609                }
610            }
611
612            err.into_inner().into()
613        })
614    }
615}
616
617#[cfg(test)]
618mod unit_tests {
619    use std::time::Duration;
620
621    use super::*;
622
623    #[test]
624    fn can_generate_repeated_vars() {
625        assert_eq!(repeat_vars(1).to_string(), "?");
626        assert_eq!(repeat_vars(2).to_string(), "?,?");
627        assert_eq!(repeat_vars(5).to_string(), "?,?,?,?,?");
628    }
629
630    #[test]
631    #[should_panic(expected = "Can't generate zero repeated vars")]
632    fn generating_zero_vars_panics() {
633        repeat_vars(0);
634    }
635
636    #[test]
637    fn test_time_to_timestamp() {
638        assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH), 0);
639        assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH + Duration::from_secs(60)), 60);
640
641        // Fallback value on overflow.
642        assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH - Duration::from_secs(60)), 0);
643    }
644}