use std::time::Duration;
use crate::env::{database_url_from_env, EnvError};
use super::Dialect;
#[derive(Debug, thiserror::Error)]
pub enum PoolError {
#[error("connect: {0}")]
Connect(String),
#[error("unsupported scheme in URL `{0}` — expected postgres://, mysql://, or sqlite:")]
UnsupportedScheme(String),
#[error(
"URL scheme `{scheme}` requires the `{feature}` Cargo feature on rustango \
— add `features = [\"{feature}\"]` to your dependency"
)]
FeatureNotEnabled {
scheme: &'static str,
feature: &'static str,
},
#[error(transparent)]
Env(#[from] EnvError),
}
#[derive(Clone)]
pub enum Pool {
#[cfg(feature = "postgres")]
Postgres(sqlx::PgPool),
#[cfg(feature = "mysql")]
Mysql(sqlx::MySqlPool),
#[cfg(feature = "sqlite")]
Sqlite(sqlx::SqlitePool),
}
impl Pool {
pub async fn connect(url: &str) -> Result<Self, PoolError> {
let scheme = url.split(':').next().unwrap_or("").to_ascii_lowercase();
match scheme.as_str() {
"postgres" | "postgresql" => Self::connect_postgres_inner(url).await,
"mysql" => Self::connect_mysql_inner(url).await,
"sqlite" => Self::connect_sqlite_inner(url).await,
_ => Err(PoolError::UnsupportedScheme(url.to_owned())),
}
}
pub async fn connect_with_timeout(url: &str, timeout: Duration) -> Result<Self, PoolError> {
let scheme = url.split(':').next().unwrap_or("").to_ascii_lowercase();
match scheme.as_str() {
#[cfg(feature = "postgres")]
"postgres" | "postgresql" => {
let pool = sqlx::postgres::PgPoolOptions::new()
.acquire_timeout(timeout)
.connect(url)
.await
.map_err(|e| PoolError::Connect(e.to_string()))?;
Ok(Self::Postgres(pool))
}
#[cfg(not(feature = "postgres"))]
"postgres" | "postgresql" => Err(PoolError::FeatureNotEnabled {
scheme: "postgres",
feature: "postgres",
}),
#[cfg(feature = "mysql")]
"mysql" => {
let pool = sqlx::mysql::MySqlPoolOptions::new()
.acquire_timeout(timeout)
.connect(url)
.await
.map_err(|e| PoolError::Connect(e.to_string()))?;
Ok(Self::Mysql(pool))
}
#[cfg(not(feature = "mysql"))]
"mysql" => Err(PoolError::FeatureNotEnabled {
scheme: "mysql",
feature: "mysql",
}),
#[cfg(feature = "sqlite")]
"sqlite" => {
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.acquire_timeout(timeout)
.connect(url)
.await
.map_err(|e| PoolError::Connect(e.to_string()))?;
Ok(Self::Sqlite(pool))
}
#[cfg(not(feature = "sqlite"))]
"sqlite" => Err(PoolError::FeatureNotEnabled {
scheme: "sqlite",
feature: "sqlite",
}),
_ => Err(PoolError::UnsupportedScheme(url.to_owned())),
}
}
pub async fn connect_from_env() -> Result<Self, PoolError> {
let url = database_url_from_env()?;
Self::connect(&url).await
}
#[must_use]
pub fn dialect(&self) -> &'static dyn Dialect {
match self {
#[cfg(feature = "postgres")]
Pool::Postgres(_) => super::postgres::DIALECT,
#[cfg(feature = "mysql")]
Pool::Mysql(_) => super::mysql::DIALECT,
#[cfg(feature = "sqlite")]
Pool::Sqlite(_) => super::sqlite::DIALECT,
}
}
#[must_use]
pub fn backend_name(&self) -> &'static str {
self.dialect().name()
}
#[must_use]
#[cfg(feature = "postgres")]
pub fn as_postgres(&self) -> Option<&sqlx::PgPool> {
match self {
Pool::Postgres(p) => Some(p),
#[cfg(feature = "mysql")]
Pool::Mysql(_) => None,
#[cfg(feature = "sqlite")]
Pool::Sqlite(_) => None,
}
}
#[must_use]
#[cfg(feature = "mysql")]
pub fn as_mysql(&self) -> Option<&sqlx::MySqlPool> {
match self {
#[cfg(feature = "postgres")]
Pool::Postgres(_) => None,
Pool::Mysql(p) => Some(p),
#[cfg(feature = "sqlite")]
Pool::Sqlite(_) => None,
}
}
#[must_use]
#[cfg(feature = "sqlite")]
pub fn as_sqlite(&self) -> Option<&sqlx::SqlitePool> {
match self {
#[cfg(feature = "postgres")]
Pool::Postgres(_) => None,
#[cfg(feature = "mysql")]
Pool::Mysql(_) => None,
Pool::Sqlite(p) => Some(p),
}
}
#[cfg(feature = "postgres")]
async fn connect_postgres_inner(url: &str) -> Result<Self, PoolError> {
let pool = sqlx::PgPool::connect(url)
.await
.map_err(|e| PoolError::Connect(e.to_string()))?;
Ok(Self::Postgres(pool))
}
#[cfg(not(feature = "postgres"))]
async fn connect_postgres_inner(_url: &str) -> Result<Self, PoolError> {
Err(PoolError::FeatureNotEnabled {
scheme: "postgres",
feature: "postgres",
})
}
#[cfg(feature = "mysql")]
async fn connect_mysql_inner(url: &str) -> Result<Self, PoolError> {
let pool = sqlx::MySqlPool::connect(url)
.await
.map_err(|e| PoolError::Connect(e.to_string()))?;
Ok(Self::Mysql(pool))
}
#[cfg(not(feature = "mysql"))]
#[allow(clippy::unused_async)]
async fn connect_mysql_inner(_url: &str) -> Result<Self, PoolError> {
Err(PoolError::FeatureNotEnabled {
scheme: "mysql",
feature: "mysql",
})
}
#[cfg(feature = "sqlite")]
async fn connect_sqlite_inner(url: &str) -> Result<Self, PoolError> {
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.connect(url)
.await
.map_err(|e| PoolError::Connect(e.to_string()))?;
Ok(Self::Sqlite(pool))
}
#[cfg(not(feature = "sqlite"))]
#[allow(clippy::unused_async)]
async fn connect_sqlite_inner(_url: &str) -> Result<Self, PoolError> {
Err(PoolError::FeatureNotEnabled {
scheme: "sqlite",
feature: "sqlite",
})
}
}
impl std::fmt::Debug for Pool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(feature = "postgres")]
Pool::Postgres(_) => f.write_str("Pool::Postgres(<sqlx::PgPool>)"),
#[cfg(feature = "mysql")]
Pool::Mysql(_) => f.write_str("Pool::Mysql(<sqlx::MySqlPool>)"),
#[cfg(feature = "sqlite")]
Pool::Sqlite(_) => f.write_str("Pool::Sqlite(<sqlx::SqlitePool>)"),
}
}
}
#[cfg(feature = "postgres")]
impl From<sqlx::PgPool> for Pool {
fn from(p: sqlx::PgPool) -> Self {
Pool::Postgres(p)
}
}
#[cfg(feature = "mysql")]
impl From<sqlx::MySqlPool> for Pool {
fn from(p: sqlx::MySqlPool) -> Self {
Pool::Mysql(p)
}
}
#[cfg(feature = "sqlite")]
impl From<sqlx::SqlitePool> for Pool {
fn from(p: sqlx::SqlitePool) -> Self {
Pool::Sqlite(p)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn unrecognized_scheme_errors_clearly() {
let err = Pool::connect("oracle://user@host/db").await.unwrap_err();
match err {
PoolError::UnsupportedScheme(s) => assert!(s.starts_with("oracle://")),
other => panic!("wrong error variant: {other:?}"),
}
}
#[tokio::test]
async fn empty_url_errors_clearly() {
let err = Pool::connect("").await.unwrap_err();
assert!(matches!(err, PoolError::UnsupportedScheme(_)));
}
#[cfg(all(feature = "postgres", not(feature = "mysql")))]
#[tokio::test]
async fn mysql_url_errors_when_feature_not_enabled() {
let err = Pool::connect("mysql://user:pass@host:3306/db")
.await
.unwrap_err();
match err {
PoolError::FeatureNotEnabled { scheme, feature } => {
assert_eq!(scheme, "mysql");
assert_eq!(feature, "mysql");
}
other => panic!("wrong variant: {other:?}"),
}
}
#[cfg(feature = "postgres")]
#[tokio::test]
async fn from_pg_pool_wraps() {
let pg = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost:1/none")
.unwrap();
let pool: Pool = pg.into();
assert_eq!(pool.backend_name(), "postgres");
assert!(pool.as_postgres().is_some());
#[cfg(feature = "mysql")]
assert!(pool.as_mysql().is_none());
}
#[cfg(feature = "mysql")]
#[tokio::test]
async fn from_mysql_pool_wraps() {
let my = sqlx::mysql::MySqlPoolOptions::new()
.max_connections(1)
.connect_lazy("mysql://user:pass@localhost:1/none")
.unwrap();
let pool: Pool = my.into();
assert_eq!(pool.backend_name(), "mysql");
assert!(pool.as_mysql().is_some());
#[cfg(feature = "postgres")]
assert!(pool.as_postgres().is_none());
}
#[cfg(feature = "sqlite")]
#[tokio::test]
async fn sqlite_url_connect_succeeds_in_memory() {
let pool = Pool::connect("sqlite::memory:").await.unwrap();
assert_eq!(pool.backend_name(), "sqlite");
assert!(pool.as_sqlite().is_some());
}
#[cfg(feature = "sqlite")]
#[tokio::test]
async fn sqlite_from_pool_dispatches_to_sqlite_dialect() {
let sqlite_pool = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(1)
.connect_lazy("sqlite::memory:")
.unwrap();
let pool: Pool = sqlite_pool.into();
assert_eq!(pool.backend_name(), "sqlite");
let d = pool.dialect();
assert_eq!(d.name(), "sqlite");
assert!(d.supports_returning());
assert_eq!(d.bool_literal(true), "1");
assert!(pool.as_sqlite().is_some());
#[cfg(feature = "postgres")]
assert!(pool.as_postgres().is_none());
#[cfg(feature = "mysql")]
assert!(pool.as_mysql().is_none());
}
#[cfg(feature = "mysql")]
#[tokio::test]
async fn mysql_pool_dialect_is_mysql() {
let my = sqlx::mysql::MySqlPoolOptions::new()
.max_connections(1)
.connect_lazy("mysql://user:pass@localhost:1/none")
.unwrap();
let pool: Pool = my.into();
let d = pool.dialect();
assert_eq!(d.name(), "mysql");
assert_eq!(d.quote_ident("col"), "`col`");
assert_eq!(d.placeholder(1), "?");
assert!(!d.supports_returning());
}
}