Skip to main content

gatekeep_sqlx/
lib.rs

1//! `SQLx` lowering for gatekeep residual policies.
2//!
3//! This crate lowers a `gatekeep::ResidualPolicy` into trusted SQL fragments
4//! that can be appended to a `sqlx::QueryBuilder`.
5
6#![forbid(unsafe_code)]
7
8#[cfg(not(any(feature = "postgres", feature = "sqlite", feature = "mysql")))]
9compile_error!(
10    "gatekeep-sqlx requires at least one SQLx backend feature: postgres, sqlite, or mysql"
11);
12
13use std::marker::PhantomData;
14
15use gatekeep::{
16    Condition, Context, FactId, LowerError, Lowered, QueryLowering, ResidualPolicy,
17    ResidualPolicyBranch, ResidualPolicyNode,
18};
19
20mod fragment;
21
22#[cfg(feature = "mysql")]
23pub use fragment::MySqlBackend;
24#[cfg(feature = "sqlite")]
25pub use fragment::SqliteBackend;
26pub use fragment::{
27    GatekeepSqlxBackend, SqlxDriver, SqlxDriverError, SqlxFragment, SqlxValue,
28    infer_enabled_driver_from_url, validate_database_url_for_backend,
29};
30#[cfg(feature = "postgres")]
31pub use fragment::{PgFragment, PgValue, PostgresBackend};
32
33/// Maps a residual fact to a trusted predicate over the candidate row.
34pub trait SqlxFactPredicates<B>
35where
36    B: GatekeepSqlxBackend,
37{
38    /// Returns a predicate for the given fact, or `None` when the fact cannot be
39    /// represented by this backend.
40    fn predicate(&self, fact: &FactId, cx: &Context) -> Option<SqlxFragment<B>>;
41}
42
43/// Maps a residual fact to a trusted Postgres predicate over the candidate row.
44#[cfg(feature = "postgres")]
45pub trait PgFactPredicates {
46    /// Returns a predicate for the given fact, or `None` when the fact cannot be
47    /// represented by this backend.
48    fn predicate(&self, fact: &FactId, cx: &Context) -> Option<PgFragment>;
49}
50
51#[cfg(feature = "postgres")]
52impl<T> SqlxFactPredicates<PostgresBackend> for T
53where
54    T: PgFactPredicates,
55{
56    fn predicate(&self, fact: &FactId, cx: &Context) -> Option<SqlxFragment<PostgresBackend>> {
57        PgFactPredicates::predicate(self, fact, cx)
58    }
59}
60
61/// Maps a policy outcome to a total-order SQL ordinal.
62pub trait SqlOutcome {
63    /// Returns the scalar ordinal used by SQL grade projection.
64    fn to_sql_ordinal(&self) -> i64;
65}
66
67impl SqlOutcome for () {
68    fn to_sql_ordinal(&self) -> i64 {
69        0
70    }
71}
72
73/// Projection strategy for turning outcomes into SQL fragments.
74pub trait OutcomeProjection<B, O>
75where
76    B: GatekeepSqlxBackend,
77{
78    /// Builds a SQL fragment for a constant outcome.
79    fn constant(&self, outcome: &O) -> Result<SqlxFragment<B>, LowerError>;
80}
81
82/// Outcome projection backed by [`SqlOutcome`].
83#[derive(Clone, Copy, Debug, Default)]
84pub struct OrdinalProjection;
85
86impl<B, O> OutcomeProjection<B, O> for OrdinalProjection
87where
88    B: GatekeepSqlxBackend,
89    O: SqlOutcome,
90{
91    fn constant(&self, outcome: &O) -> Result<SqlxFragment<B>, LowerError> {
92        Ok(SqlxFragment::bind(outcome.to_sql_ordinal()))
93    }
94}
95
96/// Projection that rejects grade lowering.
97#[derive(Clone, Copy, Debug, Default)]
98pub struct NoGradeProjection;
99
100impl<B, O> OutcomeProjection<B, O> for NoGradeProjection
101where
102    B: GatekeepSqlxBackend,
103{
104    fn constant(&self, _outcome: &O) -> Result<SqlxFragment<B>, LowerError> {
105        Err(LowerError::NonTotalGrade)
106    }
107}
108
109/// `SQLx` lowerer for gatekeep residual policies.
110#[derive(Clone, Debug)]
111pub struct SqlxLowerer<B, P, M = OrdinalProjection> {
112    predicates: P,
113    projection: M,
114    backend: PhantomData<fn() -> B>,
115}
116
117/// Postgres lowerer for gatekeep residual policies.
118#[cfg(feature = "postgres")]
119pub type PgLowerer<P, M = OrdinalProjection> = SqlxLowerer<PostgresBackend, P, M>;
120
121#[derive(Clone, Debug, PartialEq, Eq)]
122struct SqlxLowered<B> {
123    filter: SqlxFragment<B>,
124    grade: SqlxFragment<B>,
125}
126
127impl<B, P> SqlxLowerer<B, P, OrdinalProjection>
128where
129    B: GatekeepSqlxBackend,
130{
131    /// Builds a lowerer using ordinal grade projection.
132    #[must_use]
133    pub const fn new(predicates: P) -> Self {
134        Self::with_projection(predicates, OrdinalProjection)
135    }
136}
137
138impl<B, P, M> SqlxLowerer<B, P, M>
139where
140    B: GatekeepSqlxBackend,
141{
142    /// Builds a lowerer using a caller-supplied projection strategy.
143    #[must_use]
144    pub const fn with_projection(predicates: P, projection: M) -> Self {
145        Self {
146            predicates,
147            projection,
148            backend: PhantomData,
149        }
150    }
151
152    /// Lowers only the Boolean filter. This works for every outcome lattice.
153    pub fn lower_filter<O>(
154        &self,
155        residual: &ResidualPolicy<O>,
156        cx: &Context,
157    ) -> Result<SqlxFragment<B>, LowerError>
158    where
159        P: SqlxFactPredicates<B>,
160    {
161        residual.try_fold_pruned(
162            &mut |branch| match branch {
163                ResidualPolicyBranch::OrElseFallback { fallback, .. } => {
164                    !fallback.carries_obligation()
165                }
166            },
167            &mut |node| self.lower_filter_node(node, cx),
168        )
169    }
170
171    fn lower_filter_node<O>(
172        &self,
173        node: ResidualPolicyNode<'_, O, SqlxFragment<B>>,
174        cx: &Context,
175    ) -> Result<SqlxFragment<B>, LowerError>
176    where
177        P: SqlxFactPredicates<B>,
178    {
179        match node {
180            ResidualPolicyNode::Permit(_) | ResidualPolicyNode::PermitWithTrace { .. } => {
181                Ok(SqlxFragment::trusted("TRUE"))
182            }
183            ResidualPolicyNode::Deny | ResidualPolicyNode::DenyWithTrace { .. } => {
184                Ok(SqlxFragment::trusted("FALSE"))
185            }
186            ResidualPolicyNode::Grant { condition, .. } => self.lower_condition(condition, cx),
187            ResidualPolicyNode::All { arms, .. } => Ok(fragment_set(arms, " AND ", "FALSE")),
188            ResidualPolicyNode::Any { arms, .. } => Ok(fragment_set(arms, " OR ", "FALSE")),
189            ResidualPolicyNode::OrElse {
190                fallback_policy,
191                primary,
192                fallback,
193                ..
194            } => {
195                if fallback_policy.carries_obligation() {
196                    Ok(primary)
197                } else {
198                    Ok(match fallback {
199                        Some(fallback) => SqlxFragment::binary(" OR ", vec![primary, fallback]),
200                        None => primary,
201                    })
202                }
203            }
204        }
205    }
206
207    fn lower_condition(
208        &self,
209        condition: &Condition,
210        cx: &Context,
211    ) -> Result<SqlxFragment<B>, LowerError>
212    where
213        P: SqlxFactPredicates<B>,
214    {
215        match condition {
216            Condition::Always => Ok(SqlxFragment::trusted("TRUE")),
217            Condition::Never => Ok(SqlxFragment::trusted("FALSE")),
218            Condition::Has(fact) => self
219                .predicates
220                .predicate(fact, cx)
221                .map(is_true)
222                .ok_or_else(|| LowerError::Unlowerable(fact.clone())),
223            Condition::Not(inner) => Ok(SqlxFragment::unary(
224                "NOT ",
225                self.lower_condition(inner, cx)?,
226            )),
227            Condition::All(conditions) => {
228                lower_condition_set(conditions, " AND ", "FALSE", |item| {
229                    self.lower_condition(item, cx)
230                })
231            }
232            Condition::Any(conditions) => {
233                lower_condition_set(conditions, " OR ", "FALSE", |item| {
234                    self.lower_condition(item, cx)
235                })
236            }
237        }
238    }
239
240    fn lower_policy<O>(
241        &self,
242        residual: &ResidualPolicy<O>,
243        cx: &Context,
244    ) -> Result<SqlxLowered<B>, LowerError>
245    where
246        P: SqlxFactPredicates<B>,
247        M: OutcomeProjection<B, O>,
248    {
249        residual.try_fold_pruned(
250            &mut |branch| match branch {
251                ResidualPolicyBranch::OrElseFallback { fallback, .. } => {
252                    !fallback.carries_obligation()
253                }
254            },
255            &mut |node| self.lower_node(node, cx),
256        )
257    }
258
259    fn lower_node<O>(
260        &self,
261        node: ResidualPolicyNode<'_, O, SqlxLowered<B>>,
262        cx: &Context,
263    ) -> Result<SqlxLowered<B>, LowerError>
264    where
265        P: SqlxFactPredicates<B>,
266        M: OutcomeProjection<B, O>,
267    {
268        match node {
269            ResidualPolicyNode::Permit(outcome)
270            | ResidualPolicyNode::PermitWithTrace { outcome, .. } => Ok(SqlxLowered {
271                filter: SqlxFragment::trusted("TRUE"),
272                grade: self.projection.constant(outcome)?,
273            }),
274            ResidualPolicyNode::Deny | ResidualPolicyNode::DenyWithTrace { .. } => {
275                Ok(SqlxLowered {
276                    filter: SqlxFragment::trusted("FALSE"),
277                    grade: SqlxFragment::trusted("NULL"),
278                })
279            }
280            ResidualPolicyNode::Grant {
281                outcome, condition, ..
282            } => {
283                let filter = self.lower_condition(condition, cx)?;
284                let outcome = self.projection.constant(outcome)?;
285                Ok(SqlxLowered {
286                    filter: filter.clone(),
287                    grade: case_when(filter, outcome, SqlxFragment::trusted("NULL")),
288                })
289            }
290            ResidualPolicyNode::All { arms, .. } => {
291                let (filters, grades) = unzip_lowered(arms);
292                Ok(SqlxLowered {
293                    filter: fragment_set(filters, " AND ", "FALSE"),
294                    grade: grade_set::<B>(grades, B::MIN_FUNCTION),
295                })
296            }
297            ResidualPolicyNode::Any { arms, .. } => {
298                let (filters, grades) = unzip_lowered(arms);
299                Ok(SqlxLowered {
300                    filter: fragment_set(filters, " OR ", "FALSE"),
301                    grade: grade_set::<B>(grades, B::MAX_FUNCTION),
302                })
303            }
304            ResidualPolicyNode::OrElse {
305                fallback_policy,
306                primary,
307                fallback,
308                ..
309            } => {
310                if fallback_policy.carries_obligation() {
311                    return Ok(primary);
312                }
313
314                Ok(match fallback {
315                    Some(fallback) => SqlxLowered {
316                        filter: SqlxFragment::binary(
317                            " OR ",
318                            vec![primary.filter.clone(), fallback.filter],
319                        ),
320                        grade: case_when(primary.filter, primary.grade, fallback.grade),
321                    },
322                    None => primary,
323                })
324            }
325        }
326    }
327}
328
329impl<O, B, P, M> QueryLowering<O> for SqlxLowerer<B, P, M>
330where
331    B: GatekeepSqlxBackend,
332    P: SqlxFactPredicates<B>,
333    M: OutcomeProjection<B, O>,
334{
335    type Filter = SqlxFragment<B>;
336    type Projection = SqlxFragment<B>;
337
338    fn lower(
339        &self,
340        residual: &ResidualPolicy<O>,
341        cx: &Context,
342    ) -> Result<Lowered<Self::Filter, Self::Projection>, LowerError> {
343        let lowered = self.lower_policy(residual, cx)?;
344        Ok(Lowered {
345            filter: lowered.filter,
346            grade: lowered.grade,
347        })
348    }
349}
350
351fn lower_condition_set<B>(
352    conditions: &[Condition],
353    separator: &str,
354    empty: &str,
355    lower: impl FnMut(&Condition) -> Result<SqlxFragment<B>, LowerError>,
356) -> Result<SqlxFragment<B>, LowerError> {
357    if conditions.is_empty() {
358        return Ok(SqlxFragment::trusted(empty));
359    }
360    let fragments = conditions
361        .iter()
362        .map(lower)
363        .collect::<Result<Vec<_>, _>>()?;
364    Ok(SqlxFragment::binary(separator, fragments))
365}
366
367fn fragment_set<B>(
368    fragments: Vec<SqlxFragment<B>>,
369    separator: &str,
370    empty: &str,
371) -> SqlxFragment<B> {
372    if fragments.is_empty() {
373        SqlxFragment::trusted(empty)
374    } else {
375        SqlxFragment::binary(separator, fragments)
376    }
377}
378
379fn grade_set<B>(grades: Vec<SqlxFragment<B>>, function: &str) -> SqlxFragment<B>
380where
381    B: GatekeepSqlxBackend,
382{
383    match grades.len() {
384        0 => SqlxFragment::trusted("NULL"),
385        1 => grades
386            .into_iter()
387            .next()
388            .unwrap_or_else(|| SqlxFragment::trusted("NULL")),
389        _ if B::GRADE_FUNCTION_PROPAGATES_NULL => {
390            let mut iter = grades.into_iter();
391            let mut combined = iter.next().unwrap_or_else(|| SqlxFragment::trusted("NULL"));
392            for grade in iter {
393                combined = null_safe_grade_pair(function, combined, grade);
394            }
395            combined
396        }
397        _ => SqlxFragment::function(function, grades),
398    }
399}
400
401fn null_safe_grade_pair<B>(
402    function: &str,
403    left: SqlxFragment<B>,
404    right: SqlxFragment<B>,
405) -> SqlxFragment<B> {
406    let mut fragment = SqlxFragment::trusted("CASE WHEN ");
407    fragment.push_fragment(left.clone().wrapped());
408    fragment.push_sql(" IS NULL THEN ");
409    fragment.push_fragment(right.clone());
410    fragment.push_sql(" WHEN ");
411    fragment.push_fragment(right.clone().wrapped());
412    fragment.push_sql(" IS NULL THEN ");
413    fragment.push_fragment(left.clone());
414    fragment.push_sql(" ELSE ");
415    fragment.push_fragment(SqlxFragment::function(function, vec![left, right]));
416    fragment.push_sql(" END");
417    fragment
418}
419
420fn unzip_lowered<B>(lowered: Vec<SqlxLowered<B>>) -> (Vec<SqlxFragment<B>>, Vec<SqlxFragment<B>>) {
421    lowered
422        .into_iter()
423        .map(|lowered| (lowered.filter, lowered.grade))
424        .unzip()
425}
426
427fn case_when<B>(
428    condition: SqlxFragment<B>,
429    then_expr: SqlxFragment<B>,
430    else_expr: SqlxFragment<B>,
431) -> SqlxFragment<B> {
432    let mut fragment = SqlxFragment::trusted("CASE WHEN ");
433    fragment.push_fragment(condition);
434    fragment.push_sql(" THEN ");
435    fragment.push_fragment(then_expr);
436    fragment.push_sql(" ELSE ");
437    fragment.push_fragment(else_expr);
438    fragment.push_sql(" END");
439    fragment
440}
441
442fn is_true<B>(predicate: SqlxFragment<B>) -> SqlxFragment<B> {
443    let mut fragment = SqlxFragment::trusted("(");
444    fragment.push_fragment(predicate);
445    fragment.push_sql(") IS TRUE");
446    fragment
447}