use std::error::Error as StdError;
use std::fmt;
use sqlx::{AnyConnection, Connection};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DB {
Mysql,
Psql,
Sqlite,
}
impl DB {
fn matches_url(self, url: &str) -> bool {
match self {
Self::Mysql => url.starts_with("mysql://"),
Self::Psql => {
url.starts_with("postgres://")
|| url.starts_with("postgresql://")
}
Self::Sqlite => url.starts_with("sqlite:"),
}
}
fn expected(self) -> &'static str {
match self {
Self::Mysql => "mysql://",
Self::Psql => "postgres:// or postgresql://",
Self::Sqlite => "sqlite: (e.g. sqlite://file.db or sqlite::memory:)",
}
}
}
#[derive(Debug)]
pub enum Error {
UrlBackendMismatch {
db: DB,
url: String,
},
Sqlx(sqlx::Error),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UrlBackendMismatch { db, url } => write!(
f,
"database URL '{url}' does not match {:?} (expected {})",
db,
db.expected()
),
Self::Sqlx(err) => write!(f, "sqlx error: {err}"),
}
}
}
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Self::UrlBackendMismatch { .. } => None,
Self::Sqlx(err) => Some(err),
}
}
}
impl From<sqlx::Error> for Error {
fn from(value: sqlx::Error) -> Self {
Self::Sqlx(value)
}
}
pub type Result<T> = std::result::Result<T, Error>;
fn ensure_backend_url(db: DB, database_url: &str) -> Result<()> {
if !db.matches_url(database_url) {
return Err(Error::UrlBackendMismatch {
db,
url: database_url.to_string(),
});
}
Ok(())
}
pub async fn connect(db: DB, database_url: &str) -> Result<AnyConnection> {
ensure_backend_url(db, database_url)?;
AnyConnection::connect(database_url).await.map_err(Error::from)
}