#![cfg_attr(
not(any(feature = "mariadb", feature = "postgres", feature = "sqlite")),
allow(unreachable_code, unused_variables)
)]
#[cfg(feature = "postgres")]
use std::borrow::Cow;
use crate::{ConnectOptions, Driver, Error, Result, Rows, Statement};
pub(crate) enum ConnectionInner {
#[cfg(feature = "mariadb")]
Mysql(quex_driver::mysql::Connection),
#[cfg(feature = "postgres")]
Postgres(quex_driver::postgres::Connection),
#[cfg(feature = "sqlite")]
Sqlite(quex_driver::sqlite::Connection),
_Disabled,
}
pub(crate) struct ManagedConnection {
pub(crate) driver: Driver,
pub(crate) inner: ConnectionInner,
}
pub(crate) async fn query_inner<'a>(inner: &'a mut ConnectionInner, sql: &str) -> Result<Rows<'a>> {
match inner {
#[cfg(feature = "mariadb")]
ConnectionInner::Mysql(conn) => {
let rows = conn.query(sql).await?;
Ok(Rows::mysql(rows))
}
#[cfg(feature = "postgres")]
ConnectionInner::Postgres(conn) => {
let rows = conn.query(sql).await?;
Ok(Rows::postgres(rows))
}
#[cfg(feature = "sqlite")]
ConnectionInner::Sqlite(conn) => {
let rows = conn.query(sql).await?;
Ok(Rows::sqlite(rows))
}
ConnectionInner::_Disabled => unreachable!("disabled backend placeholder"),
}
}
pub(crate) async fn prepare_inner<'a>(
inner: &'a mut ConnectionInner,
sql: &str,
) -> Result<Statement<'a>> {
match inner {
#[cfg(feature = "mariadb")]
ConnectionInner::Mysql(conn) => Ok(Statement::Mysql(conn.prepare_cached(sql).await?)),
#[cfg(feature = "postgres")]
ConnectionInner::Postgres(conn) => {
let sql = rewrite_postgres_placeholders(sql);
Ok(Statement::Postgres(conn.prepare_cached(&sql).await?))
}
#[cfg(feature = "sqlite")]
ConnectionInner::Sqlite(conn) => Ok(Statement::Sqlite(conn.prepare_cached(sql).await?)),
ConnectionInner::_Disabled => unreachable!("disabled backend placeholder"),
}
}
#[cfg(feature = "postgres")]
pub(crate) fn rewrite_postgres_placeholders(sql: &str) -> Cow<'_, str> {
let bytes = sql.as_bytes();
let mut output: Option<String> = None;
let mut last = 0;
let mut i = 0;
let mut param = 1;
while i < bytes.len() {
match bytes[i] {
b'?' => {
let output = output.get_or_insert_with(|| String::with_capacity(sql.len() + 4));
output.push_str(&sql[last..i]);
output.push('$');
output.push_str(param.to_string().as_str());
param += 1;
i += 1;
last = i;
}
b'\'' => i = skip_single_quoted(bytes, i),
b'"' => i = skip_double_quoted(bytes, i),
b'-' if bytes.get(i + 1) == Some(&b'-') => i = skip_line_comment(bytes, i + 2),
b'/' if bytes.get(i + 1) == Some(&b'*') => i = skip_block_comment(bytes, i + 2),
b'$' => {
if let Some(delimiter_len) = dollar_quote_delimiter_len(bytes, i) {
i = skip_dollar_quoted(bytes, i, delimiter_len);
} else {
i += 1;
}
}
_ => i += 1,
}
}
match output {
Some(mut output) => {
output.push_str(&sql[last..]);
Cow::Owned(output)
}
None => Cow::Borrowed(sql),
}
}
#[cfg(feature = "postgres")]
fn skip_single_quoted(bytes: &[u8], mut i: usize) -> usize {
i += 1;
while i < bytes.len() {
if bytes[i] == b'\'' {
i += 1;
if bytes.get(i) == Some(&b'\'') {
i += 1;
} else {
break;
}
} else {
i += 1;
}
}
i
}
#[cfg(feature = "postgres")]
fn skip_double_quoted(bytes: &[u8], mut i: usize) -> usize {
i += 1;
while i < bytes.len() {
if bytes[i] == b'"' {
i += 1;
if bytes.get(i) == Some(&b'"') {
i += 1;
} else {
break;
}
} else {
i += 1;
}
}
i
}
#[cfg(feature = "postgres")]
fn skip_line_comment(bytes: &[u8], mut i: usize) -> usize {
while i < bytes.len() && bytes[i] != b'\n' {
i += 1;
}
i
}
#[cfg(feature = "postgres")]
fn skip_block_comment(bytes: &[u8], mut i: usize) -> usize {
let mut depth = 1;
while i + 1 < bytes.len() {
match (bytes[i], bytes[i + 1]) {
(b'/', b'*') => {
depth += 1;
i += 2;
}
(b'*', b'/') => {
depth -= 1;
i += 2;
if depth == 0 {
break;
}
}
_ => i += 1,
}
}
bytes.len().min(i)
}
#[cfg(feature = "postgres")]
fn dollar_quote_delimiter_len(bytes: &[u8], start: usize) -> Option<usize> {
let mut i = start + 1;
if bytes.get(i) == Some(&b'$') {
return Some(2);
}
if !matches!(bytes.get(i), Some(b'a'..=b'z' | b'A'..=b'Z' | b'_')) {
return None;
}
i += 1;
while matches!(
bytes.get(i),
Some(b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_')
) {
i += 1;
}
if bytes.get(i) == Some(&b'$') {
Some(i - start + 1)
} else {
None
}
}
#[cfg(feature = "postgres")]
fn skip_dollar_quoted(bytes: &[u8], start: usize, delimiter_len: usize) -> usize {
let delimiter = &bytes[start..start + delimiter_len];
let mut i = start + delimiter_len;
while i + delimiter_len <= bytes.len() {
if &bytes[i..i + delimiter_len] == delimiter {
return i + delimiter_len;
}
i += 1;
}
bytes.len()
}
pub(crate) async fn start_transaction_inner(inner: &mut ConnectionInner) -> Result<()> {
match inner {
#[cfg(feature = "mariadb")]
ConnectionInner::Mysql(conn) => {
let _ = conn.query("start transaction").await?;
Ok(())
}
#[cfg(feature = "postgres")]
ConnectionInner::Postgres(conn) => {
let _ = conn.query("begin").await?;
Ok(())
}
#[cfg(feature = "sqlite")]
ConnectionInner::Sqlite(conn) => conn
.execute_batch("begin immediate")
.await
.map_err(Into::into),
ConnectionInner::_Disabled => unreachable!("disabled backend placeholder"),
}
}
pub(crate) async fn commit_inner(inner: &mut ConnectionInner) -> Result<()> {
match inner {
#[cfg(feature = "mariadb")]
ConnectionInner::Mysql(conn) => conn.commit().await.map_err(Into::into),
#[cfg(feature = "postgres")]
ConnectionInner::Postgres(conn) => conn.commit().await.map_err(Into::into),
#[cfg(feature = "sqlite")]
ConnectionInner::Sqlite(conn) => conn.execute_batch("commit").await.map_err(Into::into),
ConnectionInner::_Disabled => unreachable!("disabled backend placeholder"),
}
}
pub(crate) async fn rollback_inner(inner: &mut ConnectionInner) -> Result<()> {
match inner {
#[cfg(feature = "mariadb")]
ConnectionInner::Mysql(conn) => conn.rollback().await.map_err(Into::into),
#[cfg(feature = "postgres")]
ConnectionInner::Postgres(conn) => conn.rollback().await.map_err(Into::into),
#[cfg(feature = "sqlite")]
ConnectionInner::Sqlite(conn) => conn.execute_batch("rollback").await.map_err(Into::into),
ConnectionInner::_Disabled => unreachable!("disabled backend placeholder"),
}
}
pub(crate) async fn connect_managed(options: ConnectOptions) -> Result<ManagedConnection> {
let driver = options
.driver
.ok_or_else(|| Error::invalid_url("missing driver"))?;
let inner = match driver {
Driver::Mysql => {
#[cfg(feature = "mariadb")]
{
ConnectionInner::Mysql(
quex_driver::mysql::Connection::connect(
quex_driver::mysql::ConnectOptions::new()
.host(options.host.unwrap_or_else(|| "127.0.0.1".into()))
.port(options.port.unwrap_or(3306) as u32)
.user(options.username.unwrap_or_else(|| "root".into()))
.password(options.password.unwrap_or_default())
.database(options.database.unwrap_or_default())
.unix_socket(options.unix_socket.unwrap_or_default()),
)
.await?,
)
}
#[cfg(not(feature = "mariadb"))]
{
return Err(Error::Unsupported(
"mysql support is not enabled; enable the `mysql` feature".into(),
));
}
}
Driver::Pgsql => {
#[cfg(feature = "postgres")]
{
ConnectionInner::Postgres(
quex_driver::postgres::Connection::connect(
quex_driver::postgres::ConnectOptions::new()
.host(options.host.unwrap_or_else(|| "127.0.0.1".into()))
.port(options.port.unwrap_or(5432))
.user(options.username.unwrap_or_else(|| "postgres".into()))
.password(options.password.unwrap_or_default())
.database(options.database.unwrap_or_else(|| "postgres".into())),
)
.await?,
)
}
#[cfg(not(feature = "postgres"))]
{
return Err(Error::Unsupported(
"postgres support is not enabled; enable the `postgres` feature".into(),
));
}
}
Driver::Sqlite => {
#[cfg(feature = "sqlite")]
{
let mut connect = quex_driver::sqlite::ConnectOptions::new()
.read_only(options.read_only)
.create_if_missing(options.create_if_missing);
if let Some(timeout) = options.busy_timeout {
connect = connect.busy_timeout(timeout);
}
connect = if options.in_memory {
connect.in_memory()
} else if let Some(path) = options.path {
connect.path(path)
} else {
connect.path("sqlite.db")
};
ConnectionInner::Sqlite(quex_driver::sqlite::Connection::connect(connect).await?)
}
#[cfg(not(feature = "sqlite"))]
{
return Err(Error::Unsupported(
"sqlite support is not enabled; enable the `sqlite` feature".into(),
));
}
}
};
Ok(ManagedConnection { driver, inner })
}