Skip to main content

gatekeep_sqlx/
lib.rs

1//! `SQLx` lowering for gatekeep residual policies.
2//!
3//! This crate lowers a `gatekeep::ResidualPolicy` into trusted Postgres SQL
4//! fragments that can be appended to a `sqlx::QueryBuilder`.
5
6#![forbid(unsafe_code)]
7
8use gatekeep::{
9    Condition, Context, FactId, LowerError, Lowered, QueryLowering, ResidualPolicy,
10    ResidualPolicyBranch, ResidualPolicyNode,
11};
12
13mod fragment;
14
15pub use fragment::{PgFragment, PgValue};
16
17/// Maps a residual fact to a trusted Postgres predicate over the candidate row.
18pub trait PgFactPredicates {
19    /// Returns a predicate for the given fact, or `None` when the fact cannot be
20    /// represented by this backend.
21    fn predicate(&self, fact: &FactId, cx: &Context) -> Option<PgFragment>;
22}
23
24/// Maps a policy outcome to a total-order SQL ordinal.
25pub trait SqlOutcome {
26    /// Returns the scalar ordinal used by SQL `LEAST` and `GREATEST`.
27    fn to_sql_ordinal(&self) -> i64;
28}
29
30impl SqlOutcome for () {
31    fn to_sql_ordinal(&self) -> i64 {
32        0
33    }
34}
35
36/// Projection strategy for turning outcomes into SQL fragments.
37pub trait OutcomeProjection<O> {
38    /// Builds a SQL fragment for a constant outcome.
39    fn constant(&self, outcome: &O) -> Result<PgFragment, LowerError>;
40}
41
42/// Outcome projection backed by [`SqlOutcome`].
43#[derive(Clone, Copy, Debug, Default)]
44pub struct OrdinalProjection;
45
46impl<O: SqlOutcome> OutcomeProjection<O> for OrdinalProjection {
47    fn constant(&self, outcome: &O) -> Result<PgFragment, LowerError> {
48        Ok(PgFragment::bind(outcome.to_sql_ordinal()))
49    }
50}
51
52/// Projection that rejects grade lowering.
53#[derive(Clone, Copy, Debug, Default)]
54pub struct NoGradeProjection;
55
56impl<O> OutcomeProjection<O> for NoGradeProjection {
57    fn constant(&self, _outcome: &O) -> Result<PgFragment, LowerError> {
58        Err(LowerError::NonTotalGrade)
59    }
60}
61
62/// Postgres lowerer for gatekeep residual policies.
63#[derive(Clone, Debug)]
64pub struct PgLowerer<P, M = OrdinalProjection> {
65    predicates: P,
66    projection: M,
67}
68
69#[derive(Clone, Debug, PartialEq, Eq)]
70struct PgLowered {
71    filter: PgFragment,
72    grade: PgFragment,
73}
74
75impl<P> PgLowerer<P, OrdinalProjection> {
76    /// Builds a lowerer using ordinal grade projection.
77    #[must_use]
78    pub const fn new(predicates: P) -> Self {
79        Self::with_projection(predicates, OrdinalProjection)
80    }
81}
82
83impl<P, M> PgLowerer<P, M> {
84    /// Builds a lowerer using a caller-supplied projection strategy.
85    #[must_use]
86    pub const fn with_projection(predicates: P, projection: M) -> Self {
87        Self {
88            predicates,
89            projection,
90        }
91    }
92
93    /// Lowers only the Boolean filter. This works for every outcome lattice.
94    pub fn lower_filter<O>(
95        &self,
96        residual: &ResidualPolicy<O>,
97        cx: &Context,
98    ) -> Result<PgFragment, LowerError>
99    where
100        P: PgFactPredicates,
101    {
102        residual.try_fold_pruned(
103            &mut |branch| match branch {
104                ResidualPolicyBranch::OrElseFallback { fallback, .. } => {
105                    !fallback.carries_obligation()
106                }
107            },
108            &mut |node| self.lower_filter_node(node, cx),
109        )
110    }
111
112    fn lower_filter_node<O>(
113        &self,
114        node: ResidualPolicyNode<'_, O, PgFragment>,
115        cx: &Context,
116    ) -> Result<PgFragment, LowerError>
117    where
118        P: PgFactPredicates,
119    {
120        match node {
121            ResidualPolicyNode::Permit(_) | ResidualPolicyNode::PermitWithTrace { .. } => {
122                Ok(PgFragment::trusted("TRUE"))
123            }
124            ResidualPolicyNode::Deny | ResidualPolicyNode::DenyWithTrace { .. } => {
125                Ok(PgFragment::trusted("FALSE"))
126            }
127            ResidualPolicyNode::Grant { condition, .. } => self.lower_condition(condition, cx),
128            ResidualPolicyNode::All { arms, .. } => Ok(fragment_set(arms, " AND ", "FALSE")),
129            ResidualPolicyNode::Any { arms, .. } => Ok(fragment_set(arms, " OR ", "FALSE")),
130            ResidualPolicyNode::OrElse {
131                fallback_policy,
132                primary,
133                fallback,
134                ..
135            } => {
136                if fallback_policy.carries_obligation() {
137                    Ok(primary)
138                } else {
139                    Ok(match fallback {
140                        Some(fallback) => PgFragment::binary(" OR ", vec![primary, fallback]),
141                        None => primary,
142                    })
143                }
144            }
145        }
146    }
147
148    fn lower_condition(&self, condition: &Condition, cx: &Context) -> Result<PgFragment, LowerError>
149    where
150        P: PgFactPredicates,
151    {
152        match condition {
153            Condition::Always => Ok(PgFragment::trusted("TRUE")),
154            Condition::Never => Ok(PgFragment::trusted("FALSE")),
155            Condition::Has(fact) => self
156                .predicates
157                .predicate(fact, cx)
158                .map(is_true)
159                .ok_or_else(|| LowerError::Unlowerable(fact.clone())),
160            Condition::Not(inner) => {
161                Ok(PgFragment::unary("NOT ", self.lower_condition(inner, cx)?))
162            }
163            Condition::All(conditions) => {
164                lower_condition_set(conditions, " AND ", "FALSE", |item| {
165                    self.lower_condition(item, cx)
166                })
167            }
168            Condition::Any(conditions) => {
169                lower_condition_set(conditions, " OR ", "FALSE", |item| {
170                    self.lower_condition(item, cx)
171                })
172            }
173        }
174    }
175
176    fn lower_policy<O>(
177        &self,
178        residual: &ResidualPolicy<O>,
179        cx: &Context,
180    ) -> Result<PgLowered, LowerError>
181    where
182        P: PgFactPredicates,
183        M: OutcomeProjection<O>,
184    {
185        residual.try_fold_pruned(
186            &mut |branch| match branch {
187                ResidualPolicyBranch::OrElseFallback { fallback, .. } => {
188                    !fallback.carries_obligation()
189                }
190            },
191            &mut |node| self.lower_node(node, cx),
192        )
193    }
194
195    fn lower_node<O>(
196        &self,
197        node: ResidualPolicyNode<'_, O, PgLowered>,
198        cx: &Context,
199    ) -> Result<PgLowered, LowerError>
200    where
201        P: PgFactPredicates,
202        M: OutcomeProjection<O>,
203    {
204        match node {
205            ResidualPolicyNode::Permit(outcome)
206            | ResidualPolicyNode::PermitWithTrace { outcome, .. } => Ok(PgLowered {
207                filter: PgFragment::trusted("TRUE"),
208                grade: self.projection.constant(outcome)?,
209            }),
210            ResidualPolicyNode::Deny | ResidualPolicyNode::DenyWithTrace { .. } => Ok(PgLowered {
211                filter: PgFragment::trusted("FALSE"),
212                grade: PgFragment::trusted("NULL"),
213            }),
214            ResidualPolicyNode::Grant {
215                outcome, condition, ..
216            } => {
217                let filter = self.lower_condition(condition, cx)?;
218                let outcome = self.projection.constant(outcome)?;
219                Ok(PgLowered {
220                    filter: filter.clone(),
221                    grade: case_when(filter, outcome, PgFragment::trusted("NULL")),
222                })
223            }
224            ResidualPolicyNode::All { arms, .. } => {
225                let (filters, grades) = unzip_lowered(arms);
226                Ok(PgLowered {
227                    filter: fragment_set(filters, " AND ", "FALSE"),
228                    grade: grade_set(grades, "LEAST"),
229                })
230            }
231            ResidualPolicyNode::Any { arms, .. } => {
232                let (filters, grades) = unzip_lowered(arms);
233                Ok(PgLowered {
234                    filter: fragment_set(filters, " OR ", "FALSE"),
235                    grade: grade_set(grades, "GREATEST"),
236                })
237            }
238            ResidualPolicyNode::OrElse {
239                fallback_policy,
240                primary,
241                fallback,
242                ..
243            } => {
244                if fallback_policy.carries_obligation() {
245                    return Ok(primary);
246                }
247
248                Ok(match fallback {
249                    Some(fallback) => PgLowered {
250                        filter: PgFragment::binary(
251                            " OR ",
252                            vec![primary.filter.clone(), fallback.filter],
253                        ),
254                        grade: case_when(primary.filter, primary.grade, fallback.grade),
255                    },
256                    None => primary,
257                })
258            }
259        }
260    }
261}
262
263impl<O, P, M> QueryLowering<O> for PgLowerer<P, M>
264where
265    P: PgFactPredicates,
266    M: OutcomeProjection<O>,
267{
268    type Filter = PgFragment;
269    type Projection = PgFragment;
270
271    fn lower(
272        &self,
273        residual: &ResidualPolicy<O>,
274        cx: &Context,
275    ) -> Result<Lowered<Self::Filter, Self::Projection>, LowerError> {
276        let lowered = self.lower_policy(residual, cx)?;
277        Ok(Lowered {
278            filter: lowered.filter,
279            grade: lowered.grade,
280        })
281    }
282}
283
284fn lower_condition_set(
285    conditions: &[Condition],
286    separator: &str,
287    empty: &str,
288    lower: impl FnMut(&Condition) -> Result<PgFragment, LowerError>,
289) -> Result<PgFragment, LowerError> {
290    if conditions.is_empty() {
291        return Ok(PgFragment::trusted(empty));
292    }
293    let fragments = conditions
294        .iter()
295        .map(lower)
296        .collect::<Result<Vec<_>, _>>()?;
297    Ok(PgFragment::binary(separator, fragments))
298}
299
300fn fragment_set(fragments: Vec<PgFragment>, separator: &str, empty: &str) -> PgFragment {
301    if fragments.is_empty() {
302        PgFragment::trusted(empty)
303    } else {
304        PgFragment::binary(separator, fragments)
305    }
306}
307
308fn grade_set(grades: Vec<PgFragment>, function: &str) -> PgFragment {
309    if grades.is_empty() {
310        PgFragment::trusted("NULL")
311    } else {
312        PgFragment::function(function, grades)
313    }
314}
315
316fn unzip_lowered(lowered: Vec<PgLowered>) -> (Vec<PgFragment>, Vec<PgFragment>) {
317    lowered
318        .into_iter()
319        .map(|lowered| (lowered.filter, lowered.grade))
320        .unzip()
321}
322
323fn case_when(condition: PgFragment, then_expr: PgFragment, else_expr: PgFragment) -> PgFragment {
324    let mut fragment = PgFragment::trusted("CASE WHEN ");
325    fragment.push_fragment(condition);
326    fragment.push_sql(" THEN ");
327    fragment.push_fragment(then_expr);
328    fragment.push_sql(" ELSE ");
329    fragment.push_fragment(else_expr);
330    fragment.push_sql(" END");
331    fragment
332}
333
334fn is_true(predicate: PgFragment) -> PgFragment {
335    let mut fragment = PgFragment::trusted("(");
336    fragment.push_fragment(predicate);
337    fragment.push_sql(") IS TRUE");
338    fragment
339}