use super::BaseDatabaseSchemaEditor;
#[cfg(feature = "postgres")]
use crate::backends::drivers::postgresql::schema::PostgreSQLSchemaEditor;
#[cfg(feature = "mysql")]
use crate::backends::drivers::mysql::schema::MySQLSchemaEditor;
#[cfg(feature = "sqlite")]
use crate::backends::drivers::sqlite::schema::SQLiteSchemaEditor;
use std::sync::Arc;
#[cfg(feature = "postgres")]
use sqlx::PgPool;
#[cfg(feature = "mysql")]
use sqlx::MySqlPool;
#[cfg(feature = "sqlite")]
use sqlx::SqlitePool;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DatabaseType {
PostgreSQL,
MySQL,
SQLite,
}
impl DatabaseType {
pub fn from_connection_string(conn_str: &str) -> Option<Self> {
if conn_str.starts_with("postgres://") || conn_str.starts_with("postgresql://") {
Some(DatabaseType::PostgreSQL)
} else if conn_str.starts_with("mysql://") {
Some(DatabaseType::MySQL)
} else if conn_str.starts_with("sqlite://") {
Some(DatabaseType::SQLite)
} else {
None
}
}
pub fn as_str(&self) -> &'static str {
match self {
DatabaseType::PostgreSQL => "postgresql",
DatabaseType::MySQL => "mysql",
DatabaseType::SQLite => "sqlite",
}
}
}
pub struct SchemaEditorFactory {
#[cfg(feature = "postgres")]
pg_pool: Option<Arc<PgPool>>,
#[cfg(feature = "mysql")]
#[allow(dead_code)]
mysql_pool: Option<Arc<MySqlPool>>,
#[cfg(feature = "sqlite")]
#[allow(dead_code)]
sqlite_pool: Option<Arc<SqlitePool>>,
}
impl SchemaEditorFactory {
pub fn new() -> Self {
Self {
#[cfg(feature = "postgres")]
pg_pool: None,
#[cfg(feature = "mysql")]
mysql_pool: None,
#[cfg(feature = "sqlite")]
sqlite_pool: None,
}
}
#[cfg(feature = "postgres")]
pub fn new_postgres(pool: PgPool) -> Self {
Self {
pg_pool: Some(Arc::new(pool)),
#[cfg(feature = "mysql")]
mysql_pool: None,
#[cfg(feature = "sqlite")]
sqlite_pool: None,
}
}
#[cfg(feature = "mysql")]
pub fn new_mysql(pool: MySqlPool) -> Self {
Self {
#[cfg(feature = "postgres")]
pg_pool: None,
mysql_pool: Some(Arc::new(pool)),
#[cfg(feature = "sqlite")]
sqlite_pool: None,
}
}
#[cfg(feature = "sqlite")]
pub fn new_sqlite(pool: SqlitePool) -> Self {
Self {
#[cfg(feature = "postgres")]
pg_pool: None,
#[cfg(feature = "mysql")]
mysql_pool: None,
sqlite_pool: Some(Arc::new(pool)),
}
}
pub fn create_for_database(&self, db_type: DatabaseType) -> Box<dyn BaseDatabaseSchemaEditor> {
match db_type {
#[cfg(feature = "postgres")]
DatabaseType::PostgreSQL => {
let pool = self
.pg_pool
.as_ref()
.expect("PostgreSQL pool not set. Use SchemaEditorFactory::new_postgres()");
Box::new(PostgreSQLSchemaEditor::from_pool_arc(Arc::clone(pool)))
}
#[cfg(not(feature = "postgres"))]
DatabaseType::PostgreSQL => {
panic!("PostgreSQL support not enabled. Enable 'postgres' feature.")
}
#[cfg(feature = "mysql")]
DatabaseType::MySQL => Box::new(MySQLSchemaEditor::new()),
#[cfg(not(feature = "mysql"))]
DatabaseType::MySQL => {
panic!("MySQL support not enabled. Enable 'mysql' feature.")
}
#[cfg(feature = "sqlite")]
DatabaseType::SQLite => Box::new(SQLiteSchemaEditor::new()),
#[cfg(not(feature = "sqlite"))]
DatabaseType::SQLite => {
panic!("SQLite support not enabled. Enable 'sqlite' feature.")
}
}
}
pub fn create_shared(
&self,
db_type: DatabaseType,
) -> Arc<dyn BaseDatabaseSchemaEditor + Send + Sync> {
match db_type {
#[cfg(feature = "postgres")]
DatabaseType::PostgreSQL => {
let pool = self
.pg_pool
.as_ref()
.expect("PostgreSQL pool not set. Use SchemaEditorFactory::new_postgres()");
Arc::new(PostgreSQLSchemaEditor::from_pool_arc(Arc::clone(pool)))
}
#[cfg(not(feature = "postgres"))]
DatabaseType::PostgreSQL => {
panic!("PostgreSQL support not enabled. Enable 'postgres' feature.")
}
#[cfg(feature = "mysql")]
DatabaseType::MySQL => Arc::new(MySQLSchemaEditor::new()),
#[cfg(not(feature = "mysql"))]
DatabaseType::MySQL => {
panic!("MySQL support not enabled. Enable 'mysql' feature.")
}
#[cfg(feature = "sqlite")]
DatabaseType::SQLite => Arc::new(SQLiteSchemaEditor::new()),
#[cfg(not(feature = "sqlite"))]
DatabaseType::SQLite => {
panic!("SQLite support not enabled. Enable 'sqlite' feature.")
}
}
}
}
impl Default for SchemaEditorFactory {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
#[test]
fn test_database_type_from_connection_string() {
assert_eq!(
DatabaseType::from_connection_string("postgres://localhost/mydb"),
Some(DatabaseType::PostgreSQL)
);
assert_eq!(
DatabaseType::from_connection_string("postgresql://localhost/mydb"),
Some(DatabaseType::PostgreSQL)
);
assert_eq!(
DatabaseType::from_connection_string("mysql://localhost/mydb"),
Some(DatabaseType::MySQL)
);
assert_eq!(
DatabaseType::from_connection_string("sqlite:///data.db"),
Some(DatabaseType::SQLite)
);
assert_eq!(
DatabaseType::from_connection_string("unknown://localhost/mydb"),
None
);
}
#[test]
fn test_database_type_as_str() {
assert_eq!(DatabaseType::PostgreSQL.as_str(), "postgresql");
assert_eq!(DatabaseType::MySQL.as_str(), "mysql");
assert_eq!(DatabaseType::SQLite.as_str(), "sqlite");
}
#[test]
fn test_factory_creation() {
let factory = SchemaEditorFactory::new();
let _factory2 = SchemaEditorFactory::default();
drop(factory);
}
#[cfg(feature = "postgres")]
#[fixture]
async fn pg_pool() -> PgPool {
PgPool::connect_lazy("postgresql://localhost/test_db").expect("Failed to create test pool")
}
#[cfg(feature = "postgres")]
#[rstest]
#[tokio::test]
async fn test_create_postgresql_editor(#[future] pg_pool: PgPool) {
let pool = pg_pool.await;
let factory = SchemaEditorFactory::new_postgres(pool);
let _editor = factory.create_for_database(DatabaseType::PostgreSQL);
}
#[cfg(feature = "postgres")]
#[rstest]
#[tokio::test]
async fn test_create_shared_editor(#[future] pg_pool: PgPool) {
let pool = pg_pool.await;
let factory = SchemaEditorFactory::new_postgres(pool);
let editor = factory.create_shared(DatabaseType::PostgreSQL);
let _editor_clone = Arc::clone(&editor);
assert_eq!(Arc::strong_count(&editor), 2);
}
}