use std::{
collections::{HashMap, HashSet},
sync::{
Mutex, OnceLock,
atomic::{AtomicU64, Ordering},
},
};
use log::debug;
use serde::{Deserialize, Serialize};
use surfpool_db::diesel::{
self, QueryableByName, RunQueryDsl,
connection::SimpleConnection,
r2d2::{ConnectionManager, Pool},
sql_query,
sql_types::Text,
};
use crate::storage::{Storage, StorageConstructor, StorageError, StorageResult};
#[derive(Debug)]
struct SqlitePragmaCustomizer {
is_file_based: bool,
}
impl diesel::r2d2::CustomizeConnection<diesel::SqliteConnection, diesel::r2d2::Error>
for SqlitePragmaCustomizer
{
fn on_acquire(&self, conn: &mut diesel::SqliteConnection) -> Result<(), diesel::r2d2::Error> {
let pragmas = if self.is_file_based {
"PRAGMA synchronous=NORMAL; PRAGMA temp_store=MEMORY; PRAGMA mmap_size=268435456; PRAGMA cache_size=-64000; PRAGMA busy_timeout=5000;"
} else {
"PRAGMA synchronous=OFF; PRAGMA temp_store=MEMORY; PRAGMA cache_size=-64000; PRAGMA busy_timeout=5000;"
};
conn.batch_execute(pragmas)
.map_err(diesel::r2d2::Error::QueryError)
}
}
fn checkpointed_databases() -> &'static Mutex<HashSet<String>> {
static CHECKPOINTED: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
CHECKPOINTED.get_or_init(|| Mutex::new(HashSet::new()))
}
static SHARED_POOLS: OnceLock<
Mutex<HashMap<String, Pool<ConnectionManager<diesel::SqliteConnection>>>>,
> = OnceLock::new();
static MEMORY_DB_COUNTER: AtomicU64 = AtomicU64::new(0);
fn get_or_create_shared_pool(
connection_string: &str,
is_file_based: bool,
) -> StorageResult<Pool<ConnectionManager<diesel::SqliteConnection>>> {
if !is_file_based {
let manager = ConnectionManager::<diesel::SqliteConnection>::new(connection_string);
return Pool::builder()
.max_size(10)
.connection_customizer(Box::new(SqlitePragmaCustomizer { is_file_based }))
.build(manager)
.map_err(|e| StorageError::PooledConnectionError(NAME.into(), e));
}
let pools = SHARED_POOLS.get_or_init(|| Mutex::new(HashMap::new()));
let mut guard = pools.lock().map_err(|_| StorageError::LockError)?;
if let Some(pool) = guard.get(connection_string) {
debug!("Reusing shared SQLite pool for {}", connection_string);
return Ok(pool.clone());
}
debug!("Creating shared SQLite pool for {}", connection_string);
let manager = ConnectionManager::<diesel::SqliteConnection>::new(connection_string);
let pool = Pool::builder()
.max_size(10)
.connection_customizer(Box::new(SqlitePragmaCustomizer { is_file_based }))
.build(manager)
.map_err(|e| StorageError::PooledConnectionError(NAME.into(), e))?;
{
let mut conn = pool.get().map_err(|_| StorageError::LockError)?;
conn.batch_execute("PRAGMA journal_mode=WAL; PRAGMA wal_autocheckpoint=1000;")
.map_err(|e| StorageError::create_table("pragma_init", NAME, e))?;
}
guard.insert(connection_string.to_string(), pool.clone());
Ok(pool)
}
#[derive(QueryableByName, Debug)]
struct KvRecord {
#[diesel(sql_type = Text)]
key: String,
#[diesel(sql_type = Text)]
value: String,
}
#[derive(QueryableByName, Debug)]
struct ValueRecord {
#[diesel(sql_type = Text)]
value: String,
}
#[derive(QueryableByName, Debug)]
struct KeyRecord {
#[diesel(sql_type = Text)]
key: String,
}
#[derive(QueryableByName, Debug)]
struct CountRecord {
#[diesel(sql_type = diesel::sql_types::BigInt)]
count: i64,
}
#[derive(Clone)]
pub struct SqliteStorage<K, V> {
pool: Pool<ConnectionManager<diesel::SqliteConnection>>,
_phantom: std::marker::PhantomData<(K, V)>,
table_name: String,
surfnet_id: String,
is_file_based: bool,
connection_string: String,
}
const NAME: &str = "SQLite";
impl<K, V> SqliteStorage<K, V> {
fn checkpoint(&self) {
if !self.is_file_based {
return;
}
let db_path = self
.connection_string
.strip_prefix("file:")
.and_then(|s| s.split('?').next())
.unwrap_or(&self.connection_string)
.to_string();
{
let mut checkpointed = checkpointed_databases().lock().unwrap();
if checkpointed.contains(&db_path) {
debug!(
"Database {} already checkpointed, skipping for table '{}'",
db_path, self.table_name
);
return;
}
checkpointed.insert(db_path.clone());
}
debug!(
"Checkpointing WAL for database '{}' (table '{}')",
db_path, self.table_name
);
if let Ok(mut conn) = self.pool.get() {
if let Err(e) = conn.batch_execute("PRAGMA wal_checkpoint(TRUNCATE);") {
debug!("WAL checkpoint failed: {}", e);
return;
}
}
let wal_path = format!("{}-wal", db_path);
let shm_path = format!("{}-shm", db_path);
if std::path::Path::new(&wal_path).exists() {
if let Err(e) = std::fs::remove_file(&wal_path) {
debug!("Failed to remove WAL file {}: {}", wal_path, e);
} else {
debug!("Removed WAL file: {}", wal_path);
}
}
if std::path::Path::new(&shm_path).exists() {
if let Err(e) = std::fs::remove_file(&shm_path) {
debug!("Failed to remove SHM file {}: {}", shm_path, e);
} else {
debug!("Removed SHM file: {}", shm_path);
}
}
}
}
impl<K, V> SqliteStorage<K, V>
where
K: Serialize + for<'de> Deserialize<'de>,
V: Serialize + for<'de> Deserialize<'de> + Clone,
{
fn ensure_table_exists(&self) -> StorageResult<()> {
debug!("Ensuring table '{}' exists", self.table_name);
let create_table_sql = format!(
"
CREATE TABLE IF NOT EXISTS {} (
surfnet_id TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (surfnet_id, key)
)
",
self.table_name
);
debug!("Getting connection from pool for table creation");
let mut conn = self.pool.get().map_err(|_| StorageError::LockError)?;
conn.batch_execute(&create_table_sql)
.map_err(|e| StorageError::create_table(&self.table_name, NAME, e))?;
debug!("Successfully ensured table '{}' exists", self.table_name);
Ok(())
}
fn serialize_key(&self, key: &K) -> StorageResult<String> {
trace!("Serializing key for table '{}'", self.table_name);
let result =
serde_json::to_string(key).map_err(|e| StorageError::SerializeKeyError(NAME.into(), e));
if let Ok(ref serialized) = result {
trace!("Key serialized successfully: {}", serialized);
}
result
}
fn serialize_value(&self, value: &V) -> StorageResult<String> {
trace!("Serializing value for table '{}'", self.table_name);
let result = serde_json::to_string(value)
.map_err(|e| StorageError::SerializeValueError(NAME.into(), e));
if let Ok(ref serialized) = result {
trace!(
"Value serialized successfully, length: {} chars",
serialized.len()
);
}
result
}
fn deserialize_value(&self, value_str: &str) -> StorageResult<V> {
trace!(
"Deserializing value from table '{}', input length: {} chars",
self.table_name,
value_str.len()
);
let result = serde_json::from_str(value_str)
.map_err(|e| StorageError::DeserializeValueError(NAME.into(), e));
if result.is_ok() {
trace!("Value deserialized successfully");
}
result
}
fn load_value_from_db(&self, key_str: &str) -> StorageResult<Option<V>> {
debug!("Loading value from DB for key: {}", key_str);
let query = sql_query(format!(
"SELECT value FROM {} WHERE surfnet_id = ? AND key = ?",
self.table_name
))
.bind::<Text, _>(&self.surfnet_id)
.bind::<Text, _>(key_str);
trace!("Getting connection from pool for loading value");
let mut conn = self.pool.get().map_err(|_| StorageError::LockError)?;
let records = query
.load::<ValueRecord>(&mut *conn)
.map_err(|e| StorageError::get(&self.table_name, NAME, key_str, e))?;
if let Some(record) = records.into_iter().next() {
debug!("Found record for key: {}", key_str);
let value = self.deserialize_value(&record.value)?;
Ok(Some(value))
} else {
debug!("No record found for key: {}", key_str);
Ok(None)
}
}
}
impl<K, V> Storage<K, V> for SqliteStorage<K, V>
where
K: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
V: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
{
fn store(&mut self, key: K, value: V) -> StorageResult<()> {
debug!("Storing value in table '{}", self.table_name);
let key_str = self.serialize_key(&key)?;
let value_str = self.serialize_value(&value)?;
let query = sql_query(format!(
"INSERT OR REPLACE INTO {} (surfnet_id, key, value, updated_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP)",
self.table_name
))
.bind::<Text, _>(&self.surfnet_id)
.bind::<Text, _>(&key_str)
.bind::<Text, _>(&value_str);
trace!("Getting connection from pool for store operation");
let mut conn = self.pool.get().map_err(|_| StorageError::LockError)?;
query
.execute(&mut *conn)
.map_err(|e| StorageError::store(&self.table_name, NAME, &key_str, e))?;
debug!("Value stored successfully in table '{}'", self.table_name);
Ok(())
}
fn get(&self, key: &K) -> StorageResult<Option<V>> {
debug!("Getting value from table '{}", self.table_name);
let key_str = self.serialize_key(key)?;
self.load_value_from_db(&key_str)
}
fn take(&mut self, key: &K) -> StorageResult<Option<V>> {
debug!("Taking value from table '{}'", self.table_name);
let key_str = self.serialize_key(key)?;
if let Some(value) = self.load_value_from_db(&key_str)? {
debug!("Value found, removing from database");
let delete_query = sql_query(format!(
"DELETE FROM {} WHERE surfnet_id = ? AND key = ?",
self.table_name
))
.bind::<Text, _>(&self.surfnet_id)
.bind::<Text, _>(&key_str);
trace!("Getting connection from pool for delete operation");
let mut conn = self.pool.get().map_err(|_| StorageError::LockError)?;
delete_query
.execute(&mut *conn)
.map_err(|e| StorageError::delete(&self.table_name, NAME, &key_str, e))?;
debug!(
"Value taken and removed successfully from table '{}'",
self.table_name
);
Ok(Some(value))
} else {
debug!("No value found to take from table '{}'", self.table_name);
Ok(None)
}
}
fn clear(&mut self) -> StorageResult<()> {
debug!("Clearing all data from table '{}'", self.table_name);
let delete_query = sql_query(format!(
"DELETE FROM {} WHERE surfnet_id = ?",
self.table_name
))
.bind::<Text, _>(&self.surfnet_id);
trace!("Getting connection from pool for clear operation");
let mut conn = self.pool.get().map_err(|_| StorageError::LockError)?;
delete_query
.execute(&mut *conn)
.map_err(|e| StorageError::delete(&self.table_name, NAME, "*all*", e))?;
debug!("Table '{}' cleared successfully", self.table_name);
Ok(())
}
fn keys(&self) -> StorageResult<Vec<K>> {
debug!("Fetching all keys from table '{}'", self.table_name);
let query = sql_query(format!(
"SELECT key FROM {} WHERE surfnet_id = ?",
self.table_name
))
.bind::<Text, _>(&self.surfnet_id);
trace!("Getting connection from pool for keys operation");
let mut conn = self.pool.get().map_err(|_| StorageError::LockError)?;
let records = query
.load::<KeyRecord>(&mut *conn)
.map_err(|e| StorageError::get_all_keys(&self.table_name, NAME, e))?;
let mut keys = Vec::new();
for record in records {
let key: K = serde_json::from_str(&record.key)
.map_err(|e| StorageError::DeserializeValueError(NAME.into(), e))?;
keys.push(key);
}
debug!(
"Retrieved {} keys from table '{}'",
keys.len(),
self.table_name
);
Ok(keys)
}
fn clone_box(&self) -> Box<dyn Storage<K, V>> {
Box::new(self.clone())
}
fn shutdown(&self) {
self.checkpoint();
}
fn count(&self) -> StorageResult<u64> {
debug!("Counting entries in table '{}'", self.table_name);
let query = sql_query(format!(
"SELECT COUNT(*) as count FROM {} WHERE surfnet_id = ?",
self.table_name
))
.bind::<Text, _>(&self.surfnet_id);
trace!("Getting connection from pool for count operation");
let mut conn = self.pool.get().map_err(|_| StorageError::LockError)?;
let records = query
.load::<CountRecord>(&mut *conn)
.map_err(|e| StorageError::count(&self.table_name, NAME, e))?;
let count = records.first().map(|r| r.count as u64).unwrap_or(0);
debug!("Table '{}' has {} entries", self.table_name, count);
Ok(count)
}
fn into_iter(&self) -> StorageResult<Box<dyn Iterator<Item = (K, V)> + '_>> {
debug!(
"Creating iterator for all key-value pairs in table '{}'",
self.table_name
);
let query = sql_query(format!(
"SELECT key, value FROM {} WHERE surfnet_id = ?",
self.table_name
))
.bind::<Text, _>(&self.surfnet_id);
trace!("Getting connection from pool for into_iter operation");
let mut conn = self.pool.get().map_err(|_| StorageError::LockError)?;
let records = query
.load::<KvRecord>(&mut *conn)
.map_err(|e| StorageError::get_all_key_value_pairs(&self.table_name, NAME, e))?;
let iter = records.into_iter().filter_map(move |record| {
let key: K = match serde_json::from_str(&record.key) {
Ok(k) => k,
Err(e) => {
debug!("Failed to deserialize key: {}", e);
return None;
}
};
let value: V = match serde_json::from_str(&record.value) {
Ok(v) => v,
Err(e) => {
debug!("Failed to deserialize value: {}", e);
return None;
}
};
Some((key, value))
});
debug!(
"Iterator created successfully for table '{}'",
self.table_name
);
Ok(Box::new(iter))
}
}
impl<K, V> StorageConstructor<K, V> for SqliteStorage<K, V>
where
K: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
V: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
{
fn connect(database_url: &str, table_name: &str, surfnet_id: &str) -> StorageResult<Self> {
debug!(
"Connecting to SQLite database: {} with table: {} and surfnet_id: {}",
database_url, table_name, surfnet_id
);
let connection_string = if database_url == ":memory:" {
let id = MEMORY_DB_COUNTER.fetch_add(1, Ordering::Relaxed);
format!("file:memdb{}?mode=memory&cache=shared", id)
} else if database_url.starts_with("file:") {
if database_url.contains('?') {
format!("{}&mode=rwc", database_url)
} else {
format!("{}?mode=rwc", database_url)
}
} else {
format!("file:{}?mode=rwc", database_url)
};
let is_file_based = database_url != ":memory:";
let pool = get_or_create_shared_pool(&connection_string, is_file_based)?;
let storage = SqliteStorage {
pool,
_phantom: std::marker::PhantomData,
table_name: table_name.to_string(),
surfnet_id: surfnet_id.to_string(),
is_file_based,
connection_string,
};
storage.ensure_table_exists()?;
debug!(
"SQLite storage connected successfully for table: {}",
table_name
);
Ok(storage)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(diesel::QueryableByName, Debug)]
struct PragmaInt {
#[diesel(sql_type = diesel::sql_types::BigInt)]
value: i64,
}
#[test]
fn test_shared_pool_across_multiple_storages() {
let temp_file = tempfile::NamedTempFile::new().unwrap();
let db_path = temp_file.path().to_str().unwrap();
let storage1: SqliteStorage<String, String> =
SqliteStorage::connect(db_path, "table1", "surfnet1").unwrap();
let storage2: SqliteStorage<String, String> =
SqliteStorage::connect(db_path, "table2", "surfnet1").unwrap();
let storage3: SqliteStorage<String, String> =
SqliteStorage::connect(db_path, "table3", "surfnet1").unwrap();
let _conn1 = storage1.pool.get().unwrap();
let state1 = storage1.pool.state();
let state2 = storage2.pool.state();
let state3 = storage3.pool.state();
assert_eq!(
state1.connections, state2.connections,
"pools should be shared"
);
assert_eq!(
state2.connections, state3.connections,
"pools should be shared"
);
assert_eq!(
state1.idle_connections, state2.idle_connections,
"pools should be shared"
);
}
#[test]
fn test_pragmas_applied_to_pooled_connections() {
let temp_file = tempfile::NamedTempFile::new().unwrap();
let db_path = temp_file.path().to_str().unwrap();
let storage: SqliteStorage<String, String> =
SqliteStorage::connect(db_path, "pragma_test", "surfnet1").unwrap();
let _conn1 = storage.pool.get().unwrap();
let mut conn2 = storage.pool.get().unwrap();
let cache_result: Vec<PragmaInt> =
diesel::sql_query("SELECT cache_size as value FROM pragma_cache_size")
.load(&mut *conn2)
.unwrap();
assert_eq!(cache_result[0].value, -64000, "cache_size should be -64000");
let timeout_result: Vec<PragmaInt> =
diesel::sql_query("SELECT timeout as value FROM pragma_busy_timeout")
.load(&mut *conn2)
.unwrap();
assert_eq!(timeout_result[0].value, 5000, "busy_timeout should be 5000");
}
#[test]
fn test_pragmas_applied_to_file_database() {
let temp_file = tempfile::NamedTempFile::new().unwrap();
let db_path = temp_file.path().to_str().unwrap();
let storage: SqliteStorage<String, String> =
SqliteStorage::connect(db_path, "pragma_test", "surfnet1").unwrap();
let _conn1 = storage.pool.get().unwrap();
let mut conn2 = storage.pool.get().unwrap();
let result: Vec<PragmaInt> =
diesel::sql_query("SELECT synchronous as value FROM pragma_synchronous")
.load(&mut *conn2)
.unwrap();
assert_eq!(result[0].value, 1, "synchronous should be NORMAL (1)");
let temp_result: Vec<PragmaInt> =
diesel::sql_query("SELECT temp_store as value FROM pragma_temp_store")
.load(&mut *conn2)
.unwrap();
assert_eq!(temp_result[0].value, 2, "temp_store should be MEMORY (2)");
}
#[test]
fn test_in_memory_databases_are_isolated() {
let mut storage1: SqliteStorage<String, String> =
SqliteStorage::connect(":memory:", "test_table", "surfnet1").unwrap();
let storage2: SqliteStorage<String, String> =
SqliteStorage::connect(":memory:", "test_table", "surfnet1").unwrap();
storage1
.store("key1".to_string(), "value1".to_string())
.unwrap();
let result = storage2.get(&"key1".to_string()).unwrap();
assert!(result.is_none(), "in-memory databases should be isolated");
}
}