Skip to main content

gatekeep_sqlx/
fragment.rs

1use std::marker::PhantomData;
2
3use sqlx::{
4    QueryBuilder,
5    types::{
6        Uuid,
7        time::{Date, OffsetDateTime, PrimitiveDateTime, Time},
8    },
9};
10
11/// `SQLx` backend supported by gatekeep lowering.
12pub trait GatekeepSqlxBackend: Clone + Copy + core::fmt::Debug + Send + Sync + 'static {
13    /// `SQLx` database driver for this backend.
14    type Database: sqlx::Database;
15
16    /// Database driver represented by this backend.
17    const DRIVER: SqlxDriver;
18
19    /// Stable backend name.
20    const NAME: &'static str;
21
22    /// Appends one bind placeholder to rendered SQL.
23    fn push_placeholder(sql: &mut String, index: usize);
24
25    /// Appends one typed bind value to a `SQLx` query builder.
26    fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue);
27
28    /// Name of the SQL function that returns the lower of two non-null grades.
29    const MIN_FUNCTION: &'static str;
30
31    /// Name of the SQL function that returns the higher of two non-null grades.
32    const MAX_FUNCTION: &'static str;
33
34    /// Whether the backend's grade functions return `NULL` when any input is
35    /// `NULL`.
36    const GRADE_FUNCTION_PROPAGATES_NULL: bool;
37}
38
39/// Postgres backend marker.
40#[cfg(feature = "postgres")]
41#[derive(Clone, Copy, Debug)]
42pub struct PostgresBackend;
43
44#[cfg(feature = "postgres")]
45impl GatekeepSqlxBackend for PostgresBackend {
46    type Database = sqlx::Postgres;
47
48    const DRIVER: SqlxDriver = SqlxDriver::Postgres;
49    const NAME: &'static str = "postgres";
50    const MIN_FUNCTION: &'static str = "LEAST";
51    const MAX_FUNCTION: &'static str = "GREATEST";
52    const GRADE_FUNCTION_PROPAGATES_NULL: bool = false;
53
54    fn push_placeholder(sql: &mut String, index: usize) {
55        sql.push('$');
56        sql.push_str(&index.to_string());
57    }
58
59    fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue) {
60        match value {
61            SqlxValue::Bool(value) => {
62                builder.push_bind(*value);
63            }
64            SqlxValue::I16(value) => {
65                builder.push_bind(*value);
66            }
67            SqlxValue::I32(value) => {
68                builder.push_bind(*value);
69            }
70            SqlxValue::I64(value) => {
71                builder.push_bind(*value);
72            }
73            SqlxValue::Text(value) => {
74                builder.push_bind(value.clone());
75            }
76            SqlxValue::Bytes(value) => {
77                builder.push_bind(value.clone());
78            }
79            SqlxValue::Uuid(value) => {
80                builder.push_bind(*value);
81            }
82            SqlxValue::Date(value) => {
83                builder.push_bind(*value);
84            }
85            SqlxValue::Time(value) => {
86                builder.push_bind(*value);
87            }
88            SqlxValue::Timestamp(value) => {
89                builder.push_bind(*value);
90            }
91            SqlxValue::TimestampTz(value) => {
92                builder.push_bind(*value);
93            }
94        }
95    }
96}
97
98/// `SQLite` backend marker.
99#[cfg(feature = "sqlite")]
100#[derive(Clone, Copy, Debug)]
101pub struct SqliteBackend;
102
103#[cfg(feature = "sqlite")]
104impl GatekeepSqlxBackend for SqliteBackend {
105    type Database = sqlx::Sqlite;
106
107    const DRIVER: SqlxDriver = SqlxDriver::Sqlite;
108    const NAME: &'static str = "sqlite";
109    const MIN_FUNCTION: &'static str = "min";
110    const MAX_FUNCTION: &'static str = "max";
111    const GRADE_FUNCTION_PROPAGATES_NULL: bool = true;
112
113    fn push_placeholder(sql: &mut String, _index: usize) {
114        sql.push('?');
115    }
116
117    fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue) {
118        match value {
119            SqlxValue::Bool(value) => {
120                builder.push_bind(*value);
121            }
122            SqlxValue::I16(value) => {
123                builder.push_bind(*value);
124            }
125            SqlxValue::I32(value) => {
126                builder.push_bind(*value);
127            }
128            SqlxValue::I64(value) => {
129                builder.push_bind(*value);
130            }
131            SqlxValue::Text(value) => {
132                builder.push_bind(value.clone());
133            }
134            SqlxValue::Bytes(value) => {
135                builder.push_bind(value.clone());
136            }
137            SqlxValue::Uuid(value) => {
138                builder.push_bind(*value);
139            }
140            SqlxValue::Date(value) => {
141                builder.push_bind(*value);
142            }
143            SqlxValue::Time(value) => {
144                builder.push_bind(*value);
145            }
146            SqlxValue::Timestamp(value) => {
147                builder.push_bind(*value);
148            }
149            SqlxValue::TimestampTz(value) => {
150                builder.push_bind(*value);
151            }
152        }
153    }
154}
155
156/// `MySQL` backend marker.
157#[cfg(feature = "mysql")]
158#[derive(Clone, Copy, Debug)]
159pub struct MySqlBackend;
160
161#[cfg(feature = "mysql")]
162impl GatekeepSqlxBackend for MySqlBackend {
163    type Database = sqlx::MySql;
164
165    const DRIVER: SqlxDriver = SqlxDriver::MySql;
166    const NAME: &'static str = "mysql";
167    const MIN_FUNCTION: &'static str = "LEAST";
168    const MAX_FUNCTION: &'static str = "GREATEST";
169    const GRADE_FUNCTION_PROPAGATES_NULL: bool = true;
170
171    fn push_placeholder(sql: &mut String, _index: usize) {
172        sql.push('?');
173    }
174
175    fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue) {
176        match value {
177            SqlxValue::Bool(value) => {
178                builder.push_bind(*value);
179            }
180            SqlxValue::I16(value) => {
181                builder.push_bind(*value);
182            }
183            SqlxValue::I32(value) => {
184                builder.push_bind(*value);
185            }
186            SqlxValue::I64(value) => {
187                builder.push_bind(*value);
188            }
189            SqlxValue::Text(value) => {
190                builder.push_bind(value.clone());
191            }
192            SqlxValue::Bytes(value) => {
193                builder.push_bind(value.clone());
194            }
195            SqlxValue::Uuid(value) => {
196                builder.push_bind(*value);
197            }
198            SqlxValue::Date(value) => {
199                builder.push_bind(*value);
200            }
201            SqlxValue::Time(value) => {
202                builder.push_bind(*value);
203            }
204            SqlxValue::Timestamp(value) => {
205                builder.push_bind(*value);
206            }
207            SqlxValue::TimestampTz(value) => {
208                builder.push_bind(*value);
209            }
210        }
211    }
212}
213
214/// Supported `SQLx` database driver.
215#[derive(Clone, Copy, Debug, PartialEq, Eq)]
216#[non_exhaustive]
217pub enum SqlxDriver {
218    /// Postgres `SQLx` driver.
219    Postgres,
220    /// `SQLite` `SQLx` driver.
221    Sqlite,
222    /// `MySQL` `SQLx` driver.
223    MySql,
224}
225
226impl SqlxDriver {
227    /// Stable driver name.
228    #[must_use]
229    pub const fn name(self) -> &'static str {
230        match self {
231            Self::Postgres => "postgres",
232            Self::Sqlite => "sqlite",
233            Self::MySql => "mysql",
234        }
235    }
236
237    /// Whether this crate was compiled with the matching backend feature.
238    #[must_use]
239    pub const fn is_enabled(self) -> bool {
240        match self {
241            Self::Postgres => cfg!(feature = "postgres"),
242            Self::Sqlite => cfg!(feature = "sqlite"),
243            Self::MySql => cfg!(feature = "mysql"),
244        }
245    }
246}
247
248/// Database driver configuration error.
249#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
250#[non_exhaustive]
251pub enum SqlxDriverError {
252    /// The URL scheme is not recognized as a `SQLx` database driver.
253    #[error("unsupported SQLx database URL scheme {scheme:?}")]
254    UnsupportedUrlScheme {
255        /// URL scheme, if one could be parsed.
256        scheme: Option<String>,
257    },
258
259    /// The URL selects a driver whose feature was not enabled.
260    #[error("SQLx driver {driver} is not enabled for gatekeep-sqlx")]
261    DriverNotEnabled {
262        /// Driver inferred from the URL.
263        driver: &'static str,
264    },
265
266    /// The configured driver does not match the selected backend.
267    #[error("SQLx backend mismatch: expected {expected}, found {actual}")]
268    BackendMismatch {
269        /// Backend expected by the selected lowerer.
270        expected: &'static str,
271        /// Driver inferred from runtime configuration.
272        actual: &'static str,
273    },
274}
275
276/// Infers the `SQLx` driver from a database URL or `SQLx`-style `SQLite` memory URL.
277///
278/// # Errors
279///
280/// Returns [`SqlxDriverError`] when the URL scheme is unsupported or when the
281/// inferred driver was not enabled at compile time.
282pub fn infer_enabled_driver_from_url(database_url: &str) -> Result<SqlxDriver, SqlxDriverError> {
283    let driver = infer_driver_from_url(database_url)?;
284    if driver.is_enabled() {
285        Ok(driver)
286    } else {
287        Err(SqlxDriverError::DriverNotEnabled {
288            driver: driver.name(),
289        })
290    }
291}
292
293/// Validates that a database URL matches a selected backend.
294///
295/// # Errors
296///
297/// Returns [`SqlxDriverError`] when the URL is unsupported, names a disabled
298/// driver, or names a different driver from `B`.
299pub fn validate_database_url_for_backend<B>(database_url: &str) -> Result<(), SqlxDriverError>
300where
301    B: GatekeepSqlxBackend,
302{
303    let actual = infer_enabled_driver_from_url(database_url)?;
304    if actual == B::DRIVER {
305        Ok(())
306    } else {
307        Err(SqlxDriverError::BackendMismatch {
308            expected: B::NAME,
309            actual: actual.name(),
310        })
311    }
312}
313
314fn infer_driver_from_url(database_url: &str) -> Result<SqlxDriver, SqlxDriverError> {
315    if database_url.starts_with("sqlite:") {
316        return Ok(SqlxDriver::Sqlite);
317    }
318
319    let Some((scheme, _rest)) = database_url.split_once(':') else {
320        return Err(SqlxDriverError::UnsupportedUrlScheme { scheme: None });
321    };
322
323    match scheme {
324        "postgres" | "postgresql" => Ok(SqlxDriver::Postgres),
325        "mysql" | "mariadb" => Ok(SqlxDriver::MySql),
326        "sqlite" => Ok(SqlxDriver::Sqlite),
327        other => Err(SqlxDriverError::UnsupportedUrlScheme {
328            scheme: Some(other.to_owned()),
329        }),
330    }
331}
332
333/// Scalar value carried by a lowered SQL fragment.
334#[derive(Clone, Debug, PartialEq, Eq)]
335#[non_exhaustive]
336pub enum SqlxValue {
337    /// Boolean bind value.
338    Bool(bool),
339    /// Signed 16-bit integer bind value.
340    I16(i16),
341    /// Signed 32-bit integer bind value.
342    I32(i32),
343    /// Signed 64-bit integer bind value.
344    I64(i64),
345    /// Text bind value.
346    Text(String),
347    /// Binary bind value.
348    Bytes(Vec<u8>),
349    /// UUID bind value.
350    Uuid(Uuid),
351    /// Date bind value.
352    Date(Date),
353    /// Time bind value.
354    Time(Time),
355    /// Timestamp without time zone bind value.
356    Timestamp(PrimitiveDateTime),
357    /// Timestamp with time zone bind value.
358    TimestampTz(OffsetDateTime),
359}
360
361macro_rules! impl_sqlx_value_from {
362    ($ty:ty, $variant:ident) => {
363        impl From<$ty> for SqlxValue {
364            fn from(value: $ty) -> Self {
365                Self::$variant(value)
366            }
367        }
368    };
369}
370
371impl_sqlx_value_from!(bool, Bool);
372impl_sqlx_value_from!(i16, I16);
373impl_sqlx_value_from!(i32, I32);
374impl_sqlx_value_from!(i64, I64);
375impl_sqlx_value_from!(String, Text);
376impl_sqlx_value_from!(Vec<u8>, Bytes);
377impl_sqlx_value_from!(Uuid, Uuid);
378impl_sqlx_value_from!(Date, Date);
379impl_sqlx_value_from!(Time, Time);
380impl_sqlx_value_from!(PrimitiveDateTime, Timestamp);
381impl_sqlx_value_from!(OffsetDateTime, TimestampTz);
382
383impl From<&str> for SqlxValue {
384    fn from(value: &str) -> Self {
385        Self::Text(value.to_owned())
386    }
387}
388
389impl From<&[u8]> for SqlxValue {
390    fn from(value: &[u8]) -> Self {
391        Self::Bytes(value.to_vec())
392    }
393}
394
395#[derive(Clone, Debug, PartialEq, Eq)]
396enum SqlPart {
397    Text(String),
398    Bind(SqlxValue),
399}
400
401/// Trusted SQL plus ordered bind values for one `SQLx` backend.
402#[derive(Debug, PartialEq, Eq)]
403pub struct SqlxFragment<B> {
404    parts: Vec<SqlPart>,
405    backend: PhantomData<fn() -> B>,
406}
407
408impl<B> Clone for SqlxFragment<B> {
409    fn clone(&self) -> Self {
410        Self {
411            parts: self.parts.clone(),
412            backend: PhantomData,
413        }
414    }
415}
416
417impl<B> Default for SqlxFragment<B> {
418    fn default() -> Self {
419        Self {
420            parts: Vec::new(),
421            backend: PhantomData,
422        }
423    }
424}
425
426impl<B> SqlxFragment<B> {
427    /// Builds a fragment from SQL owned by the application.
428    ///
429    /// Callers must not pass user-supplied text here. Dynamic values belong in
430    /// bind fragments built with [`Self::bind`].
431    #[must_use]
432    pub fn trusted(sql: impl Into<String>) -> Self {
433        let sql = sql.into();
434        if sql.is_empty() {
435            Self::default()
436        } else {
437            Self {
438                parts: vec![SqlPart::Text(sql)],
439                backend: PhantomData,
440            }
441        }
442    }
443
444    /// Builds a bind fragment from a supported `SQLx` scalar value.
445    #[must_use]
446    pub fn bind(value: impl Into<SqlxValue>) -> Self {
447        Self {
448            parts: vec![SqlPart::Bind(value.into())],
449            backend: PhantomData,
450        }
451    }
452
453    /// Returns the ordered bind values.
454    pub fn binds(&self) -> impl Iterator<Item = &SqlxValue> {
455        self.parts.iter().filter_map(|part| match part {
456            SqlPart::Text(_) => None,
457            SqlPart::Bind(value) => Some(value),
458        })
459    }
460
461    /// Appends another fragment to this one.
462    pub fn push_fragment(&mut self, fragment: Self) {
463        self.parts.extend(fragment.parts);
464    }
465
466    pub(crate) fn push_sql(&mut self, sql: impl Into<String>) {
467        let sql = sql.into();
468        if !sql.is_empty() {
469            self.parts.push(SqlPart::Text(sql));
470        }
471    }
472
473    #[must_use]
474    pub(crate) fn wrapped(self) -> Self {
475        let mut fragment = Self::trusted("(");
476        fragment.push_fragment(self);
477        fragment.push_sql(")");
478        fragment
479    }
480
481    #[must_use]
482    pub(crate) fn unary(prefix: &str, inner: Self) -> Self {
483        let mut fragment = Self::trusted(prefix);
484        fragment.push_fragment(inner.wrapped());
485        fragment
486    }
487
488    #[must_use]
489    pub(crate) fn binary(separator: &str, fragments: Vec<Self>) -> Self {
490        let mut iter = fragments.into_iter();
491        let Some(first) = iter.next() else {
492            return Self::trusted("FALSE");
493        };
494
495        let mut fragment = first.wrapped();
496        for next in iter {
497            fragment.push_sql(separator);
498            fragment.push_fragment(next.wrapped());
499        }
500        fragment
501    }
502
503    #[must_use]
504    pub(crate) fn function(name: &str, fragments: Vec<Self>) -> Self {
505        let mut fragment = Self::trusted(name);
506        fragment.push_sql("(");
507
508        let mut iter = fragments.into_iter();
509        if let Some(first) = iter.next() {
510            fragment.push_fragment(first);
511            for next in iter {
512                fragment.push_sql(", ");
513                fragment.push_fragment(next);
514            }
515        }
516
517        fragment.push_sql(")");
518        fragment
519    }
520}
521
522impl<B> SqlxFragment<B>
523where
524    B: GatekeepSqlxBackend,
525{
526    /// Converts the fragment to SQL with this backend's placeholder syntax.
527    #[must_use]
528    pub fn to_sql(&self) -> String {
529        let mut sql = String::new();
530        let mut placeholders = 0usize;
531
532        for part in &self.parts {
533            match part {
534                SqlPart::Text(text) => sql.push_str(text),
535                SqlPart::Bind(_) => {
536                    placeholders += 1;
537                    B::push_placeholder(&mut sql, placeholders);
538                }
539            }
540        }
541        sql
542    }
543
544    /// Appends this fragment to a `SQLx` query builder.
545    pub fn push_to(&self, builder: &mut QueryBuilder<B::Database>) {
546        for part in &self.parts {
547            match part {
548                SqlPart::Text(text) => {
549                    builder.push(text);
550                }
551                SqlPart::Bind(value) => B::push_bind(builder, value),
552            }
553        }
554    }
555}
556
557/// Postgres scalar value carried by a lowered SQL fragment.
558pub type PgValue = SqlxValue;
559
560/// Trusted Postgres SQL plus ordered bind values.
561#[cfg(feature = "postgres")]
562pub type PgFragment = SqlxFragment<PostgresBackend>;
563
564#[cfg(feature = "postgres")]
565impl SqlxFragment<PostgresBackend> {
566    /// Converts the fragment to Postgres placeholders (`$1`, `$2`, ...).
567    #[must_use]
568    pub fn to_postgres_sql(&self) -> String {
569        self.to_sql()
570    }
571}