use core::fmt;
use std::{
borrow::{Borrow, Cow},
cmp::min,
iter,
ops::Deref,
};
use async_trait::async_trait;
use itertools::Itertools;
use matrix_sdk_store_encryption::StoreCipher;
use ruma::{serde::Raw, time::SystemTime, OwnedEventId, OwnedRoomId};
use rusqlite::{limits::Limit, OptionalExtension, Params, Row, Statement, Transaction};
use serde::{de::DeserializeOwned, Serialize};
use tracing::{error, trace, warn};
use zeroize::Zeroize;
use crate::{
connection::Connection as SqliteAsyncConn,
error::{Error, Result},
OpenStoreError, RuntimeConfig, Secret,
};
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) enum Key {
Plain(Vec<u8>),
Hashed([u8; 32]),
}
impl Deref for Key {
type Target = [u8];
fn deref(&self) -> &Self::Target {
match self {
Key::Plain(slice) => slice,
Key::Hashed(bytes) => bytes,
}
}
}
impl Borrow<[u8]> for Key {
fn borrow(&self) -> &[u8] {
self.deref()
}
}
impl rusqlite::ToSql for Key {
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
self.deref().to_sql()
}
}
#[async_trait]
pub(crate) trait SqliteAsyncConnExt {
async fn execute<P>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: P,
) -> rusqlite::Result<usize>
where
P: Params + Send + 'static;
async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()>;
async fn prepare<T, F>(
&self,
sql: impl AsRef<str> + Send + 'static,
f: F,
) -> rusqlite::Result<T>
where
T: Send + 'static,
F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static;
async fn query_row<T, P, F>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: P,
f: F,
) -> rusqlite::Result<T>
where
T: Send + 'static,
P: Params + Send + 'static,
F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static;
async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
where
T: Send + 'static,
E: From<rusqlite::Error> + Send + 'static,
F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static;
async fn chunk_large_query_over<Query, Res>(
&self,
mut keys_to_chunk: Vec<Key>,
result_capacity: Option<usize>,
do_query: Query,
) -> Result<Vec<Res>>
where
Res: Send + 'static,
Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
async fn apply_runtime_config(&self, runtime_config: RuntimeConfig) -> Result<()> {
let RuntimeConfig { optimize, cache_size, journal_size_limit } = runtime_config;
if optimize {
self.optimize().await?;
}
self.cache_size(cache_size).await?;
self.journal_size_limit(journal_size_limit).await?;
Ok(())
}
async fn optimize(&self) -> Result<()> {
self.execute_batch("PRAGMA optimize = 0x10002;").await?;
Ok(())
}
async fn cache_size(&self, cache_size: u32) -> Result<()> {
let n = cache_size / 1024;
self.execute_batch(format!("PRAGMA cache_size = -{n};")).await?;
Ok(())
}
async fn journal_size_limit(&self, limit: u32) -> Result<()> {
self.execute_batch(format!("PRAGMA journal_size_limit = {limit};")).await?;
Ok(())
}
async fn vacuum(&self) -> Result<()> {
if let Err(error) = self.execute_batch("VACUUM").await {
#[cfg(not(any(test, debug_assertions)))]
tracing::warn!("Failed to vacuum database: {error}");
#[cfg(any(test, debug_assertions))]
return Err(error.into());
} else {
trace!("VACUUM complete");
}
Ok(())
}
async fn get_db_size(&self) -> Result<usize> {
let page_size =
self.query_row("PRAGMA page_size;", (), |row| row.get::<_, usize>(0)).await?;
let total_pages =
self.query_row("PRAGMA page_count;", (), |row| row.get::<_, usize>(0)).await?;
Ok(total_pages * page_size)
}
}
#[async_trait]
impl SqliteAsyncConnExt for SqliteAsyncConn {
async fn execute<P>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: P,
) -> rusqlite::Result<usize>
where
P: Params + Send + 'static,
{
self.interact(move |conn| conn.execute(sql.as_ref(), params)).await.unwrap()
}
async fn execute_batch(&self, sql: impl AsRef<str> + Send + 'static) -> rusqlite::Result<()> {
self.interact(move |conn| conn.execute_batch(sql.as_ref())).await.unwrap()
}
async fn prepare<T, F>(
&self,
sql: impl AsRef<str> + Send + 'static,
f: F,
) -> rusqlite::Result<T>
where
T: Send + 'static,
F: FnOnce(Statement<'_>) -> rusqlite::Result<T> + Send + 'static,
{
self.interact(move |conn| f(conn.prepare(sql.as_ref())?)).await.unwrap()
}
async fn query_row<T, P, F>(
&self,
sql: impl AsRef<str> + Send + 'static,
params: P,
f: F,
) -> rusqlite::Result<T>
where
T: Send + 'static,
P: Params + Send + 'static,
F: FnOnce(&Row<'_>) -> rusqlite::Result<T> + Send + 'static,
{
self.interact(move |conn| conn.query_row(sql.as_ref(), params, f)).await.unwrap()
}
async fn with_transaction<T, E, F>(&self, f: F) -> Result<T, E>
where
T: Send + 'static,
E: From<rusqlite::Error> + Send + 'static,
F: FnOnce(&Transaction<'_>) -> Result<T, E> + Send + 'static,
{
self.interact(move |conn| {
let txn = conn.transaction()?;
let result = f(&txn)?;
txn.commit()?;
Ok(result)
})
.await
.unwrap()
}
async fn chunk_large_query_over<Query, Res>(
&self,
keys_to_chunk: Vec<Key>,
result_capacity: Option<usize>,
do_query: Query,
) -> Result<Vec<Res>>
where
Res: Send + 'static,
Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
{
self.with_transaction(move |txn| {
txn.chunk_large_query_over(keys_to_chunk, result_capacity, do_query)
})
.await
}
}
pub(crate) trait SqliteTransactionExt {
fn chunk_large_query_over<Key, Query, Res>(
&self,
keys_to_chunk: Vec<Key>,
result_capacity: Option<usize>,
do_query: Query,
) -> Result<Vec<Res>>
where
Res: Send + 'static,
Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static;
}
impl SqliteTransactionExt for Transaction<'_> {
fn chunk_large_query_over<Key, Query, Res>(
&self,
mut keys_to_chunk: Vec<Key>,
result_capacity: Option<usize>,
do_query: Query,
) -> Result<Vec<Res>>
where
Res: Send + 'static,
Query: Fn(&Transaction<'_>, Vec<Key>) -> Result<Vec<Res>> + Send + 'static,
{
let maximum_chunk_size = self.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER)? / 2;
let maximum_chunk_size: usize = maximum_chunk_size
.try_into()
.map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;
if keys_to_chunk.len() < maximum_chunk_size {
let chunk = keys_to_chunk;
Ok(do_query(self, chunk)?)
} else {
let capacity = result_capacity.unwrap_or_default();
let mut all_results = Vec::with_capacity(capacity);
while !keys_to_chunk.is_empty() {
let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size));
let chunk = keys_to_chunk;
keys_to_chunk = tail;
all_results.extend(do_query(self, chunk)?);
}
Ok(all_results)
}
}
}
pub(crate) trait SqliteKeyValueStoreConnExt {
fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()>;
fn set_serialized_kv<T: Serialize + Send>(&self, key: &str, value: T) -> Result<()> {
let serialized_value = rmp_serde::to_vec_named(&value)?;
self.set_kv(key, &serialized_value)?;
Ok(())
}
fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
fn set_db_version(&self, version: u8) -> rusqlite::Result<()> {
self.set_kv("version", &[version])
}
}
impl SqliteKeyValueStoreConnExt for rusqlite::Connection {
fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()> {
self.execute(
"INSERT INTO kv VALUES (?1, ?2) ON CONFLICT (key) DO UPDATE SET value = ?2",
(key, value),
)?;
Ok(())
}
fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
self.execute("DELETE FROM kv WHERE key = ?1", (key,))?;
Ok(())
}
}
#[async_trait]
pub(crate) trait SqliteKeyValueStoreAsyncConnExt: SqliteAsyncConnExt {
async fn kv_table_exists(&self) -> rusqlite::Result<bool> {
self.query_row(
"SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'kv')",
(),
|row| row.get(0),
)
.await
}
async fn get_kv(&self, key: &str) -> rusqlite::Result<Option<Vec<u8>>> {
let key = key.to_owned();
self.query_row("SELECT value FROM kv WHERE key = ?", (key,), |row| row.get(0))
.await
.optional()
}
async fn get_serialized_kv<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
let Some(bytes) = self.get_kv(key).await? else {
return Ok(None);
};
Ok(Some(rmp_serde::from_slice(&bytes)?))
}
async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()>;
async fn set_serialized_kv<T: Serialize + Send + 'static>(
&self,
key: &str,
value: T,
) -> Result<()>;
async fn clear_kv(&self, key: &str) -> rusqlite::Result<()>;
async fn db_version(&self) -> Result<u8, OpenStoreError> {
let kv_exists = self.kv_table_exists().await.map_err(OpenStoreError::LoadVersion)?;
if kv_exists {
match self.get_kv("version").await.map_err(OpenStoreError::LoadVersion)?.as_deref() {
Some([v]) => Ok(*v),
Some(_) => Err(OpenStoreError::InvalidVersion),
None => Err(OpenStoreError::MissingVersion),
}
} else {
Ok(0)
}
}
async fn get_or_create_store_cipher(
&self,
mut secret: Secret,
) -> Result<StoreCipher, OpenStoreError> {
let encrypted_cipher = self.get_kv("cipher").await.map_err(OpenStoreError::LoadCipher)?;
let cipher = if let Some(encrypted) = encrypted_cipher {
match secret {
Secret::PassPhrase(ref passphrase) => StoreCipher::import(passphrase, &encrypted)?,
Secret::Key(ref key) => StoreCipher::import_with_key(key, &encrypted)?,
}
} else {
let cipher = StoreCipher::new()?;
let export = match secret {
Secret::PassPhrase(ref passphrase) => {
#[cfg(not(test))]
{
cipher.export(passphrase)
}
#[cfg(test)]
{
cipher._insecure_export_fast_for_testing(passphrase)
}
}
Secret::Key(ref key) => cipher.export_with_key(key),
};
self.set_kv("cipher", export?).await.map_err(OpenStoreError::SaveCipher)?;
cipher
};
secret.zeroize();
Ok(cipher)
}
}
#[async_trait]
impl SqliteKeyValueStoreAsyncConnExt for SqliteAsyncConn {
async fn set_kv(&self, key: &str, value: Vec<u8>) -> rusqlite::Result<()> {
let key = key.to_owned();
self.interact(move |conn| conn.set_kv(&key, &value)).await.unwrap()?;
Ok(())
}
async fn set_serialized_kv<T: Serialize + Send + 'static>(
&self,
key: &str,
value: T,
) -> Result<()> {
let key = key.to_owned();
self.interact(move |conn| conn.set_serialized_kv(&key, value)).await.unwrap()?;
Ok(())
}
async fn clear_kv(&self, key: &str) -> rusqlite::Result<()> {
let key = key.to_owned();
self.interact(move |conn| conn.clear_kv(&key)).await.unwrap()?;
Ok(())
}
}
pub(crate) fn repeat_vars(count: usize) -> impl fmt::Display {
assert_ne!(count, 0, "Can't generate zero repeated vars");
iter::repeat_n("?", count).format(",")
}
pub(crate) fn time_to_timestamp(time: SystemTime) -> i64 {
time.duration_since(SystemTime::UNIX_EPOCH)
.ok()
.and_then(|d| d.as_secs().try_into().ok())
.unwrap_or(0)
}
pub(crate) trait EncryptableStore {
fn get_cypher(&self) -> Option<&StoreCipher>;
fn encode_key(&self, table_name: &str, key: impl AsRef<[u8]>) -> Key {
let bytes = key.as_ref();
if let Some(store_cipher) = self.get_cypher() {
Key::Hashed(store_cipher.hash_key(table_name, bytes))
} else {
Key::Plain(bytes.to_owned())
}
}
fn encode_value(&self, value: Vec<u8>) -> Result<Vec<u8>> {
if let Some(key) = self.get_cypher() {
let encrypted = key.encrypt_value_data(value)?;
Ok(rmp_serde::to_vec_named(&encrypted)?)
} else {
Ok(value)
}
}
fn decode_value<'a>(&self, value: &'a [u8]) -> Result<Cow<'a, [u8]>> {
if let Some(key) = self.get_cypher() {
let encrypted = rmp_serde::from_slice(value)?;
let decrypted = key.decrypt_value_data(encrypted)?;
Ok(Cow::Owned(decrypted))
} else {
Ok(Cow::Borrowed(value))
}
}
fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
let serialized = rmp_serde::to_vec_named(value)?;
self.encode_value(serialized)
}
fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
let decoded = self.decode_value(value)?;
Ok(rmp_serde::from_slice(&decoded)?)
}
fn serialize_json(&self, value: &impl Serialize) -> Result<Vec<u8>> {
let serialized = serde_json::to_vec(value)?;
self.encode_value(serialized)
}
fn deserialize_json<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T> {
let decoded = self.decode_value(data)?;
let json_deserializer = &mut serde_json::Deserializer::from_slice(&decoded);
serde_path_to_error::deserialize(json_deserializer).map_err(|err| {
let raw_json: Option<Raw<serde_json::Value>> = serde_json::from_slice(&decoded).ok();
let target_type = std::any::type_name::<T>();
let serde_path = err.path().to_string();
error!(
sentry = true,
%err,
"Failed to deserialize {target_type} in a store: {serde_path}",
);
if let Some(raw) = raw_json {
if let Some(room_id) = raw.get_field::<OwnedRoomId>("room_id").ok().flatten() {
warn!("Found a room id in the source data to deserialize: {room_id}");
}
if let Some(event_id) = raw.get_field::<OwnedEventId>("event_id").ok().flatten() {
warn!("Found an event id in the source data to deserialize: {event_id}");
}
}
err.into_inner().into()
})
}
}
#[cfg(test)]
mod unit_tests {
use std::time::Duration;
use super::*;
#[test]
fn can_generate_repeated_vars() {
assert_eq!(repeat_vars(1).to_string(), "?");
assert_eq!(repeat_vars(2).to_string(), "?,?");
assert_eq!(repeat_vars(5).to_string(), "?,?,?,?,?");
}
#[test]
#[should_panic(expected = "Can't generate zero repeated vars")]
fn generating_zero_vars_panics() {
repeat_vars(0);
}
#[test]
fn test_time_to_timestamp() {
assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH), 0);
assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH + Duration::from_secs(60)), 60);
assert_eq!(time_to_timestamp(SystemTime::UNIX_EPOCH - Duration::from_secs(60)), 0);
}
}