use std::marker::PhantomData;
use sqlx::{
QueryBuilder,
types::{
Uuid,
time::{Date, OffsetDateTime, PrimitiveDateTime, Time},
},
};
pub trait GatekeepSqlxBackend: Clone + Copy + core::fmt::Debug + Send + Sync + 'static {
type Database: sqlx::Database;
const DRIVER: SqlxDriver;
const NAME: &'static str;
fn push_placeholder(sql: &mut String, index: usize);
fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue);
const MIN_FUNCTION: &'static str;
const MAX_FUNCTION: &'static str;
const GRADE_FUNCTION_PROPAGATES_NULL: bool;
}
#[cfg(feature = "postgres")]
#[derive(Clone, Copy, Debug)]
pub struct PostgresBackend;
#[cfg(feature = "postgres")]
impl GatekeepSqlxBackend for PostgresBackend {
type Database = sqlx::Postgres;
const DRIVER: SqlxDriver = SqlxDriver::Postgres;
const NAME: &'static str = "postgres";
const MIN_FUNCTION: &'static str = "LEAST";
const MAX_FUNCTION: &'static str = "GREATEST";
const GRADE_FUNCTION_PROPAGATES_NULL: bool = false;
fn push_placeholder(sql: &mut String, index: usize) {
sql.push('$');
sql.push_str(&index.to_string());
}
fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue) {
match value {
SqlxValue::Bool(value) => {
builder.push_bind(*value);
}
SqlxValue::I16(value) => {
builder.push_bind(*value);
}
SqlxValue::I32(value) => {
builder.push_bind(*value);
}
SqlxValue::I64(value) => {
builder.push_bind(*value);
}
SqlxValue::Text(value) => {
builder.push_bind(value.clone());
}
SqlxValue::Bytes(value) => {
builder.push_bind(value.clone());
}
SqlxValue::Uuid(value) => {
builder.push_bind(*value);
}
SqlxValue::Date(value) => {
builder.push_bind(*value);
}
SqlxValue::Time(value) => {
builder.push_bind(*value);
}
SqlxValue::Timestamp(value) => {
builder.push_bind(*value);
}
SqlxValue::TimestampTz(value) => {
builder.push_bind(*value);
}
}
}
}
#[cfg(feature = "sqlite")]
#[derive(Clone, Copy, Debug)]
pub struct SqliteBackend;
#[cfg(feature = "sqlite")]
impl GatekeepSqlxBackend for SqliteBackend {
type Database = sqlx::Sqlite;
const DRIVER: SqlxDriver = SqlxDriver::Sqlite;
const NAME: &'static str = "sqlite";
const MIN_FUNCTION: &'static str = "min";
const MAX_FUNCTION: &'static str = "max";
const GRADE_FUNCTION_PROPAGATES_NULL: bool = true;
fn push_placeholder(sql: &mut String, _index: usize) {
sql.push('?');
}
fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue) {
match value {
SqlxValue::Bool(value) => {
builder.push_bind(*value);
}
SqlxValue::I16(value) => {
builder.push_bind(*value);
}
SqlxValue::I32(value) => {
builder.push_bind(*value);
}
SqlxValue::I64(value) => {
builder.push_bind(*value);
}
SqlxValue::Text(value) => {
builder.push_bind(value.clone());
}
SqlxValue::Bytes(value) => {
builder.push_bind(value.clone());
}
SqlxValue::Uuid(value) => {
builder.push_bind(*value);
}
SqlxValue::Date(value) => {
builder.push_bind(*value);
}
SqlxValue::Time(value) => {
builder.push_bind(*value);
}
SqlxValue::Timestamp(value) => {
builder.push_bind(*value);
}
SqlxValue::TimestampTz(value) => {
builder.push_bind(*value);
}
}
}
}
#[cfg(feature = "mysql")]
#[derive(Clone, Copy, Debug)]
pub struct MySqlBackend;
#[cfg(feature = "mysql")]
impl GatekeepSqlxBackend for MySqlBackend {
type Database = sqlx::MySql;
const DRIVER: SqlxDriver = SqlxDriver::MySql;
const NAME: &'static str = "mysql";
const MIN_FUNCTION: &'static str = "LEAST";
const MAX_FUNCTION: &'static str = "GREATEST";
const GRADE_FUNCTION_PROPAGATES_NULL: bool = true;
fn push_placeholder(sql: &mut String, _index: usize) {
sql.push('?');
}
fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue) {
match value {
SqlxValue::Bool(value) => {
builder.push_bind(*value);
}
SqlxValue::I16(value) => {
builder.push_bind(*value);
}
SqlxValue::I32(value) => {
builder.push_bind(*value);
}
SqlxValue::I64(value) => {
builder.push_bind(*value);
}
SqlxValue::Text(value) => {
builder.push_bind(value.clone());
}
SqlxValue::Bytes(value) => {
builder.push_bind(value.clone());
}
SqlxValue::Uuid(value) => {
builder.push_bind(*value);
}
SqlxValue::Date(value) => {
builder.push_bind(*value);
}
SqlxValue::Time(value) => {
builder.push_bind(*value);
}
SqlxValue::Timestamp(value) => {
builder.push_bind(*value);
}
SqlxValue::TimestampTz(value) => {
builder.push_bind(*value);
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum SqlxDriver {
Postgres,
Sqlite,
MySql,
}
impl SqlxDriver {
#[must_use]
pub const fn name(self) -> &'static str {
match self {
Self::Postgres => "postgres",
Self::Sqlite => "sqlite",
Self::MySql => "mysql",
}
}
#[must_use]
pub const fn is_enabled(self) -> bool {
match self {
Self::Postgres => cfg!(feature = "postgres"),
Self::Sqlite => cfg!(feature = "sqlite"),
Self::MySql => cfg!(feature = "mysql"),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
#[non_exhaustive]
pub enum SqlxDriverError {
#[error("unsupported SQLx database URL scheme {scheme:?}")]
UnsupportedUrlScheme {
scheme: Option<String>,
},
#[error("SQLx driver {driver} is not enabled for gatekeep-sqlx")]
DriverNotEnabled {
driver: &'static str,
},
#[error("SQLx backend mismatch: expected {expected}, found {actual}")]
BackendMismatch {
expected: &'static str,
actual: &'static str,
},
}
pub fn infer_enabled_driver_from_url(database_url: &str) -> Result<SqlxDriver, SqlxDriverError> {
let driver = infer_driver_from_url(database_url)?;
if driver.is_enabled() {
Ok(driver)
} else {
Err(SqlxDriverError::DriverNotEnabled {
driver: driver.name(),
})
}
}
pub fn validate_database_url_for_backend<B>(database_url: &str) -> Result<(), SqlxDriverError>
where
B: GatekeepSqlxBackend,
{
let actual = infer_enabled_driver_from_url(database_url)?;
if actual == B::DRIVER {
Ok(())
} else {
Err(SqlxDriverError::BackendMismatch {
expected: B::NAME,
actual: actual.name(),
})
}
}
fn infer_driver_from_url(database_url: &str) -> Result<SqlxDriver, SqlxDriverError> {
if database_url.starts_with("sqlite:") {
return Ok(SqlxDriver::Sqlite);
}
let Some((scheme, _rest)) = database_url.split_once(':') else {
return Err(SqlxDriverError::UnsupportedUrlScheme { scheme: None });
};
match scheme {
"postgres" | "postgresql" => Ok(SqlxDriver::Postgres),
"mysql" | "mariadb" => Ok(SqlxDriver::MySql),
"sqlite" => Ok(SqlxDriver::Sqlite),
other => Err(SqlxDriverError::UnsupportedUrlScheme {
scheme: Some(other.to_owned()),
}),
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum SqlxValue {
Bool(bool),
I16(i16),
I32(i32),
I64(i64),
Text(String),
Bytes(Vec<u8>),
Uuid(Uuid),
Date(Date),
Time(Time),
Timestamp(PrimitiveDateTime),
TimestampTz(OffsetDateTime),
}
macro_rules! impl_sqlx_value_from {
($ty:ty, $variant:ident) => {
impl From<$ty> for SqlxValue {
fn from(value: $ty) -> Self {
Self::$variant(value)
}
}
};
}
impl_sqlx_value_from!(bool, Bool);
impl_sqlx_value_from!(i16, I16);
impl_sqlx_value_from!(i32, I32);
impl_sqlx_value_from!(i64, I64);
impl_sqlx_value_from!(String, Text);
impl_sqlx_value_from!(Vec<u8>, Bytes);
impl_sqlx_value_from!(Uuid, Uuid);
impl_sqlx_value_from!(Date, Date);
impl_sqlx_value_from!(Time, Time);
impl_sqlx_value_from!(PrimitiveDateTime, Timestamp);
impl_sqlx_value_from!(OffsetDateTime, TimestampTz);
impl From<&str> for SqlxValue {
fn from(value: &str) -> Self {
Self::Text(value.to_owned())
}
}
impl From<&[u8]> for SqlxValue {
fn from(value: &[u8]) -> Self {
Self::Bytes(value.to_vec())
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum SqlPart {
Text(String),
Bind(SqlxValue),
}
#[derive(Debug, PartialEq, Eq)]
pub struct SqlxFragment<B> {
parts: Vec<SqlPart>,
backend: PhantomData<fn() -> B>,
}
impl<B> Clone for SqlxFragment<B> {
fn clone(&self) -> Self {
Self {
parts: self.parts.clone(),
backend: PhantomData,
}
}
}
impl<B> Default for SqlxFragment<B> {
fn default() -> Self {
Self {
parts: Vec::new(),
backend: PhantomData,
}
}
}
impl<B> SqlxFragment<B> {
#[must_use]
pub fn trusted(sql: impl Into<String>) -> Self {
let sql = sql.into();
if sql.is_empty() {
Self::default()
} else {
Self {
parts: vec![SqlPart::Text(sql)],
backend: PhantomData,
}
}
}
#[must_use]
pub fn bind(value: impl Into<SqlxValue>) -> Self {
Self {
parts: vec![SqlPart::Bind(value.into())],
backend: PhantomData,
}
}
pub fn binds(&self) -> impl Iterator<Item = &SqlxValue> {
self.parts.iter().filter_map(|part| match part {
SqlPart::Text(_) => None,
SqlPart::Bind(value) => Some(value),
})
}
pub fn push_fragment(&mut self, fragment: Self) {
self.parts.extend(fragment.parts);
}
pub(crate) fn push_sql(&mut self, sql: impl Into<String>) {
let sql = sql.into();
if !sql.is_empty() {
self.parts.push(SqlPart::Text(sql));
}
}
#[must_use]
pub(crate) fn wrapped(self) -> Self {
let mut fragment = Self::trusted("(");
fragment.push_fragment(self);
fragment.push_sql(")");
fragment
}
#[must_use]
pub(crate) fn unary(prefix: &str, inner: Self) -> Self {
let mut fragment = Self::trusted(prefix);
fragment.push_fragment(inner.wrapped());
fragment
}
#[must_use]
pub(crate) fn binary(separator: &str, fragments: Vec<Self>) -> Self {
let mut iter = fragments.into_iter();
let Some(first) = iter.next() else {
return Self::trusted("FALSE");
};
let mut fragment = first.wrapped();
for next in iter {
fragment.push_sql(separator);
fragment.push_fragment(next.wrapped());
}
fragment
}
#[must_use]
pub(crate) fn function(name: &str, fragments: Vec<Self>) -> Self {
let mut fragment = Self::trusted(name);
fragment.push_sql("(");
let mut iter = fragments.into_iter();
if let Some(first) = iter.next() {
fragment.push_fragment(first);
for next in iter {
fragment.push_sql(", ");
fragment.push_fragment(next);
}
}
fragment.push_sql(")");
fragment
}
}
impl<B> SqlxFragment<B>
where
B: GatekeepSqlxBackend,
{
#[must_use]
pub fn to_sql(&self) -> String {
let mut sql = String::new();
let mut placeholders = 0usize;
for part in &self.parts {
match part {
SqlPart::Text(text) => sql.push_str(text),
SqlPart::Bind(_) => {
placeholders += 1;
B::push_placeholder(&mut sql, placeholders);
}
}
}
sql
}
pub fn push_to(&self, builder: &mut QueryBuilder<B::Database>) {
for part in &self.parts {
match part {
SqlPart::Text(text) => {
builder.push(text);
}
SqlPart::Bind(value) => B::push_bind(builder, value),
}
}
}
}
pub type PgValue = SqlxValue;
#[cfg(feature = "postgres")]
pub type PgFragment = SqlxFragment<PostgresBackend>;
#[cfg(feature = "postgres")]
impl SqlxFragment<PostgresBackend> {
#[must_use]
pub fn to_postgres_sql(&self) -> String {
self.to_sql()
}
}