use std::time::Duration;
use crate::env::{database_url_from_env, EnvError};
use super::Dialect;
#[derive(Debug, thiserror::Error)]
pub enum PoolError {
#[error("connect: {0}")]
Connect(String),
#[error("unsupported scheme in URL `{0}` — expected postgres:// or mysql://")]
UnsupportedScheme(String),
#[error(
"URL scheme `{scheme}` requires the `{feature}` Cargo feature on rustango \
— add `features = [\"{feature}\"]` to your dependency"
)]
FeatureNotEnabled {
scheme: &'static str,
feature: &'static str,
},
#[error(transparent)]
Env(#[from] EnvError),
}
#[derive(Clone)]
pub enum Pool {
#[cfg(feature = "postgres")]
Postgres(sqlx::PgPool),
#[cfg(feature = "mysql")]
Mysql(sqlx::MySqlPool),
}
impl Pool {
pub async fn connect(url: &str) -> Result<Self, PoolError> {
let scheme = url
.split("://")
.next()
.unwrap_or("")
.to_ascii_lowercase();
match scheme.as_str() {
"postgres" | "postgresql" => Self::connect_postgres_inner(url).await,
"mysql" => Self::connect_mysql_inner(url).await,
_ => Err(PoolError::UnsupportedScheme(url.to_owned())),
}
}
pub async fn connect_with_timeout(url: &str, timeout: Duration) -> Result<Self, PoolError> {
let scheme = url
.split("://")
.next()
.unwrap_or("")
.to_ascii_lowercase();
match scheme.as_str() {
#[cfg(feature = "postgres")]
"postgres" | "postgresql" => {
let pool = sqlx::postgres::PgPoolOptions::new()
.acquire_timeout(timeout)
.connect(url)
.await
.map_err(|e| PoolError::Connect(e.to_string()))?;
Ok(Self::Postgres(pool))
}
#[cfg(not(feature = "postgres"))]
"postgres" | "postgresql" => Err(PoolError::FeatureNotEnabled {
scheme: "postgres",
feature: "postgres",
}),
#[cfg(feature = "mysql")]
"mysql" => {
let pool = sqlx::mysql::MySqlPoolOptions::new()
.acquire_timeout(timeout)
.connect(url)
.await
.map_err(|e| PoolError::Connect(e.to_string()))?;
Ok(Self::Mysql(pool))
}
#[cfg(not(feature = "mysql"))]
"mysql" => Err(PoolError::FeatureNotEnabled {
scheme: "mysql",
feature: "mysql",
}),
_ => Err(PoolError::UnsupportedScheme(url.to_owned())),
}
}
pub async fn connect_from_env() -> Result<Self, PoolError> {
let url = database_url_from_env()?;
Self::connect(&url).await
}
#[must_use]
pub fn dialect(&self) -> &'static dyn Dialect {
match self {
#[cfg(feature = "postgres")]
Pool::Postgres(_) => super::postgres::DIALECT,
#[cfg(feature = "mysql")]
Pool::Mysql(_) => super::mysql::DIALECT,
}
}
#[must_use]
pub fn backend_name(&self) -> &'static str {
self.dialect().name()
}
#[must_use]
#[cfg(feature = "postgres")]
pub fn as_postgres(&self) -> Option<&sqlx::PgPool> {
match self {
Pool::Postgres(p) => Some(p),
#[cfg(feature = "mysql")]
Pool::Mysql(_) => None,
}
}
#[must_use]
#[cfg(feature = "mysql")]
pub fn as_mysql(&self) -> Option<&sqlx::MySqlPool> {
match self {
#[cfg(feature = "postgres")]
Pool::Postgres(_) => None,
Pool::Mysql(p) => Some(p),
}
}
#[cfg(feature = "postgres")]
async fn connect_postgres_inner(url: &str) -> Result<Self, PoolError> {
let pool = sqlx::PgPool::connect(url)
.await
.map_err(|e| PoolError::Connect(e.to_string()))?;
Ok(Self::Postgres(pool))
}
#[cfg(not(feature = "postgres"))]
async fn connect_postgres_inner(_url: &str) -> Result<Self, PoolError> {
Err(PoolError::FeatureNotEnabled {
scheme: "postgres",
feature: "postgres",
})
}
#[cfg(feature = "mysql")]
async fn connect_mysql_inner(url: &str) -> Result<Self, PoolError> {
let pool = sqlx::MySqlPool::connect(url)
.await
.map_err(|e| PoolError::Connect(e.to_string()))?;
Ok(Self::Mysql(pool))
}
#[cfg(not(feature = "mysql"))]
#[allow(clippy::unused_async)]
async fn connect_mysql_inner(_url: &str) -> Result<Self, PoolError> {
Err(PoolError::FeatureNotEnabled {
scheme: "mysql",
feature: "mysql",
})
}
}
impl std::fmt::Debug for Pool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(feature = "postgres")]
Pool::Postgres(_) => f.write_str("Pool::Postgres(<sqlx::PgPool>)"),
#[cfg(feature = "mysql")]
Pool::Mysql(_) => f.write_str("Pool::Mysql(<sqlx::MySqlPool>)"),
}
}
}
#[cfg(feature = "postgres")]
impl From<sqlx::PgPool> for Pool {
fn from(p: sqlx::PgPool) -> Self {
Pool::Postgres(p)
}
}
#[cfg(feature = "mysql")]
impl From<sqlx::MySqlPool> for Pool {
fn from(p: sqlx::MySqlPool) -> Self {
Pool::Mysql(p)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn unrecognized_scheme_errors_clearly() {
let err = Pool::connect("oracle://user@host/db").await.unwrap_err();
match err {
PoolError::UnsupportedScheme(s) => assert!(s.starts_with("oracle://")),
other => panic!("wrong error variant: {other:?}"),
}
}
#[tokio::test]
async fn empty_url_errors_clearly() {
let err = Pool::connect("").await.unwrap_err();
assert!(matches!(err, PoolError::UnsupportedScheme(_)));
}
#[cfg(all(feature = "postgres", not(feature = "mysql")))]
#[tokio::test]
async fn mysql_url_errors_when_feature_not_enabled() {
let err = Pool::connect("mysql://user:pass@host:3306/db").await.unwrap_err();
match err {
PoolError::FeatureNotEnabled { scheme, feature } => {
assert_eq!(scheme, "mysql");
assert_eq!(feature, "mysql");
}
other => panic!("wrong variant: {other:?}"),
}
}
#[cfg(feature = "postgres")]
#[tokio::test]
async fn from_pg_pool_wraps() {
let pg = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost:1/none")
.unwrap();
let pool: Pool = pg.into();
assert_eq!(pool.backend_name(), "postgres");
assert!(pool.as_postgres().is_some());
#[cfg(feature = "mysql")]
assert!(pool.as_mysql().is_none());
}
#[cfg(feature = "mysql")]
#[tokio::test]
async fn from_mysql_pool_wraps() {
let my = sqlx::mysql::MySqlPoolOptions::new()
.max_connections(1)
.connect_lazy("mysql://user:pass@localhost:1/none")
.unwrap();
let pool: Pool = my.into();
assert_eq!(pool.backend_name(), "mysql");
assert!(pool.as_mysql().is_some());
#[cfg(feature = "postgres")]
assert!(pool.as_postgres().is_none());
}
#[cfg(feature = "mysql")]
#[tokio::test]
async fn mysql_pool_dialect_is_mysql() {
let my = sqlx::mysql::MySqlPoolOptions::new()
.max_connections(1)
.connect_lazy("mysql://user:pass@localhost:1/none")
.unwrap();
let pool: Pool = my.into();
let d = pool.dialect();
assert_eq!(d.name(), "mysql");
assert_eq!(d.quote_ident("col"), "`col`");
assert_eq!(d.placeholder(1), "?");
assert!(!d.supports_returning());
}
}