1#![forbid(unsafe_code)]
7
8#[cfg(not(any(feature = "postgres", feature = "sqlite", feature = "mysql")))]
9compile_error!(
10 "gatekeep-sqlx requires at least one SQLx backend feature: postgres, sqlite, or mysql"
11);
12
13use std::marker::PhantomData;
14
15use gatekeep::{
16 Condition, Context, FactId, LowerError, Lowered, QueryLowering, ResidualPolicy,
17 ResidualPolicyBranch, ResidualPolicyNode,
18};
19
20mod fragment;
21
22#[cfg(feature = "mysql")]
23pub use fragment::MySqlBackend;
24#[cfg(feature = "sqlite")]
25pub use fragment::SqliteBackend;
26pub use fragment::{
27 GatekeepSqlxBackend, SqlxDriver, SqlxDriverError, SqlxFragment, SqlxValue,
28 infer_enabled_driver_from_url, validate_database_url_for_backend,
29};
30#[cfg(feature = "postgres")]
31pub use fragment::{PgFragment, PgValue, PostgresBackend};
32
33pub trait SqlxFactPredicates<B>
35where
36 B: GatekeepSqlxBackend,
37{
38 fn predicate(&self, fact: &FactId, cx: &Context) -> Option<SqlxFragment<B>>;
41}
42
43#[cfg(feature = "postgres")]
45pub trait PgFactPredicates {
46 fn predicate(&self, fact: &FactId, cx: &Context) -> Option<PgFragment>;
49}
50
51#[cfg(feature = "postgres")]
52impl<T> SqlxFactPredicates<PostgresBackend> for T
53where
54 T: PgFactPredicates,
55{
56 fn predicate(&self, fact: &FactId, cx: &Context) -> Option<SqlxFragment<PostgresBackend>> {
57 PgFactPredicates::predicate(self, fact, cx)
58 }
59}
60
61pub trait SqlOutcome {
63 fn to_sql_ordinal(&self) -> i64;
65}
66
67impl SqlOutcome for () {
68 fn to_sql_ordinal(&self) -> i64 {
69 0
70 }
71}
72
73pub trait OutcomeProjection<B, O>
75where
76 B: GatekeepSqlxBackend,
77{
78 fn constant(&self, outcome: &O) -> Result<SqlxFragment<B>, LowerError>;
80}
81
82#[derive(Clone, Copy, Debug, Default)]
84pub struct OrdinalProjection;
85
86impl<B, O> OutcomeProjection<B, O> for OrdinalProjection
87where
88 B: GatekeepSqlxBackend,
89 O: SqlOutcome,
90{
91 fn constant(&self, outcome: &O) -> Result<SqlxFragment<B>, LowerError> {
92 Ok(SqlxFragment::bind(outcome.to_sql_ordinal()))
93 }
94}
95
96#[derive(Clone, Copy, Debug, Default)]
98pub struct NoGradeProjection;
99
100impl<B, O> OutcomeProjection<B, O> for NoGradeProjection
101where
102 B: GatekeepSqlxBackend,
103{
104 fn constant(&self, _outcome: &O) -> Result<SqlxFragment<B>, LowerError> {
105 Err(LowerError::NonTotalGrade)
106 }
107}
108
109#[derive(Clone, Debug)]
111pub struct SqlxLowerer<B, P, M = OrdinalProjection> {
112 predicates: P,
113 projection: M,
114 backend: PhantomData<fn() -> B>,
115}
116
117#[cfg(feature = "postgres")]
119pub type PgLowerer<P, M = OrdinalProjection> = SqlxLowerer<PostgresBackend, P, M>;
120
121#[derive(Clone, Debug, PartialEq, Eq)]
122struct SqlxLowered<B> {
123 filter: SqlxFragment<B>,
124 grade: SqlxFragment<B>,
125}
126
127impl<B, P> SqlxLowerer<B, P, OrdinalProjection>
128where
129 B: GatekeepSqlxBackend,
130{
131 #[must_use]
133 pub const fn new(predicates: P) -> Self {
134 Self::with_projection(predicates, OrdinalProjection)
135 }
136}
137
138impl<B, P, M> SqlxLowerer<B, P, M>
139where
140 B: GatekeepSqlxBackend,
141{
142 #[must_use]
144 pub const fn with_projection(predicates: P, projection: M) -> Self {
145 Self {
146 predicates,
147 projection,
148 backend: PhantomData,
149 }
150 }
151
152 pub fn lower_filter<O>(
154 &self,
155 residual: &ResidualPolicy<O>,
156 cx: &Context,
157 ) -> Result<SqlxFragment<B>, LowerError>
158 where
159 P: SqlxFactPredicates<B>,
160 {
161 residual.try_fold_pruned(
162 &mut |branch| match branch {
163 ResidualPolicyBranch::OrElseFallback { fallback, .. } => {
164 !fallback.carries_obligation()
165 }
166 },
167 &mut |node| self.lower_filter_node(node, cx),
168 )
169 }
170
171 fn lower_filter_node<O>(
172 &self,
173 node: ResidualPolicyNode<'_, O, SqlxFragment<B>>,
174 cx: &Context,
175 ) -> Result<SqlxFragment<B>, LowerError>
176 where
177 P: SqlxFactPredicates<B>,
178 {
179 match node {
180 ResidualPolicyNode::Permit(_) | ResidualPolicyNode::PermitWithTrace { .. } => {
181 Ok(SqlxFragment::trusted("TRUE"))
182 }
183 ResidualPolicyNode::Deny | ResidualPolicyNode::DenyWithTrace { .. } => {
184 Ok(SqlxFragment::trusted("FALSE"))
185 }
186 ResidualPolicyNode::Grant { condition, .. } => self.lower_condition(condition, cx),
187 ResidualPolicyNode::All { arms, .. } => Ok(fragment_set(arms, " AND ", "FALSE")),
188 ResidualPolicyNode::Any { arms, .. } => Ok(fragment_set(arms, " OR ", "FALSE")),
189 ResidualPolicyNode::OrElse {
190 fallback_policy,
191 primary,
192 fallback,
193 ..
194 } => {
195 if fallback_policy.carries_obligation() {
196 Ok(primary)
197 } else {
198 Ok(match fallback {
199 Some(fallback) => SqlxFragment::binary(" OR ", vec![primary, fallback]),
200 None => primary,
201 })
202 }
203 }
204 }
205 }
206
207 fn lower_condition(
208 &self,
209 condition: &Condition,
210 cx: &Context,
211 ) -> Result<SqlxFragment<B>, LowerError>
212 where
213 P: SqlxFactPredicates<B>,
214 {
215 match condition {
216 Condition::Always => Ok(SqlxFragment::trusted("TRUE")),
217 Condition::Never => Ok(SqlxFragment::trusted("FALSE")),
218 Condition::Has(fact) => self
219 .predicates
220 .predicate(fact, cx)
221 .map(is_true)
222 .ok_or_else(|| LowerError::Unlowerable(fact.clone())),
223 Condition::Not(inner) => Ok(SqlxFragment::unary(
224 "NOT ",
225 self.lower_condition(inner, cx)?,
226 )),
227 Condition::All(conditions) => {
228 lower_condition_set(conditions, " AND ", "FALSE", |item| {
229 self.lower_condition(item, cx)
230 })
231 }
232 Condition::Any(conditions) => {
233 lower_condition_set(conditions, " OR ", "FALSE", |item| {
234 self.lower_condition(item, cx)
235 })
236 }
237 }
238 }
239
240 fn lower_policy<O>(
241 &self,
242 residual: &ResidualPolicy<O>,
243 cx: &Context,
244 ) -> Result<SqlxLowered<B>, LowerError>
245 where
246 P: SqlxFactPredicates<B>,
247 M: OutcomeProjection<B, O>,
248 {
249 residual.try_fold_pruned(
250 &mut |branch| match branch {
251 ResidualPolicyBranch::OrElseFallback { fallback, .. } => {
252 !fallback.carries_obligation()
253 }
254 },
255 &mut |node| self.lower_node(node, cx),
256 )
257 }
258
259 fn lower_node<O>(
260 &self,
261 node: ResidualPolicyNode<'_, O, SqlxLowered<B>>,
262 cx: &Context,
263 ) -> Result<SqlxLowered<B>, LowerError>
264 where
265 P: SqlxFactPredicates<B>,
266 M: OutcomeProjection<B, O>,
267 {
268 match node {
269 ResidualPolicyNode::Permit(outcome)
270 | ResidualPolicyNode::PermitWithTrace { outcome, .. } => Ok(SqlxLowered {
271 filter: SqlxFragment::trusted("TRUE"),
272 grade: self.projection.constant(outcome)?,
273 }),
274 ResidualPolicyNode::Deny | ResidualPolicyNode::DenyWithTrace { .. } => {
275 Ok(SqlxLowered {
276 filter: SqlxFragment::trusted("FALSE"),
277 grade: SqlxFragment::trusted("NULL"),
278 })
279 }
280 ResidualPolicyNode::Grant {
281 outcome, condition, ..
282 } => {
283 let filter = self.lower_condition(condition, cx)?;
284 let outcome = self.projection.constant(outcome)?;
285 Ok(SqlxLowered {
286 filter: filter.clone(),
287 grade: case_when(filter, outcome, SqlxFragment::trusted("NULL")),
288 })
289 }
290 ResidualPolicyNode::All { arms, .. } => {
291 let (filters, grades) = unzip_lowered(arms);
292 Ok(SqlxLowered {
293 filter: fragment_set(filters, " AND ", "FALSE"),
294 grade: grade_set::<B>(grades, B::MIN_FUNCTION),
295 })
296 }
297 ResidualPolicyNode::Any { arms, .. } => {
298 let (filters, grades) = unzip_lowered(arms);
299 Ok(SqlxLowered {
300 filter: fragment_set(filters, " OR ", "FALSE"),
301 grade: grade_set::<B>(grades, B::MAX_FUNCTION),
302 })
303 }
304 ResidualPolicyNode::OrElse {
305 fallback_policy,
306 primary,
307 fallback,
308 ..
309 } => {
310 if fallback_policy.carries_obligation() {
311 return Ok(primary);
312 }
313
314 Ok(match fallback {
315 Some(fallback) => SqlxLowered {
316 filter: SqlxFragment::binary(
317 " OR ",
318 vec![primary.filter.clone(), fallback.filter],
319 ),
320 grade: case_when(primary.filter, primary.grade, fallback.grade),
321 },
322 None => primary,
323 })
324 }
325 }
326 }
327}
328
329impl<O, B, P, M> QueryLowering<O> for SqlxLowerer<B, P, M>
330where
331 B: GatekeepSqlxBackend,
332 P: SqlxFactPredicates<B>,
333 M: OutcomeProjection<B, O>,
334{
335 type Filter = SqlxFragment<B>;
336 type Projection = SqlxFragment<B>;
337
338 fn lower(
339 &self,
340 residual: &ResidualPolicy<O>,
341 cx: &Context,
342 ) -> Result<Lowered<Self::Filter, Self::Projection>, LowerError> {
343 let lowered = self.lower_policy(residual, cx)?;
344 Ok(Lowered {
345 filter: lowered.filter,
346 grade: lowered.grade,
347 })
348 }
349}
350
351fn lower_condition_set<B>(
352 conditions: &[Condition],
353 separator: &str,
354 empty: &str,
355 lower: impl FnMut(&Condition) -> Result<SqlxFragment<B>, LowerError>,
356) -> Result<SqlxFragment<B>, LowerError> {
357 if conditions.is_empty() {
358 return Ok(SqlxFragment::trusted(empty));
359 }
360 let fragments = conditions
361 .iter()
362 .map(lower)
363 .collect::<Result<Vec<_>, _>>()?;
364 Ok(SqlxFragment::binary(separator, fragments))
365}
366
367fn fragment_set<B>(
368 fragments: Vec<SqlxFragment<B>>,
369 separator: &str,
370 empty: &str,
371) -> SqlxFragment<B> {
372 if fragments.is_empty() {
373 SqlxFragment::trusted(empty)
374 } else {
375 SqlxFragment::binary(separator, fragments)
376 }
377}
378
379fn grade_set<B>(grades: Vec<SqlxFragment<B>>, function: &str) -> SqlxFragment<B>
380where
381 B: GatekeepSqlxBackend,
382{
383 match grades.len() {
384 0 => SqlxFragment::trusted("NULL"),
385 1 => grades
386 .into_iter()
387 .next()
388 .unwrap_or_else(|| SqlxFragment::trusted("NULL")),
389 _ if B::GRADE_FUNCTION_PROPAGATES_NULL => {
390 let mut iter = grades.into_iter();
391 let mut combined = iter.next().unwrap_or_else(|| SqlxFragment::trusted("NULL"));
392 for grade in iter {
393 combined = null_safe_grade_pair(function, combined, grade);
394 }
395 combined
396 }
397 _ => SqlxFragment::function(function, grades),
398 }
399}
400
401fn null_safe_grade_pair<B>(
402 function: &str,
403 left: SqlxFragment<B>,
404 right: SqlxFragment<B>,
405) -> SqlxFragment<B> {
406 let mut fragment = SqlxFragment::trusted("CASE WHEN ");
407 fragment.push_fragment(left.clone().wrapped());
408 fragment.push_sql(" IS NULL THEN ");
409 fragment.push_fragment(right.clone());
410 fragment.push_sql(" WHEN ");
411 fragment.push_fragment(right.clone().wrapped());
412 fragment.push_sql(" IS NULL THEN ");
413 fragment.push_fragment(left.clone());
414 fragment.push_sql(" ELSE ");
415 fragment.push_fragment(SqlxFragment::function(function, vec![left, right]));
416 fragment.push_sql(" END");
417 fragment
418}
419
420fn unzip_lowered<B>(lowered: Vec<SqlxLowered<B>>) -> (Vec<SqlxFragment<B>>, Vec<SqlxFragment<B>>) {
421 lowered
422 .into_iter()
423 .map(|lowered| (lowered.filter, lowered.grade))
424 .unzip()
425}
426
427fn case_when<B>(
428 condition: SqlxFragment<B>,
429 then_expr: SqlxFragment<B>,
430 else_expr: SqlxFragment<B>,
431) -> SqlxFragment<B> {
432 let mut fragment = SqlxFragment::trusted("CASE WHEN ");
433 fragment.push_fragment(condition);
434 fragment.push_sql(" THEN ");
435 fragment.push_fragment(then_expr);
436 fragment.push_sql(" ELSE ");
437 fragment.push_fragment(else_expr);
438 fragment.push_sql(" END");
439 fragment
440}
441
442fn is_true<B>(predicate: SqlxFragment<B>) -> SqlxFragment<B> {
443 let mut fragment = SqlxFragment::trusted("(");
444 fragment.push_fragment(predicate);
445 fragment.push_sql(") IS TRUE");
446 fragment
447}