1use std::marker::PhantomData;
2
3use sqlx::{
4 QueryBuilder,
5 types::{
6 Uuid,
7 time::{Date, OffsetDateTime, PrimitiveDateTime, Time},
8 },
9};
10
11pub trait GatekeepSqlxBackend: Clone + Copy + core::fmt::Debug + Send + Sync + 'static {
13 type Database: sqlx::Database;
15
16 const DRIVER: SqlxDriver;
18
19 const NAME: &'static str;
21
22 fn push_placeholder(sql: &mut String, index: usize);
24
25 fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue);
27
28 const MIN_FUNCTION: &'static str;
30
31 const MAX_FUNCTION: &'static str;
33
34 const GRADE_FUNCTION_PROPAGATES_NULL: bool;
37}
38
39#[cfg(feature = "postgres")]
41#[derive(Clone, Copy, Debug)]
42pub struct PostgresBackend;
43
44#[cfg(feature = "postgres")]
45impl GatekeepSqlxBackend for PostgresBackend {
46 type Database = sqlx::Postgres;
47
48 const DRIVER: SqlxDriver = SqlxDriver::Postgres;
49 const NAME: &'static str = "postgres";
50 const MIN_FUNCTION: &'static str = "LEAST";
51 const MAX_FUNCTION: &'static str = "GREATEST";
52 const GRADE_FUNCTION_PROPAGATES_NULL: bool = false;
53
54 fn push_placeholder(sql: &mut String, index: usize) {
55 sql.push('$');
56 sql.push_str(&index.to_string());
57 }
58
59 fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue) {
60 match value {
61 SqlxValue::Bool(value) => {
62 builder.push_bind(*value);
63 }
64 SqlxValue::I16(value) => {
65 builder.push_bind(*value);
66 }
67 SqlxValue::I32(value) => {
68 builder.push_bind(*value);
69 }
70 SqlxValue::I64(value) => {
71 builder.push_bind(*value);
72 }
73 SqlxValue::Text(value) => {
74 builder.push_bind(value.clone());
75 }
76 SqlxValue::Bytes(value) => {
77 builder.push_bind(value.clone());
78 }
79 SqlxValue::Uuid(value) => {
80 builder.push_bind(*value);
81 }
82 SqlxValue::Date(value) => {
83 builder.push_bind(*value);
84 }
85 SqlxValue::Time(value) => {
86 builder.push_bind(*value);
87 }
88 SqlxValue::Timestamp(value) => {
89 builder.push_bind(*value);
90 }
91 SqlxValue::TimestampTz(value) => {
92 builder.push_bind(*value);
93 }
94 }
95 }
96}
97
98#[cfg(feature = "sqlite")]
100#[derive(Clone, Copy, Debug)]
101pub struct SqliteBackend;
102
103#[cfg(feature = "sqlite")]
104impl GatekeepSqlxBackend for SqliteBackend {
105 type Database = sqlx::Sqlite;
106
107 const DRIVER: SqlxDriver = SqlxDriver::Sqlite;
108 const NAME: &'static str = "sqlite";
109 const MIN_FUNCTION: &'static str = "min";
110 const MAX_FUNCTION: &'static str = "max";
111 const GRADE_FUNCTION_PROPAGATES_NULL: bool = true;
112
113 fn push_placeholder(sql: &mut String, _index: usize) {
114 sql.push('?');
115 }
116
117 fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue) {
118 match value {
119 SqlxValue::Bool(value) => {
120 builder.push_bind(*value);
121 }
122 SqlxValue::I16(value) => {
123 builder.push_bind(*value);
124 }
125 SqlxValue::I32(value) => {
126 builder.push_bind(*value);
127 }
128 SqlxValue::I64(value) => {
129 builder.push_bind(*value);
130 }
131 SqlxValue::Text(value) => {
132 builder.push_bind(value.clone());
133 }
134 SqlxValue::Bytes(value) => {
135 builder.push_bind(value.clone());
136 }
137 SqlxValue::Uuid(value) => {
138 builder.push_bind(*value);
139 }
140 SqlxValue::Date(value) => {
141 builder.push_bind(*value);
142 }
143 SqlxValue::Time(value) => {
144 builder.push_bind(*value);
145 }
146 SqlxValue::Timestamp(value) => {
147 builder.push_bind(*value);
148 }
149 SqlxValue::TimestampTz(value) => {
150 builder.push_bind(*value);
151 }
152 }
153 }
154}
155
156#[cfg(feature = "mysql")]
158#[derive(Clone, Copy, Debug)]
159pub struct MySqlBackend;
160
161#[cfg(feature = "mysql")]
162impl GatekeepSqlxBackend for MySqlBackend {
163 type Database = sqlx::MySql;
164
165 const DRIVER: SqlxDriver = SqlxDriver::MySql;
166 const NAME: &'static str = "mysql";
167 const MIN_FUNCTION: &'static str = "LEAST";
168 const MAX_FUNCTION: &'static str = "GREATEST";
169 const GRADE_FUNCTION_PROPAGATES_NULL: bool = true;
170
171 fn push_placeholder(sql: &mut String, _index: usize) {
172 sql.push('?');
173 }
174
175 fn push_bind(builder: &mut QueryBuilder<Self::Database>, value: &SqlxValue) {
176 match value {
177 SqlxValue::Bool(value) => {
178 builder.push_bind(*value);
179 }
180 SqlxValue::I16(value) => {
181 builder.push_bind(*value);
182 }
183 SqlxValue::I32(value) => {
184 builder.push_bind(*value);
185 }
186 SqlxValue::I64(value) => {
187 builder.push_bind(*value);
188 }
189 SqlxValue::Text(value) => {
190 builder.push_bind(value.clone());
191 }
192 SqlxValue::Bytes(value) => {
193 builder.push_bind(value.clone());
194 }
195 SqlxValue::Uuid(value) => {
196 builder.push_bind(*value);
197 }
198 SqlxValue::Date(value) => {
199 builder.push_bind(*value);
200 }
201 SqlxValue::Time(value) => {
202 builder.push_bind(*value);
203 }
204 SqlxValue::Timestamp(value) => {
205 builder.push_bind(*value);
206 }
207 SqlxValue::TimestampTz(value) => {
208 builder.push_bind(*value);
209 }
210 }
211 }
212}
213
214#[derive(Clone, Copy, Debug, PartialEq, Eq)]
216#[non_exhaustive]
217pub enum SqlxDriver {
218 Postgres,
220 Sqlite,
222 MySql,
224}
225
226impl SqlxDriver {
227 #[must_use]
229 pub const fn name(self) -> &'static str {
230 match self {
231 Self::Postgres => "postgres",
232 Self::Sqlite => "sqlite",
233 Self::MySql => "mysql",
234 }
235 }
236
237 #[must_use]
239 pub const fn is_enabled(self) -> bool {
240 match self {
241 Self::Postgres => cfg!(feature = "postgres"),
242 Self::Sqlite => cfg!(feature = "sqlite"),
243 Self::MySql => cfg!(feature = "mysql"),
244 }
245 }
246}
247
248#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
250#[non_exhaustive]
251pub enum SqlxDriverError {
252 #[error("unsupported SQLx database URL scheme {scheme:?}")]
254 UnsupportedUrlScheme {
255 scheme: Option<String>,
257 },
258
259 #[error("SQLx driver {driver} is not enabled for gatekeep-sqlx")]
261 DriverNotEnabled {
262 driver: &'static str,
264 },
265
266 #[error("SQLx backend mismatch: expected {expected}, found {actual}")]
268 BackendMismatch {
269 expected: &'static str,
271 actual: &'static str,
273 },
274}
275
276pub fn infer_enabled_driver_from_url(database_url: &str) -> Result<SqlxDriver, SqlxDriverError> {
283 let driver = infer_driver_from_url(database_url)?;
284 if driver.is_enabled() {
285 Ok(driver)
286 } else {
287 Err(SqlxDriverError::DriverNotEnabled {
288 driver: driver.name(),
289 })
290 }
291}
292
293pub fn validate_database_url_for_backend<B>(database_url: &str) -> Result<(), SqlxDriverError>
300where
301 B: GatekeepSqlxBackend,
302{
303 let actual = infer_enabled_driver_from_url(database_url)?;
304 if actual == B::DRIVER {
305 Ok(())
306 } else {
307 Err(SqlxDriverError::BackendMismatch {
308 expected: B::NAME,
309 actual: actual.name(),
310 })
311 }
312}
313
314fn infer_driver_from_url(database_url: &str) -> Result<SqlxDriver, SqlxDriverError> {
315 if database_url.starts_with("sqlite:") {
316 return Ok(SqlxDriver::Sqlite);
317 }
318
319 let Some((scheme, _rest)) = database_url.split_once(':') else {
320 return Err(SqlxDriverError::UnsupportedUrlScheme { scheme: None });
321 };
322
323 match scheme {
324 "postgres" | "postgresql" => Ok(SqlxDriver::Postgres),
325 "mysql" | "mariadb" => Ok(SqlxDriver::MySql),
326 "sqlite" => Ok(SqlxDriver::Sqlite),
327 other => Err(SqlxDriverError::UnsupportedUrlScheme {
328 scheme: Some(other.to_owned()),
329 }),
330 }
331}
332
333#[derive(Clone, Debug, PartialEq, Eq)]
335#[non_exhaustive]
336pub enum SqlxValue {
337 Bool(bool),
339 I16(i16),
341 I32(i32),
343 I64(i64),
345 Text(String),
347 Bytes(Vec<u8>),
349 Uuid(Uuid),
351 Date(Date),
353 Time(Time),
355 Timestamp(PrimitiveDateTime),
357 TimestampTz(OffsetDateTime),
359}
360
361macro_rules! impl_sqlx_value_from {
362 ($ty:ty, $variant:ident) => {
363 impl From<$ty> for SqlxValue {
364 fn from(value: $ty) -> Self {
365 Self::$variant(value)
366 }
367 }
368 };
369}
370
371impl_sqlx_value_from!(bool, Bool);
372impl_sqlx_value_from!(i16, I16);
373impl_sqlx_value_from!(i32, I32);
374impl_sqlx_value_from!(i64, I64);
375impl_sqlx_value_from!(String, Text);
376impl_sqlx_value_from!(Vec<u8>, Bytes);
377impl_sqlx_value_from!(Uuid, Uuid);
378impl_sqlx_value_from!(Date, Date);
379impl_sqlx_value_from!(Time, Time);
380impl_sqlx_value_from!(PrimitiveDateTime, Timestamp);
381impl_sqlx_value_from!(OffsetDateTime, TimestampTz);
382
383impl From<&str> for SqlxValue {
384 fn from(value: &str) -> Self {
385 Self::Text(value.to_owned())
386 }
387}
388
389impl From<&[u8]> for SqlxValue {
390 fn from(value: &[u8]) -> Self {
391 Self::Bytes(value.to_vec())
392 }
393}
394
395#[derive(Clone, Debug, PartialEq, Eq)]
396enum SqlPart {
397 Text(String),
398 Bind(SqlxValue),
399}
400
401#[derive(Debug, PartialEq, Eq)]
403pub struct SqlxFragment<B> {
404 parts: Vec<SqlPart>,
405 backend: PhantomData<fn() -> B>,
406}
407
408impl<B> Clone for SqlxFragment<B> {
409 fn clone(&self) -> Self {
410 Self {
411 parts: self.parts.clone(),
412 backend: PhantomData,
413 }
414 }
415}
416
417impl<B> Default for SqlxFragment<B> {
418 fn default() -> Self {
419 Self {
420 parts: Vec::new(),
421 backend: PhantomData,
422 }
423 }
424}
425
426impl<B> SqlxFragment<B> {
427 #[must_use]
432 pub fn trusted(sql: impl Into<String>) -> Self {
433 let sql = sql.into();
434 if sql.is_empty() {
435 Self::default()
436 } else {
437 Self {
438 parts: vec![SqlPart::Text(sql)],
439 backend: PhantomData,
440 }
441 }
442 }
443
444 #[must_use]
446 pub fn bind(value: impl Into<SqlxValue>) -> Self {
447 Self {
448 parts: vec![SqlPart::Bind(value.into())],
449 backend: PhantomData,
450 }
451 }
452
453 pub fn binds(&self) -> impl Iterator<Item = &SqlxValue> {
455 self.parts.iter().filter_map(|part| match part {
456 SqlPart::Text(_) => None,
457 SqlPart::Bind(value) => Some(value),
458 })
459 }
460
461 pub fn push_fragment(&mut self, fragment: Self) {
463 self.parts.extend(fragment.parts);
464 }
465
466 pub(crate) fn push_sql(&mut self, sql: impl Into<String>) {
467 let sql = sql.into();
468 if !sql.is_empty() {
469 self.parts.push(SqlPart::Text(sql));
470 }
471 }
472
473 #[must_use]
474 pub(crate) fn wrapped(self) -> Self {
475 let mut fragment = Self::trusted("(");
476 fragment.push_fragment(self);
477 fragment.push_sql(")");
478 fragment
479 }
480
481 #[must_use]
482 pub(crate) fn unary(prefix: &str, inner: Self) -> Self {
483 let mut fragment = Self::trusted(prefix);
484 fragment.push_fragment(inner.wrapped());
485 fragment
486 }
487
488 #[must_use]
489 pub(crate) fn binary(separator: &str, fragments: Vec<Self>) -> Self {
490 let mut iter = fragments.into_iter();
491 let Some(first) = iter.next() else {
492 return Self::trusted("FALSE");
493 };
494
495 let mut fragment = first.wrapped();
496 for next in iter {
497 fragment.push_sql(separator);
498 fragment.push_fragment(next.wrapped());
499 }
500 fragment
501 }
502
503 #[must_use]
504 pub(crate) fn function(name: &str, fragments: Vec<Self>) -> Self {
505 let mut fragment = Self::trusted(name);
506 fragment.push_sql("(");
507
508 let mut iter = fragments.into_iter();
509 if let Some(first) = iter.next() {
510 fragment.push_fragment(first);
511 for next in iter {
512 fragment.push_sql(", ");
513 fragment.push_fragment(next);
514 }
515 }
516
517 fragment.push_sql(")");
518 fragment
519 }
520}
521
522impl<B> SqlxFragment<B>
523where
524 B: GatekeepSqlxBackend,
525{
526 #[must_use]
528 pub fn to_sql(&self) -> String {
529 let mut sql = String::new();
530 let mut placeholders = 0usize;
531
532 for part in &self.parts {
533 match part {
534 SqlPart::Text(text) => sql.push_str(text),
535 SqlPart::Bind(_) => {
536 placeholders += 1;
537 B::push_placeholder(&mut sql, placeholders);
538 }
539 }
540 }
541 sql
542 }
543
544 pub fn push_to(&self, builder: &mut QueryBuilder<B::Database>) {
546 for part in &self.parts {
547 match part {
548 SqlPart::Text(text) => {
549 builder.push(text);
550 }
551 SqlPart::Bind(value) => B::push_bind(builder, value),
552 }
553 }
554 }
555}
556
557pub type PgValue = SqlxValue;
559
560#[cfg(feature = "postgres")]
562pub type PgFragment = SqlxFragment<PostgresBackend>;
563
564#[cfg(feature = "postgres")]
565impl SqlxFragment<PostgresBackend> {
566 #[must_use]
568 pub fn to_postgres_sql(&self) -> String {
569 self.to_sql()
570 }
571}