use std::sync::Arc;
use crate::dialect::Dialect;
use crate::driver::ExecuteResult;
use crate::error::OrmError;
use crate::row::Row;
use crate::value::Value;
#[derive(Clone)]
enum Backend {
#[cfg(feature = "sqlite")]
Sqlite(crate::driver::sqlite::SqlitePool),
#[cfg(feature = "postgres")]
Postgres(crate::driver::postgres::PostgresPool),
#[cfg(feature = "mysql")]
Mysql(crate::driver::mysql::MysqlPool),
}
#[derive(Clone)]
pub struct Database {
backend: Backend,
dialect: Arc<dyn Dialect>,
}
impl Database {
pub async fn connect(url: &str, max_connections: u32) -> crate::Result<Self> {
let scheme = url.split_once(':').map(|(scheme, _)| scheme).unwrap_or("");
match scheme {
#[cfg(feature = "sqlite")]
"sqlite" | "" => {
let pool = crate::driver::sqlite::SqlitePool::new(url, max_connections)?;
Ok(Self {
backend: Backend::Sqlite(pool),
dialect: Arc::new(crate::dialect::SqliteDialect::new()),
})
}
#[cfg(feature = "postgres")]
"postgres" | "postgresql" => {
let pool = crate::driver::postgres::PostgresPool::new(url, max_connections)?;
Ok(Self {
backend: Backend::Postgres(pool),
dialect: Arc::new(crate::dialect::PostgresDialect::new()),
})
}
#[cfg(not(feature = "postgres"))]
"postgres" | "postgresql" => Err(OrmError::configuration(
"this build cannot connect to PostgreSQL; enable the `postgres` \
feature to compile the driver",
)),
#[cfg(feature = "mysql")]
"mysql" | "mariadb" => {
let pool = crate::driver::mysql::MysqlPool::new(url, max_connections)?;
Ok(Self {
backend: Backend::Mysql(pool),
dialect: Arc::new(crate::dialect::MySqlDialect::new()),
})
}
#[cfg(not(feature = "mysql"))]
"mysql" | "mariadb" => Err(OrmError::configuration(
"this build cannot connect to MySQL; enable the `mysql` feature to \
compile the driver",
)),
other => Err(OrmError::configuration(format!(
"no compiled-in backend for url scheme `{other}`"
))),
}
}
pub fn dialect(&self) -> &Arc<dyn Dialect> {
&self.dialect
}
pub async fn fetch_all(&self, sql: String, params: Vec<Value>) -> crate::Result<Vec<Row>> {
match &self.backend {
#[cfg(feature = "sqlite")]
Backend::Sqlite(pool) => pool.fetch_all(sql, params).await,
#[cfg(feature = "postgres")]
Backend::Postgres(pool) => pool.fetch_all(sql, params).await,
#[cfg(feature = "mysql")]
Backend::Mysql(pool) => pool.fetch_all(sql, params).await,
}
}
pub async fn execute(
&self,
sql: String,
params: Vec<Value>,
) -> crate::Result<ExecuteResult> {
match &self.backend {
#[cfg(feature = "sqlite")]
Backend::Sqlite(pool) => pool.execute(sql, params).await,
#[cfg(feature = "postgres")]
Backend::Postgres(pool) => pool.execute(sql, params).await,
#[cfg(feature = "mysql")]
Backend::Mysql(pool) => pool.execute(sql, params).await,
}
}
pub async fn execute_batch(&self, sql: String) -> crate::Result<()> {
match &self.backend {
#[cfg(feature = "sqlite")]
Backend::Sqlite(pool) => pool.execute_batch(sql).await,
#[cfg(feature = "postgres")]
Backend::Postgres(pool) => pool.execute_batch(sql).await,
#[cfg(feature = "mysql")]
Backend::Mysql(pool) => pool.execute_batch(sql).await,
}
}
pub fn statement_count(&self) -> u64 {
match &self.backend {
#[cfg(feature = "sqlite")]
Backend::Sqlite(pool) => pool.statement_count(),
#[cfg(feature = "postgres")]
Backend::Postgres(pool) => pool.statement_count(),
#[cfg(feature = "mysql")]
Backend::Mysql(pool) => pool.statement_count(),
}
}
pub async fn close(&self) {
match &self.backend {
#[cfg(feature = "sqlite")]
Backend::Sqlite(pool) => pool.close().await,
#[cfg(feature = "postgres")]
Backend::Postgres(pool) => pool.close().await,
#[cfg(feature = "mysql")]
Backend::Mysql(pool) => pool.close().await,
}
}
pub(crate) async fn pinned(&self) -> crate::Result<Pinned> {
let backend = match &self.backend {
#[cfg(feature = "sqlite")]
Backend::Sqlite(pool) => PinnedBackend::Sqlite(pool.acquire_pinned().await?),
#[cfg(feature = "postgres")]
Backend::Postgres(pool) => PinnedBackend::Postgres(pool.acquire_pinned().await?),
#[cfg(feature = "mysql")]
Backend::Mysql(pool) => PinnedBackend::Mysql(pool.acquire_pinned().await?),
};
Ok(Pinned {
backend,
dialect: Arc::clone(&self.dialect),
})
}
}
pub(crate) struct Pinned {
backend: PinnedBackend,
dialect: Arc<dyn Dialect>,
}
enum PinnedBackend {
#[cfg(feature = "sqlite")]
Sqlite(crate::driver::sqlite::PinnedSqlite),
#[cfg(feature = "postgres")]
Postgres(crate::driver::postgres::PinnedPostgres),
#[cfg(feature = "mysql")]
Mysql(crate::driver::mysql::PinnedMysql),
}
impl crate::executor::Executor for Pinned {
fn dialect(&self) -> &dyn Dialect {
self.dialect.as_ref()
}
async fn fetch_all(&self, sql: String, params: Vec<Value>) -> crate::Result<Vec<Row>> {
match &self.backend {
#[cfg(feature = "sqlite")]
PinnedBackend::Sqlite(pinned) => pinned.fetch_all(sql, params).await,
#[cfg(feature = "postgres")]
PinnedBackend::Postgres(pinned) => pinned.fetch_all(sql, params).await,
#[cfg(feature = "mysql")]
PinnedBackend::Mysql(pinned) => pinned.fetch_all(sql, params).await,
}
}
async fn execute(&self, sql: String, params: Vec<Value>) -> crate::Result<ExecuteResult> {
match &self.backend {
#[cfg(feature = "sqlite")]
PinnedBackend::Sqlite(pinned) => pinned.execute(sql, params).await,
#[cfg(feature = "postgres")]
PinnedBackend::Postgres(pinned) => pinned.execute(sql, params).await,
#[cfg(feature = "mysql")]
PinnedBackend::Mysql(pinned) => pinned.execute(sql, params).await,
}
}
}
impl Pinned {
pub(crate) async fn execute_batch(&self, sql: String) -> crate::Result<()> {
match &self.backend {
#[cfg(feature = "sqlite")]
PinnedBackend::Sqlite(pinned) => pinned.execute_batch(sql).await,
#[cfg(feature = "postgres")]
PinnedBackend::Postgres(pinned) => pinned.execute_batch(sql).await,
#[cfg(feature = "mysql")]
PinnedBackend::Mysql(pinned) => pinned.execute_batch(sql).await,
}
}
pub(crate) fn rollback_now(&self) {
match &self.backend {
#[cfg(feature = "sqlite")]
PinnedBackend::Sqlite(pinned) => pinned.rollback_now(),
#[cfg(feature = "postgres")]
PinnedBackend::Postgres(pinned) => pinned.rollback_now(),
#[cfg(feature = "mysql")]
PinnedBackend::Mysql(pinned) => pinned.rollback_now(),
}
}
}