use std::sync::Arc;
use serde::{de::DeserializeOwned, ser::Serialize};
use tokio::sync::Mutex;
use crate::{
repository::{
RepositoryItem, RepositoryMigrationStep, RepositoryMigrations, validate_registry_name,
},
sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
};
#[derive(Clone)]
pub struct SqliteDatabase(Arc<Mutex<rusqlite::Connection>>);
fn validate_identifier(name: &'static str) -> Result<&'static str, DatabaseError> {
if validate_registry_name(name) {
Ok(name)
} else {
Err(DatabaseError::Internal(
rusqlite::Error::InvalidParameterName(name.to_string()).to_string(),
))
}
}
impl SqliteDatabase {
fn initialize_internal(
mut db: rusqlite::Connection,
migrations: RepositoryMigrations,
) -> Result<Self, DatabaseError> {
db.pragma_update(None, "journal_mode", "WAL")?;
let transaction = db.transaction()?;
for step in &migrations.steps {
match step {
RepositoryMigrationStep::Add(data) => {
transaction.execute(
&format!(
"CREATE TABLE IF NOT EXISTS \"{}\" (key TEXT PRIMARY KEY, value TEXT NOT NULL);",
validate_identifier(data.name())?,
),
[],
)?;
}
RepositoryMigrationStep::Remove(data) => {
transaction.execute(
&format!(
"DROP TABLE IF EXISTS \"{}\";",
validate_identifier(data.name())?,
),
[],
)?;
}
}
}
transaction.commit()?;
Ok(SqliteDatabase(Arc::new(Mutex::new(db))))
}
}
impl Database for SqliteDatabase {
async fn initialize(
configuration: DatabaseConfiguration,
migrations: RepositoryMigrations,
) -> Result<Self, DatabaseError> {
let DatabaseConfiguration::Sqlite {
db_name,
folder_path: mut path,
} = configuration
else {
return Err(DatabaseError::UnsupportedConfiguration(configuration));
};
path.push(format!("{db_name}.sqlite"));
let db = rusqlite::Connection::open(path)?;
Self::initialize_internal(db, migrations)
}
async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
key: &str,
) -> Result<Option<T>, DatabaseError> {
let conn = self.0.lock().await;
let mut stmt = conn.prepare(&format!(
"SELECT value FROM \"{}\" WHERE key = ?1",
validate_identifier(T::NAME)?
))?;
let mut rows = stmt.query([key])?;
if let Some(row) = rows.next()? {
let value = row.get::<_, String>(0)?;
Ok(Some(serde_json::from_str(&value)?))
} else {
Ok(None)
}
}
async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
) -> Result<Vec<T>, DatabaseError> {
let conn = self.0.lock().await;
let mut stmt = conn.prepare(&format!(
"SELECT key, value FROM \"{}\"",
validate_identifier(T::NAME)?
))?;
let rows = stmt.query_map([], |row| row.get(1))?;
let mut results = Vec::new();
for row in rows {
let value: String = row?;
let value: T = serde_json::from_str(&value)?;
results.push(value);
}
Ok(results)
}
async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
key: &str,
value: T,
) -> Result<(), DatabaseError> {
let mut conn = self.0.lock().await;
let transaction = conn.transaction()?;
let value = serde_json::to_string(&value)?;
transaction.execute(
&format!(
"INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
validate_identifier(T::NAME)?,
),
[key, &value],
)?;
transaction.commit()?;
Ok(())
}
async fn set_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
values: Vec<(String, T)>,
) -> Result<(), DatabaseError> {
let mut conn = self.0.lock().await;
let transaction = conn.transaction()?;
let sql = format!(
"INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
validate_identifier(T::NAME)?,
);
for (key, value) in values {
let value = serde_json::to_string(&value)?;
transaction.execute(&sql, [&key, &value])?;
}
transaction.commit()?;
Ok(())
}
async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
key: &str,
) -> Result<(), DatabaseError> {
let mut conn = self.0.lock().await;
let transaction = conn.transaction()?;
transaction.execute(
&format!(
"DELETE FROM \"{}\" WHERE key = ?1",
validate_identifier(T::NAME)?
),
[key],
)?;
transaction.commit()?;
Ok(())
}
async fn remove_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
keys: Vec<String>,
) -> Result<(), DatabaseError> {
let mut conn = self.0.lock().await;
let transaction = conn.transaction()?;
let sql = format!(
"DELETE FROM \"{}\" WHERE key = ?1",
validate_identifier(T::NAME)?
);
for key in keys {
transaction.execute(&sql, [&key])?;
}
transaction.commit()?;
Ok(())
}
async fn remove_all<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
) -> Result<(), DatabaseError> {
let mut conn = self.0.lock().await;
let transaction = conn.transaction()?;
transaction.execute(
&format!("DELETE FROM \"{}\"", validate_identifier(T::NAME)?),
[],
)?;
transaction.commit()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::register_repository_item;
#[tokio::test]
async fn test_sqlite_integration() {
let db = rusqlite::Connection::open_in_memory().unwrap();
#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct TestA(usize);
register_repository_item!(String => TestA, "TestItem_A");
#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct TestB(usize);
register_repository_item!(String => TestB, "TestItem_B");
let steps = vec![
RepositoryMigrationStep::Remove(TestB::data()),
RepositoryMigrationStep::Add(TestA::data()),
RepositoryMigrationStep::Add(TestB::data()),
RepositoryMigrationStep::Remove(TestB::data()),
];
let migrations = RepositoryMigrations::new(steps);
let db = SqliteDatabase::initialize_internal(db, migrations).unwrap();
assert_eq!(db.list::<TestA>().await.unwrap(), Vec::<TestA>::new());
db.set("key1", TestA(42)).await.unwrap();
assert_eq!(db.get::<TestA>("key1").await.unwrap(), Some(TestA(42)));
db.remove::<TestA>("key1").await.unwrap();
assert_eq!(db.get::<TestA>("key1").await.unwrap(), None);
}
#[tokio::test]
async fn test_sqlite_database_path_construction() {
let temp_dir = std::env::temp_dir().join("bitwarden_state_test");
std::fs::create_dir_all(&temp_dir).unwrap();
let config = DatabaseConfiguration::Sqlite {
db_name: "test_db".to_string(),
folder_path: temp_dir.clone(),
};
SqliteDatabase::initialize(config, RepositoryMigrations::new(vec![]))
.await
.unwrap();
assert!(temp_dir.join("test_db.sqlite").exists());
std::fs::remove_dir_all(&temp_dir).ok();
}
}