use async_trait::async_trait;
use std::{
collections::BTreeMap,
sync::{Arc, Mutex},
};
use crate::{
Database, DatabaseError, Row,
query::{
DeleteStatement, InsertStatement, SelectQuery, UpdateStatement, UpsertMultiStatement,
UpsertStatement,
},
rusqlite::RusqliteDatabase,
};
static DATABASE_REGISTRY: std::sync::LazyLock<Mutex<BTreeMap<String, Arc<RusqliteDatabase>>>> =
std::sync::LazyLock::new(|| Mutex::new(BTreeMap::new()));
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct SimulationDatabase {
inner: Arc<RusqliteDatabase>,
}
impl SimulationDatabase {
pub fn new_for_path(path: Option<&str>) -> Result<Self, DatabaseError> {
let registry = &DATABASE_REGISTRY;
let mut registry_guard = registry.lock().unwrap();
if let Some(path) = path
&& let Some(existing_db) = registry_guard.get(path)
{
return Ok(Self {
inner: Arc::clone(existing_db),
});
}
let db = Self::create_new_database()?;
if let Some(path) = path {
registry_guard.insert(path.to_string(), Arc::clone(&db.inner));
}
drop(registry_guard);
Ok(db)
}
pub fn new() -> Result<Self, DatabaseError> {
Self::create_new_database()
}
fn create_new_database() -> Result<Self, DatabaseError> {
use std::sync::atomic::AtomicU64;
static ID: AtomicU64 = AtomicU64::new(0);
let id = ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let db_url = format!("file:sqlx_memdb_{id}:?mode=memory&cache=shared&uri=true");
let mut connections = Vec::new();
for _ in 0..5 {
let conn = ::rusqlite::Connection::open(&db_url)
.map_err(|e| DatabaseError::Rusqlite(e.into()))?;
conn.busy_timeout(std::time::Duration::from_millis(10))
.map_err(|e| DatabaseError::Rusqlite(e.into()))?;
connections.push(Arc::new(switchy_async::sync::Mutex::new(conn)));
}
Ok(Self {
inner: Arc::new(RusqliteDatabase::new(connections)),
})
}
}
#[async_trait]
impl Database for SimulationDatabase {
async fn query(&self, query: &SelectQuery<'_>) -> Result<Vec<Row>, DatabaseError> {
self.inner.query(query).await
}
async fn query_first(&self, query: &SelectQuery<'_>) -> Result<Option<Row>, DatabaseError> {
self.inner.query_first(query).await
}
async fn exec_update(
&self,
statement: &UpdateStatement<'_>,
) -> Result<Vec<Row>, DatabaseError> {
self.inner.exec_update(statement).await
}
async fn exec_update_first(
&self,
statement: &UpdateStatement<'_>,
) -> Result<Option<Row>, DatabaseError> {
self.inner.exec_update_first(statement).await
}
async fn exec_insert(&self, statement: &InsertStatement<'_>) -> Result<Row, DatabaseError> {
self.inner.exec_insert(statement).await
}
async fn exec_upsert(
&self,
statement: &UpsertStatement<'_>,
) -> Result<Vec<Row>, DatabaseError> {
self.inner.exec_upsert(statement).await
}
async fn exec_upsert_first(
&self,
statement: &UpsertStatement<'_>,
) -> Result<Row, DatabaseError> {
self.inner.exec_upsert_first(statement).await
}
async fn exec_upsert_multi(
&self,
statement: &UpsertMultiStatement<'_>,
) -> Result<Vec<Row>, DatabaseError> {
self.inner.exec_upsert_multi(statement).await
}
async fn exec_delete(
&self,
statement: &DeleteStatement<'_>,
) -> Result<Vec<Row>, DatabaseError> {
self.inner.exec_delete(statement).await
}
async fn exec_delete_first(
&self,
statement: &DeleteStatement<'_>,
) -> Result<Option<Row>, DatabaseError> {
self.inner.exec_delete_first(statement).await
}
async fn exec_raw(&self, statement: &str) -> Result<(), DatabaseError> {
self.inner.exec_raw(statement).await
}
#[cfg(feature = "schema")]
async fn exec_create_table(
&self,
statement: &crate::schema::CreateTableStatement<'_>,
) -> Result<(), DatabaseError> {
self.inner.exec_create_table(statement).await
}
#[cfg(feature = "schema")]
async fn exec_drop_table(
&self,
statement: &crate::schema::DropTableStatement<'_>,
) -> Result<(), DatabaseError> {
self.inner.exec_drop_table(statement).await
}
#[cfg(feature = "schema")]
async fn exec_create_index(
&self,
statement: &crate::schema::CreateIndexStatement<'_>,
) -> Result<(), DatabaseError> {
self.inner.exec_create_index(statement).await
}
#[cfg(feature = "schema")]
async fn exec_drop_index(
&self,
statement: &crate::schema::DropIndexStatement<'_>,
) -> Result<(), DatabaseError> {
self.inner.exec_drop_index(statement).await
}
#[cfg(feature = "schema")]
async fn exec_alter_table(
&self,
statement: &crate::schema::AlterTableStatement<'_>,
) -> Result<(), DatabaseError> {
self.inner.exec_alter_table(statement).await
}
#[cfg(feature = "schema")]
async fn table_exists(&self, table_name: &str) -> Result<bool, DatabaseError> {
self.inner.table_exists(table_name).await
}
#[cfg(feature = "schema")]
async fn list_tables(&self) -> Result<Vec<String>, DatabaseError> {
self.inner.list_tables().await
}
#[cfg(feature = "schema")]
async fn get_table_info(
&self,
table_name: &str,
) -> Result<Option<crate::schema::TableInfo>, DatabaseError> {
self.inner.get_table_info(table_name).await
}
#[cfg(feature = "schema")]
async fn get_table_columns(
&self,
table_name: &str,
) -> Result<Vec<crate::schema::ColumnInfo>, DatabaseError> {
self.inner.get_table_columns(table_name).await
}
#[cfg(feature = "schema")]
async fn column_exists(
&self,
table_name: &str,
column_name: &str,
) -> Result<bool, DatabaseError> {
self.inner.column_exists(table_name, column_name).await
}
async fn query_raw(&self, query: &str) -> Result<Vec<crate::Row>, DatabaseError> {
self.inner.query_raw(query).await
}
async fn begin_transaction(
&self,
) -> Result<Box<dyn crate::DatabaseTransaction>, DatabaseError> {
self.inner.begin_transaction().await
}
async fn exec_raw_params(
&self,
query: &str,
params: &[crate::DatabaseValue],
) -> Result<u64, DatabaseError> {
self.inner.exec_raw_params(query, params).await
}
async fn query_raw_params(
&self,
query: &str,
params: &[crate::DatabaseValue],
) -> Result<Vec<crate::Row>, DatabaseError> {
self.inner.query_raw_params(query, params).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Database, query::FilterableQuery};
#[switchy_async::test]
async fn test_path_based_database_isolation() {
let db1 = SimulationDatabase::new_for_path(Some("path1.db")).unwrap();
let db2 = SimulationDatabase::new_for_path(Some("path2.db")).unwrap();
db1.exec_raw("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)")
.await
.unwrap();
db2.exec_raw("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)")
.await
.unwrap();
db1.insert("test")
.value("value", "db1_data")
.execute(&db1)
.await
.unwrap();
db2.insert("test")
.value("value", "db2_data")
.execute(&db2)
.await
.unwrap();
let rows1 = db1.select("test").execute(&db1).await.unwrap();
let rows2 = db2.select("test").execute(&db2).await.unwrap();
assert_eq!(rows1.len(), 1);
assert_eq!(rows2.len(), 1);
assert_eq!(rows1[0].columns[1].1, "db1_data".into());
assert_eq!(rows2[0].columns[1].1, "db2_data".into());
}
#[switchy_async::test]
async fn test_same_path_returns_same_database() {
let db1 = SimulationDatabase::new_for_path(Some("same_path.db")).unwrap();
let db2 = SimulationDatabase::new_for_path(Some("same_path.db")).unwrap();
db1.exec_raw("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)")
.await
.unwrap();
db1.insert("test")
.value("value", "shared_data")
.execute(&db1)
.await
.unwrap();
let rows = db2.select("test").execute(&db2).await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].columns[1].1, "shared_data".into());
}
#[switchy_async::test]
async fn test_simulator_transaction_delegation() {
let db = SimulationDatabase::new().unwrap();
db.exec_raw("CREATE TABLE test_users (id INTEGER PRIMARY KEY, name TEXT NOT NULL)")
.await
.unwrap();
let transaction = db.begin_transaction().await.unwrap();
transaction
.insert("test_users")
.value("name", "TestUser")
.execute(&*transaction)
.await
.unwrap();
let rows = transaction
.select("test_users")
.where_eq("name", "TestUser")
.execute(&*transaction)
.await
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(
rows,
vec![Row {
columns: vec![
("id".into(), i64::from(1).into()),
("name".into(), "TestUser".into())
]
}]
);
transaction.commit().await.unwrap();
let rows_after_commit = db
.select("test_users")
.where_eq("name", "TestUser")
.execute(&db)
.await
.unwrap();
assert_eq!(rows_after_commit.len(), 1);
assert_eq!(
rows_after_commit,
vec![Row {
columns: vec![
("id".into(), i64::from(1).into()),
("name".into(), "TestUser".into())
]
}]
);
}
#[switchy_async::test]
async fn test_simulator_transaction_rollback() {
let db = SimulationDatabase::new().unwrap();
db.exec_raw("CREATE TABLE test_rollback (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")
.await
.unwrap();
db.insert("test_rollback")
.value("value", "initial")
.execute(&db)
.await
.unwrap();
let transaction = db.begin_transaction().await.unwrap();
transaction
.insert("test_rollback")
.value("value", "transactional")
.execute(&*transaction)
.await
.unwrap();
let rows_in_tx = transaction
.select("test_rollback")
.execute(&*transaction)
.await
.unwrap();
assert_eq!(rows_in_tx.len(), 2);
transaction.rollback().await.unwrap();
let rows_after_rollback = db.select("test_rollback").execute(&db).await.unwrap();
assert_eq!(rows_after_rollback.len(), 1); assert_eq!(
rows_after_rollback,
vec![Row {
columns: vec![
("id".into(), i64::from(1).into()),
("value".into(), "initial".into())
]
}]
);
}
#[cfg(feature = "schema")]
#[switchy_async::test]
async fn test_simulator_introspection_delegation() {
let db = SimulationDatabase::new().unwrap();
db.exec_raw(
"CREATE TABLE test_introspection (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
age INTEGER,
score REAL DEFAULT 0.0
)",
)
.await
.unwrap();
assert!(db.table_exists("test_introspection").await.unwrap());
assert!(!db.table_exists("nonexistent_table").await.unwrap());
assert!(db.column_exists("test_introspection", "id").await.unwrap());
assert!(
db.column_exists("test_introspection", "name")
.await
.unwrap()
);
assert!(
!db.column_exists("test_introspection", "nonexistent")
.await
.unwrap()
);
assert!(!db.column_exists("nonexistent_table", "id").await.unwrap());
let columns = db.get_table_columns("test_introspection").await.unwrap();
assert_eq!(columns.len(), 4);
assert_eq!(columns[0].name, "id");
assert!(columns[0].is_primary_key);
assert_eq!(columns[1].name, "name");
assert!(!columns[1].nullable);
assert_eq!(columns[2].name, "age");
assert!(columns[2].nullable);
assert_eq!(columns[3].name, "score");
assert!(columns[3].nullable);
let table_info = db.get_table_info("test_introspection").await.unwrap();
assert!(table_info.is_some());
let info = table_info.unwrap();
assert_eq!(info.name, "test_introspection");
assert_eq!(info.columns.len(), 4);
assert!(info.columns.contains_key("id"));
assert!(info.columns.contains_key("name"));
assert!(info.columns.contains_key("age"));
assert!(info.columns.contains_key("score"));
let empty_columns = db.get_table_columns("nonexistent").await.unwrap();
assert!(empty_columns.is_empty());
let no_table_info = db.get_table_info("nonexistent").await.unwrap();
assert!(no_table_info.is_none());
}
#[cfg(feature = "schema")]
#[switchy_async::test]
async fn test_simulator_transaction_introspection() {
let db = SimulationDatabase::new().unwrap();
db.exec_raw("CREATE TABLE tx_test (id INTEGER PRIMARY KEY, data TEXT)")
.await
.unwrap();
let transaction = db.begin_transaction().await.unwrap();
assert!(transaction.table_exists("tx_test").await.unwrap());
assert!(transaction.column_exists("tx_test", "id").await.unwrap());
let columns = transaction.get_table_columns("tx_test").await.unwrap();
assert_eq!(columns.len(), 2);
let table_info = transaction.get_table_info("tx_test").await.unwrap();
assert!(table_info.is_some());
transaction.commit().await.unwrap();
}
#[cfg(feature = "schema")]
#[switchy_async::test]
async fn test_simulator_path_isolation() {
let db1 = SimulationDatabase::new_for_path(Some("introspection_path1.db")).unwrap();
let db2 = SimulationDatabase::new_for_path(Some("introspection_path2.db")).unwrap();
db1.exec_raw("CREATE TABLE path1_table (id INTEGER, name TEXT)")
.await
.unwrap();
db2.exec_raw("CREATE TABLE path2_table (id INTEGER, value TEXT)")
.await
.unwrap();
assert!(db1.table_exists("path1_table").await.unwrap());
assert!(!db1.table_exists("path2_table").await.unwrap());
assert!(db2.table_exists("path2_table").await.unwrap());
assert!(!db2.table_exists("path1_table").await.unwrap());
assert!(db1.column_exists("path1_table", "name").await.unwrap());
assert!(!db1.column_exists("path1_table", "value").await.unwrap());
assert!(db2.column_exists("path2_table", "value").await.unwrap());
assert!(!db2.column_exists("path2_table", "name").await.unwrap());
let info1 = db1.get_table_info("path1_table").await.unwrap();
let info2 = db2.get_table_info("path2_table").await.unwrap();
assert!(info1.is_some());
assert!(info2.is_some());
assert!(info1.unwrap().columns.contains_key("name"));
assert!(info2.unwrap().columns.contains_key("value"));
}
#[cfg(feature = "schema")]
#[switchy_async::test]
async fn test_list_tables_basic() {
let db = SimulationDatabase::new().unwrap();
let tables = db.list_tables().await.unwrap();
assert!(tables.is_empty(), "New database should have no tables");
db.exec_raw("CREATE TABLE table1 (id INTEGER PRIMARY KEY, name TEXT)")
.await
.unwrap();
db.exec_raw("CREATE TABLE table2 (id INTEGER PRIMARY KEY, value REAL)")
.await
.unwrap();
let mut tables = db.list_tables().await.unwrap();
tables.sort(); assert_eq!(tables, vec!["table1", "table2"]);
db.exec_raw("DROP TABLE table1").await.unwrap();
let tables = db.list_tables().await.unwrap();
assert_eq!(tables, vec!["table2"]);
assert!(!tables.contains(&"table1".to_string()));
}
#[cfg(feature = "schema")]
#[switchy_async::test]
async fn test_list_tables_with_transactions() {
let db = SimulationDatabase::new().unwrap();
db.exec_raw("CREATE TABLE base_table (id INTEGER)")
.await
.unwrap();
let tables = db.list_tables().await.unwrap();
assert_eq!(tables.len(), 1);
assert!(tables.contains(&"base_table".to_string()));
let tx = db.begin_transaction().await.unwrap();
tx.exec_raw("CREATE TABLE tx_table (id INTEGER)")
.await
.unwrap();
let tables_in_tx = tx.list_tables().await.unwrap();
assert_eq!(tables_in_tx.len(), 2);
assert!(tables_in_tx.contains(&"base_table".to_string()));
assert!(tables_in_tx.contains(&"tx_table".to_string()));
tx.rollback().await.unwrap();
let tables_after_rollback = db.list_tables().await.unwrap();
assert_eq!(tables_after_rollback.len(), 1);
assert!(tables_after_rollback.contains(&"base_table".to_string()));
assert!(!tables_after_rollback.contains(&"tx_table".to_string()));
}
#[cfg(feature = "schema")]
#[switchy_async::test]
async fn test_list_tables_isolation() {
let db1 = SimulationDatabase::new_for_path(Some("isolation1.db")).unwrap();
let db2 = SimulationDatabase::new_for_path(Some("isolation2.db")).unwrap();
db1.exec_raw("CREATE TABLE db1_table (id INTEGER)")
.await
.unwrap();
db2.exec_raw("CREATE TABLE db2_table (id INTEGER)")
.await
.unwrap();
let tables1 = db1.list_tables().await.unwrap();
let tables2 = db2.list_tables().await.unwrap();
assert_eq!(tables1.len(), 1);
assert_eq!(tables2.len(), 1);
assert!(tables1.contains(&"db1_table".to_string()));
assert!(tables2.contains(&"db2_table".to_string()));
assert!(!tables1.contains(&"db2_table".to_string()));
assert!(!tables2.contains(&"db1_table".to_string()));
}
#[cfg(feature = "schema")]
#[switchy_async::test]
async fn test_list_tables_after_commit() {
let db = SimulationDatabase::new().unwrap();
let tx = db.begin_transaction().await.unwrap();
tx.exec_raw("CREATE TABLE committed_table (id INTEGER)")
.await
.unwrap();
let tables_in_tx = tx.list_tables().await.unwrap();
assert!(tables_in_tx.contains(&"committed_table".to_string()));
tx.commit().await.unwrap();
let tables_after_commit = db.list_tables().await.unwrap();
assert_eq!(tables_after_commit.len(), 1);
assert!(tables_after_commit.contains(&"committed_table".to_string()));
}
}