Skip to main content

gatekeep_sqlx/
fragment.rs

1use sqlx::{
2    Postgres, QueryBuilder,
3    types::{
4        Uuid,
5        time::{Date, OffsetDateTime, PrimitiveDateTime, Time},
6    },
7};
8
9/// Postgres scalar value carried by a lowered SQL fragment.
10#[derive(Clone, Debug, PartialEq, Eq)]
11#[non_exhaustive]
12pub enum PgValue {
13    /// Boolean bind value.
14    Bool(bool),
15    /// Signed 16-bit integer bind value.
16    I16(i16),
17    /// Signed 32-bit integer bind value.
18    I32(i32),
19    /// Signed 64-bit integer bind value.
20    I64(i64),
21    /// Text bind value.
22    Text(String),
23    /// Binary bind value.
24    Bytes(Vec<u8>),
25    /// UUID bind value.
26    Uuid(Uuid),
27    /// Date bind value.
28    Date(Date),
29    /// Time bind value.
30    Time(Time),
31    /// Timestamp without time zone bind value.
32    Timestamp(PrimitiveDateTime),
33    /// Timestamp with time zone bind value.
34    TimestampTz(OffsetDateTime),
35}
36
37macro_rules! impl_pg_value_from {
38    ($ty:ty, $variant:ident) => {
39        impl From<$ty> for PgValue {
40            fn from(value: $ty) -> Self {
41                Self::$variant(value)
42            }
43        }
44    };
45}
46
47impl_pg_value_from!(bool, Bool);
48impl_pg_value_from!(i16, I16);
49impl_pg_value_from!(i32, I32);
50impl_pg_value_from!(i64, I64);
51impl_pg_value_from!(String, Text);
52impl_pg_value_from!(Vec<u8>, Bytes);
53impl_pg_value_from!(Uuid, Uuid);
54impl_pg_value_from!(Date, Date);
55impl_pg_value_from!(Time, Time);
56impl_pg_value_from!(PrimitiveDateTime, Timestamp);
57impl_pg_value_from!(OffsetDateTime, TimestampTz);
58
59impl From<&str> for PgValue {
60    fn from(value: &str) -> Self {
61        Self::Text(value.to_owned())
62    }
63}
64
65impl From<&[u8]> for PgValue {
66    fn from(value: &[u8]) -> Self {
67        Self::Bytes(value.to_vec())
68    }
69}
70
71#[derive(Clone, Debug, PartialEq, Eq)]
72enum SqlPart {
73    Text(String),
74    Bind(PgValue),
75}
76
77/// Trusted Postgres SQL plus ordered bind values.
78#[derive(Clone, Debug, Default, PartialEq, Eq)]
79pub struct PgFragment {
80    parts: Vec<SqlPart>,
81}
82
83impl PgFragment {
84    /// Builds a fragment from SQL owned by the application.
85    ///
86    /// Callers must not pass user-supplied text here. Dynamic values belong in
87    /// bind fragments built with [`Self::bind`].
88    #[must_use]
89    pub fn trusted(sql: impl Into<String>) -> Self {
90        let sql = sql.into();
91        if sql.is_empty() {
92            Self::default()
93        } else {
94            Self {
95                parts: vec![SqlPart::Text(sql)],
96            }
97        }
98    }
99
100    /// Builds a bind fragment from a supported Postgres scalar value.
101    #[must_use]
102    pub fn bind(value: impl Into<PgValue>) -> Self {
103        Self {
104            parts: vec![SqlPart::Bind(value.into())],
105        }
106    }
107
108    /// Returns the ordered bind values.
109    pub fn binds(&self) -> impl Iterator<Item = &PgValue> {
110        self.parts.iter().filter_map(|part| match part {
111            SqlPart::Text(_) => None,
112            SqlPart::Bind(value) => Some(value),
113        })
114    }
115
116    /// Appends another fragment to this one.
117    pub fn push_fragment(&mut self, fragment: Self) {
118        self.parts.extend(fragment.parts);
119    }
120
121    /// Converts the fragment to Postgres placeholders (`$1`, `$2`, ...).
122    #[must_use]
123    pub fn to_postgres_sql(&self) -> String {
124        let mut sql = String::new();
125        let mut placeholders = 0usize;
126
127        for part in &self.parts {
128            match part {
129                SqlPart::Text(text) => sql.push_str(text),
130                SqlPart::Bind(_) => {
131                    placeholders += 1;
132                    sql.push('$');
133                    sql.push_str(&placeholders.to_string());
134                }
135            }
136        }
137        sql
138    }
139
140    /// Appends this fragment to a `SQLx` Postgres query builder.
141    pub fn push_to(&self, builder: &mut QueryBuilder<Postgres>) {
142        for part in &self.parts {
143            match part {
144                SqlPart::Text(text) => {
145                    builder.push(text);
146                }
147                SqlPart::Bind(value) => push_bind(builder, value),
148            }
149        }
150    }
151
152    pub(crate) fn push_sql(&mut self, sql: impl Into<String>) {
153        let sql = sql.into();
154        if !sql.is_empty() {
155            self.parts.push(SqlPart::Text(sql));
156        }
157    }
158
159    #[must_use]
160    pub(crate) fn wrapped(self) -> Self {
161        let mut fragment = Self::trusted("(");
162        fragment.push_fragment(self);
163        fragment.push_sql(")");
164        fragment
165    }
166
167    #[must_use]
168    pub(crate) fn unary(prefix: &str, inner: Self) -> Self {
169        let mut fragment = Self::trusted(prefix);
170        fragment.push_fragment(inner.wrapped());
171        fragment
172    }
173
174    #[must_use]
175    pub(crate) fn binary(separator: &str, fragments: Vec<Self>) -> Self {
176        let mut iter = fragments.into_iter();
177        let Some(first) = iter.next() else {
178            return Self::trusted("FALSE");
179        };
180
181        let mut fragment = first.wrapped();
182        for next in iter {
183            fragment.push_sql(separator);
184            fragment.push_fragment(next.wrapped());
185        }
186        fragment
187    }
188
189    #[must_use]
190    pub(crate) fn function(name: &str, fragments: Vec<Self>) -> Self {
191        let mut fragment = Self::trusted(name);
192        fragment.push_sql("(");
193
194        let mut iter = fragments.into_iter();
195        if let Some(first) = iter.next() {
196            fragment.push_fragment(first);
197            for next in iter {
198                fragment.push_sql(", ");
199                fragment.push_fragment(next);
200            }
201        }
202
203        fragment.push_sql(")");
204        fragment
205    }
206}
207
208fn push_bind(builder: &mut QueryBuilder<Postgres>, value: &PgValue) {
209    match value {
210        PgValue::Bool(value) => {
211            builder.push_bind(*value);
212        }
213        PgValue::I16(value) => {
214            builder.push_bind(*value);
215        }
216        PgValue::I32(value) => {
217            builder.push_bind(*value);
218        }
219        PgValue::I64(value) => {
220            builder.push_bind(*value);
221        }
222        PgValue::Text(value) => {
223            builder.push_bind(value.clone());
224        }
225        PgValue::Bytes(value) => {
226            builder.push_bind(value.clone());
227        }
228        PgValue::Uuid(value) => {
229            builder.push_bind(*value);
230        }
231        PgValue::Date(value) => {
232            builder.push_bind(*value);
233        }
234        PgValue::Time(value) => {
235            builder.push_bind(*value);
236        }
237        PgValue::Timestamp(value) => {
238            builder.push_bind(*value);
239        }
240        PgValue::TimestampTz(value) => {
241            builder.push_bind(*value);
242        }
243    }
244}